diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index d7710bc1989..0ce11b620a6 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -11,7 +11,7 @@ # Utility functions for TOSAQuantizer # -from typing import cast +from typing import cast, Sequence import torch from torch._subclasses import FakeTensor @@ -76,9 +76,32 @@ def is_large_scalar(node: Node, gm: GraphModule): def is_non_float_tensor(node: Node) -> bool: - """Check if the input is not a float tensor, so that we can skip quantization for the node - since observers only works with float Tensors + """Check if the output of a node has a data type other than `torch.float32`. + + If the output is not `torch.float32`, quantization cannot be performed, as + observers only work with floating-point tensors. + + Args: + node (Node): The node to check the output(s) for. + + Returns: + bool: `True` if the data type is not float32, otherwise `False`. + + Note: + - If `node.meta["val"]` is a `list`, the function returns `True` if **any** + element is **not** an instance of `FakeTensor` or does **not** have + `torch.float32` as its data type. + - If node.meta["val"] is missing or is not an instance of `FakeTensor`, the + function returns True. """ + if "val" in node.meta and isinstance(node.meta["val"], Sequence): + return any( + not isinstance(fake_tensor, FakeTensor) + or fake_tensor.dtype != torch.float32 + for fake_tensor in node.meta["val"] + ) + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): return True + return node.meta["val"].dtype != torch.float32 diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 1c6d05f2557..e9ed6be81f3 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import operator from dataclasses import dataclass from typing import Callable, List, Optional @@ -11,6 +12,7 @@ import torch.fx from executorch.backends.arm.quantizer import arm_quantizer_utils from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from executorch.backends.arm.tosa_utils import get_node_debug_info from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec from torch.ao.quantization.quantizer.utils import ( _annotate_input_qspec_map, @@ -18,6 +20,8 @@ ) from torch.fx import Node +logger = logging.getLogger(__name__) + @dataclass(frozen=True) class _QuantProperty: @@ -45,19 +49,52 @@ def _as_list(x): def _is_ok_for_quantization( - node: Node, quant_property: _QuantProperty, gm: torch.fx.GraphModule + node: Node, quant_properties: _OpQuantProperties, gm: torch.fx.GraphModule ) -> bool: - if quant_property.optional and ( - quant_property.index >= len(node.args) - or node.args[quant_property.index] is None - ): - return True + """Check if a node can be quantized. + + A node can be quantized if: + - All inputs that are required for quantization are of type `float32` + and are not large scalar values. + - The output of the node itself is of type `float32` and is not a large scalar. + + Args: + node (Node): The node being analyzed. + quant_properties (_OpQuantProperties): Contains quantization properties for + the node, including input and output quantization specifications. + gm (torch.fx.GraphModule): The graph module containing the computational graph. + + Returns: + bool: `True` if the node can be quantized, otherwise `False`. + """ + # Check output + if quant_properties.quant_output is not None: + if not arm_quantizer_utils.is_ok_for_quantization(node, gm): # type: ignore[attr-defined] + logger.debug( + f"Could not quantize node due to output: " + f"{get_node_debug_info(node, gm)}" + ) - for n_arg in _as_list(node.args[quant_property.index]): - assert isinstance(n_arg, Node) - if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined] return False + # Check inputs + for quant_property in quant_properties.quant_inputs: + if quant_property.optional and ( + quant_property.index >= len(node.args) + or node.args[quant_property.index] is None + ): + continue + + for n_arg in _as_list(node.args[quant_property.index]): + assert isinstance(n_arg, Node) + if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined] + logger.debug( + f'could not quantize node due to input "{node}": ' + f"{get_node_debug_info(node, gm)}" + ) + + return False + return True @@ -355,14 +392,9 @@ def any_or_hardtanh_min_zero(n: Node): return quant_properties # Check that each inputs/outputs can be quantized properly with the - # provided QuantProperties - for quant_property in quant_properties.quant_inputs: - if not _is_ok_for_quantization(node, quant_property, gm): - return None # type: ignore[return-value] - - if quant_properties.quant_output is not None: - if not _is_ok_for_quantization(node, quant_properties.quant_output, gm): - return None # type: ignore[return-value] + # provided quantization properties. + if not _is_ok_for_quantization(node, quant_properties, gm): + return None # type: ignore[return-value] return quant_properties