|
19 | 19 | from executorch.backends.cadence.aot.replace_ops import (
|
20 | 20 | ForceChannelLastForConvPass,
|
21 | 21 | MakeSliceAndCatDimOutermostPass,
|
| 22 | + ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, |
22 | 23 | ReplaceAddMMWithLinearPass,
|
23 | 24 | ReplaceAtenApproxGeluWithApproxGeluPass,
|
24 | 25 | ReplaceAtenConvolutionWithJarvisConvolutionPass,
|
@@ -1936,3 +1937,102 @@ def test_extract_mul_argument_to_full(
|
1936 | 1937 | },
|
1937 | 1938 | )
|
1938 | 1939 | )
|
| 1940 | + |
| 1941 | + |
| 1942 | +class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase): |
| 1943 | + def _get_adaptive_avg_pool_gm( |
| 1944 | + self, input_shape: Tuple[int, int, int, int], output_shape: Tuple[int, int] |
| 1945 | + ) -> torch.fx.GraphModule: |
| 1946 | + builder = GraphBuilder() |
| 1947 | + x = builder.placeholder("x", torch.randn(*input_shape)) |
| 1948 | + adaptive_avg_pool2d = builder.call_operator( |
| 1949 | + exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape) |
| 1950 | + ) |
| 1951 | + builder.output([adaptive_avg_pool2d]) |
| 1952 | + return builder.get_graph_module() |
| 1953 | + |
| 1954 | + def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: |
| 1955 | + gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) |
| 1956 | + self.assertEqual( |
| 1957 | + len( |
| 1958 | + gm.graph.find_nodes( |
| 1959 | + op="call_function", |
| 1960 | + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, |
| 1961 | + ) |
| 1962 | + ), |
| 1963 | + 1, |
| 1964 | + ) |
| 1965 | + self.assertEqual( |
| 1966 | + len( |
| 1967 | + gm.graph.find_nodes( |
| 1968 | + op="call_function", |
| 1969 | + target=exir_ops.edge.aten.avg_pool2d.default, |
| 1970 | + ) |
| 1971 | + ), |
| 1972 | + 0, |
| 1973 | + ) |
| 1974 | + p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() |
| 1975 | + updated_gm = p.call(gm).graph_module |
| 1976 | + self.assertEqual( |
| 1977 | + len( |
| 1978 | + updated_gm.graph.find_nodes( |
| 1979 | + op="call_function", |
| 1980 | + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, |
| 1981 | + ) |
| 1982 | + ), |
| 1983 | + 0, |
| 1984 | + ) |
| 1985 | + avg_pool2d_nodes = updated_gm.graph.find_nodes( |
| 1986 | + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default |
| 1987 | + ) |
| 1988 | + self.assertEqual( |
| 1989 | + len(avg_pool2d_nodes), |
| 1990 | + 1, |
| 1991 | + ) |
| 1992 | + avg_pool2d_node = avg_pool2d_nodes[0] |
| 1993 | + |
| 1994 | + self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16 |
| 1995 | + self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16 |
| 1996 | + self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0 |
| 1997 | + self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False |
| 1998 | + self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True |
| 1999 | + self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None |
| 2000 | + |
| 2001 | + def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: |
| 2002 | + gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) |
| 2003 | + self.assertEqual( |
| 2004 | + len( |
| 2005 | + gm.graph.find_nodes( |
| 2006 | + op="call_function", |
| 2007 | + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, |
| 2008 | + ) |
| 2009 | + ), |
| 2010 | + 1, |
| 2011 | + ) |
| 2012 | + self.assertEqual( |
| 2013 | + len( |
| 2014 | + gm.graph.find_nodes( |
| 2015 | + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default |
| 2016 | + ) |
| 2017 | + ), |
| 2018 | + 0, |
| 2019 | + ) |
| 2020 | + # Shapes are not multiples of each other, so pass will not trigger |
| 2021 | + p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() |
| 2022 | + updated_gm = p.call(gm).graph_module |
| 2023 | + self.assertEqual( |
| 2024 | + len( |
| 2025 | + updated_gm.graph.find_nodes( |
| 2026 | + op="call_function", |
| 2027 | + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, |
| 2028 | + ) |
| 2029 | + ), |
| 2030 | + 1, |
| 2031 | + ) |
| 2032 | + avg_pool2d_nodes = updated_gm.graph.find_nodes( |
| 2033 | + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default |
| 2034 | + ) |
| 2035 | + self.assertEqual( |
| 2036 | + len(avg_pool2d_nodes), |
| 2037 | + 0, |
| 2038 | + ) |
0 commit comments