diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index fb7815bd1..214a117fc 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -21,6 +21,7 @@ collapse_slices, no_op, pattern, + fuse_pad_into_conv, ) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) @@ -30,6 +31,7 @@ *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, *basic_rules.basic_optimization_rules().rules, + *fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules, ) diff --git a/onnxscript/rewriter/fuse_pad_into_conv.py b/onnxscript/rewriter/fuse_pad_into_conv.py new file mode 100644 index 000000000..5736b9026 --- /dev/null +++ b/onnxscript/rewriter/fuse_pad_into_conv.py @@ -0,0 +1,296 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses Pad nodes into preceding nodes. Supported fusion patterns: +- Conv ∘ Pad -> Conv +- ConvInteger ∘ Pad -> ConvInteger + +To make some rules possible, we implicitly transform `auto_pad` attribute into its explicit list. +""" + +from __future__ import annotations + +from typing import List, Sequence + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern as orp + + +def fill_pads_with_axes(pads: Sequence[int], axes: Sequence[int], rank: int) -> List[int]: + new_pads = [] + for axis in range(rank): + if axis not in axes: + start_value = end_value = 0 + else: + start_value = pads[axes.index(axis)] + end_value = pads[axes.index(axis) + len(axes)] + pad_len = len(new_pads) // 2 + new_pads.insert(pad_len + axis, end_value) + new_pads.insert(axis, start_value) + return new_pads + + +def read_conv_attributes(ir_conv: ir.Node) -> dict[str, Sequence[int] | str]: + # Read attributes + attributes = {} + ir_attributes = ir_conv.attributes + attributes["kernel_shape"] = ir_attributes.get_ints( + "kernel_shape", ir_conv.inputs[1].shape[2:] + ) + attributes["strides"] = ir_attributes.get_ints( + "strides", [1] * len(ir_conv.inputs[0].shape[2:]) + ) + attributes["auto_pad"] = ir_attributes.get_string("auto_pad", "NOTSET") + if "pads" in ir_attributes: + attributes["pads"] = ir_attributes.get_ints("pads") + return attributes + + +class _FusePadConvBase(orp.RewriteRuleClassBase): + """Interface for PadConv nodes fusion.""" + + def __init__(self, as_function: bool = False): + # Remove nodes is set to False to remove unused nodes after the rewrite. + super().__init__(remove_nodes=False, as_function=as_function) + + def rewrite( + self, op: ir.tape.Tape, x: ir.Value, pad: ir.Value, conv: ir.Value + ) -> ir.Value: + pad_node = pad.producer() + conv_node = conv.producer() + + # Retrieve the padding and axes + x_rank = len(x.shape) + pad_pads = pad_node.inputs[1].const_value.numpy().tolist() + if len(pad_node.inputs) > 3 and (axes := pad_node.inputs[3]) is not None: + axes = [x if x >= 0 else x_rank + x for x in axes.const_value.numpy()] + else: + axes = list(range(x_rank)) + + # Fulfill pad_pads in every dimension (filling with zero the other ones) + pad_pads = fill_pads_with_axes(pad_pads, axes, x_rank) + + # Get only spatial pads + new_pads = pad_pads[2:x_rank] + pad_pads[x_rank + 2 :] + + # Replace conv pads = new + old + conv_attr = conv_node.attributes.copy() + if "pads" in conv_attr: + new_pads = [x + y for x, y in zip(conv_attr["pads"].as_ints(), new_pads)] + conv_attr.add(ir.AttrInt64s("pads", new_pads)) + + return op.op( + conv_node.op_type, + inputs=(x, *conv_node.inputs[1:]), + attributes=conv_attr, + domain=conv_node.domain, + name=conv_node.name, + ) + + def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult: + del context # Unused + check_result = orp.MatchResult() + pad_node = pad.producer() + x_rank = len(x.shape) + + # Pad constraints: attributes + if (mode := pad_node.attributes.get("mode", None)) and mode.as_string() != "constant": + return check_result.fail(f"{pad_node.name} mode must be 'constant'.") + + # Pad constraints: inputs + if (pads := pad_node.inputs[1]).const_value is None: + return check_result.fail(f"{pads.name} is not a constant/initializer.") + if len(pad_node.inputs) > 2 and (constant_value := pad_node.inputs[2]) is not None: + if constant_value.const_value is None: + return check_result.fail( + f"{constant_value.name} is not a constant/initializer." + ) + elif constant_value.const_value.numpy().item() != 0: + return check_result.fail(f"{constant_value.name} must be equal to 0.") + if len(pad_node.inputs) > 3 and (axes := pad_node.inputs[3]) is not None: + if axes.const_value is None: + return check_result.fail(f"{axes.name} is not a constant/initializer.") + axes_list = [x if x >= 0 else x_rank + x for x in axes.const_value.numpy()] + else: + axes_list = list(range(x_rank)) + + # Pad constraints: values + pads_list = fill_pads_with_axes(pads.const_value.numpy(), axes_list, x_rank) + if np.any(pads_list[:2] + pads_list[x_rank : x_rank + 2]): + return check_result.fail(f"{pads.name} must be zero in non-spatial dimensions.") + + return check_result + + +class FusePadConv(_FusePadConvBase): + """Replaces ``Pad(Conv(x))`` with ``Conv(x)``.""" + + def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: + return op.Conv( + op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]), + _allow_other_inputs=True, + _outputs=["conv"], + ) + + def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult: + check_result = super().check(context, x, pad, conv) + if check_result.reason: + return check_result + + # Conv constraints: attributes + conv_node = conv.producer() + if ( + apad := conv_node.attributes.get("auto_pad", None) + ) and apad.as_string() != "NOTSET": + return check_result.fail(f"{conv_node.name} auto_pad must be 'NOTSET'.") + return check_result + + +class FusePadConvInteger(FusePadConv): + """Replaces ``Pad(ConvInteger(x))`` with ``ConvInteger(x)``.""" + + def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: + return op.ConvInteger( + op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]), + _allow_other_inputs=True, + _outputs=["conv"], + ) + + +class _NormalizePadFormatBase(orp.RewriteRuleClassBase): + """Interface to normalize pad attributes in conv nodes.""" + + @staticmethod + def compute_pads( + input_shape: Sequence[int], + output_shape: Sequence[int], + attributes: dict[str, Sequence[int] | str], + ) -> Sequence[int]: + raise NotImplementedError("Child have to implement this function") + + def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value: + conv_node = conv.producer() + + # Read spatial dimensions and attributes + input_shape = conv_node.inputs[0].shape[2:] + output_shape = conv_node.outputs[0].shape[2:] + attributes = read_conv_attributes(conv_node) + + # Convert auto_pad mode into an explicit list + pads = self.compute_pads(input_shape, output_shape, attributes) + + # Replace auto_pad, forcing to the explicit list + conv_attr = conv_node.attributes.copy() + conv_attr.add(ir.AttrString("auto_pad", "NOTSET")) + if any(x != 0 for x in pads): + conv_attr.add(ir.AttrInt64s("pads", pads)) + + return op.op( + conv_node.op_type, + inputs=conv_node.inputs, + attributes=conv_attr, + domain=conv_node.domain, + name=conv_node.name, + ) + + def check(self, context, conv: ir.Value, **__) -> orp.MatchResult: + del context + check_result = orp.MatchResult() + + # Conv constraints: attributes + conv_node = conv.producer() + auto_pad = conv_node.attributes.get_string("auto_pad", None) + if auto_pad in [None, "NOTSET"]: + return check_result.fail( + f"{conv_node.name} auto_pad must be different to 'NOTSET'." + ) + + # Conv constraints: inputs/outputs + input_shape = conv_node.inputs[0].shape + output_shape = conv_node.outputs[0].shape + if len(input_shape) <= 2: + return check_result.fail(f"Input shapes are not defined on {conv_node.name}.") + if len(output_shape) <= 2: + return check_result.fail(f"Output shapes are not defined on {conv_node.name}.") + + # Conv constraints: values + if auto_pad != "VALID": + error_msg = "Expected static spatial {} shapes on " + conv_node.name + "." + if not all(isinstance(x, int) for x in input_shape[2:]): + return check_result.fail(error_msg.format("input")) + if not all(isinstance(x, int) for x in output_shape[2:]): + return check_result.fail(error_msg.format("output")) + attributes = read_conv_attributes(conv_node) + if len(attributes["kernel_shape"]) != len(attributes["strides"]): + return check_result.fail( + f"strides must have the same length than kernel_shape on {conv_node.name}." + ) + return check_result + + +class NormalizePadFormatConv(_NormalizePadFormatBase): + """Convert auto_pad attribute into 'NOTSET' in Conv nodes .""" + + @staticmethod + def compute_pads( + input_shape: Sequence[int], + output_shape: Sequence[int], + attributes: dict[str, Sequence[int] | str], + ) -> Sequence[int]: + # Compute pads, following auto_pad/pads attributes + if attributes["auto_pad"] in ["NOTSET", "VALID"]: + assert len(input_shape) > 0 + return attributes.get("pads", [0] * len(input_shape) * 2) + + bottom_pads, top_pads = [], [] + kernel_shape, strides = attributes["kernel_shape"], attributes["strides"] + assert len(kernel_shape) == len(strides) == len(input_shape) == len(output_shape) + for x, y, k, s in zip(input_shape, output_shape, kernel_shape, strides): + # Compute the output shape and the total padding to apply + total_pads = max(0, (y - 1) * s + k - x) + + # Depending of mode, apply the padding to the upper or lower part + pad1 = total_pads // 2 + pad2 = total_pads - pad1 + if attributes["auto_pad"] == "SAME_UPPER": + bottom_pads.append(pad1) + top_pads.append(pad2) + else: + top_pads.append(pad1) + bottom_pads.append(pad2) + return bottom_pads + top_pads + + def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: + return op.Conv(x, _allow_other_inputs=True, _outputs=["conv"]) + + +class NormalizePadFormatConvInteger(NormalizePadFormatConv): + """Convert auto_pad attribute into 'NOTSET' in ConvInteger nodes .""" + + def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: + return op.ConvInteger(x, _allow_other_inputs=True, _outputs=["conv"]) + + +normalize_pad_format_conv = NormalizePadFormatConv.rule() +normalize_pad_format_conv_integer = NormalizePadFormatConvInteger.rule() +fuse_pad_into_conv = FusePadConv.rule() +fuse_pad_into_conv_integer = FusePadConvInteger.rule() + + +def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet: + """Returns a set of rewrite rules that fuse Pad nodes into preceding: + - Conv + - ConvInteger + + Returns: + RewriteRuleSet + """ + return orp.RewriteRuleSet( + [ + normalize_pad_format_conv, + normalize_pad_format_conv_integer, + fuse_pad_into_conv, + fuse_pad_into_conv_integer, + ] + ) diff --git a/onnxscript/rewriter/fuse_pad_into_conv_test.py b/onnxscript/rewriter/fuse_pad_into_conv_test.py new file mode 100644 index 000000000..0ed5decaa --- /dev/null +++ b/onnxscript/rewriter/fuse_pad_into_conv_test.py @@ -0,0 +1,406 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +from typing import Mapping, Sequence + +import numpy as np +import onnx_ir as ir +import parameterized +from onnx_ir.passes.common import onnx_checker, shape_inference + +from onnxscript.rewriter import pattern as orp +from onnxscript.rewriter import testing +from onnxscript.rewriter.fuse_pad_into_conv import ( + fuse_pad_into_conv, + fuse_pad_into_conv_rule_set, + normalize_pad_format_conv, +) + + +def _clone_model(model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + +class FusePadConvBaseTest(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250522) + + def get_conv_weights(self, shape: Sequence[int], tape: ir.tape.Tape = None): + w = ir.tensor(self.rng.uniform(-0.5, 0.5, shape).astype("float32"), name="W") + if tape is not None: + w = tape.initializer(w) + return w + + def build_model( + self, + op_type: str, + input_shape: ir.Shape, + weight_shape: Sequence[int], + pad_inputs: Sequence[ir.TensorProtocol | ir.Value | None], + pad_attributes: Mapping[str, ir.Attr] | None = None, + conv_attributes: Mapping[str, ir.Attr] | None = None, + ) -> ir.Model: + tape = ir.tape.Tape() + inputs = [] + output_shape = ir.Shape((input_shape[0],) + ("?",) * (len(input_shape) - 1)) + + # Convert pad_inputs to initializers (if needed) + pad_inputs = list(pad_inputs) + for idx, x in enumerate(pad_inputs): + if isinstance(x, ir.TensorProtocol): + pad_inputs[idx] = tape.initializer(x) + elif isinstance(x, ir.Value): + inputs.append(x) + elif isinstance(x, float): + pad_inputs[idx] = tape.op("Constant", inputs=[], attributes={"value_float": x}) + elif x is not None: + raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.") + + # Register operations in the tape + idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT + x = ir.Input("X", shape=input_shape, type=ir.TensorType(idtype)) + y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes) + y = tape.op( + op_type, + inputs=[y, self.get_conv_weights(weight_shape, tape)], + attributes=conv_attributes, + output=ir.Input("Y", shape=output_shape, type=ir.TensorType(x.dtype)), + ) + if op_type == "ConvInteger": + y.dtype = ir.DataType.INT32 + + # Build the model + ir_model = ir.Model( + ir.Graph( + inputs=[x, *inputs], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 20}, + name="model", + ), + ir_version=10, + ) + onnx_checker.CheckerPass(True)(ir_model) + ir_model = shape_inference.infer_shapes(ir_model) + return ir_model + + +class FusePadConvTest(FusePadConvBaseTest): + @parameterized.parameterized.expand( + [ + (pad_pads, const_value, axes, conv_pads, conv_auto_pad) + for pad_pads, axes, conv_pads, conv_auto_pad in [ + ([0, 0, 2, 2, 0, 0, 2, 2], None, None, None), + ([0, 2, 2, 0, 2, 2], ir.tensor([1, -2, -1], name="axes"), [2, 0, 2, 0], None), + ([1, 1, 1, 1], ir.tensor([-2, 3], name="axes"), [0, 1, 0, 1], None), + ([1, 3, 1, 3], ir.tensor([3, 2], name="axes"), None, "VALID"), + ] + for const_value in [None, 0.0] + ] + ) + def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads, conv_auto_pad): + pad_inputs = [ir.tensor(pad_pads, name="pads")] + if const_value is not None or axes is not None: + pad_inputs.append(const_value) + if axes is not None: + pad_inputs.append(axes) + base_model = self.build_model( + op_type="Conv", + input_shape=ir.Shape(("N", 32, 14, 16)), + weight_shape=(10, 32, 3, 3), + pad_inputs=pad_inputs, + conv_attributes={"pads": conv_pads, "auto_pad": conv_auto_pad}, + ) + updated_model = _clone_model(base_model) + + # Apply rule + count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + + # Check that Pad was fused + self.assertEqual(count, 1 if conv_auto_pad is None else 2) + self.assertEqual(updated_model.graph.num_nodes(), 1) + onnx_checker.CheckerPass(True)(updated_model) + + # Check inference + inputs = self.rng.random((1, 32, 14, 16), dtype="float32") + testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand( + [ + ( + "constant", + ir.tensor([1] * 10, name="pads"), + ir.tensor([0.0], name="const_value"), + None, + "NOTSET", + "must be zero in non-spatial dimensions", + ), + ( + "constant", + ir.tensor([0, 0, 0, 0], name="pads"), + ir.tensor([1.0], name="const_value"), + ir.tensor([0, -1], name="axes"), + "NOTSET", + "must be equal to 0.", + ), + ( + "edge", + ir.tensor([0, 0, 0, 0], name="pads"), + ir.tensor([0.0], name="const_value"), + ir.tensor([0, -1], name="axes"), + "NOTSET", + "mode must be 'constant'.", + ), + ( + "constant", + ir.Value( + name="pads", shape=ir.Shape([4]), type=ir.TensorType(ir.DataType.INT64) + ), + None, + ir.tensor([0, -1], name="axes"), + "NOTSET", + "pads is not a constant/initializer.", + ), + ( + "constant", + ir.tensor([0] * 10, name="pads"), + ir.Value( + name="cval", shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.FLOAT) + ), + None, + "NOTSET", + "cval is not a constant", + ), + ( + "constant", + ir.tensor([0, 0, 0, 0], name="pads"), + None, + ir.Value( + name="axes", shape=ir.Shape([2]), type=ir.TensorType(ir.DataType.INT64) + ), + "NOTSET", + "axes is not a constant", + ), + ( + "constant", + ir.tensor([0, 0, 0, 0], name="pads"), + ir.tensor([0.0], name="const_value"), + ir.tensor([0, -1], name="axes"), + "VALID", + "auto_pad must be 'NOTSET'.", + ), + ] + ) + def test_unsupported_fuse_pad_into_conv( + self, mode, pads, const_value, axes, auto_pad, err_msg + ): + base_model = self.build_model( + op_type="Conv", + input_shape=ir.Shape(("N", 32, 14, 16, 12)), + weight_shape=(10, 32, 3, 4, 5), + pad_inputs=[pads, const_value, axes], + pad_attributes={"mode": mode}, + conv_attributes={"auto_pad": auto_pad}, + ) + + # Apply rule and check it was not applied + tracer = orp.MatchingTracer() + count = fuse_pad_into_conv.apply_to_model(base_model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[fuse_pad_into_conv][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, err_msg) + + +class FusePadConvIntegerTest(FusePadConvBaseTest): + def get_conv_weights(self, shape: Sequence[int], tape: ir.tape.Tape = None): + w = ir.tensor(self.rng.integers(0, 256, shape).astype("uint8"), name="W") + if tape is not None: + w = tape.initializer(w) + return w + + @parameterized.parameterized.expand( + [ + (pad_pads, const_value, axes, conv_pads, conv_auto_pad) + for pad_pads, axes, conv_pads, conv_auto_pad in [ + ([0, 0, 3, 2, 0, 0, 1, 4], None, [1, 1, 1, 1], None), + ([2, 2, 0, 2, 2, 0], ir.tensor([-2, -1, 1], name="axes"), None, None), + ([1, 2, 2, 1], ir.tensor([-1, 2], name="axes"), [0, 1, 0, 1], None), + ([3, 3], ir.tensor([2], name="axes"), None, "SAME_UPPER"), + ] + for const_value in [None, ir.tensor(np.array([0], "uint8"), name="const_value")] + ] + ) + def test_fuse_pad_into_conv_integer( + self, pad_pads, const_value, axes, conv_pads, conv_auto_pad + ): + pad_inputs = [ir.tensor(pad_pads, name="pads")] + if const_value is not None or axes is not None: + pad_inputs.append(const_value) + if axes is not None: + pad_inputs.append(axes) + base_model = self.build_model( + op_type="ConvInteger", + input_shape=ir.Shape(("N", 24, 19, 23)), + weight_shape=(8, 24, 3, 3), + pad_inputs=pad_inputs, + conv_attributes={"pads": conv_pads, "auto_pad": conv_auto_pad}, + ) + updated_model = _clone_model(base_model) + + # Apply rule + count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + + # Check that Pad was fused + self.assertEqual(count, 1 if conv_auto_pad is None else 2) + self.assertEqual(updated_model.graph.num_nodes(), 1) + onnx_checker.CheckerPass(True)(updated_model) + + # Check inference + inputs = self.rng.integers(0, 255, (1, 24, 19, 23), dtype="uint8") + testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0) + + +class NormalizePadFormatTest(FusePadConvBaseTest): + def build_model( + self, + input_shape: ir.Shape, + conv_inputs: Sequence[int], + conv_attributes: Mapping[str, ir.Attr] | None = None, + infer_shapes=True, + ) -> ir.Model: + tape = ir.tape.Tape() + inputs = [] + output_shape = ir.Shape(("?",) * len(input_shape)) + + # Convert conv_inputs to initializers (if needed) + conv_inputs = list(conv_inputs) + for idx, x in enumerate(conv_inputs): + if isinstance(x, ir.TensorProtocol): + conv_inputs[idx] = tape.initializer(x) + elif isinstance(x, ir.Value): + inputs.append(x) + elif x is not None: + raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.") + + # Register operations in the tape + x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + y = tape.op( + "Conv", + inputs=[x, *conv_inputs], + attributes=conv_attributes, + output=ir.Input("Y", shape=output_shape, type=x.type), + ) + + # Build the model + ir_model = ir.Model( + ir.Graph( + inputs=[x, *inputs], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 20}, + name="model", + ), + ir_version=10, + ) + if len(input_shape) > 0 and infer_shapes: + onnx_checker.CheckerPass(True)(ir_model) + ir_model = shape_inference.infer_shapes(ir_model) + else: + onnx_checker.CheckerPass(False)(ir_model) + return ir_model + + @parameterized.parameterized.expand( + [ + (dynamic_shape, strides, kernel_shape, auto_pad) + for strides, kernel_shape in [((2, 3), (1, 4)), ((2, 1), (5, 2))] + for dynamic_shape, auto_pad in [ + (False, "SAME_UPPER"), + (False, "SAME_LOWER"), + (True, "VALID"), + ] + ] + ) + def test_normalize_pad_format(self, dynamic_shape, strides, kernel_shape, auto_pad): + input_shape = ( + ir.Shape(("N", "A", "B", "C")) if dynamic_shape else ir.Shape(("N", 32, 22, 27)) + ) + base_model = self.build_model( + input_shape=input_shape, + conv_inputs=[ir.tensor(self.get_conv_weights((32, 32, *kernel_shape)), name="W")], + conv_attributes={ + "strides": strides, + "auto_pad": auto_pad, + "kernel_shape": kernel_shape, + }, + ) + updated_model = _clone_model(base_model) + + # Apply rule + count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + onnx_checker.CheckerPass(True)(updated_model) + + # Check conv has changed + self.assertEqual(count, 1) + self.assertEqual(updated_model.graph[0].attributes.get_string("auto_pad"), "NOTSET") + + # Check inference + inputs = self.rng.random((1, 32, 22, 27), dtype="float32") + testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand( + [ + (ir.Shape([]), False, "Input shapes are not defined"), + (ir.Shape(("N", "C", "A")), False, "Expected static spatial input shapes"), + (ir.Shape(("N", "C", 32)), False, "Expected static spatial output shapes"), + ] + ) + def test_unsupported_normalize_pad_format(self, input_shape, infer_shapes, error_msg): + base_model = self.build_model( + input_shape=input_shape, + conv_inputs=[ir.tensor(np.ones((32, 11, 4)), name="W")], + conv_attributes={"auto_pad": "SAME_UPPER"}, + infer_shapes=infer_shapes, + ) + + # Apply rule and check it was not applied + tracer = orp.MatchingTracer() + count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, error_msg) + + def test_unsupported_normalize_pad_format_on_weights(self): + W = ir.Value(name="W", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.FLOAT)) + base_model = self.build_model( + input_shape=ir.Shape(("N", 2, 32)), + conv_inputs=[W], + conv_attributes={"auto_pad": "SAME_UPPER"}, + infer_shapes=False, + ) + # Set output shape to analyze error due to weights + base_model.graph[0].outputs[0].shape = ir.Shape(("N", 10, 32)) + + # Apply rule and check it was not applied + tracer = orp.MatchingTracer() + count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, "same length than kernel_shape") + + +if __name__ == "__main__": + unittest.main()