|
8 | 8 | import copy |
9 | 9 | import operator |
10 | 10 | import unittest |
11 | | -from typing import Any, Optional |
| 11 | +from typing import Any, Optional, Tuple |
12 | 12 |
|
13 | 13 | import torch |
14 | 14 | from torch.ao.quantization import QConfigMapping |
|
46 | 46 | QuantizationAnnotation, |
47 | 47 | QuantizationSpec, |
48 | 48 | Quantizer, |
| 49 | + SharedQuantizationSpec, |
49 | 50 | ) |
50 | 51 | from torchao.testing.pt2e._xnnpack_quantizer import ( |
51 | 52 | XNNPACKQuantizer, |
@@ -878,6 +879,57 @@ class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base): |
878 | 879 | conv_transpose_class = torch.nn.ConvTranspose2d |
879 | 880 | bn_class = torch.nn.BatchNorm2d |
880 | 881 |
|
| 882 | + def test_qat_shared_qspec(self): |
| 883 | + """ |
| 884 | + Test that nodes used in the keys of `input_qspec_map` refer to the |
| 885 | + new nodes after QAT fusion, not the old nodes that no longer exist. |
| 886 | + """ |
| 887 | + m = DoubleConvBnModel() |
| 888 | + example_inputs = (torch.randn(1, 3, 5, 5),) |
| 889 | + m = torch.export.export_for_training(m, example_inputs, strict=True).module() |
| 890 | + old_nodes = set(m.graph.nodes) |
| 891 | + m = prepare_qat_pt2e(m, DoubleConvBnQuantizer()) |
| 892 | + new_nodes = set(m.graph.nodes) |
| 893 | + old_nodes = old_nodes - new_nodes |
| 894 | + assert old_nodes.isdisjoint(new_nodes), "bad test setup" |
| 895 | + assert len(old_nodes) == 4, ( |
| 896 | + f"bad test setup, old nodes should have 2 convs and 2 bns: {old_nodes}" |
| 897 | + ) |
| 898 | + |
| 899 | + # first, gather a list of nodes to check from input and output qspecs |
| 900 | + nodes_to_check = set() |
| 901 | + for n in m.graph.nodes: |
| 902 | + annotations = n.meta.get("quantization_annotation") |
| 903 | + if annotations is None: |
| 904 | + continue |
| 905 | + nodes_to_check.update(list(annotations.input_qspec_map.keys())) |
| 906 | + for qspec in list(annotations.input_qspec_map.values()) + [ |
| 907 | + annotations.output_qspec |
| 908 | + ]: |
| 909 | + if isinstance(qspec, SharedQuantizationSpec): |
| 910 | + if isinstance(qspec.edge_or_node, torch.fx.Node): |
| 911 | + nodes_to_check.add(qspec.edge_or_node) |
| 912 | + else: |
| 913 | + (src, dest) = qspec.edge_or_node |
| 914 | + nodes_to_check.update([src, dest]) |
| 915 | + |
| 916 | + # assert that none of the nodes refer to old nodes |
| 917 | + self.assertEqual(len(nodes_to_check), 5) |
| 918 | + num_batch_norm_nodes_checked = 0 |
| 919 | + for n in nodes_to_check: |
| 920 | + if n.target == torch.ops.aten.batch_norm.default: |
| 921 | + num_batch_norm_nodes_checked += 1 |
| 922 | + self.assertTrue( |
| 923 | + n not in old_nodes, |
| 924 | + f"found old node {n} in qspec, old nodes: {old_nodes}", |
| 925 | + ) |
| 926 | + self.assertTrue( |
| 927 | + n in new_nodes, f"found node {n} in qspec not in new nodes: {new_nodes}" |
| 928 | + ) |
| 929 | + assert num_batch_norm_nodes_checked == 2, ( |
| 930 | + f"bad test setup, didn't check 2 bns, only checked these: {nodes_to_check}" |
| 931 | + ) |
| 932 | + |
881 | 933 |
|
882 | 934 | def _is_conv_node(n: torch.fx.Node): |
883 | 935 | return n.op == "call_function" and n.target in [ |
@@ -913,6 +965,107 @@ def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule): |
913 | 965 | return (conv_node, bn_node, getitem_node) |
914 | 966 |
|
915 | 967 |
|
| 968 | +class DoubleConvBnModel(torch.nn.Module): |
| 969 | + def __init__(self): |
| 970 | + super().__init__() |
| 971 | + self.conv1 = torch.nn.Conv2d(3, 3, 3, bias=False) |
| 972 | + self.bn1 = torch.nn.BatchNorm2d(3) |
| 973 | + self.conv2 = torch.nn.Conv2d(3, 3, 3, bias=False) |
| 974 | + self.bn2 = torch.nn.BatchNorm2d(3) |
| 975 | + |
| 976 | + def forward(self, x): |
| 977 | + x1 = self.conv1(x) |
| 978 | + x1 = self.bn1(x1) |
| 979 | + x2 = self.conv2(x) |
| 980 | + x2 = self.bn2(x2) |
| 981 | + return torch.cat((x1, x2)) |
| 982 | + |
| 983 | + |
| 984 | +class DoubleConvBnQuantizer(Quantizer): |
| 985 | + """ |
| 986 | + Dummy quantizer that a model with double conv-bn, followed by a torch.cat |
| 987 | + of the two conv-bns. |
| 988 | + """ |
| 989 | + |
| 990 | + def __init__(self): |
| 991 | + super().__init__() |
| 992 | + self.act_qspec = QuantizationSpec( |
| 993 | + dtype=torch.uint8, |
| 994 | + quant_min=0, |
| 995 | + quant_max=255, |
| 996 | + qscheme=torch.per_tensor_affine, |
| 997 | + observer_or_fake_quant_ctr=default_fake_quant, |
| 998 | + ) |
| 999 | + self.weight_qspec = QuantizationSpec( |
| 1000 | + dtype=torch.uint8, |
| 1001 | + quant_min=0, |
| 1002 | + quant_max=255, |
| 1003 | + qscheme=torch.per_tensor_affine, |
| 1004 | + observer_or_fake_quant_ctr=default_fake_quant, |
| 1005 | + ) |
| 1006 | + |
| 1007 | + def _get_all_nodes(self, model: torch.nn.Module) -> Tuple: |
| 1008 | + """ |
| 1009 | + Return a 5-tuple of (conv1, bn1, conv2, bn2, cat) nodes. |
| 1010 | + """ |
| 1011 | + conv1, bn1, conv2, bn2, cat = None, None, None, None, None |
| 1012 | + for n in model.graph.nodes: |
| 1013 | + if _is_conv_node(n): |
| 1014 | + if conv1 is None: |
| 1015 | + conv1 = n |
| 1016 | + else: |
| 1017 | + conv2 = n |
| 1018 | + if n.target == torch.ops.aten.batch_norm.default: |
| 1019 | + if bn1 is None: |
| 1020 | + bn1 = n |
| 1021 | + else: |
| 1022 | + bn2 = n |
| 1023 | + if n.target == torch.ops.aten.cat.default: |
| 1024 | + cat = n |
| 1025 | + assert conv1 is not None and bn1 is not None, "bad test setup" |
| 1026 | + assert conv2 is not None and bn2 is not None, "bad test setup" |
| 1027 | + assert cat is not None, "bad test setup" |
| 1028 | + return (conv1, bn1, conv2, bn2, cat) |
| 1029 | + |
| 1030 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 1031 | + (conv1, bn1, conv2, bn2, cat) = self._get_all_nodes(model) |
| 1032 | + conv1.meta["quantization_annotation"] = QuantizationAnnotation( |
| 1033 | + input_qspec_map={ |
| 1034 | + conv1.args[0]: self.act_qspec, |
| 1035 | + conv1.args[1]: self.weight_qspec, |
| 1036 | + }, |
| 1037 | + _annotated=True, |
| 1038 | + ) |
| 1039 | + bn1.meta["quantization_annotation"] = QuantizationAnnotation( |
| 1040 | + output_qspec=self.act_qspec, |
| 1041 | + _annotated=True, |
| 1042 | + ) |
| 1043 | + |
| 1044 | + conv2.meta["quantization_annotation"] = QuantizationAnnotation( |
| 1045 | + input_qspec_map={ |
| 1046 | + conv2.args[0]: self.act_qspec, |
| 1047 | + conv2.args[1]: self.weight_qspec, |
| 1048 | + }, |
| 1049 | + _annotated=True, |
| 1050 | + ) |
| 1051 | + bn2.meta["quantization_annotation"] = QuantizationAnnotation( |
| 1052 | + output_qspec=self.act_qspec, |
| 1053 | + _annotated=True, |
| 1054 | + ) |
| 1055 | + cat.meta["quantization_annotation"] = QuantizationAnnotation( |
| 1056 | + input_qspec_map={ |
| 1057 | + bn1: SharedQuantizationSpec(bn1), |
| 1058 | + bn2: SharedQuantizationSpec(bn2), |
| 1059 | + }, |
| 1060 | + output_qspec=self.act_qspec, |
| 1061 | + _annotated=True, |
| 1062 | + ) |
| 1063 | + return model |
| 1064 | + |
| 1065 | + def validate(self, model: torch.fx.GraphModule): |
| 1066 | + pass |
| 1067 | + |
| 1068 | + |
916 | 1069 | class ConvBnInt32WeightQuantizer(Quantizer): |
917 | 1070 | """ |
918 | 1071 | Dummy quantizer that annotates conv bn in such a way that the weights |
|
0 commit comments