From 4a665e04fc0f43b54f4b4aa9fe0c083f807ed3cb Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Sun, 18 May 2025 08:32:52 +0200 Subject: [PATCH 1/2] feat(rewriter): introduce fuse batchnorm - Fuses Batchnorm node into the following nodes (Conv, ConvTranspose, Gemm) --- onnxscript/rewriter/__init__.py | 2 + onnxscript/rewriter/fuse_batchnorm.py | 183 ++++++++++++++++ onnxscript/rewriter/fuse_batchnorm_test.py | 238 +++++++++++++++++++++ 3 files changed, 423 insertions(+) create mode 100644 onnxscript/rewriter/fuse_batchnorm.py create mode 100644 onnxscript/rewriter/fuse_batchnorm_test.py diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 5efaf784b..0f2868257 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -18,6 +18,7 @@ broadcast_to_matmul, cast_constant_of_shape, collapse_slices, + fuse_batchnorm, gemm_to_matmul_add, llama_rule_sets, no_op, @@ -28,6 +29,7 @@ _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( *no_op.rules.rules, # TODO: merge this rule into constant folding? *broadcast_to_matmul.rules.rules, + *fuse_batchnorm.fuse_batchnorm_rule_set().rules, gemm_to_matmul_add.rule, # type: ignore[has-type] *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, diff --git a/onnxscript/rewriter/fuse_batchnorm.py b/onnxscript/rewriter/fuse_batchnorm.py new file mode 100644 index 000000000..8b836fc85 --- /dev/null +++ b/onnxscript/rewriter/fuse_batchnorm.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns: +- BatchNormalization + Conv -> Conv +- BatchNormalization + ConvTranpose -> ConvTranpose +- BatchNormalization + Gemm -> Gemm + +Approach: + Given an inbound operation output: Y = W * X + B + And a BatchNormalization outputs: Y_BN = (gamma * (Y - μ) / std) + β, where std = sqrt(var + eps) + + The fusion updates the inbound weights as follows: + - W_fused = W * (gamma / std) + - B_fused = (B - μ) * (gamma / std) + β +""" + +from abc import ABC, abstractmethod + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import pattern as orp + + +class FuseBatchNormBase(orp.RewriteRuleClassBase, ABC): + """Interface for BatchNormalization nodes fusion.""" + + def __init__( + self, + op_type: str, + name: str | None = None, + remove_nodes: bool = True, + as_function: bool = False, + ) -> None: + super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function) + self.op_type = op_type + + @abstractmethod + def get_filters_axis(self, attributes) -> int: + """Return the axis along which BatchNorm scale should be broadcasted.""" + + def _reshape_for_broadcast(self, x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray: + # Convert axis to positive + if axis < 0: + axis += rank + + # Build shape: 1s everywhere except -1 at the target axis + broadcast_shape = [1 if axis != i else -1 for i in range(rank)] + return np.reshape(x, broadcast_shape) + + def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value): + batchnorm_node = batchnorm_out.producer() + # Get BatchNorm parameters + gamma, beta, input_mean, input_var = [ + inp.const_value.numpy() for inp in batchnorm_node.inputs[1:] + ] + + # 1e-5 is the default value for epsilon according to + # https://onnx.ai/onnx/operators/onnx__BatchNormalization.html#attributes + default_eps = ir.Attr("epsilon", ir.AttributeType.FLOAT, 1e-5) + eps = batchnorm_node.attributes.get("epsilon", default_eps).as_float() + + # Compute the scale_factor to update the inbound weights and bias + scale_factor = gamma / np.sqrt(input_var + eps) + + # Update inbound weights + inbound_node = inbound_out.producer() + weights = inbound_node.inputs[1].const_value.numpy() + + # Reshape scale factor so it is broadcastable + axis = self.get_filters_axis(inbound_node.attributes) + fused_weights = ir.tensor( + weights * self._reshape_for_broadcast(scale_factor, weights.ndim, axis=axis) + ) + + # Update bias + if len(inbound_node.inputs) > 2: + original_bias = inbound_node.inputs[2].const_value.numpy() + bias_name = inbound_node.inputs[2].name + else: + original_bias = np.zeros_like(input_mean) + bias_name = x.name + "_bias" + fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta) + + return op.op( + self.op_type, + inputs=[ + x, + op.initializer(fused_weights, name=inbound_node.inputs[1].name), + op.initializer(fused_bias, name=bias_name), + ], + attributes=inbound_node.attributes, + ) + + def check(self, context, x, inbound_out, batchnorm_out) -> orp.MatchResult: + del context # Unused + check_result = orp.MatchResult() + + inbound_node = inbound_out.producer() + batchnorm_node = batchnorm_out.producer() + + # Check that inbound weights + (inbound bias) + batchnorm params are initializers + initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]] + if len(inbound_node.inputs) > 2: + initializers.append(inbound_node.inputs[2]) + + for initializer in initializers: + if not initializer.is_initializer() or initializer.const_value is None: + return check_result.fail(f"{initializer.name} is not a constant initializer") + + return check_result + + +class FuseBatchNormIntoConv(FuseBatchNormBase): + """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``.""" + + def __init__(self): + super().__init__("Conv") + + def get_filters_axis(self, attributes) -> int: + return 0 + + def pattern(self, op, x): + return op.BatchNormalization( + op.Conv(x, _allow_other_inputs=True, _outputs=["inbound_out"]), + _allow_other_inputs=True, + _outputs=["batchnorm_out"], + ) + + +class FuseBatchNormIntoConvTranspose(FuseBatchNormBase): + """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``.""" + + def __init__(self): + super().__init__("ConvTranspose") + + def get_filters_axis(self, attributes) -> int: + return 1 + + def pattern(self, op, x): + return op.BatchNormalization( + op.ConvTranspose(x, _allow_other_inputs=True, _outputs=["inbound_out"]), + _allow_other_inputs=True, + _outputs=["batchnorm_out"], + ) + + +class FuseBatchNormIntoGemm(FuseBatchNormBase): + """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" + + def __init__(self): + super().__init__("Gemm") + + def get_filters_axis(self, attributes) -> int: + return 0 if attributes.get("transB") is not None and attributes["transB"].value else 1 + + def pattern(self, op, x): + return op.BatchNormalization( + op.Gemm(x, _allow_other_inputs=True, _outputs=["inbound_out"]), + _allow_other_inputs=True, + _outputs=["batchnorm_out"], + ) + + +fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule() +fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule() +fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule() + + +def fuse_batchnorm_rule_set() -> orp.RewriteRuleSet: + """Returns a set of rewrite rules that fuse BatchNormalization nodes + into preceding nodes such as Conv, ConvTranspose, and Gemm. + + Returns: + RewriteRuleSet + """ + return orp.RewriteRuleSet( + [ + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_convtranspose_rule, + fuse_batchnorm_into_gemm_rule, + ] + ) diff --git a/onnxscript/rewriter/fuse_batchnorm_test.py b/onnxscript/rewriter/fuse_batchnorm_test.py new file mode 100644 index 000000000..d13a27914 --- /dev/null +++ b/onnxscript/rewriter/fuse_batchnorm_test.py @@ -0,0 +1,238 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx.checker +import onnx.parser +import parameterized + +from onnxscript import ir +from onnxscript.rewriter import fuse_batchnorm, testing + + +class FuseBatchnormTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ("bias_false", False), + ("bias_true", True), + ] + ) + def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): + convtranspose_inputs = "X, W" + parameters = ( + "float[32, 64, 3, 3] W, " + "float[64] gamma, " + "float[64] beta, " + "float[64] input_mean, " + "float[64] input_var" + ) + if convtranspose_bias: + parameters += ", float[64] B" + convtranspose_inputs += ", B" + + model_proto = onnx.parser.parse_model(f""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) + <{parameters}> + {{ + X1 = ConvTranspose({convtranspose_inputs}) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + }} + """) + # Add initializers + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(32, 64, 3, 3).astype(np.float32), name="W" + ), + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="gamma"), + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="beta"), + onnx.numpy_helper.from_array( + np.random.randn(64).astype(np.float32), name="input_mean" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(64)).astype(np.float32), name="input_var" + ), + ] + if convtranspose_bias: + initializers.append( + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + ) + model_proto.graph.initializer.extend(initializers) + + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + + # Apply rule + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # Check that BatchNorm was fused + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + + # Check inference + testing.assert_numerically_equal( + model_proto, model, (np.random.rand(1, 32, 14, 16).astype(np.float32),) + ) + + output_model_proto = ir.serde.serialize_model(model) + onnx.checker.check_model(output_model_proto, True) + + @parameterized.parameterized.expand( + [ + ("bias_false", False), + ("bias_true", True), + ] + ) + def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): + conv_inputs = "X, W" + parameters = ( + "float[64, 32, 3, 3] W, " + "float[64] gamma, " + "float[64] beta, " + "float[64] input_mean, " + "float[64] input_var" + ) + if conv_bias: + parameters += ", float[64] B" + conv_inputs += ", B" + + model_proto = onnx.parser.parse_model(f""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) + <{parameters}> + {{ + X1 = Conv({conv_inputs}) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + }} + """) + # Add initializers + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(64, 32, 3, 3).astype(np.float32), name="W" + ), + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="gamma"), + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="beta"), + onnx.numpy_helper.from_array( + np.random.randn(64).astype(np.float32), name="input_mean" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(64)).astype(np.float32), name="input_var" + ), + ] + if conv_bias: + initializers.append( + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + ) + model_proto.graph.initializer.extend(initializers) + + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + + # Apply rule + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # Check that BatchNorm was fused + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + + # Check inference + testing.assert_numerically_equal( + model_proto, model, (np.random.rand(1, 32, 14, 16).astype(np.float32),) + ) + + output_model_proto = ir.serde.serialize_model(model) + onnx.checker.check_model(output_model_proto, True) + + @parameterized.parameterized.expand( + [ + ("bias_false_transB_0", False, 0), + ("bias_true_transB_0", True, 0), + ("bias_false_transB_1", False, 1), + ("bias_true_transB_1", True, 1), + ] + ) + def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): + gemm_inputs = "X, W" + parameters = ( + f"float{'[64, 32]' if transB else '[32, 64]'} W, " + "float[64] gamma, " + "float[64] beta, " + "float[64] input_mean, " + "float[64] input_var" + ) + + if gemm_bias: + parameters += ", float[64] B" + gemm_inputs += ", B" + + model_proto = onnx.parser.parse_model(f""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32] X) => (float [N, ?] Y) + <{parameters}> + {{ + X1 = Gemm({gemm_inputs}) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + }} + """) + weights = np.random.randn(32, 64).astype(np.float32) + if transB: + weights = weights.T + + # Add initializers + initializers = [ + onnx.numpy_helper.from_array(weights, name="W"), + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="gamma"), + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="beta"), + onnx.numpy_helper.from_array( + np.random.randn(64).astype(np.float32), name="input_mean" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(64)).astype(np.float32), name="input_var" + ), + ] + if gemm_bias: + initializers.append( + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + ) + model_proto.graph.initializer.extend(initializers) + + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + + # Apply rule + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # Check that BatchNorm was fused + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + + # Check inference + testing.assert_numerically_equal( + model_proto, model, (np.random.rand(1, 32).astype(np.float32),) + ) + + output_model_proto = ir.serde.serialize_model(model) + onnx.checker.check_model(output_model_proto, True) + + def test_fuse_batchnorm_non_initializers(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X, float[64, 32, 3, 3] W, float[64] B, + float[64] gamma, float[64] beta, float[64] input_var, + float[64] input_mean) => (float [N, ?, ?, ?] Y) + { + X1 = Conv(X, W, B) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + } + """) + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # No changes were applied + self.assertEqual(count, 0) + + +if __name__ == "__main__": + unittest.main() From a0b7beefa98a5a27d07160e68a032b2c12050817 Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Sat, 24 May 2025 20:59:21 +0200 Subject: [PATCH 2/2] review: apply requested changes - Make the rule optional - Improve code/test (checks, type-checking) --- onnxscript/rewriter/__init__.py | 2 - onnxscript/rewriter/fuse_batchnorm.py | 53 +++++++++-------- onnxscript/rewriter/fuse_batchnorm_test.py | 67 ++++++++++++++-------- 3 files changed, 72 insertions(+), 50 deletions(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 0f2868257..5efaf784b 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -18,7 +18,6 @@ broadcast_to_matmul, cast_constant_of_shape, collapse_slices, - fuse_batchnorm, gemm_to_matmul_add, llama_rule_sets, no_op, @@ -29,7 +28,6 @@ _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( *no_op.rules.rules, # TODO: merge this rule into constant folding? *broadcast_to_matmul.rules.rules, - *fuse_batchnorm.fuse_batchnorm_rule_set().rules, gemm_to_matmul_add.rule, # type: ignore[has-type] *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, diff --git a/onnxscript/rewriter/fuse_batchnorm.py b/onnxscript/rewriter/fuse_batchnorm.py index 8b836fc85..b8b5c143d 100644 --- a/onnxscript/rewriter/fuse_batchnorm.py +++ b/onnxscript/rewriter/fuse_batchnorm.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns: -- BatchNormalization + Conv -> Conv -- BatchNormalization + ConvTranpose -> ConvTranpose -- BatchNormalization + Gemm -> Gemm +- BatchNormalization ∘ Conv -> Conv +- BatchNormalization ∘ ConvTranpose -> ConvTranpose +- BatchNormalization ∘ Gemm -> Gemm Approach: Given an inbound operation output: Y = W * X + B @@ -15,6 +15,7 @@ """ from abc import ABC, abstractmethod +from typing import Mapping import numpy as np @@ -22,7 +23,13 @@ from onnxscript.rewriter import pattern as orp -class FuseBatchNormBase(orp.RewriteRuleClassBase, ABC): +def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray: + # Build shape: 1s everywhere except -1 at the target axis + broadcast_shape = [1 if axis != i else -1 for i in range(rank)] + return np.reshape(x, broadcast_shape) + + +class _FuseBatchNormBase(orp.RewriteRuleClassBase, ABC): """Interface for BatchNormalization nodes fusion.""" def __init__( @@ -36,18 +43,9 @@ def __init__( self.op_type = op_type @abstractmethod - def get_filters_axis(self, attributes) -> int: + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: """Return the axis along which BatchNorm scale should be broadcasted.""" - def _reshape_for_broadcast(self, x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray: - # Convert axis to positive - if axis < 0: - axis += rank - - # Build shape: 1s everywhere except -1 at the target axis - broadcast_shape = [1 if axis != i else -1 for i in range(rank)] - return np.reshape(x, broadcast_shape) - def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value): batchnorm_node = batchnorm_out.producer() # Get BatchNorm parameters @@ -70,7 +68,7 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu # Reshape scale factor so it is broadcastable axis = self.get_filters_axis(inbound_node.attributes) fused_weights = ir.tensor( - weights * self._reshape_for_broadcast(scale_factor, weights.ndim, axis=axis) + weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis) ) # Update bias @@ -92,7 +90,9 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu attributes=inbound_node.attributes, ) - def check(self, context, x, inbound_out, batchnorm_out) -> orp.MatchResult: + def check( + self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value + ) -> orp.MatchResult: del context # Unused check_result = orp.MatchResult() @@ -100,24 +100,27 @@ def check(self, context, x, inbound_out, batchnorm_out) -> orp.MatchResult: batchnorm_node = batchnorm_out.producer() # Check that inbound weights + (inbound bias) + batchnorm params are initializers + # and that they are not graph inputs initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]] if len(inbound_node.inputs) > 2: initializers.append(inbound_node.inputs[2]) for initializer in initializers: if not initializer.is_initializer() or initializer.const_value is None: - return check_result.fail(f"{initializer.name} is not a constant initializer") + return check_result.fail(f"{initializer.name} is not a constant initializer.") + if initializer.is_graph_input(): + return check_result.fail(f"{initializer.name} is a graph input.") return check_result -class FuseBatchNormIntoConv(FuseBatchNormBase): +class FuseBatchNormIntoConv(_FuseBatchNormBase): """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``.""" def __init__(self): super().__init__("Conv") - def get_filters_axis(self, attributes) -> int: + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 0 def pattern(self, op, x): @@ -128,13 +131,13 @@ def pattern(self, op, x): ) -class FuseBatchNormIntoConvTranspose(FuseBatchNormBase): +class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase): """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``.""" def __init__(self): super().__init__("ConvTranspose") - def get_filters_axis(self, attributes) -> int: + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 1 def pattern(self, op, x): @@ -145,14 +148,16 @@ def pattern(self, op, x): ) -class FuseBatchNormIntoGemm(FuseBatchNormBase): +class FuseBatchNormIntoGemm(_FuseBatchNormBase): """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" def __init__(self): super().__init__("Gemm") - def get_filters_axis(self, attributes) -> int: - return 0 if attributes.get("transB") is not None and attributes["transB"].value else 1 + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: + return ( + 0 if attributes.get("transB") is not None and attributes["transB"].as_int() else 1 + ) def pattern(self, op, x): return op.BatchNormalization( diff --git a/onnxscript/rewriter/fuse_batchnorm_test.py b/onnxscript/rewriter/fuse_batchnorm_test.py index d13a27914..20d272abd 100644 --- a/onnxscript/rewriter/fuse_batchnorm_test.py +++ b/onnxscript/rewriter/fuse_batchnorm_test.py @@ -12,6 +12,22 @@ class FuseBatchnormTest(unittest.TestCase): + def _create_batchnorm_params(self, size: int): + return [ + onnx.numpy_helper.from_array( + np.random.randn(size).astype(np.float32), name="gamma" + ), + onnx.numpy_helper.from_array( + np.random.randn(size).astype(np.float32), name="beta" + ), + onnx.numpy_helper.from_array( + np.random.randn(size).astype(np.float32), name="input_mean" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(size)).astype(np.float32), name="input_var" + ), + ] + @parameterized.parameterized.expand( [ ("bias_false", False), @@ -45,14 +61,7 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): onnx.numpy_helper.from_array( np.random.randn(32, 64, 3, 3).astype(np.float32), name="W" ), - onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="gamma"), - onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="beta"), - onnx.numpy_helper.from_array( - np.random.randn(64).astype(np.float32), name="input_mean" - ), - onnx.numpy_helper.from_array( - np.abs(np.random.randn(64)).astype(np.float32), name="input_var" - ), + *self._create_batchnorm_params(size=64), ] if convtranspose_bias: initializers.append( @@ -111,14 +120,7 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): onnx.numpy_helper.from_array( np.random.randn(64, 32, 3, 3).astype(np.float32), name="W" ), - onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="gamma"), - onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="beta"), - onnx.numpy_helper.from_array( - np.random.randn(64).astype(np.float32), name="input_mean" - ), - onnx.numpy_helper.from_array( - np.abs(np.random.randn(64)).astype(np.float32), name="input_var" - ), + *self._create_batchnorm_params(size=64), ] if conv_bias: initializers.append( @@ -182,14 +184,7 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): # Add initializers initializers = [ onnx.numpy_helper.from_array(weights, name="W"), - onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="gamma"), - onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="beta"), - onnx.numpy_helper.from_array( - np.random.randn(64).astype(np.float32), name="input_mean" - ), - onnx.numpy_helper.from_array( - np.abs(np.random.randn(64)).astype(np.float32), name="input_var" - ), + *self._create_batchnorm_params(size=64), ] if gemm_bias: initializers.append( @@ -233,6 +228,30 @@ def test_fuse_batchnorm_non_initializers(self): # No changes were applied self.assertEqual(count, 0) + def test_fuse_batchnorm_graph_inputs(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X, float[64, 32, 3, 3] W) => (float [N, ?, ?, ?] Y) + { + X1 = Conv(X, W) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + } + """) + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(64, 32, 3, 3).astype(np.float32), name="W" + ), + *self._create_batchnorm_params(size=64), + ] + model_proto.graph.initializer.extend(initializers) + onnx.checker.check_model(model_proto, True) + + model = ir.serde.deserialize_model(model_proto) + count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + + # No changes were applied as W is a graph input + self.assertEqual(count, 0) + if __name__ == "__main__": unittest.main()