From c10ba1525db53f3f261cd2cf899417e2775d3d45 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 1 Apr 2025 12:58:21 -0700 Subject: [PATCH 1/7] Add gguf q4_k_s quantization Summary: Didn't implement the algorithm to choose_qparams from gguf, since it's complicated, e.g. https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L744 and https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L827C14-L827C28 but implemented a simple choose_qparams that can fit the gguf format: Q4_K: w = q * block_scale(6-bit) + block_min(6-bit) Test Plan: python test/prototype/test_gguf_quant.py Reviewers: Subscribers: Tasks: Tags: --- test/prototype/test_gguf_quant.py | 53 ++++ .../prototype/quantization/gguf/__init__.py | 11 + .../gguf/gguf_quantized_tensor.py | 237 ++++++++++++++++++ torchao/quantization/quant_primitives.py | 171 +++++++++++++ 4 files changed, 472 insertions(+) create mode 100644 test/prototype/test_gguf_quant.py create mode 100644 torchao/prototype/quantization/gguf/__init__.py create mode 100644 torchao/prototype/quantization/gguf/gguf_quantized_tensor.py diff --git a/test/prototype/test_gguf_quant.py b/test/prototype/test_gguf_quant.py new file mode 100644 index 0000000000..f469931792 --- /dev/null +++ b/test/prototype/test_gguf_quant.py @@ -0,0 +1,53 @@ +import unittest + +import torch + +from torchao.prototype.quantization.gguf import ( + GGUFQuantizedTensor, + GGUFWeightOnlyConfig, + choose_qparams_gguf, +) +from torchao.quantization import quantize_ +from torchao.quantization.utils import compute_error + + +class TestGGUFQuantization(unittest.TestCase): + def setUp(self): + torch.manual_seed(123) + self.input = torch.randn(2, 256, dtype=torch.float32) + self.n_super_blocks = 8 + self.block_size = (1, 32) + self.dtype = torch.uint4 + + def test_choose_qparams_gguf(self): + ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) = choose_qparams_gguf(self.input, self.block_size, self.dtype) + + assert super_block_scale_scale.shape, (2, 8) + assert super_block_min_scale.shape, (2, 8) + assert quantized_block_scale.shape, (2, 32) + + def test_gguf_quantized_tensor_from_float(self): + gqt = GGUFQuantizedTensor.from_float( + self.input, + self.n_super_blocks, + self.dtype, + ) + + dequant = gqt.dequantize() + + sqnr = compute_error(dequant, self.input) + self.assertGreater(sqnr, 30) + + def test_quantize_api(self): + m = torch.nn.Sequential(torch.nn.Linear(256, 64)) + quantize_(m, GGUFWeightOnlyConfig()) + assert type(m[0].weight) == GGUFQuantizedTensor + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/prototype/quantization/gguf/__init__.py b/torchao/prototype/quantization/gguf/__init__.py new file mode 100644 index 0000000000..4ada094dac --- /dev/null +++ b/torchao/prototype/quantization/gguf/__init__.py @@ -0,0 +1,11 @@ +from .gguf_quantized_tensor import ( + GGUFQuantizedTensor, + GGUFWeightOnlyConfig, + choose_qparams_gguf, +) + +__all__ = [ + "GGUFQuantizedTensor", + "choose_qparams_gguf", + "GGUFWeightOnlyConfig", +] diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py new file mode 100644 index 0000000000..8391835b9c --- /dev/null +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional + +import torch + +from torchao.core.config import AOBaseConfig +from torchao.quantization.quant_primitives import ( + choose_qparams_gguf, + dequantize_gguf, + quantize_gguf, +) +from torchao.quantization.transform_module import register_quantize_module_handler +from torchao.utils import TorchAOBaseTensor + +_QK_K = 256 + +__all__ = [ + "GGUFQuantizedTensor", + "choose_qparams_gguf", + "quantize_gguf", + "dequantize_gguf", + "GGUFWeightOnlyConfig", +] + + +class GGUFQuantizedTensor(TorchAOBaseTensor): + """ + A Tensor subclass that when applied to a weight used in a linear op/module, + changes that linear op to a weight-only int4 quantized linear op with groupwise + affine quantization on the weight. + """ + + @staticmethod + def __new__( + cls, + n_super_blocks, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + shape, + **kwargs, + ): + kwargs["device"] = kwargs.get("device", super_block_scale_scale.device) + kwargs["dtype"] = kwargs.get("dtype", super_block_scale_scale.dtype) + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + n_super_blocks, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + shape, + **kwargs, + ): + self.n_super_blocks = n_super_blocks + self.super_block_scale_scale = super_block_scale_scale + self.super_block_min_scale = super_block_min_scale + self.quantized_block_scale = quantized_block_scale + self.quantized_block_min = quantized_block_min + self.int_data = int_data + + def _apply_fn_to_data(self, fn): + return self.__class__( + self.n_super_blocks, + fn(self.super_block_scale_scale), + fn(self.super_block_min_sclae), + fn(self.quantized_block_scale), + fn(self.quantized_block_min), + fn(self.int_data), + self.shape, + dtype=self.dtype, + ) + + def __tensor_flatten__(self): + return [ + "super_block_scale_scale", + "super_block_min_scale", + "quantized_block_scale", + "quantized_block_min", + "int_data", + ], ( + self.n_super_blocks, + self.dtype, + self.shape, + ) + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None + ): + ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + ) = ( + tensor_data_dict["super_block_scale_scale"], + tensor_data_dict["super_block_min_scale"], + tensor_data_dict["quantized_block_scale"], + tensor_data_dict["quantized_block_min"], + tensor_data_dict["int_data"], + ) + n_super_blocks, dtype, shape = attributes + return cls( + n_super_blocks, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + shape if outer_size is None else outer_size, + dtype=dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + block_size = tuple( + [1] * (self.int_data.ndim - 1) + [_QK_K // self.n_super_blocks] + ) + return dequantize_gguf( + self.int_data, + block_size, + self.dtype, + self.super_block_scale_scale, + self.super_block_min_scale, + self.quantized_block_scale, + self.quantized_block_min, + ) + + def detach(self): + """ + Returns a new `CodebookQuantizedTensor`. + """ + return self.__class__( + self.n_super_blocks, + self.super_block_scale_scale.detach(), + self.super_block_min_scale.detach(), + self.quantized_block_scale.detach(), + self.quantized_block_min.detach(), + self.int_data.detach(), + self.shape, + dtype=self.dtype, + ) + + def requires_grad_(self, requires_grad=False): + """ + Modifies the tensor's `requires_grad` status in-place. + """ + assert not requires_grad, "Only requires_grad == False is supported" + return self + + @classmethod + def from_float(cls, input_float, n_super_blocks, target_dtype): + """ + Method used to convert a linear weight tensor to an instance of the + GGMLInt4LinearWeight subclass. + + Example usage:: + + model.lin_mod.weight = ( + GGMLInt4LinearWeight.from_float(model.lin_mod.weight) + ) + """ + assert ( + target_dtype == torch.uint4 + ), "only uint4 quantization is supported right now" + block_size = (1, _QK_K // n_super_blocks) + ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) = choose_qparams_gguf(input_float, block_size, target_dtype) + + int_data = quantize_gguf( + input_float, + block_size, + target_dtype, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) + return cls( + n_super_blocks, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + input_float.shape, + dtype=torch.float16, + ) + + +@dataclass +class GGUFWeightOnlyConfig(AOBaseConfig): + dtype: torch.dtype = torch.uint4 + n_super_blocks: int = 8 + + +@register_quantize_module_handler(GGUFWeightOnlyConfig) +def _gguf_weight_only_transform( + module: torch.nn.Module, + config: GGUFWeightOnlyConfig, +): + """ + Applies gguf weight-only quantization to linear layers. + + Args: + dtype: torch.uint1 to torch.uint8, torch.int32 supported. + n_super_blocks: the number of super blocks in a 256 element block for gguf, e.g. when it is 8 + it means we have blocks of 32 and 8 blocks in a superblock of 256 elements. + Returns: + Callable for quantization transformation. + """ + weight = module.weight + if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0): + return module + + quantized_weight = GGUFQuantizedTensor.from_float( + weight, n_super_blocks=config.n_super_blocks, target_dtype=config.dtype + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + return module diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 05be8c5c30..252f6540ec 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -42,6 +42,9 @@ "choose_qparams_affine_float8", "quantize_affine_float8", "dequantize_affine_float8", + "choose_qparams_gguf", + "quantize_gguf", + "dequantize_gguf", ] @@ -195,6 +198,8 @@ class TorchAODType(Enum): _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS) assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys() +_GGUF_QK_K = 256 + _ONES_TABLE = [_n_ones(i) for i in range(8)] quant_lib = torch.library.Library("quant", "FRAGMENT") @@ -1039,6 +1044,172 @@ def reshape_w(w): return q_w, s_group, s_channel, w_ref +def choose_qparams_gguf( + input: Optional[torch.Tensor], + block_size: List[int], + target_dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # there are two sets of qparams: quantized_block_scale, quantized_block_min and super_block_scale_scale and super_block_min_scale + # the relationship is the following: + # block_scale = quantized_block_scale * super_block_sclae + # block_min = quantized_block_min * super_block_min + # quantized_val = (float_val - block_min) / block_scale + quant_min + # first we calculate block_scale and block_min + # then we calculate super_block_scale_scale and super_block_min_scale + # after that we can calculate quantized_block_scale and quantized_min_scale + # the returned values are: super_block_scale_scale, super_block_min_scale, quantized_block_scale + # and quantized_min_scale + + # 1. get block_scale block_min + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + input = input.view(shape_for_reduction) + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + quant_max = 15 + quant_min = 0 + # asymmetric quant to fully utilize the range + block_scale = max_val / (float(quant_max - quant_min) / 2) + block_scale = (max_val - min_val) / float(quant_max - quant_min) + block_min = min_val + + # 2. get super_block_scale_scale and super_block_min_scale + block_scale_clone = block_scale.clone() + block_min_clone = block_min.clone() + + assert _GGUF_QK_K % block_size[-1] == 0 + super_block_size = (1, _GGUF_QK_K // block_size[-1]) + shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_size, block_scale_clone.size() + ) + block_scale_clone = block_scale_clone.view(shape_for_reduction) + block_min_clone = block_min_clone.view(shape_for_reduction) + + shape_after_reduction = shape_for_reduction.copy() + for i in reduction_dims: + shape_after_reduction[i] = 1 + + block_scale_absmax = torch.amax( + torch.abs(block_scale_clone), dim=reduction_dims, keepdim=False + ) + block_min_absmax = torch.amax( + torch.abs(block_min_clone), dim=reduction_dims, keepdim=False + ) + + # 2. get super_block_scale_scale and super_block_min_scale + quant_max = 2**6 - 1 + quant_min = 0 + super_block_scale_scale = block_scale_absmax / float(quant_max - quant_min) + super_block_min_scale = block_min_absmax / float(quant_max - quant_min) + super_block_scale_scale_view = super_block_scale_scale.view(shape_after_reduction) + super_block_min_scale_view = super_block_min_scale.view(shape_after_reduction) + + # 3. quantize block scale and min are stored in 6 bits using super_block_scale_scale and super_block_min_scale + quantized_block_scale = torch.clamp( + block_scale / super_block_scale_scale_view, quant_min, quant_max + ) + quantized_block_min = torch.clamp( + block_min / super_block_min_scale_view, quant_min, quant_max + ) + return ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) + + +def quantize_gguf( + input: torch.Tensor, + block_size: List[int], + target_dtype: torch.dtype, + super_block_scale_scale: torch.Tensor, + super_block_min_scale: torch.Tensor, + quantized_block_scale: torch.Tensor, + quantized_block_min: torch.Tensor, +) -> torch.Tensor: + assert target_dtype == torch.uint4 + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + + block_shape_after_reduction = shape_for_reduction.copy() + for i in reduction_dims: + block_shape_after_reduction[i] = 1 + + quantized_block_scale = quantized_block_scale.view(block_shape_after_reduction) + + super_block_size = (1, _GGUF_QK_K // block_size[-1], 1) + shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_size, quantized_block_scale.size() + ) + super_block_shape_after_reduction = shape_for_reduction.copy() + for i in reduction_dims: + super_block_shape_after_reduction[i] = 1 + + super_block_scale_scale = super_block_scale_scale.view( + super_block_shape_after_reduction + ) + super_block_min_scale = super_block_min_scale.view( + super_block_shape_after_reduction + ) + + quantized_block_scale = quantized_block_scale.view(block_shape_after_reduction) + quantized_block_min = quantized_block_min.view(block_shape_after_reduction) + + block_scale = super_block_scale_scale * quantized_block_scale + block_min = super_block_min_scale * quantized_block_min + int_data = (input - block_min) / block_scale + int_data = int_data.view(original_shape) + return int_data + + +def dequantize_gguf( + input: torch.Tensor, + block_size: List[int], + target_dtype: torch.dtype, + super_block_scale_scale: torch.Tensor, + super_block_min_scale: torch.Tensor, + quantized_block_scale: torch.Tensor, + quantized_block_min: torch.Tensor, +) -> torch.Tensor: + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + + quantized_block_scale = quantized_block_scale.view(shape_after_reduction) + quantized_block_min = quantized_block_min.view(shape_after_reduction) + + super_block_size = (1, _GGUF_QK_K // block_size[-1], 1) + super_block_shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_size, quantized_block_scale.size() + ) + super_block_shape_after_reduction = super_block_shape_for_reduction + for i in reduction_dims: + super_block_shape_after_reduction[i] = 1 + + super_block_scale_scale = super_block_scale_scale.view( + super_block_shape_after_reduction + ) + super_block_min_scale = super_block_min_scale.view( + super_block_shape_after_reduction + ) + + block_scale = super_block_scale_scale * quantized_block_scale + block_min = super_block_min_scale * quantized_block_min + dequant = input * block_scale + block_min + dequant = dequant.view(original_shape) + return dequant + + def dequantize_affine_qqq( w: torch.Tensor, s_group: torch.Tensor, From 163267d2fe2b756abdc33c285215bdf739a35be6 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 4 Apr 2025 15:29:21 -0700 Subject: [PATCH 2/7] fix --- test/prototype/test_gguf_quant.py | 4 +-- .../gguf/gguf_quantized_tensor.py | 32 ++++++++++--------- torchao/quantization/quant_primitives.py | 19 +++++++---- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/test/prototype/test_gguf_quant.py b/test/prototype/test_gguf_quant.py index f469931792..1f11ef560d 100644 --- a/test/prototype/test_gguf_quant.py +++ b/test/prototype/test_gguf_quant.py @@ -15,7 +15,7 @@ class TestGGUFQuantization(unittest.TestCase): def setUp(self): torch.manual_seed(123) self.input = torch.randn(2, 256, dtype=torch.float32) - self.n_super_blocks = 8 + self.n_blocks_per_superblock = 8 self.block_size = (1, 32) self.dtype = torch.uint4 @@ -34,7 +34,7 @@ def test_choose_qparams_gguf(self): def test_gguf_quantized_tensor_from_float(self): gqt = GGUFQuantizedTensor.from_float( self.input, - self.n_super_blocks, + self.n_blocks_per_superblock, self.dtype, ) diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index 8391835b9c..f79c246cb4 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -39,7 +39,7 @@ class GGUFQuantizedTensor(TorchAOBaseTensor): @staticmethod def __new__( cls, - n_super_blocks, + n_blocks_per_superblock, super_block_scale_scale, super_block_min_scale, quantized_block_scale, @@ -55,7 +55,7 @@ def __new__( def __init__( self, - n_super_blocks, + n_blocks_per_superblock, super_block_scale_scale, super_block_min_scale, quantized_block_scale, @@ -64,7 +64,7 @@ def __init__( shape, **kwargs, ): - self.n_super_blocks = n_super_blocks + self.n_blocks_per_superblock = n_blocks_per_superblock self.super_block_scale_scale = super_block_scale_scale self.super_block_min_scale = super_block_min_scale self.quantized_block_scale = quantized_block_scale @@ -73,7 +73,7 @@ def __init__( def _apply_fn_to_data(self, fn): return self.__class__( - self.n_super_blocks, + self.n_blocks_per_superblock, fn(self.super_block_scale_scale), fn(self.super_block_min_sclae), fn(self.quantized_block_scale), @@ -91,7 +91,7 @@ def __tensor_flatten__(self): "quantized_block_min", "int_data", ], ( - self.n_super_blocks, + self.n_blocks_per_superblock, self.dtype, self.shape, ) @@ -113,9 +113,9 @@ def __tensor_unflatten__( tensor_data_dict["quantized_block_min"], tensor_data_dict["int_data"], ) - n_super_blocks, dtype, shape = attributes + n_blocks_per_superblock, dtype, shape = attributes return cls( - n_super_blocks, + n_blocks_per_superblock, super_block_scale_scale, super_block_min_scale, quantized_block_scale, @@ -127,7 +127,7 @@ def __tensor_unflatten__( def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: block_size = tuple( - [1] * (self.int_data.ndim - 1) + [_QK_K // self.n_super_blocks] + [1] * (self.int_data.ndim - 1) + [_QK_K // self.n_blocks_per_superblock] ) return dequantize_gguf( self.int_data, @@ -144,7 +144,7 @@ def detach(self): Returns a new `CodebookQuantizedTensor`. """ return self.__class__( - self.n_super_blocks, + self.n_blocks_per_superblock, self.super_block_scale_scale.detach(), self.super_block_min_scale.detach(), self.quantized_block_scale.detach(), @@ -162,7 +162,7 @@ def requires_grad_(self, requires_grad=False): return self @classmethod - def from_float(cls, input_float, n_super_blocks, target_dtype): + def from_float(cls, input_float, n_blocks_per_superblock, target_dtype): """ Method used to convert a linear weight tensor to an instance of the GGMLInt4LinearWeight subclass. @@ -176,7 +176,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype): assert ( target_dtype == torch.uint4 ), "only uint4 quantization is supported right now" - block_size = (1, _QK_K // n_super_blocks) + block_size = (1, _QK_K // n_blocks_per_superblock) ( super_block_scale_scale, super_block_min_scale, @@ -194,7 +194,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype): quantized_block_min, ) return cls( - n_super_blocks, + n_blocks_per_superblock, super_block_scale_scale, super_block_min_scale, quantized_block_scale, @@ -208,7 +208,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype): @dataclass class GGUFWeightOnlyConfig(AOBaseConfig): dtype: torch.dtype = torch.uint4 - n_super_blocks: int = 8 + n_blocks_per_superblock: int = 8 @register_quantize_module_handler(GGUFWeightOnlyConfig) @@ -221,7 +221,7 @@ def _gguf_weight_only_transform( Args: dtype: torch.uint1 to torch.uint8, torch.int32 supported. - n_super_blocks: the number of super blocks in a 256 element block for gguf, e.g. when it is 8 + n_blocks_per_superblock: the number of super blocks in a 256 element block for gguf, e.g. when it is 8 it means we have blocks of 32 and 8 blocks in a superblock of 256 elements. Returns: Callable for quantization transformation. @@ -231,7 +231,9 @@ def _gguf_weight_only_transform( return module quantized_weight = GGUFQuantizedTensor.from_float( - weight, n_super_blocks=config.n_super_blocks, target_dtype=config.dtype + weight, + n_blocks_per_superblock=config.n_blocks_per_superblock, + target_dtype=config.dtype, ) module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) return module diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 252f6540ec..fdf6305bdb 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1098,19 +1098,26 @@ def choose_qparams_gguf( ) # 2. get super_block_scale_scale and super_block_min_scale - quant_max = 2**6 - 1 - quant_min = 0 - super_block_scale_scale = block_scale_absmax / float(quant_max - quant_min) - super_block_min_scale = block_min_absmax / float(quant_max - quant_min) + # TODO: make this configurable + # we also quantize the quantization parameters (scale and min) for each block to 6 bit + # for Q4_K + qparam_quant_max = 2**6 - 1 + qparam_quant_min = 0 + super_block_scale_scale = block_scale_absmax / float( + qparam_quant_max - qparam_quant_min + ) + super_block_min_scale = block_min_absmax / float( + qparam_quant_max - qparam_quant_min + ) super_block_scale_scale_view = super_block_scale_scale.view(shape_after_reduction) super_block_min_scale_view = super_block_min_scale.view(shape_after_reduction) # 3. quantize block scale and min are stored in 6 bits using super_block_scale_scale and super_block_min_scale quantized_block_scale = torch.clamp( - block_scale / super_block_scale_scale_view, quant_min, quant_max + block_scale / super_block_scale_scale_view, qparam_quant_min, qparam_quant_max ) quantized_block_min = torch.clamp( - block_min / super_block_min_scale_view, quant_min, quant_max + block_min / super_block_min_scale_view, qparam_quant_min, qparam_quant_max ) return ( super_block_scale_scale, From 7e1e0197dedc431e777b79fb9fc510a517df5fb7 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 5 Apr 2025 22:22:38 -0700 Subject: [PATCH 3/7] test with phi4 --- .../prototype/quantization/gguf/__init__.py | 4 +- torchao/prototype/quantization/gguf/api.py | 42 ++++++ .../gguf/gguf_quantized_tensor.py | 78 +++++------- torchao/quantization/quant_primitives.py | 120 ++++++++++-------- 4 files changed, 148 insertions(+), 96 deletions(-) create mode 100644 torchao/prototype/quantization/gguf/api.py diff --git a/torchao/prototype/quantization/gguf/__init__.py b/torchao/prototype/quantization/gguf/__init__.py index 4ada094dac..6784779b1c 100644 --- a/torchao/prototype/quantization/gguf/__init__.py +++ b/torchao/prototype/quantization/gguf/__init__.py @@ -1,11 +1,9 @@ from .gguf_quantized_tensor import ( GGUFQuantizedTensor, - GGUFWeightOnlyConfig, - choose_qparams_gguf, ) +from .api import GGUFWeightOnlyConfig __all__ = [ "GGUFQuantizedTensor", - "choose_qparams_gguf", "GGUFWeightOnlyConfig", ] diff --git a/torchao/prototype/quantization/gguf/api.py b/torchao/prototype/quantization/gguf/api.py new file mode 100644 index 0000000000..d277db50cd --- /dev/null +++ b/torchao/prototype/quantization/gguf/api.py @@ -0,0 +1,42 @@ +import torch +from dataclasses import dataclass +from torchao.core.config import AOBaseConfig +from torchao.quantization.transform_module import register_quantize_module_handler +from .gguf_quantized_tensor import GGUFQuantizedTensor + +__all__ = [ + "GGUFWeightOnlyConfig", +] + +@dataclass +class GGUFWeightOnlyConfig(AOBaseConfig): + dtype: torch.dtype = torch.uint4 + n_blocks_per_superblock: int = 8 + + +@register_quantize_module_handler(GGUFWeightOnlyConfig) +def _gguf_weight_only_transform( + module: torch.nn.Module, + config: GGUFWeightOnlyConfig, +): + """ + Applies gguf weight-only quantization to linear layers. + + Args: + dtype: torch.uint1 to torch.uint8, torch.int32 supported. + n_blocks_per_superblock: the number of super blocks in a 256 element block for gguf, e.g. when it is 8 + it means we have blocks of 32 and 8 blocks in a superblock of 256 elements. + Returns: + Callable for quantization transformation. + """ + weight = module.weight + if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0): + return module + + quantized_weight = GGUFQuantizedTensor.from_float( + weight, + n_blocks_per_superblock=config.n_blocks_per_superblock, + target_dtype=config.dtype, + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + return module diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index f79c246cb4..b94c38027b 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -4,28 +4,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass from typing import Optional import torch -from torchao.core.config import AOBaseConfig from torchao.quantization.quant_primitives import ( choose_qparams_gguf, dequantize_gguf, quantize_gguf, ) -from torchao.quantization.transform_module import register_quantize_module_handler from torchao.utils import TorchAOBaseTensor +from torch.utils._python_dispatch import return_and_correct_aliasing _QK_K = 256 +aten = torch.ops.aten __all__ = [ "GGUFQuantizedTensor", - "choose_qparams_gguf", - "quantize_gguf", - "dequantize_gguf", - "GGUFWeightOnlyConfig", ] @@ -126,6 +121,9 @@ def __tensor_unflatten__( ) def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + block_size = tuple( [1] * (self.int_data.ndim - 1) + [_QK_K // self.n_blocks_per_superblock] ) @@ -137,19 +135,20 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor self.super_block_min_scale, self.quantized_block_scale, self.quantized_block_min, + output_dtype=output_dtype, ) - def detach(self): + def _apply_fn_to_data(self, fn): """ Returns a new `CodebookQuantizedTensor`. """ return self.__class__( self.n_blocks_per_superblock, - self.super_block_scale_scale.detach(), - self.super_block_min_scale.detach(), - self.quantized_block_scale.detach(), - self.quantized_block_min.detach(), - self.int_data.detach(), + fn(self.super_block_scale_scale), + fn(self.super_block_min_scale), + fn(self.quantized_block_scale), + fn(self.quantized_block_min), + fn(self.int_data), self.shape, dtype=self.dtype, ) @@ -201,39 +200,32 @@ def from_float(cls, input_float, n_blocks_per_superblock, target_dtype): quantized_block_min, int_data, input_float.shape, - dtype=torch.float16, ) -@dataclass -class GGUFWeightOnlyConfig(AOBaseConfig): - dtype: torch.dtype = torch.uint4 - n_blocks_per_superblock: int = 8 +implements = GGUFQuantizedTensor.implements +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) -@register_quantize_module_handler(GGUFWeightOnlyConfig) -def _gguf_weight_only_transform( - module: torch.nn.Module, - config: GGUFWeightOnlyConfig, -): - """ - Applies gguf weight-only quantization to linear layers. - - Args: - dtype: torch.uint1 to torch.uint8, torch.int32 supported. - n_blocks_per_superblock: the number of super blocks in a 256 element block for gguf, e.g. when it is 8 - it means we have blocks of 32 and 8 blocks in a superblock of 256 elements. - Returns: - Callable for quantization transformation. - """ - weight = module.weight - if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0): - return module - - quantized_weight = GGUFQuantizedTensor.from_float( - weight, - n_blocks_per_superblock=config.n_blocks_per_superblock, - target_dtype=config.dtype, + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, ) - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - return module + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + dtype = input_tensor.dtype + if hasattr(weight_tensor, "dequantize"): + weight_tensor = weight_tensor.dequantize(output_dtype=dtype) + + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index fdf6305bdb..812c29f869 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1049,16 +1049,19 @@ def choose_qparams_gguf( block_size: List[int], target_dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # there are two sets of qparams: quantized_block_scale, quantized_block_min and super_block_scale_scale and super_block_min_scale - # the relationship is the following: - # block_scale = quantized_block_scale * super_block_sclae - # block_min = quantized_block_min * super_block_min - # quantized_val = (float_val - block_min) / block_scale + quant_min - # first we calculate block_scale and block_min - # then we calculate super_block_scale_scale and super_block_min_scale - # after that we can calculate quantized_block_scale and quantized_min_scale - # the returned values are: super_block_scale_scale, super_block_min_scale, quantized_block_scale - # and quantized_min_scale + """ + There are two sets of qparams: quantized_block_scale, quantized_block_min and super_block_scale_scale and super_block_min_scale + the relationship is the following: + block_scale = quantized_block_scale * super_block_sclae + block_min = quantized_block_min * super_block_min + quantized_val = (float_val - block_min) / block_scale + quant_min + first we calculate block_scale and block_min + then we calculate super_block_scale_scale and super_block_min_scale + after that we can calculate quantized_block_scale and quantized_min_scale + the returned values are: super_block_scale_scale, super_block_min_scale, quantized_block_scale + and quantized_min_scale + """ + dtype = input.dtype # 1. get block_scale block_min shape_for_reduction, reduction_dims = _get_reduction_params( @@ -1075,26 +1078,23 @@ def choose_qparams_gguf( block_min = min_val # 2. get super_block_scale_scale and super_block_min_scale - block_scale_clone = block_scale.clone() - block_min_clone = block_min.clone() - assert _GGUF_QK_K % block_size[-1] == 0 super_block_size = (1, _GGUF_QK_K // block_size[-1]) shape_for_reduction, reduction_dims = _get_reduction_params( - super_block_size, block_scale_clone.size() + super_block_size, block_scale.size() ) - block_scale_clone = block_scale_clone.view(shape_for_reduction) - block_min_clone = block_min_clone.view(shape_for_reduction) + block_scale = block_scale.view(shape_for_reduction) + block_min = block_min.view(shape_for_reduction) shape_after_reduction = shape_for_reduction.copy() for i in reduction_dims: shape_after_reduction[i] = 1 block_scale_absmax = torch.amax( - torch.abs(block_scale_clone), dim=reduction_dims, keepdim=False + torch.abs(block_scale), dim=reduction_dims, keepdim=False ) block_min_absmax = torch.amax( - torch.abs(block_min_clone), dim=reduction_dims, keepdim=False + torch.abs(block_min), dim=reduction_dims, keepdim=False ) # 2. get super_block_scale_scale and super_block_min_scale @@ -1120,10 +1120,10 @@ def choose_qparams_gguf( block_min / super_block_min_scale_view, qparam_quant_min, qparam_quant_max ) return ( - super_block_scale_scale, - super_block_min_scale, - quantized_block_scale, - quantized_block_min, + super_block_scale_scale.to(dtype), + super_block_min_scale.to(dtype), + quantized_block_scale.to(dtype), + quantized_block_min.to(dtype), ) @@ -1137,40 +1137,48 @@ def quantize_gguf( quantized_block_min: torch.Tensor, ) -> torch.Tensor: assert target_dtype == torch.uint4 - shape_for_reduction, reduction_dims = _get_reduction_params( + + # step 1: first order quantization + # just going through shape calculation for block_scale and block_min to get the correct shape + input_shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input.size() ) - original_shape = input.shape - input = input.view(shape_for_reduction) - - block_shape_after_reduction = shape_for_reduction.copy() + block_qparam_shape_after_reduction = input_shape_for_reduction.copy() for i in reduction_dims: - block_shape_after_reduction[i] = 1 + block_qparam_shape_after_reduction[i] = 1 + original_shape = input.shape + input = input.view(input_shape_for_reduction) + quantized_block_scale = quantized_block_scale.view(block_qparam_shape_after_reduction) + quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction) - quantized_block_scale = quantized_block_scale.view(block_shape_after_reduction) + # step 2: second order quantization, recover unquantized block_scale and block_min super_block_size = (1, _GGUF_QK_K // block_size[-1], 1) - shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params( super_block_size, quantized_block_scale.size() ) - super_block_shape_after_reduction = shape_for_reduction.copy() + super_block_qparam_shape_after_reduction = super_block_input_shape_for_reduction.copy() for i in reduction_dims: - super_block_shape_after_reduction[i] = 1 + super_block_qparam_shape_after_reduction[i] = 1 + quantized_block_scale = quantized_block_scale.view(super_block_input_shape_for_reduction) + quantized_block_min = quantized_block_min.view(super_block_input_shape_for_reduction) super_block_scale_scale = super_block_scale_scale.view( - super_block_shape_after_reduction + super_block_qparam_shape_after_reduction ) super_block_min_scale = super_block_min_scale.view( - super_block_shape_after_reduction + super_block_qparam_shape_after_reduction ) - quantized_block_scale = quantized_block_scale.view(block_shape_after_reduction) - quantized_block_min = quantized_block_min.view(block_shape_after_reduction) - block_scale = super_block_scale_scale * quantized_block_scale block_min = super_block_min_scale * quantized_block_min + + # step 3: quantization with the unquantized block_scale and block_min + block_scale = block_scale.view(block_qparam_shape_after_reduction) + block_min = block_min.view(block_qparam_shape_after_reduction) int_data = (input - block_min) / block_scale int_data = int_data.view(original_shape) + return int_data @@ -1182,38 +1190,50 @@ def dequantize_gguf( super_block_min_scale: torch.Tensor, quantized_block_scale: torch.Tensor, quantized_block_min: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - shape_for_reduction, reduction_dims = _get_reduction_params( + # step 1. reshape input and quantized block scale and min to the shape + # after first quantization + input_shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input.size() ) - original_shape = input.shape - input = input.view(shape_for_reduction) - shape_after_reduction = shape_for_reduction + block_qparam_shape_after_reduction = input_shape_for_reduction.copy() for i in reduction_dims: - shape_after_reduction[i] = 1 + block_qparam_shape_after_reduction[i] = 1 - quantized_block_scale = quantized_block_scale.view(shape_after_reduction) - quantized_block_min = quantized_block_min.view(shape_after_reduction) + original_shape = input.shape + input = input.view(input_shape_for_reduction) + quantized_block_scale = quantized_block_scale.view(block_qparam_shape_after_reduction) + quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction) + # step 2. calculate and reshape block_qparams for second quantization step super_block_size = (1, _GGUF_QK_K // block_size[-1], 1) - super_block_shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params( super_block_size, quantized_block_scale.size() ) - super_block_shape_after_reduction = super_block_shape_for_reduction + super_block_qparam_shape_after_reduction = super_block_input_shape_for_reduction.copy() for i in reduction_dims: - super_block_shape_after_reduction[i] = 1 - + super_block_qparam_shape_after_reduction[i] = 1 + quantized_block_scale = quantized_block_scale.view(super_block_input_shape_for_reduction) + quantized_block_min = quantized_block_min.view(super_block_input_shape_for_reduction) super_block_scale_scale = super_block_scale_scale.view( - super_block_shape_after_reduction + super_block_qparam_shape_after_reduction ) super_block_min_scale = super_block_min_scale.view( - super_block_shape_after_reduction + super_block_qparam_shape_after_reduction ) block_scale = super_block_scale_scale * quantized_block_scale block_min = super_block_min_scale * quantized_block_min + + # step 3. dequantize with block_scale and block_min + block_scale = block_scale.view(block_qparam_shape_after_reduction) + block_min = block_min.view(block_qparam_shape_after_reduction) dequant = input * block_scale + block_min dequant = dequant.view(original_shape) + if output_dtype is not None: + dequant = dequant.to(output_dtype) + return dequant From 36432d32dc44c50975c10bfe60f61ca19ad8ca52 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 5 Apr 2025 22:29:52 -0700 Subject: [PATCH 4/7] pre-commit run --- torchao/prototype/quantization/gguf/__init__.py | 2 +- torchao/prototype/quantization/gguf/api.py | 12 +++++++++++- .../quantization/gguf/gguf_quantized_tensor.py | 4 +++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/quantization/gguf/__init__.py b/torchao/prototype/quantization/gguf/__init__.py index 6784779b1c..3e43e1f3dc 100644 --- a/torchao/prototype/quantization/gguf/__init__.py +++ b/torchao/prototype/quantization/gguf/__init__.py @@ -1,7 +1,7 @@ +from .api import GGUFWeightOnlyConfig from .gguf_quantized_tensor import ( GGUFQuantizedTensor, ) -from .api import GGUFWeightOnlyConfig __all__ = [ "GGUFQuantizedTensor", diff --git a/torchao/prototype/quantization/gguf/api.py b/torchao/prototype/quantization/gguf/api.py index d277db50cd..bc4b46992a 100644 --- a/torchao/prototype/quantization/gguf/api.py +++ b/torchao/prototype/quantization/gguf/api.py @@ -1,13 +1,23 @@ -import torch +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + from dataclasses import dataclass + +import torch + from torchao.core.config import AOBaseConfig from torchao.quantization.transform_module import register_quantize_module_handler + from .gguf_quantized_tensor import GGUFQuantizedTensor __all__ = [ "GGUFWeightOnlyConfig", ] + @dataclass class GGUFWeightOnlyConfig(AOBaseConfig): dtype: torch.dtype = torch.uint4 diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index b94c38027b..4d9e19992a 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -7,6 +7,7 @@ from typing import Optional import torch +from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.quant_primitives import ( choose_qparams_gguf, @@ -14,7 +15,6 @@ quantize_gguf, ) from torchao.utils import TorchAOBaseTensor -from torch.utils._python_dispatch import return_and_correct_aliasing _QK_K = 256 aten = torch.ops.aten @@ -205,6 +205,7 @@ def from_float(cls, input_float, n_blocks_per_superblock, target_dtype): implements = GGUFQuantizedTensor.implements + @implements([aten.detach.default, aten.alias.default]) def _(func, types, args, kwargs): return return_and_correct_aliasing( @@ -225,6 +226,7 @@ def _(func, types, args, kwargs): ) dtype = input_tensor.dtype + if hasattr(weight_tensor, "dequantize"): weight_tensor = weight_tensor.dequantize(output_dtype=dtype) From afff7122d70c50f52981d8c6cb45c89212c7103f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 7 Apr 2025 12:54:08 -0700 Subject: [PATCH 5/7] update --- test/prototype/test_gguf_quant.py | 4 +- torchao/core/config.py | 2 +- torchao/prototype/quantization/__init__.py | 5 +++ .../gguf/gguf_quantized_tensor.py | 38 ++++++++++++++++++- 4 files changed, 45 insertions(+), 4 deletions(-) diff --git a/test/prototype/test_gguf_quant.py b/test/prototype/test_gguf_quant.py index 1f11ef560d..e03aa7050c 100644 --- a/test/prototype/test_gguf_quant.py +++ b/test/prototype/test_gguf_quant.py @@ -5,10 +5,10 @@ from torchao.prototype.quantization.gguf import ( GGUFQuantizedTensor, GGUFWeightOnlyConfig, - choose_qparams_gguf, ) -from torchao.quantization import quantize_ +from torchao.quantization.quant_primitives import choose_qparams_gguf from torchao.quantization.utils import compute_error +from torchao.quantization import quantize_ class TestGGUFQuantization(unittest.TestCase): diff --git a/torchao/core/config.py b/torchao/core/config.py index 4a5a4c5720..920764ba25 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -171,7 +171,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: return json.loads(json.dumps(config, cls=ConfigJSONEncoder)) -ALLOWED_AO_MODULES = {"torchao.quantization", "torchao.sparsity.sparse_api"} +ALLOWED_AO_MODULES = {"torchao.quantization", "torchao.sparsity.sparse_api", "torchao.prototype.quantization"} def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: diff --git a/torchao/prototype/quantization/__init__.py b/torchao/prototype/quantization/__init__.py index e69de29bb2..bf49e2717b 100644 --- a/torchao/prototype/quantization/__init__.py +++ b/torchao/prototype/quantization/__init__.py @@ -0,0 +1,5 @@ +from .gguf import GGUFWeightOnlyConfig + +__all__ = [ + "GGUFWeightOnlyConfig", +] diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index 4d9e19992a..e8a6900ac3 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -14,7 +14,10 @@ dequantize_gguf, quantize_gguf, ) -from torchao.utils import TorchAOBaseTensor +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) _QK_K = 256 aten = torch.ops.aten @@ -138,6 +141,20 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor output_dtype=output_dtype, ) + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + self.n_blocks_per_superblock, + self.super_block_scale_scale.to(device), + self.super_block_min_scale.to(device), + self.quantized_block_scale.to(device), + self.quantized_block_min.to(device), + self.int_data.to(device), + self.shape, + **kwargs, + ) + def _apply_fn_to_data(self, fn): """ Returns a new `CodebookQuantizedTensor`. @@ -212,6 +229,21 @@ def _(func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) @implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): @@ -231,3 +263,7 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize(output_dtype=dtype) return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([GGUFQuantizedTensor]) From 63d8d5a7a53b9131b1119e0136b227673f9724e6 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 7 Apr 2025 21:13:55 -0700 Subject: [PATCH 6/7] run precommit --- test/prototype/test_gguf_quant.py | 8 +++++++- .../prototype/quantization/gguf/gguf_quantized_tensor.py | 3 +++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_gguf_quant.py b/test/prototype/test_gguf_quant.py index e03aa7050c..b68d84b101 100644 --- a/test/prototype/test_gguf_quant.py +++ b/test/prototype/test_gguf_quant.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + import unittest import torch @@ -6,9 +12,9 @@ GGUFQuantizedTensor, GGUFWeightOnlyConfig, ) +from torchao.quantization import quantize_ from torchao.quantization.quant_primitives import choose_qparams_gguf from torchao.quantization.utils import compute_error -from torchao.quantization import quantize_ class TestGGUFQuantization(unittest.TestCase): diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index e8a6900ac3..0bb7b9a623 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -229,6 +229,7 @@ def _(func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) + @implements(aten.clone.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( @@ -245,6 +246,7 @@ def _(func, types, args, kwargs): args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) + @implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -264,6 +266,7 @@ def _(func, types, args, kwargs): return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` torch.serialization.add_safe_globals([GGUFQuantizedTensor]) From d4bb04da10f52b6f1cfe8b1efe76573e7b6091bc Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 7 Apr 2025 22:07:49 -0700 Subject: [PATCH 7/7] format --- torchao/core/config.py | 6 ++++- torchao/quantization/quant_primitives.py | 33 +++++++++++++++++------- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/torchao/core/config.py b/torchao/core/config.py index 920764ba25..fe03ac225b 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -171,7 +171,11 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: return json.loads(json.dumps(config, cls=ConfigJSONEncoder)) -ALLOWED_AO_MODULES = {"torchao.quantization", "torchao.sparsity.sparse_api", "torchao.prototype.quantization"} +ALLOWED_AO_MODULES = { + "torchao.quantization", + "torchao.sparsity.sparse_api", + "torchao.prototype.quantization", +} def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 812c29f869..bc176c9d17 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1148,21 +1148,28 @@ def quantize_gguf( block_qparam_shape_after_reduction[i] = 1 original_shape = input.shape input = input.view(input_shape_for_reduction) - quantized_block_scale = quantized_block_scale.view(block_qparam_shape_after_reduction) + quantized_block_scale = quantized_block_scale.view( + block_qparam_shape_after_reduction + ) quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction) - # step 2: second order quantization, recover unquantized block_scale and block_min super_block_size = (1, _GGUF_QK_K // block_size[-1], 1) super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params( super_block_size, quantized_block_scale.size() ) - super_block_qparam_shape_after_reduction = super_block_input_shape_for_reduction.copy() + super_block_qparam_shape_after_reduction = ( + super_block_input_shape_for_reduction.copy() + ) for i in reduction_dims: super_block_qparam_shape_after_reduction[i] = 1 - quantized_block_scale = quantized_block_scale.view(super_block_input_shape_for_reduction) - quantized_block_min = quantized_block_min.view(super_block_input_shape_for_reduction) + quantized_block_scale = quantized_block_scale.view( + super_block_input_shape_for_reduction + ) + quantized_block_min = quantized_block_min.view( + super_block_input_shape_for_reduction + ) super_block_scale_scale = super_block_scale_scale.view( super_block_qparam_shape_after_reduction ) @@ -1203,7 +1210,9 @@ def dequantize_gguf( original_shape = input.shape input = input.view(input_shape_for_reduction) - quantized_block_scale = quantized_block_scale.view(block_qparam_shape_after_reduction) + quantized_block_scale = quantized_block_scale.view( + block_qparam_shape_after_reduction + ) quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction) # step 2. calculate and reshape block_qparams for second quantization step @@ -1211,11 +1220,17 @@ def dequantize_gguf( super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params( super_block_size, quantized_block_scale.size() ) - super_block_qparam_shape_after_reduction = super_block_input_shape_for_reduction.copy() + super_block_qparam_shape_after_reduction = ( + super_block_input_shape_for_reduction.copy() + ) for i in reduction_dims: super_block_qparam_shape_after_reduction[i] = 1 - quantized_block_scale = quantized_block_scale.view(super_block_input_shape_for_reduction) - quantized_block_min = quantized_block_min.view(super_block_input_shape_for_reduction) + quantized_block_scale = quantized_block_scale.view( + super_block_input_shape_for_reduction + ) + quantized_block_min = quantized_block_min.view( + super_block_input_shape_for_reduction + ) super_block_scale_scale = super_block_scale_scale.view( super_block_qparam_shape_after_reduction )