From eb448990f59299b078c5fd4d278b6e3c6afeaa25 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 7 Apr 2025 10:34:51 -0700 Subject: [PATCH] per_channel_group can't be dynamic Summary: There are some dynamism issues that arise when checking the semantics of quantize_affine nodes. We avoid them by accounting for free_symbols. Differential Revision: D72488540 --- backends/xnnpack/utils/quant_utils.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/backends/xnnpack/utils/quant_utils.py b/backends/xnnpack/utils/quant_utils.py index db1914e3910..cb91b78c123 100644 --- a/backends/xnnpack/utils/quant_utils.py +++ b/backends/xnnpack/utils/quant_utils.py @@ -12,6 +12,7 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) +from torch.fx.experimental.symbolic_shapes import free_symbols, has_free_symbols _Q_OPS = { "quantize_per_tensor.tensor", @@ -126,8 +127,8 @@ def is_affine_qdq(node: torch.fx.Node) -> bool: def _get_block_size_input_scale(node: torch.fx.Node): assert is_affine_qdq(node) block_size = node.args[1] - input_val = node.all_input_nodes[0].meta["val"] - scale_val = node.all_input_nodes[1].meta["val"] + input_val = cast(torch.fx.Node, node.args[0]).meta["val"] + scale_val = cast(torch.fx.Node, node.args[2]).meta["val"] return block_size, input_val, scale_val @@ -145,7 +146,21 @@ def is_per_token(node: torch.fx.Node): flag &= block_size[i] == 1 scale_numel_expected *= input_val.shape[i] - flag &= block_size[-1] == input_val.shape[-1] + ic_block_size = block_size[-1] + if isinstance(ic_block_size, torch.fx.Node): + ic_block_size = ic_block_size.meta["val"] + assert free_symbols( + ic_block_size + ), f"block_size: {block_size} given, but {block_size[-1]} is not a dynamic symint" + + ic_dim = input_val.shape[-1] + if isinstance(ic_dim, torch.fx.Node): + ic_dim = ic_dim.meta["val"] + assert free_symbols( + ic_dim + ), f"input_shape: {input_val.shape} given, but {input_val.shape[-1]} is not a dynamic symint" + + flag &= ic_dim == ic_block_size flag &= scale_val.numel() == scale_numel_expected return flag @@ -160,6 +175,11 @@ def is_per_channel_group(node: torch.fx.Node): return True elif is_affine_qdq(node): block_size, input_val, scale_val = _get_block_size_input_scale(node) + # per channel group is only valid on static weights + # so scales and weights can't have dynamic shape + if has_free_symbols(input_val.shape) or has_free_symbols(scale_val.shape): + return False + flag = True flag &= len(block_size) == 2 flag &= block_size[0] == 1