Skip to content

Commit e866010

Browse files
committed
[Rewriter]: introduce fuse_pad_into_conv (#2301)
1 parent b7a7e14 commit e866010

File tree

3 files changed

+357
-0
lines changed

3 files changed

+357
-0
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
collapse_slices,
2222
no_op,
2323
pattern,
24+
fuse_pad_into_conv,
2425
)
2526

2627
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
@@ -30,6 +31,7 @@
3031
*cast_constant_of_shape.rules.rules,
3132
*collapse_slices.rules.rules,
3233
*basic_rules.basic_optimization_rules().rules,
34+
*fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules,
3335
)
3436

3537

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Fuses Pad nodes into preceding nodes. Supported fusion patterns:
4+
- Pad ∘ Conv -> Conv
5+
"""
6+
7+
import typing
8+
9+
import numpy as np
10+
import onnx_ir as ir
11+
12+
from onnxscript.rewriter import pattern as orp
13+
14+
15+
def fill_pads_with_axes(
16+
pads: typing.Sequence[int], axes: typing.Sequence[int], rank: int
17+
) -> typing.List[int]:
18+
new_pads = []
19+
for axis in range(rank):
20+
if axis not in axes:
21+
start_value = end_value = 0
22+
else:
23+
start_value = pads[axes.index(axis)]
24+
end_value = pads[axes.index(axis) + len(axes)]
25+
pad_len = len(new_pads) // 2
26+
new_pads.insert(pad_len + axis, end_value)
27+
new_pads.insert(axis, start_value)
28+
return new_pads
29+
30+
31+
class _FusePadConvBase(orp.RewriteRuleClassBase):
32+
"""Interface for PadConv nodes fusion."""
33+
34+
def __init__(self, name: str, as_function: bool = False):
35+
# Remove nodes is set to False to remove unused nodes after the rewrite.
36+
super().__init__(name=name, remove_nodes=False, as_function=as_function)
37+
38+
def rewrite(
39+
self, op: ir.tape.Tape, x: ir.Value, pad: ir.Value, conv: ir.Value
40+
) -> ir.Value:
41+
pnode = pad.producer()
42+
cnode = conv.producer()
43+
44+
# Retrieve the padding and axes
45+
x_rank = len(x.shape)
46+
pad_pads = pnode.inputs[1].const_value.numpy().tolist()
47+
if len(pnode.inputs) > 3 and (axes := pnode.inputs[3]) is not None:
48+
axes = [x if x >= 0 else x_rank + x for x in axes.const_value.numpy()]
49+
else:
50+
axes = list(range(x_rank))
51+
52+
# Fulfill pad_pads in every dimension (filling with zero the other ones)
53+
pad_pads = fill_pads_with_axes(pad_pads, axes, x_rank)
54+
55+
# Get only spatial pads
56+
new_pads = pad_pads[2:x_rank] + pad_pads[x_rank + 2 :]
57+
58+
# Replace conv pads = new + old
59+
conv_attr: typing.Mapping[str, ir.Attr] = cnode.attributes.copy()
60+
if "pads" in conv_attr:
61+
new_pads = [x + y for x, y in zip(conv_attr["pads"].as_ints(), new_pads)]
62+
conv_attr["pads"] = ir.convenience.convert_attribute("pads", new_pads)
63+
64+
return op.op(
65+
cnode.op_type,
66+
inputs=(x, *cnode.inputs[1:]),
67+
attributes=conv_attr,
68+
domain=cnode.domain,
69+
name=cnode.name,
70+
)
71+
72+
def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult:
73+
del context # Unused
74+
check_result = orp.MatchResult()
75+
pnode = pad.producer()
76+
x_rank = len(x.shape)
77+
78+
# Pad constraints: attributes
79+
if (mode := pnode.attributes.get("mode", None)) and mode.as_string() != "constant":
80+
return check_result.fail(f"{pnode.name} mode must be 'constant'.")
81+
82+
# Pad constraints: inputs
83+
if (pads := pnode.inputs[1]).const_value is None:
84+
return check_result.fail(f"{pads.name} is not a constant/initializer.")
85+
if len(pnode.inputs) > 2 and (constant_value := pnode.inputs[2]) is not None:
86+
if constant_value.const_value is None:
87+
return check_result.fail(
88+
f"{constant_value.name} is not a constant/initializer."
89+
)
90+
elif constant_value.const_value.numpy().item() != 0:
91+
return check_result.fail(f"{constant_value.name} must be equal to 0.")
92+
axes = list(range(x_rank))
93+
if len(pnode.inputs) > 3 and (axes := pnode.inputs[3]) is not None:
94+
if axes.const_value is None:
95+
return check_result.fail(f"{axes.name} is not a constant/initializer.")
96+
axes_list = [x if x >= 0 else x_rank + x for x in axes.const_value.numpy()]
97+
else:
98+
axes_list = list(range(x_rank))
99+
100+
# Pad constraints: values
101+
pads_list = fill_pads_with_axes(pads.const_value.numpy(), axes_list, x_rank)
102+
if np.any(pads_list[:2] + pads_list[x_rank : x_rank + 2]):
103+
return check_result.fail(f"{pads.name} must be zero in non-spatial dimensions.")
104+
105+
return check_result
106+
107+
108+
class FusePadConv(_FusePadConvBase):
109+
"""Replaces ``Pad(Conv(x))`` with ``Conv(x)``."""
110+
111+
def __init__(self, as_function: bool = False):
112+
super().__init__(name="FusePadConv", as_function=as_function)
113+
114+
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
115+
return op.Conv(
116+
op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]),
117+
_allow_other_inputs=True,
118+
_outputs=["conv"],
119+
)
120+
121+
def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult:
122+
check_result = super().check(context, x, pad, conv)
123+
if check_result.reason:
124+
return check_result
125+
126+
# Conv constraints: attributes
127+
cnode = conv.producer()
128+
if (apad := cnode.attributes.get("auto_pad", None)) and apad.as_string() != "NOTSET":
129+
return check_result.fail(f"{cnode.name} auto_pad must be 'NOTSET'.")
130+
return check_result
131+
132+
133+
fuse_pad_into_conv = FusePadConv.rule()
134+
135+
136+
def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet:
137+
"""Returns a set of rewrite rules that fuse Pad nodes into preceding:
138+
- Conv
139+
140+
Returns:
141+
RewriteRuleSet
142+
"""
143+
return orp.RewriteRuleSet([fuse_pad_into_conv])
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import typing
4+
import unittest
5+
6+
import numpy as np
7+
import onnx_ir as ir
8+
import parameterized
9+
from onnx_ir.passes.common import onnx_checker
10+
11+
from onnxscript.rewriter import pattern as orp
12+
from onnxscript.rewriter import testing
13+
from onnxscript.rewriter.fuse_pad_into_conv import (
14+
fuse_pad_into_conv,
15+
fuse_pad_into_conv_rule_set,
16+
)
17+
18+
19+
def _clone_model(model: ir.Model) -> ir.Model:
20+
return ir.from_proto(ir.to_proto(model))
21+
22+
23+
class FusePadConvBaseTest(unittest.TestCase):
24+
@property
25+
def rng(self):
26+
return np.random.default_rng(20250522)
27+
28+
def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = None):
29+
w = ir.tensor(self.rng.uniform(-0.5, 0.5, shape).astype("float32"), name="W")
30+
if tape is not None:
31+
w = tape.initializer(w)
32+
return w
33+
34+
def build_model(
35+
self,
36+
input_shape: ir.Shape,
37+
weight_shape: typing.Sequence[int],
38+
pad_inputs: typing.Sequence[ir.TensorProtocol | ir.Value | None],
39+
pad_attributes: typing.Mapping[str, ir.Attr] | None = None,
40+
conv_attributes: typing.Mapping[str, ir.Attr] | None = None,
41+
opset_imports: typing.Mapping[str, int] = {"": 20},
42+
) -> ir.Model:
43+
tape = ir.tape.Tape()
44+
inputs = []
45+
output_shape = ir.Shape((input_shape[0],) + ("?",) * (len(input_shape) - 1))
46+
47+
# Convert pad_inputs to initializers (if needed)
48+
pad_inputs = list(pad_inputs)
49+
for idx, x in enumerate(pad_inputs):
50+
if isinstance(x, ir.TensorProtocol):
51+
pad_inputs[idx] = tape.initializer(x)
52+
elif isinstance(x, ir.Value):
53+
inputs.append(x)
54+
elif isinstance(x, float):
55+
pad_inputs[idx] = tape.op("Constant", inputs=[], attributes={"value_float": x})
56+
elif x is not None:
57+
raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.")
58+
59+
# Register operations in the tape
60+
x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
61+
y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes)
62+
y = tape.op(
63+
"Conv",
64+
inputs=[y, self.get_conv_weights(weight_shape, tape)],
65+
attributes=conv_attributes,
66+
output=ir.Input("Y", shape=output_shape, type=ir.TensorType(x.dtype)),
67+
)
68+
69+
# Build the model
70+
ir_model = ir.Model(
71+
ir.Graph(
72+
inputs=[x, *inputs],
73+
outputs=[y],
74+
nodes=tape.nodes,
75+
initializers=tape.initializers,
76+
opset_imports=opset_imports,
77+
name="model",
78+
),
79+
ir_version=9,
80+
)
81+
onnx_checker.CheckerPass(True)(ir_model)
82+
return ir_model
83+
84+
85+
class FusePadConvTest(FusePadConvBaseTest):
86+
@parameterized.parameterized.expand(
87+
[
88+
(pad_pads, const_value, axes, conv_pads)
89+
for pad_pads, axes, conv_pads in [
90+
([0, 0, 2, 2, 0, 0, 2, 2], None, None),
91+
([0, 2, 2, 0, 2, 2], ir.tensor([1, -2, -1], name="axes"), [2, 0, 2, 0]),
92+
([1, 1, 1, 1], ir.tensor([-2, 3], name="axes"), [0, 1, 0, 1]),
93+
]
94+
for const_value in [None, 0.0]
95+
]
96+
)
97+
def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads):
98+
pad_inputs = [ir.tensor(pad_pads, name="pads")]
99+
if const_value is not None or axes is not None:
100+
pad_inputs.append(const_value)
101+
if axes is not None:
102+
pad_inputs.append(axes)
103+
base_model = self.build_model(
104+
input_shape=ir.Shape(("N", 32, 14, 16)),
105+
weight_shape=(10, 32, 3, 3),
106+
pad_inputs=pad_inputs,
107+
conv_attributes={"pads": conv_pads},
108+
)
109+
updated_model = _clone_model(base_model)
110+
111+
# Apply rule
112+
count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model)
113+
114+
# Check that Pad was fused
115+
self.assertEqual(count, 1)
116+
self.assertEqual(updated_model.graph.num_nodes(), 1)
117+
onnx_checker.CheckerPass(True)(updated_model)
118+
119+
# Check inference
120+
inputs = self.rng.random((1, 32, 14, 16), dtype="float32")
121+
testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0)
122+
123+
@parameterized.parameterized.expand(
124+
[
125+
(
126+
"constant",
127+
ir.tensor([1] * 10, name="pads"),
128+
ir.tensor([0.0], name="const_value"),
129+
None,
130+
"NOTSET",
131+
"must be zero in non-spatial dimensions",
132+
),
133+
(
134+
"constant",
135+
ir.tensor([0, 0, 0, 0], name="pads"),
136+
ir.tensor([1.0], name="const_value"),
137+
ir.tensor([0, -1], name="axes"),
138+
"NOTSET",
139+
"must be equal to 0.",
140+
),
141+
(
142+
"edge",
143+
ir.tensor([0, 0, 0, 0], name="pads"),
144+
ir.tensor([0.0], name="const_value"),
145+
ir.tensor([0, -1], name="axes"),
146+
"NOTSET",
147+
"mode must be 'constant'.",
148+
),
149+
(
150+
"constant",
151+
ir.Value(
152+
name="pads", shape=ir.Shape([4]), type=ir.TensorType(ir.DataType.INT64)
153+
),
154+
None,
155+
ir.tensor([0, -1], name="axes"),
156+
"NOTSET",
157+
"pads is not a constant/initializer.",
158+
),
159+
(
160+
"constant",
161+
ir.tensor([0] * 10, name="pads"),
162+
ir.Value(
163+
name="cval", shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.FLOAT)
164+
),
165+
None,
166+
"NOTSET",
167+
"cval is not a constant",
168+
),
169+
(
170+
"constant",
171+
ir.tensor([0, 0, 0, 0], name="pads"),
172+
None,
173+
ir.Value(
174+
name="axes", shape=ir.Shape([2]), type=ir.TensorType(ir.DataType.INT64)
175+
),
176+
"NOTSET",
177+
"axes is not a constant",
178+
),
179+
(
180+
"constant",
181+
ir.tensor([0, 0, 0, 0], name="pads"),
182+
ir.tensor([0.0], name="const_value"),
183+
ir.tensor([0, -1], name="axes"),
184+
"VALID",
185+
"auto_pad must be 'NOTSET'.",
186+
),
187+
]
188+
)
189+
def test_unsupported_fuse_pad_into_conv(
190+
self, mode, pads, const_value, axes, auto_pad, err_msg
191+
):
192+
base_model = self.build_model(
193+
input_shape=ir.Shape(("N", 32, 14, 16, 12)),
194+
weight_shape=(10, 32, 3, 4, 5),
195+
pad_inputs=[pads, const_value, axes],
196+
pad_attributes={"mode": mode},
197+
conv_attributes={"auto_pad": auto_pad},
198+
)
199+
200+
# Apply rule and check it was not applied
201+
tracer = orp.MatchingTracer()
202+
count = fuse_pad_into_conv.apply_to_model(base_model, tracer=tracer)
203+
self.assertEqual(count, 0)
204+
205+
# Check that the error message is the expected one
206+
tracer_match = tracer.best_matches_map[fuse_pad_into_conv][0]
207+
self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED)
208+
self.assertRegex(tracer_match.match_result.reason, err_msg)
209+
210+
211+
if __name__ == "__main__":
212+
unittest.main()

0 commit comments

Comments
 (0)