diff --git a/backends/xnnpack/operators/op_dequantize_per_tensor.py b/backends/xnnpack/operators/op_dequantize_per_tensor.py index 604fc859eda..52c9ff1124b 100644 --- a/backends/xnnpack/operators/op_dequantize_per_tensor.py +++ b/backends/xnnpack/operators/op_dequantize_per_tensor.py @@ -11,6 +11,7 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( XNNConvert, XNNGraph, @@ -23,12 +24,7 @@ @register_node_visitor class OpDeQuantizePerTensor(NodeVisitor): """ - Dequantize Per Tensor Node visitor. We only insert an XNNPACK node if - this op was found as a graph input or graph output. This is so we - dequantize the input going in. Every other instance of quantize per - tensor is only used as signaling for q params of node inputs, so - we ignore those. This is because xnnpack only supports entire graph - quantization + Dequantize Per Tensor Node visitor """ target = "quantized_decomposed.dequantize_per_tensor.default" @@ -44,10 +40,9 @@ def define_node( debug_handle: int, ) -> None: """ - We only define a node if it is a graph output + We only define a node if it is not an implict dq node """ - # TODO:@maxren better handle in-graph quantization conversions, this is hacky - if self.is_graph_output(node): + if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node): dq_input = get_input_node(node, 0) input_quant_params = QuantParams.from_q_dq_node(node) # fp32 output diff --git a/backends/xnnpack/operators/op_quantize_per_tensor.py b/backends/xnnpack/operators/op_quantize_per_tensor.py index 90effb2afbc..82d391a0a98 100644 --- a/backends/xnnpack/operators/op_quantize_per_tensor.py +++ b/backends/xnnpack/operators/op_quantize_per_tensor.py @@ -11,6 +11,7 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( XNNConvert, XNNGraph, @@ -23,12 +24,7 @@ @register_node_visitor class OpQuantizePerTensor(NodeVisitor): """ - Quantize Per Tensor Node visitor. We only insert an XNNPACK node if - this op was found as a graph input or graph output. This is so we - quantize the input going in. Every other instance of quantize per - tensor is only used as signaling for q params of node inputs, so - we ignore those. This is because xnnpack only supports entire graph - quantization + Quantize Per Tensor Node visitor """ target = "quantized_decomposed.quantize_per_tensor.default" @@ -44,11 +40,10 @@ def define_node( debug_handle: int, ) -> None: """ - We only define a node if it is a graph input + We only define a node if it is not an implict q node """ - # TODO:@maxren better handle in-graph quantization conversions, this is hacky q_input = get_input_node(node, 0) - if self.is_graph_input(q_input): + if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node): input_quant_params = QuantParams.from_q_dq_node(node) # fp32 input self.define_tensor(q_input, xnn_graph, vals_to_ids) diff --git a/backends/xnnpack/partition/TARGETS b/backends/xnnpack/partition/TARGETS index cf954148dbc..105a4f7d6d3 100644 --- a/backends/xnnpack/partition/TARGETS +++ b/backends/xnnpack/partition/TARGETS @@ -25,6 +25,7 @@ runtime.python_library( "@EXECUTORCH_CLIENTS", ], deps = [ + ":configs", ":support_patterns", "//executorch/backends/xnnpack:xnnpack_preprocess", "//executorch/exir:delegate", @@ -34,3 +35,17 @@ runtime.python_library( "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", ], ) + +runtime.python_library( + name = "configs", + srcs = [ + "configs.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/exir:lib", + ], +) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py new file mode 100644 index 00000000000..a3b66d3fcb4 --- /dev/null +++ b/backends/xnnpack/partition/configs.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops + +### +### Module based partitioners +### + +SUPPORTED_OPS = [ + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.constant_pad_nd.default, + exir_ops.edge.aten.upsample_bilinear2d.default, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.max.dim, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.hardswish.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.pow.Tensor_Scalar, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten._prelu_kernel.default, + exir_ops.edge.aten.slice_copy.Tensor, +] + +SUPPORTED_MODULES = [ + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.ReLU, + torch.nn.Sigmoid, + torch.nn.Softmax, + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.Linear, + torch.nn.functional.linear, + torch.nn.Hardtanh, + torch.nn.MaxPool2d, + torch.nn.LeakyReLU, + torch.nn.ELU, + torch.nn.AvgPool2d, + torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr + torch.cat, + torch.concat, + torch.concatenate, +] + +# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support +SUPPORTED_QUANT_OPS = [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.hardtanh.default, # TODO - which one module or op or both? + exir_ops.edge.aten.slice_copy.Tensor, +] + +SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET = { + op.name() + for op in ( + SUPPORTED_QUANT_OPS + + [ + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.aten.max_pool2d.default, + exir_ops.edge.aten.linear.default, + ] + ) +} + +# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support +SUPPORTED_QUANT_MODULES = [ + torch.clamp, + torch.mean, + torch.permute, + torch.permute_copy, + torch.cat, + torch.concat, + torch.concatenate, + torch.nn.Linear, + torch.nn.functional.linear, + # TODO - T158982884 + # torch.ao.nn.quantized.reference.modules.linear.Linear, + torch.nn.MaxPool2d, + torch.nn.Conv1d, + torch.nn.functional.conv1d, + torch.ao.nn.quantized.reference.modules.conv.Conv1d, + torch.nn.Conv2d, + torch.nn.functional.conv2d, + torch.nn.functional.pad, + torch.nn.functional.elu, + torch.ao.nn.quantized.reference.modules.conv.Conv2d, + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.ConstantPad2d, + torch.nn.ELU, + torch.nn.Hardtanh, + torch.nn.ReLU, + torch.nn.functional.relu, + torch.nn.functional.relu_, + torch.nn.functional.leaky_relu, + torch.nn.functional.leaky_relu_, + torch.nn.LeakyReLU, +] + +SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET = set(SUPPORTED_QUANT_MODULES) + +# Modules which support dynamic quantization +SUPPORTED_DYN_QUANT_MODULES = [ + torch.nn.Linear, + torch.nn.functional.linear, +] diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index beb960472ec..59319358993 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -10,6 +10,14 @@ from typing import Any, Callable, cast, Dict, List, Optional, Union import torch + +from executorch.backends.xnnpack.partition.configs import ( + SUPPORTED_DYN_QUANT_MODULES, + SUPPORTED_MODULES, + SUPPORTED_OPS, + SUPPORTED_QUANT_MODULES, + SUPPORTED_QUANT_OPS, +) from executorch.backends.xnnpack.partition.support_patterns import ( get_add_graphs, get_all_dynamically_quantized_linear_pattern, @@ -522,107 +530,6 @@ def __init__(self): ) -### -### Module based partitioners -### - -SUPPORTED_OPS = [ - exir_ops.edge.aten.div.Tensor, - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.floor.default, - exir_ops.edge.aten.maximum.default, - exir_ops.edge.aten.minimum.default, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.constant_pad_nd.default, - exir_ops.edge.aten.upsample_bilinear2d.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.max.dim, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.sqrt.default, - exir_ops.edge.aten.ceil.default, - exir_ops.edge.aten.hardswish.default, - exir_ops.edge.aten.neg.default, - exir_ops.edge.aten.pow.Tensor_Scalar, - exir_ops.edge.aten.abs.default, - exir_ops.edge.aten._prelu_kernel.default, - exir_ops.edge.aten.slice_copy.Tensor, -] - -SUPPORTED_MODULES = [ - torch.nn.Conv1d, - torch.nn.Conv2d, - torch.nn.ReLU, - torch.nn.Sigmoid, - torch.nn.Softmax, - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.Linear, - torch.nn.functional.linear, - torch.nn.Hardtanh, - torch.nn.MaxPool2d, - torch.nn.LeakyReLU, - torch.nn.ELU, - torch.nn.AvgPool2d, - torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr - torch.cat, - torch.concat, - torch.concatenate, -] - -# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support -SUPPORTED_QUANT_OPS = [ - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.hardtanh.default, # TODO - which one module or op or both? - exir_ops.edge.aten.slice_copy.Tensor, -] - -# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support -SUPPORTED_QUANT_MODULES = [ - torch.clamp, - torch.mean, - torch.permute, - torch.permute_copy, - torch.cat, - torch.concat, - torch.concatenate, - torch.nn.Linear, - torch.nn.functional.linear, - # TODO - T158982884 - # torch.ao.nn.quantized.reference.modules.linear.Linear, - torch.nn.MaxPool2d, - torch.nn.Conv1d, - torch.nn.functional.conv1d, - torch.ao.nn.quantized.reference.modules.conv.Conv1d, - torch.nn.Conv2d, - torch.nn.functional.conv2d, - torch.nn.functional.pad, - torch.nn.functional.elu, - torch.ao.nn.quantized.reference.modules.conv.Conv2d, - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.ConstantPad2d, - torch.nn.ELU, - torch.nn.Hardtanh, - torch.nn.ReLU, - torch.nn.functional.relu, - torch.nn.functional.relu_, - torch.nn.functional.leaky_relu, - torch.nn.functional.leaky_relu_, - torch.nn.LeakyReLU, -] - -# Modules which support dynamic quantization -SUPPORTED_DYN_QUANT_MODULES = [ - torch.nn.Linear, - torch.nn.functional.linear, -] - - class XnnpackFloatingPointPartitioner(Partitioner): """ Module and Opname based partitioner for FP32 modules/ops listed in diff --git a/backends/xnnpack/passes/TARGETS b/backends/xnnpack/passes/TARGETS index 070211bdc15..bacf4761a28 100644 --- a/backends/xnnpack/passes/TARGETS +++ b/backends/xnnpack/passes/TARGETS @@ -10,10 +10,12 @@ python_library( "fuse_batch_norm_with_conv.py", "prelu_reshape_pass.py", "remove_getitem_op.py", + "tag_implicit_q_dq_pass.py", ], deps = [ "//caffe2:torch", "//executorch/backends/transforms:lib", + "//executorch/backends/xnnpack/partition:configs", "//executorch/backends/xnnpack/utils:xnnpack_utils", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", diff --git a/backends/xnnpack/passes/__init__.py b/backends/xnnpack/passes/__init__.py index e6c9b580d0a..eb5a2e5e58b 100644 --- a/backends/xnnpack/passes/__init__.py +++ b/backends/xnnpack/passes/__init__.py @@ -14,6 +14,7 @@ ) from executorch.backends.xnnpack.passes.prelu_reshape_pass import PReLUReshapePass from executorch.backends.xnnpack.passes.remove_getitem_op import RemoveGetItemPass +from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass from executorch.exir.passes import PassManager from executorch.exir.passes.const_prop_pass import ConstPropPass @@ -27,5 +28,6 @@ Conv1dUnsqueezePass(), PReLUReshapePass(), ChannelsLastTaggedReshapePass(), + TagImplicitQDqPass(), ] ) diff --git a/backends/xnnpack/passes/tag_implicit_q_dq_pass.py b/backends/xnnpack/passes/tag_implicit_q_dq_pass.py new file mode 100644 index 00000000000..c333b1fbc42 --- /dev/null +++ b/backends/xnnpack/passes/tag_implicit_q_dq_pass.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, List, Optional + +import torch +from executorch.backends.xnnpack.partition.configs import ( + SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET, + SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET, +) +from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class TagImplicitQDqPass(ExportPass): + """ + This pass is used to tag "implicit" q/dq nodes, which should be ignored + during preprocessing. + + A q or dq node is deemed to be "implicit" if any of the following hold: + a) All of its inputs are constants (get_attr nodes), since (de)quantizing + constants is done outside of executing the graph + b) It is the q or dq surrounding a "supported" group of nodes, ordered as + dq -> [supported group] -> q. A "supported" group is comprised of one of + the following: + ( i) A single supported op, from SUPPORTED_QUANT_OPS_SET, + ( ii) A single supported module, from SUPPORTED_QUANT_MODULES_SET, or + (iii) a chain of nodes matching a supported chain from + SUPPORTED_QUANT_CHAINS. + q/dq nodes which match this condition should be + ignore during preprocessing because they are only used as signaling for q + params of node inputs + c) It is a dq followed by aten.linear.default and then an output node. This + is because aten.linear.default is a special op corresponding with + dqlinear which doesn't necessarily have an q after it + """ + + _END_OF_CHAIN_MARKER = "END_OF_CHAIN" + # TODO: @salilsdesai Avoid hardcoding quant module chains here (instead get from quantizer) + SUPPORTED_QUANT_CHAINS = { + exir_ops.edge.aten.add.Tensor.name(): { + exir_ops.edge.aten.relu.default.name(): { + _END_OF_CHAIN_MARKER: True, + } + }, + exir_ops.edge.aten.convolution.default.name(): { + exir_ops.edge.aten.relu.default.name(): { + _END_OF_CHAIN_MARKER: True, + } + }, + exir_ops.edge.aten.mul.Tensor.name(): { + exir_ops.edge.aten.relu.default.name(): { + _END_OF_CHAIN_MARKER: True, + } + }, + exir_ops.edge.aten.sub.Tensor.name(): { + exir_ops.edge.aten.relu.default.name(): { + _END_OF_CHAIN_MARKER: True, + } + }, + } + IS_IMPLICIT_Q_DQ_TAG = "IS_IMPLICIT_Q_DQ_TAG" + + def is_param_node(self, node: torch.fx.Node) -> bool: + return node.op == "get_attr" + + def is_output_node(self, node: torch.fx.Node) -> bool: + return node.op == "output" + + def is_dynamically_quantized(self, node: torch.fx.Node) -> bool: + return any( + is_dequant(input_node) + and cast( + torch._ops.OpOverload, input_node.target + )._schema.schema.overload_name + == "tensor" + for input_node in node.all_input_nodes + ) + + def is_supported_quant_op(self, node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and cast(torch._ops.OpOverload, node.target).name() + in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET + ) + + def is_supported_quant_module(self, node: torch.fx.Node) -> bool: + is_supported = ( + "source_fn" in node.meta + and node.meta["source_fn"][1] in SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET + ) + if is_supported and self.is_supported_quant_op(node): + raise RuntimeError( + f"The same node should not be both a supported quant op and supported quant module: {node}" + ) + return is_supported + + def tag_as_implicit_q_dq(self, node: torch.fx.Node) -> None: + node.meta[TagImplicitQDqPass.IS_IMPLICIT_Q_DQ_TAG] = True + + @staticmethod + def is_tagged_as_implicit_q_dq(node: torch.fx.Node) -> bool: + return node.meta.get(TagImplicitQDqPass.IS_IMPLICIT_Q_DQ_TAG, False) + + def get_ending_implicit_q_nodes( + self, start_node: torch.fx.Node + ) -> Optional[List[torch.fx.Node]]: + """ + Returns a list of implicit q nodes which end the potential "supported" + group of nodes starting with start_node (which came after a dq), or None + if no such "supported" group exists. This list will either contain + one or zero elements. + """ + # If the node after the dq has multiple users then the dq can't be + # implicit + if len(start_node.users) != 1: + return None + + next_node = list(start_node.users)[0] + + if is_quant(next_node): + # Check if second_node (which is between dq and q nodes) is in + # supported quant ops or modules set + if self.is_supported_quant_op(start_node) or self.is_supported_quant_module( + start_node + ): + return [next_node] + elif self.is_output_node(next_node): + # Check if second_node (which is between dq and output nodes) + # is aten.linear.default + if self.is_dynamically_quantized(start_node): + return [] + else: + # Check if nodes between the dq node and the next q match + # a supported quant chain + available_chains = TagImplicitQDqPass.SUPPORTED_QUANT_CHAINS + current_node = start_node + while ( + # Not yet at end of chain in graph + not is_quant(current_node) + # Right number of users to continue chain + and len(current_node.users) == 1 + # Can continue following an available chain + and ( + current_node.op == "call_function" + and cast(torch._ops.OpOverload, current_node.target).name() + in available_chains + ) + ): + available_chains = available_chains[ + cast(torch._ops.OpOverload, current_node.target).name() + ] + current_node = list(current_node.users)[0] + + if ( + is_quant(current_node) + and TagImplicitQDqPass._END_OF_CHAIN_MARKER in available_chains + ): + # The chain of nodes between the dq and q nodes matches + # a supported quant chain + return [current_node] + + return None + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for first_node in graph_module.graph.nodes: + if (is_dequant(first_node) or is_quant(first_node)) and all( + self.is_param_node(n) for n in first_node.all_input_nodes + ): + # All of the q or dq node's inputs are constants + self.tag_as_implicit_q_dq(first_node) + continue + + if not is_dequant(first_node): + continue + + if len(first_node.users) == 0: + continue + + ending_implicit_q_nodes = [] + for user in first_node.users: + user_end_nodes = self.get_ending_implicit_q_nodes(user) + if user_end_nodes is None: + # This user isn't part of a "supported" group + ending_implicit_q_nodes = None + break + ending_implicit_q_nodes.extend(user_end_nodes) + + if ending_implicit_q_nodes is None: + # There was a user which isn't part of a "supported" group + # Don't tag anything as implicit for this iteration + continue + + self.tag_as_implicit_q_dq(first_node) + for node in ending_implicit_q_nodes: + self.tag_as_implicit_q_dq(node) + + # Since we are overriding "call", we need to call the parent's "call" + # to retrace the graph and regenerate metadata + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index 42ca210eeef..3c9647ef5ff 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -102,6 +102,7 @@ python_unittest( "//executorch/backends/xnnpack/utils:xnnpack_utils", "//executorch/exir:lib", "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", + "//executorch/exir/dialects:lib", ], ) diff --git a/backends/xnnpack/test/test_xnnpack_passes.py b/backends/xnnpack/test/test_xnnpack_passes.py index abe349bedb4..1ee81636e55 100644 --- a/backends/xnnpack/test/test_xnnpack_passes.py +++ b/backends/xnnpack/test/test_xnnpack_passes.py @@ -17,6 +17,7 @@ FuseBatchNormWithConvPass, ) from executorch.backends.xnnpack.passes.remove_getitem_op import RemoveGetItemPass +from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass from executorch.backends.xnnpack.test.test_xnnpack_utils_classes import ( OpSequencesAddConv2d, @@ -26,6 +27,7 @@ from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import ( DuplicateDequantNodePass, ) +from executorch.exir.dialects._ops import ops as exir_ops from torch.ao.quantization.backend_config.executorch import ( get_executorch_backend_config, ) @@ -446,3 +448,60 @@ def test_convert_to_linear(self): 1, expected_node="executorch_exir_dialects_edge__ops_aten_linear_default", ) + + def test_tag_implicit_q_dq_pass(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + x, 0.12345, 0, -127, 127, torch.int8 + ) + x = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( + x, 0.12345, 0, -127, 127, torch.int8 + ) + x = torch.add(x, x) + x = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + x, 0.12345, 0, -127, 127, torch.int8 + ) + x = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( + x, 0.12345, 0, -127, 127, torch.int8 + ) + x = torch.mul(x, x) + x = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + x, 0.12345, 0, -127, 127, torch.int8 + ) + x = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( + x, 0.12345, 0, -127, 127, torch.int8 + ) + x = torch.add(x, x) + x = torch.mul(x, x) + return x + + test_model = TestModule() + test_model.eval() + + sample_inputs = (torch.randn(2, 3),) + + edge_program = capture_graph_for_xnnpack(test_model, sample_inputs) + + tag_pass = TagImplicitQDqPass() + tagged_graph = tag_pass( + edge_program.exported_program.graph_module + ).graph_module.graph + + # The six tagged nodes are: + # 1) The dq of the first add input + # 2) The dq of the second add input + # 3) The q of the add output + # 4) The dq of the first mul input + # 5) The dq of the second mul input + # 6) The q of the mul output + self.assertEqual( + sum( + node.meta.get(TagImplicitQDqPass.IS_IMPLICIT_Q_DQ_TAG, False) + for node in tagged_graph.nodes + ), + 6, + )