From 637c5a8856292c68f7670a21054036625f9bf5c1 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 9 Apr 2025 21:15:11 -0700 Subject: [PATCH] [Executorch][llama] Enable quantized sdpa Pull Request resolved: https://github.com/pytorch/executorch/pull/9945 Enable leveraging quantized sdpa op when quantized kv cache is used. Instead of adding yet another arg, at the moment I have chosen to leverage quantize_kv_cache option. ghstack-source-id: 277233485 Differential Revision: [D71833064](https://our.internmc.facebook.com/intern/diff/D71833064/) --- examples/models/llama/TARGETS | 17 ++ examples/models/llama/export_llama_lib.py | 5 + .../source_transformation/custom_kv_cache.py | 63 +++++-- .../llama/source_transformation/sdpa.py | 122 +++++++++++- .../test_quantized_sdpa.py | 173 ++++++++++++++++++ extension/llm/custom_ops/CMakeLists.txt | 4 + extension/llm/custom_ops/custom_ops.py | 124 +++++++++++++ extension/llm/custom_ops/op_sdpa.cpp | 4 - extension/llm/custom_ops/op_sdpa.h | 2 - extension/llm/custom_ops/op_sdpa_aot.cpp | 8 - extension/llm/custom_ops/op_sdpa_impl.h | 16 -- third-party/ao | 2 +- 12 files changed, 489 insertions(+), 51 deletions(-) create mode 100644 examples/models/llama/source_transformation/test_quantized_sdpa.py diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 12eb5fd13dc..b892613c0f8 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -274,3 +274,20 @@ runtime.python_test( ":export_library", ], ) + +runtime.python_test( + name = "quantized_sdpa_source_transform_test", + srcs = [ + "source_transformation/test_quantized_sdpa.py", + ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/extension/llm/custom_ops:custom_ops_aot_py", + ], + deps = [ + ":custom_kv_cache", + ":sdpa", + "//caffe2:torch", + "//executorch/examples/models/llama:llama_transformer", + ], +) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 01179e8ee56..64cbc9e23af 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -63,6 +63,7 @@ replace_kv_cache_with_custom_kv_cache, replace_kv_cache_with_quantized_kv_cache, ) + from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, @@ -77,6 +78,7 @@ replace_sdpa_with_coreml_sdpa, replace_sdpa_with_custom_op, replace_sdpa_with_flex_sdpa, + replace_sdpa_with_quantized_sdpa, replace_sdpa_with_simple_sdpa, ) from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb @@ -1222,11 +1224,14 @@ def _get_source_transforms( # noqa if args.use_sdpa_with_kv_cache: transforms.append(replace_kv_cache_with_custom_kv_cache) + # todo: do this optionally transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" transforms.append(replace_kv_cache_with_quantized_kv_cache) + # Right now + transforms.append(replace_sdpa_with_quantized_sdpa) if args.use_kv_cache: if args.qnn: diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index e7138622ed9..1158a8ba7a6 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -52,6 +52,8 @@ def __init__( self.use_custom_update_cache_op = use_custom_update_cache_op self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 + self.return_float_values = True + self.max_context_length = max_context_length cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) scale_shape = (max_batch_size, max_context_length, n_heads, 1) self.register_buffer( @@ -61,17 +63,17 @@ def __init__( "v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) ) self.register_buffer( - "k_cache_scales", torch.ones(scale_shape, dtype=torch.float64) + "k_cache_scales", torch.ones(scale_shape, dtype=torch.float32) ) self.register_buffer( - "v_cache_scales", torch.ones(scale_shape, dtype=torch.float64) + "v_cache_scales", torch.ones(scale_shape, dtype=torch.float32) ) if cache_type == QuantizedCacheType.AffineAsymmetric: self.register_buffer( - "k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64) + "k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8) ) self.register_buffer( - "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64) + "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8) ) def _quantize(self, value): @@ -91,20 +93,15 @@ def _quantize(self, value): ) return quantized_value, scales, zero_points - def update(self, input_pos, k_val, v_val): - """ - k_val, v_val: [B, H, S, D] - return: [B, H, S, D] - However the storage is [B, S, H, D] so we incur transpose in, transpose out - This shall be removed by subsequent post-export graph pass - """ - k_val = k_val.transpose(1, 2) - v_val = v_val.transpose(1, 2) - # quantize current k_val and store it in the cache + def _quantize_and_update(self, input_pos, k_val, v_val): quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) - quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) + k_scales = k_scales.to(torch.float32) + k_zero_points = k_zero_points.to(self.quantized_cache_dtype) + v_scales = v_scales.to(torch.float32) + v_zero_points = v_zero_points.to(self.quantized_cache_dtype) + if self.use_custom_update_cache_op: start_pos = input_pos[0].item() _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) @@ -125,10 +122,13 @@ def update(self, input_pos, k_val, v_val): self.v_cache_scales[:, input_pos] = v_scales self.v_cache_zero_points[:, input_pos] = v_zero_points + def _update_and_return_float_values(self, input_pos, k_val, v_val): + self._quantize_and_update(input_pos, k_val, v_val) + k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, - self.k_cache_scales, - self.k_cache_zero_points, + self.k_cache_scales.to(torch.float64), + self.k_cache_zero_points.to(torch.int64), torch.iinfo(self.quantized_cache_dtype).min, torch.iinfo(self.quantized_cache_dtype).max, self.quantized_cache_dtype, @@ -136,14 +136,16 @@ def update(self, input_pos, k_val, v_val): ) v_out = torch.ops.quantized_decomposed.dequantize_per_token( self.v_cache, - self.v_cache_scales, - self.v_cache_zero_points, + self.v_cache_scales.to(torch.float64), + self.v_cache_zero_points.to(torch.int64), torch.iinfo(self.quantized_cache_dtype).min, torch.iinfo(self.quantized_cache_dtype).max, self.quantized_cache_dtype, self.cache_fp_type, ) + # When returning float values we jsut use the last value + # instead of dequantized value. start_pos = input_pos[0].item() if self.use_custom_update_cache_op: _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) @@ -152,6 +154,29 @@ def update(self, input_pos, k_val, v_val): k_out[:, input_pos] = k_val v_out[:, input_pos] = v_val + return k_out, v_out + + def _update_and_return_quantized_values(self, input_pos, k_val, v_val): + self._quantize_and_update(input_pos, k_val, v_val) + + return self.k_cache, self.v_cache + + def update(self, input_pos, k_val, v_val): + """ + k_val, v_val: [B, H, S, D] + return: [B, H, S, D] + However the storage is [B, S, H, D] so we incur transpose in, transpose out + This shall be removed by subsequent post-export graph pass + """ + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) + + if self.return_float_values: + k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val) + else: + k_out, v_out = self._update_and_return_quantized_values( + input_pos, k_val, v_val + ) return k_out.transpose(1, 2), v_out.transpose(1, 2) @classmethod diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 1bb7d277545..a50c6aeea22 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -13,7 +13,9 @@ import torch -from executorch.examples.models.llama.attention import KVCache, SDPA +from executorch.examples.models.llama.attention import Attention, KVCache, SDPA + +from .custom_kv_cache import QuantizedKVCache class SDPACustom(torch.nn.Module): @@ -76,6 +78,124 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: return module +class QuantizedSDPA(torch.nn.Module): + """ + A quantized version of the SDPA (Scaled Dot Product Attention) module. + + This module implements attention computation using quantized key-value pairs + to reduce memory footprint and potentially improve performance. It works with + a QuantizedKVCache to store and retrieve quantized key-value tensors. + + The quantization process converts floating point tensors to int8, which requires + maintaining scale and zero point values for proper dequantization during computation. + + Args: + dim (int): The dimension of the model + kv_cache (QuantizedKVCache): The cache for storing quantized key-value pairs + Note that it needs to own kv_cache to access scales and zero points, and since + SDPA forward signature only accepts q, k and v, to allow accessing scales and + zero points, we need to pass kv_cache to SDPA. + """ + + def __init__(self, dim: int, kv_cache: QuantizedKVCache): + super().__init__() + self.dim = dim + self.quantized_dtype = torch.int8 + self.float_dtype = torch.float32 + self.kv_cache = kv_cache + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k_quantized: torch.Tensor, + v_quantized: torch.Tensor, + bsz, + seqlen, + mask, + ): + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) + k_quantized = k_quantized.transpose(1, 2) + v_quantized = v_quantized.transpose(1, 2) + + q_scale, q_zero_point = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( + q, self.quantized_dtype + ) + ) + q_quantized = torch.ops.quantized_decomposed.quantize_per_token( + q, + q_scale, + q_zero_point, + torch.iinfo(self.quantized_dtype).min, + torch.iinfo(self.quantized_dtype).max, + self.quantized_dtype, + ) + q_zero_point_int8 = q_zero_point.to(dtype=torch.int8) + q_scale_fp32 = q_scale.to(dtype=torch.float32) + + k_zero_point_int8 = self.kv_cache.k_cache_zero_points + k_scale_fp32 = self.kv_cache.k_cache_scales + v_zero_point_int8 = self.kv_cache.v_cache_zero_points + v_scale_fp32 = self.kv_cache.v_cache_scales + + start_pos = input_pos[0].item() + output = torch.ops.llama.custom_quantized_sdpa( + q_quantized, + k_quantized, + v_quantized, + start_pos, + None, + 0, + True, + None, + q_zero_point_int8, + q_scale_fp32, + k_zero_point_int8, + k_scale_fp32, + v_zero_point_int8, + v_scale_fp32, + ) + + return output.view(bsz, seqlen, self.dim) + + +def _update_attention_module_with_quantized_sdpa( + module: torch.nn.Module, kv_cache: QuantizedKVCache +): + sdpa = getattr(module, "SDPA", None) + assert sdpa is not None + # pyre-ignore + setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010 + + +def _replace_sdpa_with_quantized_sdpa(module: torch.nn.Module): + for _, child in module.named_children(): + if isinstance(child, Attention): + kv_cache = getattr(child, "kv_cache", None) + if kv_cache is None: + continue + if not isinstance(kv_cache, QuantizedKVCache): + continue + # Only when kv_cache is QuantizedKVCache, we replace SDPA with QuantizedSDPA + sdpa = getattr(child, "SDPA", None) + if sdpa is None: + continue + if not isinstance(sdpa, SDPACustom): + continue + kv_cache.return_float_values = False + _update_attention_module_with_quantized_sdpa(child, kv_cache) + else: + _replace_sdpa_with_quantized_sdpa(child) + + +def replace_sdpa_with_quantized_sdpa(module: torch.nn.Module) -> torch.nn.Module: + from executorch.extension.llm.custom_ops import custom_ops # noqa + + _replace_sdpa_with_quantized_sdpa(module) + return module + + class SDPASimple(torch.nn.Module): def __init__( self, diff --git a/examples/models/llama/source_transformation/test_quantized_sdpa.py b/examples/models/llama/source_transformation/test_quantized_sdpa.py new file mode 100644 index 00000000000..242f3a0876d --- /dev/null +++ b/examples/models/llama/source_transformation/test_quantized_sdpa.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.examples.models.llama.attention import Attention, KVCache, SDPA +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + QuantizedCacheType, + QuantizedKVCache, +) +from executorch.examples.models.llama.source_transformation.sdpa import ( + QuantizedSDPA, + replace_sdpa_with_custom_op, + replace_sdpa_with_quantized_sdpa, + SDPACustom, +) + + +class MockAttention(Attention): + """Mock Attention class for testing purposes.""" + + def __init__( + self, dim, head_dim, n_rep, max_context_len=100, enable_dynamic_shape=False + ): + super().__init__() + self.dim = dim + self.head_dim = head_dim + self.n_rep = n_rep + self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len, enable_dynamic_shape) + self.kv_cache = None + + def forward(self, x, freqs_cos, freqs_sin, **kwargs): + # Not used in tests + pass + + +class QuantizedSDPATest(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.max_context_len = 5 + self.n_kv_heads = 4 + self.n_heads = 8 + self.head_dim = 16 + self.dim = self.n_heads * self.head_dim + self.enable_dynamic_shape = False + self.dtype = torch.float32 + + def _create_test_model(self): + """Create a simple model with SDPA modules for testing.""" + model = torch.nn.Module() + attention = MockAttention( + self.dim, self.head_dim, self.n_heads // self.n_kv_heads + ) + # Add KVCache to the attention module + attention.kv_cache = KVCache( + self.max_batch_size, + self.max_context_len, + self.n_kv_heads, + self.head_dim, + self.enable_dynamic_shape, + dtype=self.dtype, + ) + model.attention = attention + return model + + def test_replace_sdpa_with_quantized_sdpa(self): + """Test that replace_sdpa_with_quantized_sdpa correctly transforms SDPA to QuantizedSDPA.""" + # Create a model with SDPA + model = self._create_test_model() + + # First replace standard SDPA with SDPACustom (required before quantization) + model = replace_sdpa_with_custom_op(model) + self.assertIsInstance(model.attention.SDPA, SDPACustom) + + # Replace KVCache with QuantizedKVCache + model.attention.kv_cache = QuantizedKVCache.from_float( + model.attention.kv_cache, + QuantizedCacheType.AffineAsymmetric, + use_custom_update_cache_op=True, + ) + self.assertIsInstance(model.attention.kv_cache, QuantizedKVCache) + + # Set return_float_values to False to enable quantized operation + model.attention.kv_cache.return_float_values = False + + # Apply the transformation + model = replace_sdpa_with_quantized_sdpa(model) + + # Verify that SDPA has been replaced with QuantizedSDPA + self.assertIsInstance(model.attention.SDPA, QuantizedSDPA) + + # Verify that the QuantizedSDPA has the correct properties + self.assertEqual(model.attention.SDPA.dim, self.dim) + self.assertEqual(model.attention.SDPA.quantized_dtype, torch.int8) + self.assertEqual(model.attention.SDPA.float_dtype, torch.float32) + self.assertIs(model.attention.SDPA.kv_cache, model.attention.kv_cache) + + def test_no_replacement_when_no_quantized_kv_cache(self): + """Test that SDPA is not replaced when there's no QuantizedKVCache.""" + # Create a model with SDPA + model = self._create_test_model() + + # First replace standard SDPA with SDPACustom + model = replace_sdpa_with_custom_op(model) + self.assertIsInstance(model.attention.SDPA, SDPACustom) + + # Apply the transformation without replacing KVCache + model = replace_sdpa_with_quantized_sdpa(model) + + # Verify that SDPA has NOT been replaced with QuantizedSDPA + self.assertIsInstance(model.attention.SDPA, SDPACustom) + self.assertNotIsInstance(model.attention.SDPA, QuantizedSDPA) + + def test_forward_functionality(self): + """Test that the QuantizedSDPA forward function works correctly.""" + # This test requires the custom ops to be loaded, so we'll check if they're available + try: + from executorch.extension.llm.custom_ops import custom_ops # noqa + except ImportError: + self.skipTest( + "Custom ops not available, skipping forward functionality test" + ) + + # Create a model with SDPA + model = self._create_test_model() + + # First replace standard SDPA with SDPACustom + model = replace_sdpa_with_custom_op(model) + + # Replace KVCache with QuantizedKVCache + model.attention.kv_cache = QuantizedKVCache.from_float( + model.attention.kv_cache, + QuantizedCacheType.AffineAsymmetric, + use_custom_update_cache_op=True, + ) + + # Set return_float_values to False to enable quantized operation + model.attention.kv_cache.return_float_values = False + + # Save the original SDPACustom for comparison + # Apply the transformation + model = replace_sdpa_with_quantized_sdpa(model) + + # Create test inputs + input_pos = torch.tensor([0], dtype=torch.int64) + bsz = 1 + seqlen = 1 + q = torch.randn(bsz, self.n_heads, seqlen, self.head_dim, dtype=self.dtype) + k = torch.randn(bsz, self.n_kv_heads, seqlen, self.head_dim, dtype=self.dtype) + v = torch.randn(bsz, self.n_kv_heads, seqlen, self.head_dim, dtype=self.dtype) + + # Update the KV cache + k_quantized, v_quantized = model.attention.kv_cache.update(input_pos, k, v) + + # Run the forward pass with the quantized SDPA + try: + output = model.attention.SDPA( + input_pos, q, k_quantized, v_quantized, bsz, seqlen, None + ) + + # Verify the output shape + self.assertEqual(output.shape, (bsz, seqlen, self.dim)) + except Exception: + # If the forward pass fails, it might be due to missing custom ops + self.skipTest( + "Custom ops not available, skipping forward functionality test" + ) diff --git a/extension/llm/custom_ops/CMakeLists.txt b/extension/llm/custom_ops/CMakeLists.txt index fd2ead6c8b0..42e82dc360f 100644 --- a/extension/llm/custom_ops/CMakeLists.txt +++ b/extension/llm/custom_ops/CMakeLists.txt @@ -21,6 +21,9 @@ if(NOT EXECUTORCH_ROOT) endif() set(_common_compile_options -Wno-deprecated-declarations -fPIC) +if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64|aarch64") + list(APPEND _common_compile_options "-march=armv8.2-a+dotprod") +endif() include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) @@ -38,6 +41,7 @@ include(${EXECUTORCH_SRCS_FILE}) # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) +list(APPEND _common_include_directories ${EXECUTORCH_ROOT}/third-party/ao) # Custom op libraries set(custom_ops_libs pthreadpool) diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index d299b314816..6d96a926497 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -229,3 +229,127 @@ def update_cache_meta( # workaround. Should we just return cache instead? But I am afraid that # will result in extra memory allocation return torch.empty((1,), dtype=value.dtype, device="meta") + + +def _validate_quantized_sdpa_params( + query, + key, + value, + start_pos, + seq_len, + attn_mask, + drpout_p, + is_causal, + scale, + q_scale, + q_zero_point, + k_scale, + k_zero_point, + v_scale, + v_zero_point, + is_seq_at_dim_2, +): + assert ( + query.dim() == 4 + ), f"Expected query to be 4 dimensional but got {query.dim()} dimensions." + assert ( + key.dim() == 4 + ), f"Expected key to be 4 dimensional but got {key.dim()} dimensions." + assert ( + value.dim() == 4 + ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." + + assert (q_scale is not None) and ( + q_zero_point is not None + ), "q_scale and q_zero_point must be provided" + assert (k_scale is not None) and ( + k_zero_point is not None + ), "k_scale and k_zero_point must be provided" + assert (v_scale is not None) and ( + v_zero_point is not None + ), "v_scale and v_zero_point must be provided" + + assert query.dtype == torch.int8, f"Expected query to be int8 but got {query.dtype}" + assert key.dtype == torch.int8, f"Expected key to be int8 but got {key.dtype}" + assert value.dtype == torch.int8, f"Expected value to be int8 but got {value.dtype}" + + assert ( + q_scale.dtype == torch.float32 + ), f"Expected q_scale to be float32 but got {q_scale.dtype}" + assert ( + q_zero_point.dtype == torch.int8 + ), f"Expected q_zero_point to be int8 but got {q_zero_point.dtype}" + assert ( + k_scale.dtype == torch.float32 + ), f"Expected k_scale to be float32 but got {k_scale.dtype}" + assert ( + k_zero_point.dtype == torch.int8 + ), f"Expected k_zero_point to be int8 but got {k_zero_point.dtype}" + assert ( + v_scale.dtype == torch.float32 + ), f"Expected v_scale to be float32 but got {v_scale.dtype}" + assert ( + v_zero_point.dtype == torch.int8 + ), f"Expected v_zero_point to be int8 but got {v_zero_point.dtype}" + + assert ( + query.size()[:-1] == q_scale.size()[:-1] + ), f"Expected query and q_scale to have same size except last dimensions but got {query.size()} and {q_scale.size()}" + assert ( + query.size()[:-1] == q_zero_point.size()[:-1] + ), f"Expected query and q_zero_point to have same size except last dimensions but got {query.size()} and {q_zero_point.size()}" + + assert ( + key.size()[:-1] == k_scale.size()[:-1] + ), f"Expected key and k_scale to have same size except last dimensions but got {key.size()} and {k_scale.size()}" + assert ( + key.size()[:-1] == k_zero_point.size()[:-1] + ), f"Expected key and k_zero_point to have same size except last dimensions but got {key.size()} and {k_zero_point.size()}" + + assert ( + value.size()[:-1] == v_scale.size()[:-1] + ), f"Expected value and v_scale to have same size except last dimensions but got {value.size()} and {v_scale.size()}" + assert ( + value.size()[:-1] == v_zero_point.size()[:-1] + ), f"Expected value and v_zero_point to have same size except last dimensions but got {value.size()} and {v_zero_point.size()}" + + +@impl(custom_ops_lib, "custom_quantized_sdpa", "Meta") +def custom_quantized_sdpa_meta( + query, + key, + value, + start_pos, + attn_mask=None, + drpout_p=0.0, + is_causal=False, + scale=None, + q_zero_point=None, + q_scale=None, + k_zero_point=None, + k_scale=None, + v_zero_point=None, + v_scale=None, + is_seq_at_dim_2=False, +): + seq_len = query.size(1) + _validate_quantized_sdpa_params( + query, + key, + value, + start_pos, + seq_len, + attn_mask, + drpout_p, + is_causal, + scale, + q_scale, + q_zero_point, + k_scale, + k_zero_point, + v_scale, + v_zero_point, + is_seq_at_dim_2, + ) + + return torch.empty(query.size(), dtype=torch.float32, device="meta") diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index c5c9b79b280..a6f80a0d66d 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -477,7 +477,6 @@ Tensor& custom_sdpa_out_impl( return output; } -#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA Tensor& custom_quantized_sdpa_out( RuntimeContext& ctx, const Tensor& q, @@ -516,7 +515,6 @@ Tensor& custom_quantized_sdpa_out( v_scales, is_seq_at_dim_2); } -#endif // ENABLE_CUSTOM_QUANTIZED_SDPA /* Input params @@ -619,9 +617,7 @@ EXECUTORCH_LIBRARY( "custom_sdpa.out", torch::executor::native::custom_sdpa_out); -#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA EXECUTORCH_LIBRARY( llama, "custom_quantized_sdpa.out", torch::executor::native::custom_quantized_sdpa_out); -#endif // ENABLE_CUSTOM_QUANTIZED_SDPA diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index 3deb27b3989..9d357eb6ea1 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -56,7 +56,6 @@ Tensor& flash_attention_kernel_out( const optional scale, Tensor& output); -#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA Tensor& custom_quantized_sdpa_out( RuntimeContext& ctx, const Tensor& q, @@ -76,7 +75,6 @@ Tensor& custom_quantized_sdpa_out( const optional& v_scales, const bool is_seq_at_dim_1, Tensor& output); -#endif // ENABLE_CUSTOM_QUANTIZED_SDPA } // namespace native } // namespace executor } // namespace torch diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 2da915a19b8..ff367c85c8a 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -77,7 +77,6 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); -#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -118,7 +117,6 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_zero_points, const std::optional& v_scales, const bool is_seq_at_dim_2); -#endif // ENABLE_CUSTOM_QUANTIZED_SDPA Tensor& update_cache_out_no_context( const Tensor& value, @@ -241,7 +239,6 @@ at::Tensor custom_sdpa_aten( return output; } -#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -322,7 +319,6 @@ at::Tensor custom_quantized_sdpa_aten( output); return output; } -#endif // ENABLE_CUSTOM_QUANTIZED_SDPA Tensor& update_cache_out_no_context( const Tensor& value, @@ -371,7 +367,6 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "update_cache.out(Tensor value, Tensor(a!) cache, " "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); -#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA m.def( "custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " @@ -384,7 +379,6 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, " "Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, " "Tensor? v_scales=None, bool is_seq_at_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)"); -#endif // ENABLE_CUSTOM_QUANTIZED_SDPA } // TODO: Rename this file to op_custom_ops_aot.cpp @@ -403,7 +397,6 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl( "update_cache.out", WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); -#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA m.impl( "custom_quantized_sdpa", torch::executor::native::custom_quantized_sdpa_aten); @@ -411,5 +404,4 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { "custom_quantized_sdpa.out", WRAP_TO_ATEN( torch::executor::native::custom_quantized_sdpa_out_no_context, 15)); -#endif // ENABLE_CUSTOM_QUANTIZED_SDPA } diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 7607a1e283d..1f19fa75de7 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -23,9 +23,7 @@ #endif #include -#if defined(ENABLE_CUSTOM_QUANTIZED_SDPA) #include -#endif namespace torch { namespace executor { @@ -78,7 +76,6 @@ void _q_at_k_gemm( q_data.dtype == ScalarType::Char || q_data.dtype == ScalarType::Float, "q and k must be either int8 or float"); if (q_data.dtype == ScalarType::Char) { -#if defined(ENABLE_CUSTOM_QUANTIZED_SDPA) if constexpr (std::is_same::value) { int a_stride_m_tmp, b_stride_n_tmp; auto kernel = torchao::kernels::cpu::quantized_matmul:: @@ -105,11 +102,6 @@ void _q_at_k_gemm( ET_CHECK_MSG( false, "Accumulation in dtype other than float not supported yet"); } -#else - ET_CHECK_MSG( - false, - "Quantized SDPA is not enabled. Check ENABLE_CUSTOM_QUANTIZED_SDPA compile flag"); -#endif } else { ::executorch::cpublas::gemm( ::executorch::cpublas::TransposeType::Transpose, @@ -141,7 +133,6 @@ void _qk_at_v_gemm( const int64_t o_stride_m, const accum_t beta) { if (v_data.dtype == ScalarType::Char) { -#if defined(ENABLE_CUSTOM_QUANTIZED_SDPA) if constexpr (std::is_same::value) { int a_stride_m_tmp, b_stride_n_tmp; auto kernel = torchao::kernels::cpu::quantized_matmul:: @@ -165,11 +156,6 @@ void _qk_at_v_gemm( ET_CHECK_MSG( false, "Accumulation in dtype other than float not supported yet"); } -#else - ET_CHECK_MSG( - false, - "Quantized SDPA is not enabled. Check ENABLE_CUSTOM_QUANTIZED_SDPA compile flag"); -#endif } else { ::executorch::cpublas::gemm( ::executorch::cpublas::TransposeType::NoTranspose, @@ -487,9 +473,7 @@ void cpu_flash_attention( } bool is_quantized_sdpa = false; -#if defined(ENABLE_CUSTOM_QUANTIZED_SDPA) is_quantized_sdpa = query.scalar_type() == ScalarType::Char; -#endif auto strides = query.strides(); int64_t qStrideB = strides[0]; diff --git a/third-party/ao b/third-party/ao index 8b264ce1e15..9516764a971 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit 8b264ce1e1597f4c3f6476609334be05d294f092 +Subproject commit 9516764a97147231c72377bc1067c5e997de02f5