Skip to content

Commit 8496b55

Browse files
authored
[pt2e] Fix QAT annotations for special qspecs (#3337)
**Summary:** In pt2e QAT, we first annotate the nodes to be quantized and then perform the pattern replacement (i.e. QAT fusion) in the prepare step. After this pattern replacement, old nodes are replaced with new nodes, and any references to the old nodes must also be updated to refer to the new nodes. This commit fixes a bug where, for special qspecs like the `SharedQuantizationSpec` and `DerivedQuantizationSpec`, we only update the values of a node's `input_qspec_map`, not the keys. As a result, the keys still refer to the old nodes that do not exist anymore after the QAT fusion. **Test Plan:** ``` python test/quantization/pt2e/test_quantize_pt2e_qat.py -k test_qat_shared_qspec ```
1 parent ff0e461 commit 8496b55

File tree

2 files changed

+159
-2
lines changed

2 files changed

+159
-2
lines changed

test/quantization/pt2e/test_quantize_pt2e_qat.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import copy
99
import operator
1010
import unittest
11-
from typing import Any, Optional
11+
from typing import Any, Optional, Tuple
1212

1313
import torch
1414
from torch.ao.quantization import QConfigMapping
@@ -46,6 +46,7 @@
4646
QuantizationAnnotation,
4747
QuantizationSpec,
4848
Quantizer,
49+
SharedQuantizationSpec,
4950
)
5051
from torchao.testing.pt2e._xnnpack_quantizer import (
5152
XNNPACKQuantizer,
@@ -878,6 +879,57 @@ class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base):
878879
conv_transpose_class = torch.nn.ConvTranspose2d
879880
bn_class = torch.nn.BatchNorm2d
880881

882+
def test_qat_shared_qspec(self):
883+
"""
884+
Test that nodes used in the keys of `input_qspec_map` refer to the
885+
new nodes after QAT fusion, not the old nodes that no longer exist.
886+
"""
887+
m = DoubleConvBnModel()
888+
example_inputs = (torch.randn(1, 3, 5, 5),)
889+
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
890+
old_nodes = set(m.graph.nodes)
891+
m = prepare_qat_pt2e(m, DoubleConvBnQuantizer())
892+
new_nodes = set(m.graph.nodes)
893+
old_nodes = old_nodes - new_nodes
894+
assert old_nodes.isdisjoint(new_nodes), "bad test setup"
895+
assert len(old_nodes) == 4, (
896+
f"bad test setup, old nodes should have 2 convs and 2 bns: {old_nodes}"
897+
)
898+
899+
# first, gather a list of nodes to check from input and output qspecs
900+
nodes_to_check = set()
901+
for n in m.graph.nodes:
902+
annotations = n.meta.get("quantization_annotation")
903+
if annotations is None:
904+
continue
905+
nodes_to_check.update(list(annotations.input_qspec_map.keys()))
906+
for qspec in list(annotations.input_qspec_map.values()) + [
907+
annotations.output_qspec
908+
]:
909+
if isinstance(qspec, SharedQuantizationSpec):
910+
if isinstance(qspec.edge_or_node, torch.fx.Node):
911+
nodes_to_check.add(qspec.edge_or_node)
912+
else:
913+
(src, dest) = qspec.edge_or_node
914+
nodes_to_check.update([src, dest])
915+
916+
# assert that none of the nodes refer to old nodes
917+
self.assertEqual(len(nodes_to_check), 5)
918+
num_batch_norm_nodes_checked = 0
919+
for n in nodes_to_check:
920+
if n.target == torch.ops.aten.batch_norm.default:
921+
num_batch_norm_nodes_checked += 1
922+
self.assertTrue(
923+
n not in old_nodes,
924+
f"found old node {n} in qspec, old nodes: {old_nodes}",
925+
)
926+
self.assertTrue(
927+
n in new_nodes, f"found node {n} in qspec not in new nodes: {new_nodes}"
928+
)
929+
assert num_batch_norm_nodes_checked == 2, (
930+
f"bad test setup, didn't check 2 bns, only checked these: {nodes_to_check}"
931+
)
932+
881933

882934
def _is_conv_node(n: torch.fx.Node):
883935
return n.op == "call_function" and n.target in [
@@ -913,6 +965,107 @@ def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule):
913965
return (conv_node, bn_node, getitem_node)
914966

915967

968+
class DoubleConvBnModel(torch.nn.Module):
969+
def __init__(self):
970+
super().__init__()
971+
self.conv1 = torch.nn.Conv2d(3, 3, 3, bias=False)
972+
self.bn1 = torch.nn.BatchNorm2d(3)
973+
self.conv2 = torch.nn.Conv2d(3, 3, 3, bias=False)
974+
self.bn2 = torch.nn.BatchNorm2d(3)
975+
976+
def forward(self, x):
977+
x1 = self.conv1(x)
978+
x1 = self.bn1(x1)
979+
x2 = self.conv2(x)
980+
x2 = self.bn2(x2)
981+
return torch.cat((x1, x2))
982+
983+
984+
class DoubleConvBnQuantizer(Quantizer):
985+
"""
986+
Dummy quantizer that a model with double conv-bn, followed by a torch.cat
987+
of the two conv-bns.
988+
"""
989+
990+
def __init__(self):
991+
super().__init__()
992+
self.act_qspec = QuantizationSpec(
993+
dtype=torch.uint8,
994+
quant_min=0,
995+
quant_max=255,
996+
qscheme=torch.per_tensor_affine,
997+
observer_or_fake_quant_ctr=default_fake_quant,
998+
)
999+
self.weight_qspec = QuantizationSpec(
1000+
dtype=torch.uint8,
1001+
quant_min=0,
1002+
quant_max=255,
1003+
qscheme=torch.per_tensor_affine,
1004+
observer_or_fake_quant_ctr=default_fake_quant,
1005+
)
1006+
1007+
def _get_all_nodes(self, model: torch.nn.Module) -> Tuple:
1008+
"""
1009+
Return a 5-tuple of (conv1, bn1, conv2, bn2, cat) nodes.
1010+
"""
1011+
conv1, bn1, conv2, bn2, cat = None, None, None, None, None
1012+
for n in model.graph.nodes:
1013+
if _is_conv_node(n):
1014+
if conv1 is None:
1015+
conv1 = n
1016+
else:
1017+
conv2 = n
1018+
if n.target == torch.ops.aten.batch_norm.default:
1019+
if bn1 is None:
1020+
bn1 = n
1021+
else:
1022+
bn2 = n
1023+
if n.target == torch.ops.aten.cat.default:
1024+
cat = n
1025+
assert conv1 is not None and bn1 is not None, "bad test setup"
1026+
assert conv2 is not None and bn2 is not None, "bad test setup"
1027+
assert cat is not None, "bad test setup"
1028+
return (conv1, bn1, conv2, bn2, cat)
1029+
1030+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1031+
(conv1, bn1, conv2, bn2, cat) = self._get_all_nodes(model)
1032+
conv1.meta["quantization_annotation"] = QuantizationAnnotation(
1033+
input_qspec_map={
1034+
conv1.args[0]: self.act_qspec,
1035+
conv1.args[1]: self.weight_qspec,
1036+
},
1037+
_annotated=True,
1038+
)
1039+
bn1.meta["quantization_annotation"] = QuantizationAnnotation(
1040+
output_qspec=self.act_qspec,
1041+
_annotated=True,
1042+
)
1043+
1044+
conv2.meta["quantization_annotation"] = QuantizationAnnotation(
1045+
input_qspec_map={
1046+
conv2.args[0]: self.act_qspec,
1047+
conv2.args[1]: self.weight_qspec,
1048+
},
1049+
_annotated=True,
1050+
)
1051+
bn2.meta["quantization_annotation"] = QuantizationAnnotation(
1052+
output_qspec=self.act_qspec,
1053+
_annotated=True,
1054+
)
1055+
cat.meta["quantization_annotation"] = QuantizationAnnotation(
1056+
input_qspec_map={
1057+
bn1: SharedQuantizationSpec(bn1),
1058+
bn2: SharedQuantizationSpec(bn2),
1059+
},
1060+
output_qspec=self.act_qspec,
1061+
_annotated=True,
1062+
)
1063+
return model
1064+
1065+
def validate(self, model: torch.fx.GraphModule):
1066+
pass
1067+
1068+
9161069
class ConvBnInt32WeightQuantizer(Quantizer):
9171070
"""
9181071
Dummy quantizer that annotates conv bn in such a way that the weights

torchao/quantization/pt2e/qat_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,12 @@ def _get_new_qspec(qspec: QuantizationSpecBase):
594594
if "quantization_annotation" not in node.meta:
595595
return
596596
annotation = node.meta["quantization_annotation"]
597+
# Update both keys and values of input_qspec_map
598+
new_input_qspec_map = {}
597599
for input_node, qspec in annotation.input_qspec_map.items():
598-
annotation.input_qspec_map[input_node] = _get_new_qspec(qspec)
600+
new_input_node = original_to_replacement_node.get(input_node, input_node)
601+
new_input_qspec_map[new_input_node] = _get_new_qspec(qspec)
602+
annotation.input_qspec_map = new_input_qspec_map
599603
annotation.output_qspec = _get_new_qspec(annotation.output_qspec)
600604

601605

0 commit comments

Comments
 (0)