Skip to content

Commit 9473060

Browse files
authored
Fix batch norm folding in prepare_pt2e for multiple conv->BN chains sharing the same conv weights (#2795)
* Fix BN folding in for multiple conv->BN chains sharing the same conv weights * Fix variable names and format --------- Co-authored-by: Subhankar Pal <[email protected]>
1 parent 72b35bf commit 9473060

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from torchao.quantization.pt2e.quantizer.embedding_quantizer import ( # noqa: F811
5858
EmbeddingQuantizer,
5959
)
60+
from torchao.testing.model_architectures import ConvWithSharedWeightInExportedModel
6061
from torchao.testing.pt2e._xnnpack_quantizer import (
6162
XNNPACKQuantizer,
6263
get_symmetric_quantization_config,
@@ -150,6 +151,34 @@ def validate(self, model: torch.fx.GraphModule) -> None:
150151
node_list,
151152
)
152153

154+
def test_chunked_bn_fusion(self):
155+
batch_size = 1
156+
n_chunks = 3
157+
in_channels = 1
158+
out_channels = 32
159+
m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels)
160+
m.bn.running_var = torch.nn.Parameter(
161+
torch.rand(out_channels) * 1e-2, requires_grad=False
162+
)
163+
164+
m.eval()
165+
example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),)
166+
ref_outputs = m(*example_inputs)
167+
traced_model = torch.export.export(m, example_inputs, strict=True).module()
168+
traced_outputs = traced_model(*example_inputs)
169+
prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer())
170+
prepared_outputs = prepared_model(*example_inputs)
171+
172+
if isinstance(ref_outputs, (tuple, list)):
173+
for ref, prepared, traced in zip(
174+
ref_outputs, prepared_outputs, traced_outputs
175+
):
176+
torch.testing.assert_close(ref, traced)
177+
torch.testing.assert_close(traced, prepared)
178+
else:
179+
torch.testing.assert_close(ref_outputs, traced_outputs)
180+
torch.testing.assert_close(traced_outputs, prepared_outputs)
181+
153182
def test_wo_annotate_conv_output_quantizer(self):
154183
# TODO: use OP_TO_ANNOTATOR
155184
class BackendAQuantizer(Quantizer):

torchao/quantization/pt2e/utils.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ def fold_bn_weights_into_conv_node(
671671
conv_bias_node: Optional[Node],
672672
bn_node: Node,
673673
m: GraphModule,
674+
fake_fuse: bool = False, # removes the BN nodes but doesn't change the conv weights
674675
) -> None:
675676
# conv args: input, weight, bias, stride, padding, dilation, ...
676677
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
@@ -703,6 +704,16 @@ def fold_bn_weights_into_conv_node(
703704
if len(conv_args) == 2:
704705
conv_args.append(None)
705706

707+
if fake_fuse:
708+
fused_weight, fused_bias = (
709+
torch.nn.Parameter(conv_w, conv_w.requires_grad),
710+
torch.nn.Parameter(conv_b, conv_b.requires_grad),
711+
)
712+
else:
713+
fused_weight, fused_bias = fuse_conv_bn_weights(
714+
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose
715+
)
716+
706717
# calling data since the fused_weight and fused_bias are nn.Parameter
707718
weight_attr_name = conv_weight_node.target
708719
assert isinstance(weight_attr_name, str)
@@ -767,6 +778,9 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
767778
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
768779
if not has_bn:
769780
return
781+
782+
# track which conv weights have been fused to avoid double fusing
783+
fused_convs_weight_nodes = set()
770784
for n in m.graph.nodes:
771785
if n.op != "call_function" or n.target not in (
772786
torch.ops.aten._native_batch_norm_legit_no_training.default,
@@ -781,9 +795,14 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
781795
conv_weight_node = conv_node.args[1]
782796
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
783797
fold_bn_weights_into_conv_node(
784-
conv_node, conv_weight_node, conv_bias_node, bn_node, m
798+
conv_node,
799+
conv_weight_node,
800+
conv_bias_node,
801+
bn_node,
802+
m,
803+
(conv_weight_node in fused_convs_weight_nodes),
785804
)
786-
805+
fused_convs_weight_nodes.add(conv_weight_node)
787806
m.graph.eliminate_dead_code()
788807
m.recompile()
789808

torchao/testing/model_architectures.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,27 @@ def forward(self, x):
2222
return x
2323

2424

25+
class ConvWithSharedWeightInExportedModel(nn.Module):
26+
def __init__(
27+
self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1
28+
) -> None:
29+
super().__init__()
30+
self.n_chunks = n_chunks
31+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
32+
self.bn = nn.BatchNorm2d(out_channels)
33+
self.relu = nn.ReLU(inplace=True)
34+
35+
def forward(self, x) -> torch.Tensor:
36+
chunks = torch.chunk(x, self.n_chunks, dim=1)
37+
outputs = []
38+
for chunk in chunks:
39+
out = self.conv(chunk)
40+
out = self.bn(out)
41+
out = self.relu(out)
42+
outputs.append(out)
43+
return torch.cat(outputs, dim=1)
44+
45+
2546
class LNLinearActivationModel(nn.Module):
2647
def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="sigmoid"):
2748
super().__init__()

0 commit comments

Comments
 (0)