Skip to content

Commit 8a16e12

Browse files
authored
Implement is_qat in TorchVision (#5299)
* Add is_qat support using a method getter * Switch to an internal _fuse_modules * Fix linter. * Pass is_qat=False on PTQ * Fix bug on ra_sampler flag. * Set is_qat=True for QAT
1 parent 61a52b9 commit 8a16e12

File tree

12 files changed

+64
-55
lines changed

12 files changed

+64
-55
lines changed

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def load_data(traindir, valdir, args):
178178

179179
print("Creating data loaders")
180180
if args.distributed:
181-
if args.ra_sampler:
181+
if hasattr(args, "ra_sampler") and args.ra_sampler:
182182
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
183183
else:
184184
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)

references/classification/train_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main(args):
6363
model.to(device)
6464

6565
if not (args.test_only or args.post_training_quantize):
66-
model.fuse_model()
66+
model.fuse_model(is_qat=True)
6767
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
6868
torch.ao.quantization.prepare_qat(model, inplace=True)
6969

@@ -97,7 +97,7 @@ def main(args):
9797
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
9898
)
9999
model.eval()
100-
model.fuse_model()
100+
model.fuse_model(is_qat=False)
101101
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
102102
torch.ao.quantization.prepare(model, inplace=True)
103103
# Calibrate first

references/classification/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
344344
345345
# Quantized Classification
346346
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
347-
model.fuse_model()
347+
model.fuse_model(is_qat=True)
348348
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
349349
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
350350
print(store_model_weights(model, './qat.pth'))

test/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def test_quantized_classification_model(model_fn):
833833
model.train()
834834
model.qconfig = torch.ao.quantization.default_qat_qconfig
835835

836-
model.fuse_model()
836+
model.fuse_model(is_qat=not eval_mode)
837837
if eval_mode:
838838
torch.ao.quantization.prepare(model, inplace=True)
839839
else:

torchvision/models/quantization/googlenet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any
2+
from typing import Any, Optional
33

44
import torch
55
import torch.nn as nn
@@ -8,7 +8,7 @@
88
from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls
99

1010
from ..._internally_replaced_utils import load_state_dict_from_url
11-
from .utils import _replace_relu, quantize_model
11+
from .utils import _fuse_modules, _replace_relu, quantize_model
1212

1313

1414
__all__ = ["QuantizableGoogLeNet", "googlenet"]
@@ -30,8 +30,8 @@ def forward(self, x: Tensor) -> Tensor:
3030
x = self.relu(x)
3131
return x
3232

33-
def fuse_model(self) -> None:
34-
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
33+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
34+
_fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True)
3535

3636

3737
class QuantizableInception(Inception):
@@ -90,7 +90,7 @@ def forward(self, x: Tensor) -> GoogLeNetOutputs:
9090
else:
9191
return self.eager_outputs(x, aux2, aux1)
9292

93-
def fuse_model(self) -> None:
93+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
9494
r"""Fuse conv/bn/relu modules in googlenet model
9595
9696
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
@@ -100,7 +100,7 @@ def fuse_model(self) -> None:
100100

101101
for m in self.modules():
102102
if type(m) is QuantizableBasicConv2d:
103-
m.fuse_model()
103+
m.fuse_model(is_qat)
104104

105105

106106
def googlenet(

torchvision/models/quantization/inception.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, List
2+
from typing import Any, List, Optional
33

44
import torch
55
import torch.nn as nn
@@ -9,7 +9,7 @@
99
from torchvision.models.inception import InceptionOutputs
1010

1111
from ..._internally_replaced_utils import load_state_dict_from_url
12-
from .utils import _replace_relu, quantize_model
12+
from .utils import _fuse_modules, _replace_relu, quantize_model
1313

1414

1515
__all__ = [
@@ -35,8 +35,8 @@ def forward(self, x: Tensor) -> Tensor:
3535
x = self.relu(x)
3636
return x
3737

38-
def fuse_model(self) -> None:
39-
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
38+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
39+
_fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True)
4040

4141

4242
class QuantizableInceptionA(inception_module.InceptionA):
@@ -160,7 +160,7 @@ def forward(self, x: Tensor) -> InceptionOutputs:
160160
else:
161161
return self.eager_outputs(x, aux)
162162

163-
def fuse_model(self) -> None:
163+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
164164
r"""Fuse conv/bn/relu modules in inception model
165165
166166
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
@@ -170,7 +170,7 @@ def fuse_model(self) -> None:
170170

171171
for m in self.modules():
172172
if type(m) is QuantizableBasicConv2d:
173-
m.fuse_model()
173+
m.fuse_model(is_qat)
174174

175175

176176
def inception_v3(

torchvision/models/quantization/mobilenetv2.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

33
from torch import Tensor
44
from torch import nn
5-
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
5+
from torch.ao.quantization import QuantStub, DeQuantStub
66
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
77

88
from ..._internally_replaced_utils import load_state_dict_from_url
99
from ...ops.misc import ConvNormActivation
10-
from .utils import _replace_relu, quantize_model
10+
from .utils import _fuse_modules, _replace_relu, quantize_model
1111

1212

1313
__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"]
@@ -28,10 +28,10 @@ def forward(self, x: Tensor) -> Tensor:
2828
else:
2929
return self.conv(x)
3030

31-
def fuse_model(self) -> None:
31+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
3232
for idx in range(len(self.conv)):
3333
if type(self.conv[idx]) is nn.Conv2d:
34-
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)
34+
_fuse_modules(self.conv, [str(idx), str(idx + 1)], is_qat, inplace=True)
3535

3636

3737
class QuantizableMobileNetV2(MobileNetV2):
@@ -52,12 +52,12 @@ def forward(self, x: Tensor) -> Tensor:
5252
x = self.dequant(x)
5353
return x
5454

55-
def fuse_model(self) -> None:
55+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
5656
for m in self.modules():
5757
if type(m) is ConvNormActivation:
58-
fuse_modules(m, ["0", "1", "2"], inplace=True)
58+
_fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True)
5959
if type(m) is QuantizableInvertedResidual:
60-
m.fuse_model()
60+
m.fuse_model(is_qat)
6161

6262

6363
def mobilenet_v2(

torchvision/models/quantization/mobilenetv3.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import torch
44
from torch import nn, Tensor
5-
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
5+
from torch.ao.quantization import QuantStub, DeQuantStub
66

77
from ..._internally_replaced_utils import load_state_dict_from_url
88
from ...ops.misc import ConvNormActivation, SqueezeExcitation
99
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf
10-
from .utils import _replace_relu
10+
from .utils import _fuse_modules, _replace_relu
1111

1212

1313
__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"]
@@ -28,8 +28,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
2828
def forward(self, input: Tensor) -> Tensor:
2929
return self.skip_mul.mul(self._scale(input), input)
3030

31-
def fuse_model(self) -> None:
32-
fuse_modules(self, ["fc1", "activation"], inplace=True)
31+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
32+
_fuse_modules(self, ["fc1", "activation"], is_qat, inplace=True)
3333

3434
def _load_from_state_dict(
3535
self,
@@ -101,15 +101,15 @@ def forward(self, x: Tensor) -> Tensor:
101101
x = self.dequant(x)
102102
return x
103103

104-
def fuse_model(self) -> None:
104+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
105105
for m in self.modules():
106106
if type(m) is ConvNormActivation:
107107
modules_to_fuse = ["0", "1"]
108108
if len(m) == 3 and type(m[2]) is nn.ReLU:
109109
modules_to_fuse.append("2")
110-
fuse_modules(m, modules_to_fuse, inplace=True)
110+
_fuse_modules(m, modules_to_fuse, is_qat, inplace=True)
111111
elif type(m) is QuantizableSqueezeExcitation:
112-
m.fuse_model()
112+
m.fuse_model(is_qat)
113113

114114

115115
def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None:
@@ -135,7 +135,7 @@ def _mobilenet_v3_model(
135135
if quantize:
136136
backend = "qnnpack"
137137

138-
model.fuse_model()
138+
model.fuse_model(is_qat=True)
139139
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
140140
torch.ao.quantization.prepare_qat(model, inplace=True)
141141

torchvision/models/quantization/resnet.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from typing import Any, Type, Union, List
1+
from typing import Any, Type, Union, List, Optional
22

33
import torch
44
import torch.nn as nn
55
from torch import Tensor
6-
from torch.ao.quantization import fuse_modules
76
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
87

98
from ..._internally_replaced_utils import load_state_dict_from_url
10-
from .utils import _replace_relu, quantize_model
9+
from .utils import _fuse_modules, _replace_relu, quantize_model
1110

1211
__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"]
1312

@@ -41,10 +40,10 @@ def forward(self, x: Tensor) -> Tensor:
4140

4241
return out
4342

44-
def fuse_model(self) -> None:
45-
torch.ao.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
43+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
44+
_fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
4645
if self.downsample:
47-
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
46+
_fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
4847

4948

5049
class QuantizableBottleneck(Bottleneck):
@@ -72,10 +71,12 @@ def forward(self, x: Tensor) -> Tensor:
7271

7372
return out
7473

75-
def fuse_model(self) -> None:
76-
fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True)
74+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
75+
_fuse_modules(
76+
self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
77+
)
7778
if self.downsample:
78-
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
79+
_fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
7980

8081

8182
class QuantizableResNet(ResNet):
@@ -94,18 +95,17 @@ def forward(self, x: Tensor) -> Tensor:
9495
x = self.dequant(x)
9596
return x
9697

97-
def fuse_model(self) -> None:
98+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
9899
r"""Fuse conv/bn/relu modules in resnet models
99100
100101
Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
101102
Model is modified in place. Note that this operation does not change numerics
102103
and the model after modification is in floating point
103104
"""
104-
105-
fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True)
105+
_fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
106106
for m in self.modules():
107107
if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
108-
m.fuse_model()
108+
m.fuse_model(is_qat)
109109

110110

111111
def _resnet(

torchvision/models/quantization/shufflenetv2.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torchvision.models import shufflenetv2
77

88
from ..._internally_replaced_utils import load_state_dict_from_url
9-
from .utils import _replace_relu, quantize_model
9+
from .utils import _fuse_modules, _replace_relu, quantize_model
1010

1111
__all__ = [
1212
"QuantizableShuffleNetV2",
@@ -50,24 +50,24 @@ def forward(self, x: Tensor) -> Tensor:
5050
x = self.dequant(x)
5151
return x
5252

53-
def fuse_model(self) -> None:
53+
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
5454
r"""Fuse conv/bn/relu modules in shufflenetv2 model
5555
5656
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
5757
Model is modified in place. Note that this operation does not change numerics
5858
and the model after modification is in floating point
5959
"""
60-
6160
for name, m in self._modules.items():
62-
if name in ["conv1", "conv5"]:
63-
torch.ao.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
61+
if name in ["conv1", "conv5"] and m is not None:
62+
_fuse_modules(m, [["0", "1", "2"]], is_qat, inplace=True)
6463
for m in self.modules():
6564
if type(m) is QuantizableInvertedResidual:
6665
if len(m.branch1._modules.items()) > 0:
67-
torch.ao.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
68-
torch.ao.quantization.fuse_modules(
66+
_fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], is_qat, inplace=True)
67+
_fuse_modules(
6968
m.branch2,
7069
[["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
70+
is_qat,
7171
inplace=True,
7272
)
7373

torchvision/models/quantization/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, List, Optional, Union
2+
13
import torch
24
from torch import nn
35

@@ -39,4 +41,11 @@ def quantize_model(model: nn.Module, backend: str) -> None:
3941
model(_dummy_input_data)
4042
torch.ao.quantization.convert(model, inplace=True)
4143

42-
return
44+
45+
def _fuse_modules(
46+
model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any
47+
):
48+
if is_qat is None:
49+
is_qat = model.training
50+
method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
51+
return method(model, modules_to_fuse, **kwargs)

torchvision/prototype/models/quantization/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _mobilenet_v3_model(
4242
_replace_relu(model)
4343

4444
if quantize:
45-
model.fuse_model()
45+
model.fuse_model(is_qat=True)
4646
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
4747
torch.ao.quantization.prepare_qat(model, inplace=True)
4848

0 commit comments

Comments
 (0)