1
- from typing import Any , Type , Union , List
1
+ from typing import Any , Type , Union , List , Optional
2
2
3
3
import torch
4
4
import torch .nn as nn
5
5
from torch import Tensor
6
- from torch .ao .quantization import fuse_modules
7
6
from torchvision .models .resnet import Bottleneck , BasicBlock , ResNet , model_urls
8
7
9
8
from ..._internally_replaced_utils import load_state_dict_from_url
10
- from .utils import _replace_relu , quantize_model
9
+ from .utils import _fuse_modules , _replace_relu , quantize_model
11
10
12
11
__all__ = ["QuantizableResNet" , "resnet18" , "resnet50" , "resnext101_32x8d" ]
13
12
@@ -41,10 +40,10 @@ def forward(self, x: Tensor) -> Tensor:
41
40
42
41
return out
43
42
44
- def fuse_model (self ) -> None :
45
- torch . ao . quantization . fuse_modules (self , [["conv1" , "bn1" , "relu" ], ["conv2" , "bn2" ]], inplace = True )
43
+ def fuse_model (self , is_qat : Optional [ bool ] = None ) -> None :
44
+ _fuse_modules (self , [["conv1" , "bn1" , "relu" ], ["conv2" , "bn2" ]], is_qat , inplace = True )
46
45
if self .downsample :
47
- torch . ao . quantization . fuse_modules (self .downsample , ["0" , "1" ], inplace = True )
46
+ _fuse_modules (self .downsample , ["0" , "1" ], is_qat , inplace = True )
48
47
49
48
50
49
class QuantizableBottleneck (Bottleneck ):
@@ -72,10 +71,12 @@ def forward(self, x: Tensor) -> Tensor:
72
71
73
72
return out
74
73
75
- def fuse_model (self ) -> None :
76
- fuse_modules (self , [["conv1" , "bn1" , "relu1" ], ["conv2" , "bn2" , "relu2" ], ["conv3" , "bn3" ]], inplace = True )
74
+ def fuse_model (self , is_qat : Optional [bool ] = None ) -> None :
75
+ _fuse_modules (
76
+ self , [["conv1" , "bn1" , "relu1" ], ["conv2" , "bn2" , "relu2" ], ["conv3" , "bn3" ]], is_qat , inplace = True
77
+ )
77
78
if self .downsample :
78
- torch . ao . quantization . fuse_modules (self .downsample , ["0" , "1" ], inplace = True )
79
+ _fuse_modules (self .downsample , ["0" , "1" ], is_qat , inplace = True )
79
80
80
81
81
82
class QuantizableResNet (ResNet ):
@@ -94,18 +95,17 @@ def forward(self, x: Tensor) -> Tensor:
94
95
x = self .dequant (x )
95
96
return x
96
97
97
- def fuse_model (self ) -> None :
98
+ def fuse_model (self , is_qat : Optional [ bool ] = None ) -> None :
98
99
r"""Fuse conv/bn/relu modules in resnet models
99
100
100
101
Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
101
102
Model is modified in place. Note that this operation does not change numerics
102
103
and the model after modification is in floating point
103
104
"""
104
-
105
- fuse_modules (self , ["conv1" , "bn1" , "relu" ], inplace = True )
105
+ _fuse_modules (self , ["conv1" , "bn1" , "relu" ], is_qat , inplace = True )
106
106
for m in self .modules ():
107
107
if type (m ) is QuantizableBottleneck or type (m ) is QuantizableBasicBlock :
108
- m .fuse_model ()
108
+ m .fuse_model (is_qat )
109
109
110
110
111
111
def _resnet (
0 commit comments