Skip to content

Commit 75d4b2e

Browse files
authored
Pass to replace Adaptive Avg. Pool with Aten Avg. Pool
Differential Revision: D74559775 Pull Request resolved: #10818
1 parent adb5318 commit 75d4b2e

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# pyre-unsafe
1818

19+
import logging
1920
import math
2021
import operator
2122
from operator import neg
@@ -2346,6 +2347,66 @@ def resolve_full_arg(self, x_arg, const_arg):
23462347
return const_arg
23472348

23482349

2350+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2351+
class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass):
2352+
"""
2353+
Replace the aten adaptive avg_pool op with the aten avg_pool2d op.
2354+
"""
2355+
2356+
def call_operator(self, op, args, kwargs, meta):
2357+
# Only continue for avg_pool op
2358+
if op not in {exir_ops.edge.aten._adaptive_avg_pool2d.default}:
2359+
return super().call_operator(op, args, kwargs, meta)
2360+
2361+
# Get the input tensor
2362+
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
2363+
# Permute NCHW to NHWC for computation
2364+
in_tensor_permuted = in_tensor.permute(0, 2, 3, 1)
2365+
in_tensor_shape = in_tensor_permuted.shape
2366+
2367+
output_size = args[1]
2368+
num_dims = len(output_size)
2369+
2370+
# TODO: If in_tensor_shape is not a multiple of output size,
2371+
# this pass will not work. T224984800
2372+
dim_multiples = [
2373+
(in_tensor_shape[i + 1] % output_size[i]) == 0 for i in range(num_dims)
2374+
]
2375+
if not all(dim_multiples):
2376+
logging.info(
2377+
f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_tensor_shape} is not a multiple of output size: {output_size}"
2378+
)
2379+
return super().call_operator(op, args, kwargs, meta)
2380+
2381+
# Compute stride and kernel_size, then set default values for other arguments
2382+
stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)]
2383+
kernel_size = [
2384+
in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i]
2385+
for i in range(num_dims)
2386+
]
2387+
padding = [0] * num_dims
2388+
ceil_mode = False
2389+
count_include_pad = True
2390+
divisor_override = None
2391+
2392+
# Create a new avg_pool node with the updated args
2393+
new_args = (
2394+
args[0],
2395+
kernel_size,
2396+
stride,
2397+
padding,
2398+
ceil_mode,
2399+
count_include_pad,
2400+
divisor_override,
2401+
)
2402+
return super().call_operator(
2403+
exir_ops.edge.aten.avg_pool2d.default,
2404+
new_args,
2405+
kwargs,
2406+
meta,
2407+
)
2408+
2409+
23492410
# This class encapsulates all the functions that replace/switch one op in the
23502411
# graph with another.
23512412
class CadenceReplaceOpsInGraph:
@@ -2382,6 +2443,7 @@ class CadenceReplaceOpsInGraph:
23822443
ReplacePT2QuantWithCadenceQuantPass,
23832444
ReplacePT2DequantWithCadenceDequantPass,
23842445
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
2446+
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
23852447
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
23862448
ReplaceWhereWithFullArgsWithWhereScalar,
23872449
ReplaceAtenApproxGeluWithApproxGeluPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.backends.cadence.aot.replace_ops import (
2020
ForceChannelLastForConvPass,
2121
MakeSliceAndCatDimOutermostPass,
22+
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
2223
ReplaceAddMMWithLinearPass,
2324
ReplaceAtenApproxGeluWithApproxGeluPass,
2425
ReplaceAtenConvolutionWithJarvisConvolutionPass,
@@ -1936,3 +1937,102 @@ def test_extract_mul_argument_to_full(
19361937
},
19371938
)
19381939
)
1940+
1941+
1942+
class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase):
1943+
def _get_adaptive_avg_pool_gm(
1944+
self, input_shape: Tuple[int, int, int, int], output_shape: Tuple[int, int]
1945+
) -> torch.fx.GraphModule:
1946+
builder = GraphBuilder()
1947+
x = builder.placeholder("x", torch.randn(*input_shape))
1948+
adaptive_avg_pool2d = builder.call_operator(
1949+
exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape)
1950+
)
1951+
builder.output([adaptive_avg_pool2d])
1952+
return builder.get_graph_module()
1953+
1954+
def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None:
1955+
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8))
1956+
self.assertEqual(
1957+
len(
1958+
gm.graph.find_nodes(
1959+
op="call_function",
1960+
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
1961+
)
1962+
),
1963+
1,
1964+
)
1965+
self.assertEqual(
1966+
len(
1967+
gm.graph.find_nodes(
1968+
op="call_function",
1969+
target=exir_ops.edge.aten.avg_pool2d.default,
1970+
)
1971+
),
1972+
0,
1973+
)
1974+
p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()
1975+
updated_gm = p.call(gm).graph_module
1976+
self.assertEqual(
1977+
len(
1978+
updated_gm.graph.find_nodes(
1979+
op="call_function",
1980+
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
1981+
)
1982+
),
1983+
0,
1984+
)
1985+
avg_pool2d_nodes = updated_gm.graph.find_nodes(
1986+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
1987+
)
1988+
self.assertEqual(
1989+
len(avg_pool2d_nodes),
1990+
1,
1991+
)
1992+
avg_pool2d_node = avg_pool2d_nodes[0]
1993+
1994+
self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16
1995+
self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16
1996+
self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0
1997+
self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False
1998+
self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True
1999+
self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None
2000+
2001+
def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None:
2002+
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9))
2003+
self.assertEqual(
2004+
len(
2005+
gm.graph.find_nodes(
2006+
op="call_function",
2007+
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
2008+
)
2009+
),
2010+
1,
2011+
)
2012+
self.assertEqual(
2013+
len(
2014+
gm.graph.find_nodes(
2015+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2016+
)
2017+
),
2018+
0,
2019+
)
2020+
# Shapes are not multiples of each other, so pass will not trigger
2021+
p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()
2022+
updated_gm = p.call(gm).graph_module
2023+
self.assertEqual(
2024+
len(
2025+
updated_gm.graph.find_nodes(
2026+
op="call_function",
2027+
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
2028+
)
2029+
),
2030+
1,
2031+
)
2032+
avg_pool2d_nodes = updated_gm.graph.find_nodes(
2033+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2034+
)
2035+
self.assertEqual(
2036+
len(avg_pool2d_nodes),
2037+
0,
2038+
)

0 commit comments

Comments
 (0)