|
25 | 25 | from executorch.backends.cadence.aot.replace_ops import ( |
26 | 26 | ForceChannelLastForConvPass, |
27 | 27 | MakeSliceAndCatDimOutermostPass, |
| 28 | + ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, |
28 | 29 | ReplaceAddMMWithLinearPass, |
29 | 30 | ReplaceAtenConvolutionWithJarvisConvolutionPass, |
30 | 31 | ReplaceConstantPadNdWithSlicePass, |
@@ -1939,3 +1940,100 @@ def test_empty_slice(self): |
1939 | 1940 | ), |
1940 | 1941 | 1, |
1941 | 1942 | ) |
| 1943 | + |
| 1944 | + |
| 1945 | +class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase): |
| 1946 | + def _get_adaptive_avg_pool_gm( |
| 1947 | + self, input_shape: Tuple[int], output_shape: Tuple[int] |
| 1948 | + ) -> torch.fx.GraphModule: |
| 1949 | + builder = GraphBuilder() |
| 1950 | + x = builder.placeholder("x", torch.randn(*input_shape)) |
| 1951 | + adaptive_avg_pool2d = builder.call_operator( |
| 1952 | + exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape) |
| 1953 | + ) |
| 1954 | + builder.output([adaptive_avg_pool2d]) |
| 1955 | + return builder.get_graph_module() |
| 1956 | + |
| 1957 | + def test_replace_adaptive_avg_pool_with_aten_avg_pool(self): |
| 1958 | + gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) |
| 1959 | + self.assertEqual( |
| 1960 | + len( |
| 1961 | + gm.graph.find_nodes( |
| 1962 | + op="call_function", |
| 1963 | + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, |
| 1964 | + ) |
| 1965 | + ), |
| 1966 | + 1, |
| 1967 | + ) |
| 1968 | + self.assertEqual( |
| 1969 | + len( |
| 1970 | + gm.graph.find_nodes( |
| 1971 | + op="call_function", |
| 1972 | + target=exir_ops.edge.aten.avg_pool2d.default, |
| 1973 | + ) |
| 1974 | + ), |
| 1975 | + 0, |
| 1976 | + ) |
| 1977 | + updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module |
| 1978 | + self.assertEqual( |
| 1979 | + len( |
| 1980 | + updated_gm.graph.find_nodes( |
| 1981 | + op="call_function", |
| 1982 | + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, |
| 1983 | + ) |
| 1984 | + ), |
| 1985 | + 0, |
| 1986 | + ) |
| 1987 | + avg_pool2d_nodes = updated_gm.graph.find_nodes( |
| 1988 | + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default |
| 1989 | + ) |
| 1990 | + self.assertEqual( |
| 1991 | + len(avg_pool2d_nodes), |
| 1992 | + 1, |
| 1993 | + ) |
| 1994 | + avg_pool2d_node = avg_pool2d_nodes[0] |
| 1995 | + |
| 1996 | + self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16 |
| 1997 | + self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16 |
| 1998 | + self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0 |
| 1999 | + self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False |
| 2000 | + self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True |
| 2001 | + self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None |
| 2002 | + |
| 2003 | + def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self): |
| 2004 | + gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) |
| 2005 | + self.assertEqual( |
| 2006 | + len( |
| 2007 | + gm.graph.find_nodes( |
| 2008 | + op="call_function", |
| 2009 | + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, |
| 2010 | + ) |
| 2011 | + ), |
| 2012 | + 1, |
| 2013 | + ) |
| 2014 | + self.assertEqual( |
| 2015 | + len( |
| 2016 | + gm.graph.find_nodes( |
| 2017 | + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default |
| 2018 | + ) |
| 2019 | + ), |
| 2020 | + 0, |
| 2021 | + ) |
| 2022 | + # Shapes are not multiples of each other, so pass will not trigger |
| 2023 | + updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module |
| 2024 | + self.assertEqual( |
| 2025 | + len( |
| 2026 | + updated_gm.graph.find_nodes( |
| 2027 | + op="call_function", |
| 2028 | + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, |
| 2029 | + ) |
| 2030 | + ), |
| 2031 | + 1, |
| 2032 | + ) |
| 2033 | + avg_pool2d_nodes = updated_gm.graph.find_nodes( |
| 2034 | + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default |
| 2035 | + ) |
| 2036 | + self.assertEqual( |
| 2037 | + len(avg_pool2d_nodes), |
| 2038 | + 0, |
| 2039 | + ) |
0 commit comments