Skip to content

Commit 1164d4e

Browse files
Eashan Gargfacebook-github-bot
authored andcommitted
Pass to replace Adaptive Avg. Pool with Aten Avg. Pool
Summary: Seeing exir_ops.edge.aten._adaptive_avg_pool2d.default nodes in some graphs, pass to replace these with exir_ops.edge.aten.avg_pool2d.default Differential Revision: D74559775
1 parent adde519 commit 1164d4e

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,6 +2401,54 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24012401
return result
24022402

24032403

2404+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2405+
class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass):
2406+
"""
2407+
Replace the aten adaptive avg_pool op with the aten avg_pool2d op.
2408+
"""
2409+
2410+
def call_operator(self, op, args, kwargs, meta):
2411+
# Only continue for avg_pool op
2412+
if op not in {
2413+
exir_ops.edge.aten._adaptive_avg_pool2d.default
2414+
}:
2415+
return super().call_operator(op, args, kwargs, meta)
2416+
2417+
# Get the input tensor
2418+
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
2419+
# Permute NCHW to NHWC for computation
2420+
in_tensor_permuted = in_tensor.permute(0, 2, 3, 1)
2421+
in_tensor_shape = in_tensor_permuted.shape
2422+
2423+
output_size = args[1]
2424+
num_dims = len(output_size)
2425+
2426+
# Compute stride and kernel_size, then set default values for other arguments
2427+
stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)]
2428+
kernel_size = [in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i] for i in range(num_dims)]
2429+
padding = [0] * num_dims
2430+
ceil_mode = False
2431+
count_include_pad = True
2432+
divisor_override = None
2433+
2434+
# Create a new avg_pool node with the updated args
2435+
new_args = (
2436+
args[0],
2437+
kernel_size,
2438+
stride,
2439+
padding,
2440+
ceil_mode,
2441+
count_include_pad,
2442+
divisor_override,
2443+
)
2444+
return super().call_operator(
2445+
exir_ops.edge.aten.avg_pool2d.default,
2446+
new_args,
2447+
kwargs,
2448+
meta,
2449+
)
2450+
2451+
24042452
# This class encapsulates all the functions that replace/switch one op in the
24052453
# graph with another.
24062454
class CadenceReplaceOpsInGraph:
@@ -2438,6 +2486,7 @@ class CadenceReplaceOpsInGraph:
24382486
ReplacePT2QuantWithCadenceQuantPass,
24392487
ReplacePT2DequantWithCadenceDequantPass,
24402488
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
2489+
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
24412490
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
24422491
ReplaceWhereWithFullArgsWithWhereScalar,
24432492
ReplaceGeluWithApproximateGeluPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ForceChannelLastForConvPass,
2727
MakeSliceAndCatDimOutermostPass,
2828
ReplaceAddMMWithLinearPass,
29+
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
2930
ReplaceAtenConvolutionWithJarvisConvolutionPass,
3031
ReplaceConstantPadNdWithSlicePass,
3132
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
@@ -1971,3 +1972,99 @@ def test_empty_slice(self):
19711972
),
19721973
1,
19731974
)
1975+
1976+
class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase):
1977+
def _get_adaptive_avg_pool_gm(self, input_shape: Tuple[int], output_shape: Tuple[int]) -> torch.fx.GraphModule:
1978+
builder = GraphBuilder()
1979+
x = builder.placeholder("x", torch.randn(*input_shape))
1980+
adaptive_avg_pool2d = builder.call_operator(
1981+
exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape)
1982+
)
1983+
builder.output([adaptive_avg_pool2d])
1984+
return builder.get_graph_module()
1985+
1986+
def test_replace_adaptive_avg_pool_with_aten_avg_pool(self):
1987+
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8))
1988+
self.assertEqual(
1989+
len(
1990+
gm.graph.find_nodes(
1991+
op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default
1992+
)
1993+
),
1994+
1,
1995+
)
1996+
self.assertEqual(
1997+
len(
1998+
gm.graph.find_nodes(
1999+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2000+
)
2001+
),
2002+
0,
2003+
)
2004+
updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module
2005+
self.assertEqual(
2006+
len(
2007+
updated_gm.graph.find_nodes(
2008+
op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default
2009+
)
2010+
),
2011+
0,
2012+
)
2013+
avg_pool2d_nodes = updated_gm.graph.find_nodes(
2014+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2015+
)
2016+
self.assertEqual(
2017+
len(avg_pool2d_nodes),
2018+
1,
2019+
)
2020+
avg_pool2d_node = avg_pool2d_nodes[0]
2021+
2022+
self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16
2023+
self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16
2024+
self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0
2025+
self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False
2026+
self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True
2027+
self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None
2028+
2029+
def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self):
2030+
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9))
2031+
self.assertEqual(
2032+
len(
2033+
gm.graph.find_nodes(
2034+
op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default
2035+
)
2036+
),
2037+
1,
2038+
)
2039+
self.assertEqual(
2040+
len(
2041+
gm.graph.find_nodes(
2042+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2043+
)
2044+
),
2045+
0,
2046+
)
2047+
updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module
2048+
self.assertEqual(
2049+
len(
2050+
updated_gm.graph.find_nodes(
2051+
op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default
2052+
)
2053+
),
2054+
0,
2055+
)
2056+
avg_pool2d_nodes = updated_gm.graph.find_nodes(
2057+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2058+
)
2059+
self.assertEqual(
2060+
len(avg_pool2d_nodes),
2061+
1,
2062+
)
2063+
avg_pool2d_node = avg_pool2d_nodes[0]
2064+
2065+
self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16
2066+
self.assertEqual(avg_pool2d_node.args[2], [14, 14]) # stride is 14, 14
2067+
self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0
2068+
self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False
2069+
self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True
2070+
self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None

0 commit comments

Comments
 (0)