Skip to content

Commit f006431

Browse files
datumboxpmeier
authored andcommitted
[fbsync] Add typing annotations to models/quantization (#4232)
Summary: * fix * add typings * fixup some more types * Type more * remove mypy ignore * add missing typings * fix a few mypy errors * fix mypy errors * fix mypy * ignore types * fixup annotation * fix remaining types * cleanup #TODO comments Reviewed By: fmassa Differential Revision: D30793343 fbshipit-source-id: 0448f6f24f406827abc9e1825489c786b6f0eb11 Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 7786076 commit f006431

File tree

8 files changed

+236
-102
lines changed

8 files changed

+236
-102
lines changed

mypy.ini

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ ignore_errors=True
2020

2121
ignore_errors = True
2222

23-
[mypy-torchvision.models.quantization.*]
24-
25-
ignore_errors = True
26-
2723
[mypy-torchvision.ops.*]
2824

2925
ignore_errors = True

torchvision/models/quantization/googlenet.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torch
33
import torch.nn as nn
44
from torch.nn import functional as F
5+
from typing import Any
6+
from torch import Tensor
57

68
from ..._internally_replaced_utils import load_state_dict_from_url
79
from torchvision.models.googlenet import (
@@ -18,7 +20,13 @@
1820
}
1921

2022

21-
def googlenet(pretrained=False, progress=True, quantize=False, **kwargs):
23+
def googlenet(
24+
pretrained: bool = False,
25+
progress: bool = True,
26+
quantize: bool = False,
27+
**kwargs: Any,
28+
) -> "QuantizableGoogLeNet":
29+
2230
r"""GoogLeNet (Inception v1) model architecture from
2331
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
2432
@@ -70,48 +78,51 @@ def googlenet(pretrained=False, progress=True, quantize=False, **kwargs):
7078

7179
if not original_aux_logits:
7280
model.aux_logits = False
73-
model.aux1 = None
74-
model.aux2 = None
81+
model.aux1 = None # type: ignore[assignment]
82+
model.aux2 = None # type: ignore[assignment]
7583
return model
7684

7785

7886
class QuantizableBasicConv2d(BasicConv2d):
7987

80-
def __init__(self, *args, **kwargs):
88+
def __init__(self, *args: Any, **kwargs: Any) -> None:
8189
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
8290
self.relu = nn.ReLU()
8391

84-
def forward(self, x):
92+
def forward(self, x: Tensor) -> Tensor:
8593
x = self.conv(x)
8694
x = self.bn(x)
8795
x = self.relu(x)
8896
return x
8997

90-
def fuse_model(self):
98+
def fuse_model(self) -> None:
9199
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
92100

93101

94102
class QuantizableInception(Inception):
95103

96-
def __init__(self, *args, **kwargs):
97-
super(QuantizableInception, self).__init__(
104+
def __init__(self, *args: Any, **kwargs: Any) -> None:
105+
super(QuantizableInception, self).__init__( # type: ignore[misc]
98106
conv_block=QuantizableBasicConv2d, *args, **kwargs)
99107
self.cat = nn.quantized.FloatFunctional()
100108

101-
def forward(self, x):
109+
def forward(self, x: Tensor) -> Tensor:
102110
outputs = self._forward(x)
103111
return self.cat.cat(outputs, 1)
104112

105113

106114
class QuantizableInceptionAux(InceptionAux):
107-
108-
def __init__(self, *args, **kwargs):
109-
super(QuantizableInceptionAux, self).__init__(
110-
conv_block=QuantizableBasicConv2d, *args, **kwargs)
115+
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
116+
def __init__(self, *args: Any, **kwargs: Any) -> None:
117+
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
118+
conv_block=QuantizableBasicConv2d,
119+
*args,
120+
**kwargs
121+
)
111122
self.relu = nn.ReLU()
112123
self.dropout = nn.Dropout(0.7)
113124

114-
def forward(self, x):
125+
def forward(self, x: Tensor) -> Tensor:
115126
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
116127
x = F.adaptive_avg_pool2d(x, (4, 4))
117128
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
@@ -130,17 +141,17 @@ def forward(self, x):
130141

131142

132143
class QuantizableGoogLeNet(GoogLeNet):
133-
134-
def __init__(self, *args, **kwargs):
135-
super(QuantizableGoogLeNet, self).__init__(
144+
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
145+
def __init__(self, *args: Any, **kwargs: Any) -> None:
146+
super(QuantizableGoogLeNet, self).__init__( # type: ignore[misc]
136147
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux],
137148
*args,
138149
**kwargs
139150
)
140151
self.quant = torch.quantization.QuantStub()
141152
self.dequant = torch.quantization.DeQuantStub()
142153

143-
def forward(self, x):
154+
def forward(self, x: Tensor) -> GoogLeNetOutputs:
144155
x = self._transform_input(x)
145156
x = self.quant(x)
146157
x, aux1, aux2 = self._forward(x)
@@ -153,7 +164,7 @@ def forward(self, x):
153164
else:
154165
return self.eager_outputs(x, aux2, aux1)
155166

156-
def fuse_model(self):
167+
def fuse_model(self) -> None:
157168
r"""Fuse conv/bn/relu modules in googlenet model
158169
159170
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.

torchvision/models/quantization/inception.py

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import torch
44
import torch.nn as nn
55
import torch.nn.functional as F
6+
from torch import Tensor
7+
from typing import Any, List
8+
69
from torchvision.models import inception as inception_module
710
from torchvision.models.inception import InceptionOutputs
811
from ..._internally_replaced_utils import load_state_dict_from_url
@@ -22,7 +25,13 @@
2225
}
2326

2427

25-
def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
28+
def inception_v3(
29+
pretrained: bool = False,
30+
progress: bool = True,
31+
quantize: bool = False,
32+
**kwargs: Any,
33+
) -> "QuantizableInception3":
34+
2635
r"""Inception v3 model architecture from
2736
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
2837
@@ -84,68 +93,93 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
8493

8594

8695
class QuantizableBasicConv2d(inception_module.BasicConv2d):
87-
def __init__(self, *args, **kwargs):
96+
def __init__(self, *args: Any, **kwargs: Any) -> None:
8897
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
8998
self.relu = nn.ReLU()
9099

91-
def forward(self, x):
100+
def forward(self, x: Tensor) -> Tensor:
92101
x = self.conv(x)
93102
x = self.bn(x)
94103
x = self.relu(x)
95104
return x
96105

97-
def fuse_model(self):
106+
def fuse_model(self) -> None:
98107
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
99108

100109

101110
class QuantizableInceptionA(inception_module.InceptionA):
102-
def __init__(self, *args, **kwargs):
103-
super(QuantizableInceptionA, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
111+
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
112+
def __init__(self, *args: Any, **kwargs: Any) -> None:
113+
super(QuantizableInceptionA, self).__init__( # type: ignore[misc]
114+
conv_block=QuantizableBasicConv2d,
115+
*args,
116+
**kwargs
117+
)
104118
self.myop = nn.quantized.FloatFunctional()
105119

106-
def forward(self, x):
120+
def forward(self, x: Tensor) -> Tensor:
107121
outputs = self._forward(x)
108122
return self.myop.cat(outputs, 1)
109123

110124

111125
class QuantizableInceptionB(inception_module.InceptionB):
112-
def __init__(self, *args, **kwargs):
113-
super(QuantizableInceptionB, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
126+
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
127+
def __init__(self, *args: Any, **kwargs: Any) -> None:
128+
super(QuantizableInceptionB, self).__init__( # type: ignore[misc]
129+
conv_block=QuantizableBasicConv2d,
130+
*args,
131+
**kwargs
132+
)
114133
self.myop = nn.quantized.FloatFunctional()
115134

116-
def forward(self, x):
135+
def forward(self, x: Tensor) -> Tensor:
117136
outputs = self._forward(x)
118137
return self.myop.cat(outputs, 1)
119138

120139

121140
class QuantizableInceptionC(inception_module.InceptionC):
122-
def __init__(self, *args, **kwargs):
123-
super(QuantizableInceptionC, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
141+
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
142+
def __init__(self, *args: Any, **kwargs: Any) -> None:
143+
super(QuantizableInceptionC, self).__init__( # type: ignore[misc]
144+
conv_block=QuantizableBasicConv2d,
145+
*args,
146+
**kwargs
147+
)
124148
self.myop = nn.quantized.FloatFunctional()
125149

126-
def forward(self, x):
150+
def forward(self, x: Tensor) -> Tensor:
127151
outputs = self._forward(x)
128152
return self.myop.cat(outputs, 1)
129153

130154

131155
class QuantizableInceptionD(inception_module.InceptionD):
132-
def __init__(self, *args, **kwargs):
133-
super(QuantizableInceptionD, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
156+
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
157+
def __init__(self, *args: Any, **kwargs: Any) -> None:
158+
super(QuantizableInceptionD, self).__init__( # type: ignore[misc]
159+
conv_block=QuantizableBasicConv2d,
160+
*args,
161+
**kwargs
162+
)
134163
self.myop = nn.quantized.FloatFunctional()
135164

136-
def forward(self, x):
165+
def forward(self, x: Tensor) -> Tensor:
137166
outputs = self._forward(x)
138167
return self.myop.cat(outputs, 1)
139168

140169

141170
class QuantizableInceptionE(inception_module.InceptionE):
142-
def __init__(self, *args, **kwargs):
143-
super(QuantizableInceptionE, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
171+
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
172+
def __init__(self, *args: Any, **kwargs: Any) -> None:
173+
super(QuantizableInceptionE, self).__init__( # type: ignore[misc]
174+
conv_block=QuantizableBasicConv2d,
175+
*args,
176+
**kwargs
177+
)
144178
self.myop1 = nn.quantized.FloatFunctional()
145179
self.myop2 = nn.quantized.FloatFunctional()
146180
self.myop3 = nn.quantized.FloatFunctional()
147181

148-
def _forward(self, x):
182+
def _forward(self, x: Tensor) -> List[Tensor]:
149183
branch1x1 = self.branch1x1(x)
150184

151185
branch3x3 = self.branch3x3_1(x)
@@ -166,18 +200,28 @@ def _forward(self, x):
166200
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
167201
return outputs
168202

169-
def forward(self, x):
203+
def forward(self, x: Tensor) -> Tensor:
170204
outputs = self._forward(x)
171205
return self.myop3.cat(outputs, 1)
172206

173207

174208
class QuantizableInceptionAux(inception_module.InceptionAux):
175-
def __init__(self, *args, **kwargs):
176-
super(QuantizableInceptionAux, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
209+
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
210+
def __init__(self, *args: Any, **kwargs: Any) -> None:
211+
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
212+
conv_block=QuantizableBasicConv2d,
213+
*args,
214+
**kwargs
215+
)
177216

178217

179218
class QuantizableInception3(inception_module.Inception3):
180-
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
219+
def __init__(
220+
self,
221+
num_classes: int = 1000,
222+
aux_logits: bool = True,
223+
transform_input: bool = False,
224+
) -> None:
181225
super(QuantizableInception3, self).__init__(
182226
num_classes=num_classes,
183227
aux_logits=aux_logits,
@@ -195,7 +239,7 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
195239
self.quant = torch.quantization.QuantStub()
196240
self.dequant = torch.quantization.DeQuantStub()
197241

198-
def forward(self, x):
242+
def forward(self, x: Tensor) -> InceptionOutputs:
199243
x = self._transform_input(x)
200244
x = self.quant(x)
201245
x, aux = self._forward(x)
@@ -208,7 +252,7 @@ def forward(self, x):
208252
else:
209253
return self.eager_outputs(x, aux)
210254

211-
def fuse_model(self):
255+
def fuse_model(self) -> None:
212256
r"""Fuse conv/bn/relu modules in inception model
213257
214258
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.

torchvision/models/quantization/mobilenetv2.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from torch import nn
2+
from torch import Tensor
3+
24
from ..._internally_replaced_utils import load_state_dict_from_url
5+
6+
from typing import Any
7+
38
from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls
49
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
510
from .utils import _replace_relu, quantize_model
@@ -14,24 +19,24 @@
1419

1520

1621
class QuantizableInvertedResidual(InvertedResidual):
17-
def __init__(self, *args, **kwargs):
22+
def __init__(self, *args: Any, **kwargs: Any) -> None:
1823
super(QuantizableInvertedResidual, self).__init__(*args, **kwargs)
1924
self.skip_add = nn.quantized.FloatFunctional()
2025

21-
def forward(self, x):
26+
def forward(self, x: Tensor) -> Tensor:
2227
if self.use_res_connect:
2328
return self.skip_add.add(x, self.conv(x))
2429
else:
2530
return self.conv(x)
2631

27-
def fuse_model(self):
32+
def fuse_model(self) -> None:
2833
for idx in range(len(self.conv)):
2934
if type(self.conv[idx]) == nn.Conv2d:
3035
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)
3136

3237

3338
class QuantizableMobileNetV2(MobileNetV2):
34-
def __init__(self, *args, **kwargs):
39+
def __init__(self, *args: Any, **kwargs: Any) -> None:
3540
"""
3641
MobileNet V2 main class
3742
@@ -42,21 +47,26 @@ def __init__(self, *args, **kwargs):
4247
self.quant = QuantStub()
4348
self.dequant = DeQuantStub()
4449

45-
def forward(self, x):
50+
def forward(self, x: Tensor) -> Tensor:
4651
x = self.quant(x)
4752
x = self._forward_impl(x)
4853
x = self.dequant(x)
4954
return x
5055

51-
def fuse_model(self):
56+
def fuse_model(self) -> None:
5257
for m in self.modules():
5358
if type(m) == ConvBNReLU:
5459
fuse_modules(m, ['0', '1', '2'], inplace=True)
5560
if type(m) == QuantizableInvertedResidual:
5661
m.fuse_model()
5762

5863

59-
def mobilenet_v2(pretrained=False, progress=True, quantize=False, **kwargs):
64+
def mobilenet_v2(
65+
pretrained: bool = False,
66+
progress: bool = True,
67+
quantize: bool = False,
68+
**kwargs: Any,
69+
) -> QuantizableMobileNetV2:
6070
"""
6171
Constructs a MobileNetV2 architecture from
6272
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks"

0 commit comments

Comments
 (0)