diff --git a/mypy.ini b/mypy.ini index 52ddce8ec51..de04d6d173e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -20,10 +20,6 @@ ignore_errors=True ignore_errors = True -[mypy-torchvision.models.quantization.*] - -ignore_errors = True - [mypy-torchvision.ops.*] ignore_errors = True diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index bc1477d8f65..685815ac676 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -2,6 +2,8 @@ import torch import torch.nn as nn from torch.nn import functional as F +from typing import Any +from torch import Tensor from ..._internally_replaced_utils import load_state_dict_from_url from torchvision.models.googlenet import ( @@ -18,7 +20,13 @@ } -def googlenet(pretrained=False, progress=True, quantize=False, **kwargs): +def googlenet( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> "QuantizableGoogLeNet": + r"""GoogLeNet (Inception v1) model architecture from `"Going Deeper with Convolutions" `_. @@ -70,48 +78,51 @@ def googlenet(pretrained=False, progress=True, quantize=False, **kwargs): if not original_aux_logits: model.aux_logits = False - model.aux1 = None - model.aux2 = None + model.aux1 = None # type: ignore[assignment] + model.aux2 = None # type: ignore[assignment] return model class QuantizableBasicConv2d(BasicConv2d): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableBasicConv2d, self).__init__(*args, **kwargs) self.relu = nn.ReLU() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.conv(x) x = self.bn(x) x = self.relu(x) return x - def fuse_model(self): + def fuse_model(self) -> None: torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) class QuantizableInception(Inception): - def __init__(self, *args, **kwargs): - super(QuantizableInception, self).__init__( + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableInception, self).__init__( # type: ignore[misc] conv_block=QuantizableBasicConv2d, *args, **kwargs) self.cat = nn.quantized.FloatFunctional() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return self.cat.cat(outputs, 1) class QuantizableInceptionAux(InceptionAux): - - def __init__(self, *args, **kwargs): - super(QuantizableInceptionAux, self).__init__( - conv_block=QuantizableBasicConv2d, *args, **kwargs) + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableInceptionAux, self).__init__( # type: ignore[misc] + conv_block=QuantizableBasicConv2d, + *args, + **kwargs + ) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.7) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 x = F.adaptive_avg_pool2d(x, (4, 4)) # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 @@ -130,9 +141,9 @@ def forward(self, x): class QuantizableGoogLeNet(GoogLeNet): - - def __init__(self, *args, **kwargs): - super(QuantizableGoogLeNet, self).__init__( + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableGoogLeNet, self).__init__( # type: ignore[misc] blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs @@ -140,7 +151,7 @@ def __init__(self, *args, **kwargs): self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, x): + def forward(self, x: Tensor) -> GoogLeNetOutputs: x = self._transform_input(x) x = self.quant(x) x, aux1, aux2 = self._forward(x) @@ -153,7 +164,7 @@ def forward(self, x): else: return self.eager_outputs(x, aux2, aux1) - def fuse_model(self): + def fuse_model(self) -> None: r"""Fuse conv/bn/relu modules in googlenet model Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 833d8fb8b75..6c6384c295a 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -3,6 +3,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor +from typing import Any, List + from torchvision.models import inception as inception_module from torchvision.models.inception import InceptionOutputs from ..._internally_replaced_utils import load_state_dict_from_url @@ -22,7 +25,13 @@ } -def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs): +def inception_v3( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> "QuantizableInception3": + r"""Inception v3 model architecture from `"Rethinking the Inception Architecture for Computer Vision" `_. @@ -84,68 +93,93 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs): class QuantizableBasicConv2d(inception_module.BasicConv2d): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableBasicConv2d, self).__init__(*args, **kwargs) self.relu = nn.ReLU() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.conv(x) x = self.bn(x) x = self.relu(x) return x - def fuse_model(self): + def fuse_model(self) -> None: torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) class QuantizableInceptionA(inception_module.InceptionA): - def __init__(self, *args, **kwargs): - super(QuantizableInceptionA, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableInceptionA, self).__init__( # type: ignore[misc] + conv_block=QuantizableBasicConv2d, + *args, + **kwargs + ) self.myop = nn.quantized.FloatFunctional() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return self.myop.cat(outputs, 1) class QuantizableInceptionB(inception_module.InceptionB): - def __init__(self, *args, **kwargs): - super(QuantizableInceptionB, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableInceptionB, self).__init__( # type: ignore[misc] + conv_block=QuantizableBasicConv2d, + *args, + **kwargs + ) self.myop = nn.quantized.FloatFunctional() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return self.myop.cat(outputs, 1) class QuantizableInceptionC(inception_module.InceptionC): - def __init__(self, *args, **kwargs): - super(QuantizableInceptionC, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableInceptionC, self).__init__( # type: ignore[misc] + conv_block=QuantizableBasicConv2d, + *args, + **kwargs + ) self.myop = nn.quantized.FloatFunctional() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return self.myop.cat(outputs, 1) class QuantizableInceptionD(inception_module.InceptionD): - def __init__(self, *args, **kwargs): - super(QuantizableInceptionD, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableInceptionD, self).__init__( # type: ignore[misc] + conv_block=QuantizableBasicConv2d, + *args, + **kwargs + ) self.myop = nn.quantized.FloatFunctional() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return self.myop.cat(outputs, 1) class QuantizableInceptionE(inception_module.InceptionE): - def __init__(self, *args, **kwargs): - super(QuantizableInceptionE, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableInceptionE, self).__init__( # type: ignore[misc] + conv_block=QuantizableBasicConv2d, + *args, + **kwargs + ) self.myop1 = nn.quantized.FloatFunctional() self.myop2 = nn.quantized.FloatFunctional() self.myop3 = nn.quantized.FloatFunctional() - def _forward(self, x): + def _forward(self, x: Tensor) -> List[Tensor]: branch1x1 = self.branch1x1(x) branch3x3 = self.branch3x3_1(x) @@ -166,18 +200,28 @@ def _forward(self, x): outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] return outputs - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: outputs = self._forward(x) return self.myop3.cat(outputs, 1) class QuantizableInceptionAux(inception_module.InceptionAux): - def __init__(self, *args, **kwargs): - super(QuantizableInceptionAux, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableInceptionAux, self).__init__( # type: ignore[misc] + conv_block=QuantizableBasicConv2d, + *args, + **kwargs + ) class QuantizableInception3(inception_module.Inception3): - def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): + def __init__( + self, + num_classes: int = 1000, + aux_logits: bool = True, + transform_input: bool = False, + ) -> None: super(QuantizableInception3, self).__init__( num_classes=num_classes, aux_logits=aux_logits, @@ -195,7 +239,7 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, x): + def forward(self, x: Tensor) -> InceptionOutputs: x = self._transform_input(x) x = self.quant(x) x, aux = self._forward(x) @@ -208,7 +252,7 @@ def forward(self, x): else: return self.eager_outputs(x, aux) - def fuse_model(self): + def fuse_model(self) -> None: r"""Fuse conv/bn/relu modules in inception model Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 857d919b1fa..f914fdc7815 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,5 +1,10 @@ from torch import nn +from torch import Tensor + from ..._internally_replaced_utils import load_state_dict_from_url + +from typing import Any + from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls from torch.quantization import QuantStub, DeQuantStub, fuse_modules from .utils import _replace_relu, quantize_model @@ -14,24 +19,24 @@ class QuantizableInvertedResidual(InvertedResidual): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInvertedResidual, self).__init__(*args, **kwargs) self.skip_add = nn.quantized.FloatFunctional() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: if self.use_res_connect: return self.skip_add.add(x, self.conv(x)) else: return self.conv(x) - def fuse_model(self): + def fuse_model(self) -> None: for idx in range(len(self.conv)): if type(self.conv[idx]) == nn.Conv2d: fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True) class QuantizableMobileNetV2(MobileNetV2): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """ MobileNet V2 main class @@ -42,13 +47,13 @@ def __init__(self, *args, **kwargs): self.quant = QuantStub() self.dequant = DeQuantStub() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.quant(x) x = self._forward_impl(x) x = self.dequant(x) return x - def fuse_model(self): + def fuse_model(self) -> None: for m in self.modules(): if type(m) == ConvBNReLU: fuse_modules(m, ['0', '1', '2'], inplace=True) @@ -56,7 +61,12 @@ def fuse_model(self): m.fuse_model() -def mobilenet_v2(pretrained=False, progress=True, quantize=False, **kwargs): +def mobilenet_v2( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableMobileNetV2: """ Constructs a MobileNetV2 architecture from `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 5462af89127..a1aa9d7d4bd 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -17,23 +17,28 @@ class QuantizableSqueezeExcitation(SqueezeExcitation): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.skip_mul = nn.quantized.FloatFunctional() def forward(self, input: Tensor) -> Tensor: return self.skip_mul.mul(self._scale(input, False), input) - def fuse_model(self): + def fuse_model(self) -> None: fuse_modules(self, ['fc1', 'relu'], inplace=True) class QuantizableInvertedResidual(InvertedResidual): - def __init__(self, *args, **kwargs): - super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs) + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__( # type: ignore[misc] + se_layer=QuantizableSqueezeExcitation, + *args, + **kwargs + ) self.skip_add = nn.quantized.FloatFunctional() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: if self.use_res_connect: return self.skip_add.add(x, self.block(x)) else: @@ -41,7 +46,7 @@ def forward(self, x): class QuantizableMobileNetV3(MobileNetV3): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """ MobileNet V3 main class @@ -52,13 +57,13 @@ def __init__(self, *args, **kwargs): self.quant = QuantStub() self.dequant = DeQuantStub() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.quant(x) x = self._forward_impl(x) x = self.dequant(x) return x - def fuse_model(self): + def fuse_model(self) -> None: for m in self.modules(): if type(m) == ConvBNActivation: modules_to_fuse = ['0', '1'] @@ -74,7 +79,7 @@ def _load_weights( model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool, -): +) -> None: if model_url is None: raise ValueError("No checkpoint is available for {}".format(arch)) state_dict = load_state_dict_from_url(model_url, progress=progress) @@ -88,8 +93,9 @@ def _mobilenet_v3_model( pretrained: bool, progress: bool, quantize: bool, - **kwargs: Any -): + **kwargs: Any, +) -> QuantizableMobileNetV3: + model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) @@ -112,7 +118,12 @@ def _mobilenet_v3_model( return model -def mobilenet_v3_large(pretrained=False, progress=True, quantize=False, **kwargs): +def mobilenet_v3_large( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableMobileNetV3: """ Constructs a MobileNetV3 Large architecture from `"Searching for MobileNetV3" `_. diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 2f3f50e8013..8f87e40ec3d 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -1,6 +1,9 @@ import torch from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls import torch.nn as nn +from torch import Tensor +from typing import Any, Type, Union, List + from ..._internally_replaced_utils import load_state_dict_from_url from torch.quantization import fuse_modules from .utils import _replace_relu, quantize_model @@ -20,11 +23,11 @@ class QuantizableBasicBlock(BasicBlock): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableBasicBlock, self).__init__(*args, **kwargs) self.add_relu = torch.nn.quantized.FloatFunctional() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) @@ -41,7 +44,7 @@ def forward(self, x): return out - def fuse_model(self): + def fuse_model(self) -> None: torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'], ['conv2', 'bn2']], inplace=True) if self.downsample: @@ -49,13 +52,13 @@ def fuse_model(self): class QuantizableBottleneck(Bottleneck): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableBottleneck, self).__init__(*args, **kwargs) self.skip_add_relu = nn.quantized.FloatFunctional() self.relu1 = nn.ReLU(inplace=False) self.relu2 = nn.ReLU(inplace=False) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) @@ -73,7 +76,7 @@ def forward(self, x): return out - def fuse_model(self): + def fuse_model(self) -> None: fuse_modules(self, [['conv1', 'bn1', 'relu1'], ['conv2', 'bn2', 'relu2'], ['conv3', 'bn3']], inplace=True) @@ -83,13 +86,13 @@ def fuse_model(self): class QuantizableResNet(ResNet): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableResNet, self).__init__(*args, **kwargs) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.quant(x) # Ensure scriptability # super(QuantizableResNet,self).forward(x) @@ -98,7 +101,7 @@ def forward(self, x): x = self.dequant(x) return x - def fuse_model(self): + def fuse_model(self) -> None: r"""Fuse conv/bn/relu modules in resnet models Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization. @@ -112,7 +115,16 @@ def fuse_model(self): m.fuse_model() -def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs): +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + quantize: bool, + **kwargs: Any, +) -> QuantizableResNet: + model = QuantizableResNet(block, layers, **kwargs) _replace_relu(model) if quantize: @@ -135,7 +147,12 @@ def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs): return model -def resnet18(pretrained=False, progress=True, quantize=False, **kwargs): +def resnet18( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ @@ -148,7 +165,13 @@ def resnet18(pretrained=False, progress=True, quantize=False, **kwargs): quantize, **kwargs) -def resnet50(pretrained=False, progress=True, quantize=False, **kwargs): +def resnet50( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ @@ -161,7 +184,12 @@ def resnet50(pretrained=False, progress=True, quantize=False, **kwargs): quantize, **kwargs) -def resnext101_32x8d(pretrained=False, progress=True, quantize=False, **kwargs): +def resnext101_32x8d( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 17885015772..4f0861dcb30 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -1,12 +1,12 @@ import torch import torch.nn as nn +from torch import Tensor +from typing import Any + from ..._internally_replaced_utils import load_state_dict_from_url -import torchvision.models.shufflenetv2 -import sys +from torchvision.models import shufflenetv2 from .utils import _replace_relu, quantize_model -shufflenetv2 = sys.modules['torchvision.models.shufflenetv2'] - __all__ = [ 'QuantizableShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' @@ -22,16 +22,16 @@ class QuantizableInvertedResidual(shufflenetv2.InvertedResidual): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInvertedResidual, self).__init__(*args, **kwargs) self.cat = nn.quantized.FloatFunctional() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: if self.stride == 1: x1, x2 = x.chunk(2, dim=1) - out = self.cat.cat((x1, self.branch2(x2)), dim=1) + out = self.cat.cat([x1, self.branch2(x2)], dim=1) else: - out = self.cat.cat((self.branch1(x), self.branch2(x)), dim=1) + out = self.cat.cat([self.branch1(x), self.branch2(x)], dim=1) out = shufflenetv2.channel_shuffle(out, 2) @@ -39,18 +39,23 @@ def forward(self, x): class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): - def __init__(self, *args, **kwargs): - super(QuantizableShuffleNetV2, self).__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(QuantizableShuffleNetV2, self).__init__( # type: ignore[misc] + *args, + inverted_residual=QuantizableInvertedResidual, + **kwargs + ) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.quant(x) x = self._forward_impl(x) x = self.dequant(x) return x - def fuse_model(self): + def fuse_model(self) -> None: r"""Fuse conv/bn/relu modules in shufflenetv2 model Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. @@ -74,7 +79,15 @@ def fuse_model(self): ) -def _shufflenetv2(arch, pretrained, progress, quantize, *args, **kwargs): +def _shufflenetv2( + arch: str, + pretrained: bool, + progress: bool, + quantize: bool, + *args: Any, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + model = QuantizableShuffleNetV2(*args, **kwargs) _replace_relu(model) @@ -98,7 +111,12 @@ def _shufflenetv2(arch, pretrained, progress, quantize, *args, **kwargs): return model -def shufflenet_v2_x0_5(pretrained=False, progress=True, quantize=False, **kwargs): +def shufflenet_v2_x0_5( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: """ Constructs a ShuffleNetV2 with 0.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" @@ -113,7 +131,12 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, quantize=False, **kwargs [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) -def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs): +def shufflenet_v2_x1_0( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" @@ -128,7 +151,12 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) -def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs): +def shufflenet_v2_x1_5( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" @@ -143,7 +171,12 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) -def shufflenet_v2_x2_0(pretrained=False, progress=True, quantize=False, **kwargs): +def shufflenet_v2_x2_0( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: """ Constructs a ShuffleNetV2 with 2.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" diff --git a/torchvision/models/quantization/utils.py b/torchvision/models/quantization/utils.py index bf23c9a9332..c195d162482 100644 --- a/torchvision/models/quantization/utils.py +++ b/torchvision/models/quantization/utils.py @@ -2,7 +2,7 @@ from torch import nn -def _replace_relu(module): +def _replace_relu(module: nn.Module) -> None: reassign = {} for name, mod in module.named_children(): _replace_relu(mod) @@ -16,7 +16,7 @@ def _replace_relu(module): module._modules[key] = value -def quantize_model(model, backend): +def quantize_model(model: nn.Module, backend: str) -> None: _dummy_input_data = torch.rand(1, 3, 299, 299) if backend not in torch.backends.quantized.supported_engines: raise RuntimeError("Quantized backend not supported ") @@ -24,15 +24,16 @@ def quantize_model(model, backend): model.eval() # Make sure that weight qconfig matches that of the serialized models if backend == 'fbgemm': - model.qconfig = torch.quantization.QConfig( + model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] activation=torch.quantization.default_observer, weight=torch.quantization.default_per_channel_weight_observer) elif backend == 'qnnpack': - model.qconfig = torch.quantization.QConfig( + model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] activation=torch.quantization.default_observer, weight=torch.quantization.default_weight_observer) - model.fuse_model() + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + model.fuse_model() # type: ignore[operator] torch.quantization.prepare(model, inplace=True) model(_dummy_input_data) torch.quantization.convert(model, inplace=True)