Skip to content

Commit 1c39d9f

Browse files
committed
[Rewriter]: introduce fuse_pad_into_conv_integer (#2301)
1 parent e866010 commit 1c39d9f

File tree

2 files changed

+77
-3
lines changed

2 files changed

+77
-3
lines changed

onnxscript/rewriter/fuse_pad_into_conv.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Licensed under the MIT License.
33
"""Fuses Pad nodes into preceding nodes. Supported fusion patterns:
44
- Pad ∘ Conv -> Conv
5+
- Pad ∘ ConvInteger -> ConvInteger
56
"""
67

78
import typing
@@ -130,14 +131,35 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
130131
return check_result
131132

132133

134+
class FusePadConvInteger(FusePadConv):
135+
"""Replaces ``Pad(ConvInteger(x))`` with ``ConvInteger(x)``."""
136+
137+
def __init__(self, as_function: bool = False):
138+
super(FusePadConv, self).__init__(name="FusePadConvInteger", as_function=as_function)
139+
140+
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
141+
return op.ConvInteger(
142+
op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]),
143+
_allow_other_inputs=True,
144+
_outputs=["conv"],
145+
)
146+
147+
133148
fuse_pad_into_conv = FusePadConv.rule()
149+
fuse_pad_into_conv_integer = FusePadConvInteger.rule()
134150

135151

136152
def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet:
137153
"""Returns a set of rewrite rules that fuse Pad nodes into preceding:
138154
- Conv
155+
- ConvInteger
139156
140157
Returns:
141158
RewriteRuleSet
142159
"""
143-
return orp.RewriteRuleSet([fuse_pad_into_conv])
160+
return orp.RewriteRuleSet(
161+
[
162+
fuse_pad_into_conv,
163+
fuse_pad_into_conv_integer,
164+
]
165+
)

onnxscript/rewriter/fuse_pad_into_conv_test.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = Non
3333

3434
def build_model(
3535
self,
36+
op_type: str,
3637
input_shape: ir.Shape,
3738
weight_shape: typing.Sequence[int],
3839
pad_inputs: typing.Sequence[ir.TensorProtocol | ir.Value | None],
@@ -57,14 +58,17 @@ def build_model(
5758
raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.")
5859

5960
# Register operations in the tape
60-
x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
61+
idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT
62+
x = ir.Input("X", shape=input_shape, type=ir.TensorType(idtype))
6163
y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes)
6264
y = tape.op(
63-
"Conv",
65+
op_type,
6466
inputs=[y, self.get_conv_weights(weight_shape, tape)],
6567
attributes=conv_attributes,
6668
output=ir.Input("Y", shape=output_shape, type=ir.TensorType(x.dtype)),
6769
)
70+
if op_type == "ConvInteger":
71+
y.dtype = ir.DataType.INT32
6872

6973
# Build the model
7074
ir_model = ir.Model(
@@ -101,6 +105,7 @@ def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads):
101105
if axes is not None:
102106
pad_inputs.append(axes)
103107
base_model = self.build_model(
108+
op_type="Conv",
104109
input_shape=ir.Shape(("N", 32, 14, 16)),
105110
weight_shape=(10, 32, 3, 3),
106111
pad_inputs=pad_inputs,
@@ -190,6 +195,7 @@ def test_unsupported_fuse_pad_into_conv(
190195
self, mode, pads, const_value, axes, auto_pad, err_msg
191196
):
192197
base_model = self.build_model(
198+
op_type="Conv",
193199
input_shape=ir.Shape(("N", 32, 14, 16, 12)),
194200
weight_shape=(10, 32, 3, 4, 5),
195201
pad_inputs=[pads, const_value, axes],
@@ -208,5 +214,51 @@ def test_unsupported_fuse_pad_into_conv(
208214
self.assertRegex(tracer_match.match_result.reason, err_msg)
209215

210216

217+
class FusePadConvIntegerTest(FusePadConvBaseTest):
218+
def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = None):
219+
w = ir.tensor(self.rng.integers(0, 256, shape).astype("uint8"), name="W")
220+
if tape is not None:
221+
w = tape.initializer(w)
222+
return w
223+
224+
@parameterized.parameterized.expand(
225+
[
226+
(pad_pads, const_value, axes, conv_pads)
227+
for pad_pads, axes, conv_pads in [
228+
([0, 0, 3, 2, 0, 0, 1, 4], None, [1, 1, 1, 1]),
229+
([2, 2, 0, 2, 2, 0], ir.tensor([-2, -1, 1], name="axes"), None),
230+
([1, 2, 2, 1], ir.tensor([-1, 2], name="axes"), [0, 1, 0, 1]),
231+
]
232+
for const_value in [None, ir.tensor(np.array([0], "uint8"), name="const_value")]
233+
]
234+
)
235+
def test_fuse_pad_into_conv_integer(self, pad_pads, const_value, axes, conv_pads):
236+
pad_inputs = [ir.tensor(pad_pads, name="pads")]
237+
if const_value is not None or axes is not None:
238+
pad_inputs.append(const_value)
239+
if axes is not None:
240+
pad_inputs.append(axes)
241+
base_model = self.build_model(
242+
op_type="ConvInteger",
243+
input_shape=ir.Shape(("N", 24, 19, 23)),
244+
weight_shape=(8, 24, 3, 3),
245+
pad_inputs=pad_inputs,
246+
conv_attributes={"pads": conv_pads},
247+
)
248+
updated_model = _clone_model(base_model)
249+
250+
# Apply rule
251+
count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model)
252+
253+
# Check that Pad was fused
254+
self.assertEqual(count, 1)
255+
self.assertEqual(updated_model.graph.num_nodes(), 1)
256+
onnx_checker.CheckerPass(True)(updated_model)
257+
258+
# Check inference
259+
inputs = self.rng.integers(0, 255, (1, 24, 19, 23), dtype="uint8")
260+
testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0)
261+
262+
211263
if __name__ == "__main__":
212264
unittest.main()

0 commit comments

Comments
 (0)