|
53 | 53 | get_quant_embedding_transform,
|
54 | 54 | get_quant_weight_transform,
|
55 | 55 | )
|
56 |
| -from .source_transformation.quantized_kv_cache import ( |
57 |
| - replace_kv_cache_with_quantized_kv_cache, |
58 |
| -) |
| 56 | + |
| 57 | +# from .source_transformation.quantized_kv_cache import ( |
| 58 | +# replace_kv_cache_with_quantized_kv_cache, |
| 59 | +# ) |
59 | 60 | from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
|
60 | 61 |
|
61 | 62 | from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
|
62 |
| -from .source_transformation.sdpa import ( |
63 |
| - replace_causal_mask, |
64 |
| - replace_kv_cache_with_coreml_kv_cache, |
65 |
| - replace_kv_cache_with_simple_kv_cache, |
66 |
| - replace_sdpa_with_coreml_sdpa, |
67 |
| - replace_sdpa_with_custom_op, |
68 |
| - replace_sdpa_with_flex_sdpa, |
69 |
| - replace_sdpa_with_simple_sdpa, |
70 |
| -) |
| 63 | + |
| 64 | +# from .source_transformation.sdpa import ( |
| 65 | +# replace_causal_mask, |
| 66 | +# replace_kv_cache_with_coreml_kv_cache, |
| 67 | +# replace_kv_cache_with_simple_kv_cache, |
| 68 | +# replace_sdpa_with_coreml_sdpa, |
| 69 | +# replace_sdpa_with_custom_op, |
| 70 | +# replace_sdpa_with_flex_sdpa, |
| 71 | +# replace_sdpa_with_simple_sdpa, |
| 72 | +# ) |
71 | 73 |
|
72 | 74 | IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
|
73 | 75 | FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
|
@@ -893,23 +895,20 @@ def _get_source_transforms( # noqa
|
893 | 895 | assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
|
894 | 896 | transforms.append(replace_kv_cache_with_quantized_kv_cache)
|
895 | 897 |
|
| 898 | + if args.qnn: |
| 899 | + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` |
| 900 | + from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d |
| 901 | + |
| 902 | + # transforms.append(replace_kv_cache_with_simple_kv_cache) |
| 903 | + # transforms.append(replace_sdpa_with_flex_sdpa) |
| 904 | + # transforms.append(replace_causal_mask) |
| 905 | + transforms.append(replace_rms_norm_with_native_rms_norm) |
| 906 | + if args.optimized_rotation_path: |
| 907 | + transforms.append(fuse_layer_norms) |
| 908 | + transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) |
| 909 | + transforms.append(convert_linear_to_conv2d) |
896 | 910 | if args.use_kv_cache:
|
897 |
| - if args.qnn: |
898 |
| - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` |
899 |
| - from executorch.backends.qualcomm.utils.utils import ( |
900 |
| - convert_linear_to_conv2d, |
901 |
| - ) |
902 |
| - |
903 |
| - transforms.append(replace_kv_cache_with_simple_kv_cache) |
904 |
| - transforms.append(replace_sdpa_with_flex_sdpa) |
905 |
| - transforms.append(replace_causal_mask) |
906 |
| - transforms.append(replace_rms_norm_with_native_rms_norm) |
907 |
| - if args.optimized_rotation_path: |
908 |
| - transforms.append(fuse_layer_norms) |
909 |
| - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) |
910 |
| - transforms.append(convert_linear_to_conv2d) |
911 |
| - |
912 |
| - elif args.mps: |
| 911 | + if args.mps: |
913 | 912 | # Currently mps doesn't support sdpa op, use the simpler decomposition
|
914 | 913 | # to get free perf gain.
|
915 | 914 | transforms.append(replace_sdpa_with_simple_sdpa)
|
|
0 commit comments