3
3
import torch
4
4
import torch .nn as nn
5
5
import torch .nn .functional as F
6
+ from torch import Tensor
7
+ from typing import Any , List
8
+
6
9
from torchvision .models import inception as inception_module
7
10
from torchvision .models .inception import InceptionOutputs
8
11
from ..._internally_replaced_utils import load_state_dict_from_url
22
25
}
23
26
24
27
25
- def inception_v3 (pretrained = False , progress = True , quantize = False , ** kwargs ):
28
+ def inception_v3 (
29
+ pretrained : bool = False ,
30
+ progress : bool = True ,
31
+ quantize : bool = False ,
32
+ ** kwargs : Any ,
33
+ ) -> "QuantizableInception3" :
34
+
26
35
r"""Inception v3 model architecture from
27
36
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
28
37
@@ -84,68 +93,93 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
84
93
85
94
86
95
class QuantizableBasicConv2d (inception_module .BasicConv2d ):
87
- def __init__ (self , * args , ** kwargs ) :
96
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
88
97
super (QuantizableBasicConv2d , self ).__init__ (* args , ** kwargs )
89
98
self .relu = nn .ReLU ()
90
99
91
- def forward (self , x ) :
100
+ def forward (self , x : Tensor ) -> Tensor :
92
101
x = self .conv (x )
93
102
x = self .bn (x )
94
103
x = self .relu (x )
95
104
return x
96
105
97
- def fuse_model (self ):
106
+ def fuse_model (self ) -> None :
98
107
torch .quantization .fuse_modules (self , ["conv" , "bn" , "relu" ], inplace = True )
99
108
100
109
101
110
class QuantizableInceptionA (inception_module .InceptionA ):
102
- def __init__ (self , * args , ** kwargs ):
103
- super (QuantizableInceptionA , self ).__init__ (conv_block = QuantizableBasicConv2d , * args , ** kwargs )
111
+ # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
112
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
113
+ super (QuantizableInceptionA , self ).__init__ ( # type: ignore[misc]
114
+ conv_block = QuantizableBasicConv2d ,
115
+ * args ,
116
+ ** kwargs
117
+ )
104
118
self .myop = nn .quantized .FloatFunctional ()
105
119
106
- def forward (self , x ) :
120
+ def forward (self , x : Tensor ) -> Tensor :
107
121
outputs = self ._forward (x )
108
122
return self .myop .cat (outputs , 1 )
109
123
110
124
111
125
class QuantizableInceptionB (inception_module .InceptionB ):
112
- def __init__ (self , * args , ** kwargs ):
113
- super (QuantizableInceptionB , self ).__init__ (conv_block = QuantizableBasicConv2d , * args , ** kwargs )
126
+ # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
127
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
128
+ super (QuantizableInceptionB , self ).__init__ ( # type: ignore[misc]
129
+ conv_block = QuantizableBasicConv2d ,
130
+ * args ,
131
+ ** kwargs
132
+ )
114
133
self .myop = nn .quantized .FloatFunctional ()
115
134
116
- def forward (self , x ) :
135
+ def forward (self , x : Tensor ) -> Tensor :
117
136
outputs = self ._forward (x )
118
137
return self .myop .cat (outputs , 1 )
119
138
120
139
121
140
class QuantizableInceptionC (inception_module .InceptionC ):
122
- def __init__ (self , * args , ** kwargs ):
123
- super (QuantizableInceptionC , self ).__init__ (conv_block = QuantizableBasicConv2d , * args , ** kwargs )
141
+ # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
142
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
143
+ super (QuantizableInceptionC , self ).__init__ ( # type: ignore[misc]
144
+ conv_block = QuantizableBasicConv2d ,
145
+ * args ,
146
+ ** kwargs
147
+ )
124
148
self .myop = nn .quantized .FloatFunctional ()
125
149
126
- def forward (self , x ) :
150
+ def forward (self , x : Tensor ) -> Tensor :
127
151
outputs = self ._forward (x )
128
152
return self .myop .cat (outputs , 1 )
129
153
130
154
131
155
class QuantizableInceptionD (inception_module .InceptionD ):
132
- def __init__ (self , * args , ** kwargs ):
133
- super (QuantizableInceptionD , self ).__init__ (conv_block = QuantizableBasicConv2d , * args , ** kwargs )
156
+ # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
157
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
158
+ super (QuantizableInceptionD , self ).__init__ ( # type: ignore[misc]
159
+ conv_block = QuantizableBasicConv2d ,
160
+ * args ,
161
+ ** kwargs
162
+ )
134
163
self .myop = nn .quantized .FloatFunctional ()
135
164
136
- def forward (self , x ) :
165
+ def forward (self , x : Tensor ) -> Tensor :
137
166
outputs = self ._forward (x )
138
167
return self .myop .cat (outputs , 1 )
139
168
140
169
141
170
class QuantizableInceptionE (inception_module .InceptionE ):
142
- def __init__ (self , * args , ** kwargs ):
143
- super (QuantizableInceptionE , self ).__init__ (conv_block = QuantizableBasicConv2d , * args , ** kwargs )
171
+ # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
172
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
173
+ super (QuantizableInceptionE , self ).__init__ ( # type: ignore[misc]
174
+ conv_block = QuantizableBasicConv2d ,
175
+ * args ,
176
+ ** kwargs
177
+ )
144
178
self .myop1 = nn .quantized .FloatFunctional ()
145
179
self .myop2 = nn .quantized .FloatFunctional ()
146
180
self .myop3 = nn .quantized .FloatFunctional ()
147
181
148
- def _forward (self , x ) :
182
+ def _forward (self , x : Tensor ) -> List [ Tensor ] :
149
183
branch1x1 = self .branch1x1 (x )
150
184
151
185
branch3x3 = self .branch3x3_1 (x )
@@ -166,18 +200,28 @@ def _forward(self, x):
166
200
outputs = [branch1x1 , branch3x3 , branch3x3dbl , branch_pool ]
167
201
return outputs
168
202
169
- def forward (self , x ) :
203
+ def forward (self , x : Tensor ) -> Tensor :
170
204
outputs = self ._forward (x )
171
205
return self .myop3 .cat (outputs , 1 )
172
206
173
207
174
208
class QuantizableInceptionAux (inception_module .InceptionAux ):
175
- def __init__ (self , * args , ** kwargs ):
176
- super (QuantizableInceptionAux , self ).__init__ (conv_block = QuantizableBasicConv2d , * args , ** kwargs )
209
+ # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
210
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
211
+ super (QuantizableInceptionAux , self ).__init__ ( # type: ignore[misc]
212
+ conv_block = QuantizableBasicConv2d ,
213
+ * args ,
214
+ ** kwargs
215
+ )
177
216
178
217
179
218
class QuantizableInception3 (inception_module .Inception3 ):
180
- def __init__ (self , num_classes = 1000 , aux_logits = True , transform_input = False ):
219
+ def __init__ (
220
+ self ,
221
+ num_classes : int = 1000 ,
222
+ aux_logits : bool = True ,
223
+ transform_input : bool = False ,
224
+ ) -> None :
181
225
super (QuantizableInception3 , self ).__init__ (
182
226
num_classes = num_classes ,
183
227
aux_logits = aux_logits ,
@@ -195,7 +239,7 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
195
239
self .quant = torch .quantization .QuantStub ()
196
240
self .dequant = torch .quantization .DeQuantStub ()
197
241
198
- def forward (self , x ) :
242
+ def forward (self , x : Tensor ) -> InceptionOutputs :
199
243
x = self ._transform_input (x )
200
244
x = self .quant (x )
201
245
x , aux = self ._forward (x )
@@ -208,7 +252,7 @@ def forward(self, x):
208
252
else :
209
253
return self .eager_outputs (x , aux )
210
254
211
- def fuse_model (self ):
255
+ def fuse_model (self ) -> None :
212
256
r"""Fuse conv/bn/relu modules in inception model
213
257
214
258
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
0 commit comments