diff --git a/references/classification/README.md b/references/classification/README.md index 71692b61386..302e9d57562 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -31,6 +31,10 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch normalization and thus are trained with the default parameters. +### GoogLeNet + +The weights of the GoogLeNet model are ported from the original paper rather than trained from scratch. + ### Inception V3 The weights of the Inception V3 model are ported from the original paper rather than trained from scratch. diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index e781653f073..3bdab30c794 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -95,6 +95,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev): }, "quantization": { "input_shape": (1, 3, 224, 224), + "quantize": True, }, "segmentation": { "input_shape": (1, 3, 520, 520), diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index d25bf6f3e4d..fac2a738cba 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -19,69 +19,6 @@ } -def googlenet( - pretrained: bool = False, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> "QuantizableGoogLeNet": - - r"""GoogLeNet (Inception v1) model architecture from - `"Going Deeper with Convolutions" `_. - - Note that quantize = True returns a quantized model with 8 bit - weights. Quantized models only support inference and run on CPUs. - GPU inference is not yet supported - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - quantize (bool): If True, return a quantized version of the model - aux_logits (bool): If True, adds two auxiliary branches that can improve training. - Default: *False* when pretrained is True otherwise *True* - transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: *False* - """ - if pretrained: - if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" not in kwargs: - kwargs["aux_logits"] = False - if kwargs["aux_logits"]: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - kwargs["init_weights"] = False - - model = QuantizableGoogLeNet(**kwargs) - _replace_relu(model) - - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" - quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - model_url = quant_model_urls["googlenet_" + backend] - else: - model_url = model_urls["googlenet"] - - state_dict = load_state_dict_from_url(model_url, progress=progress) - - model.load_state_dict(state_dict) - - if not original_aux_logits: - model.aux_logits = False - model.aux1 = None # type: ignore[assignment] - model.aux2 = None # type: ignore[assignment] - return model - - class QuantizableBasicConv2d(BasicConv2d): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -164,3 +101,65 @@ def fuse_model(self) -> None: for m in self.modules(): if type(m) is QuantizableBasicConv2d: m.fuse_model() + + +def googlenet( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableGoogLeNet: + r"""GoogLeNet (Inception v1) model architecture from + `"Going Deeper with Convolutions" `_. + + Note that quantize = True returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + quantize (bool): If True, return a quantized version of the model + aux_logits (bool): If True, adds two auxiliary branches that can improve training. + Default: *False* when pretrained is True otherwise *True* + transform_input (bool): If True, preprocesses the input according to the method with which it + was trained on ImageNet. Default: *False* + """ + if pretrained: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" not in kwargs: + kwargs["aux_logits"] = False + if kwargs["aux_logits"]: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True + kwargs["init_weights"] = False + + model = QuantizableGoogLeNet(**kwargs) + _replace_relu(model) + + if quantize: + # TODO use pretrained as a string to specify the backend + backend = "fbgemm" + quantize_model(model, backend) + else: + assert pretrained in [True, False] + + if pretrained: + if quantize: + model_url = quant_model_urls["googlenet_" + backend] + else: + model_url = model_urls["googlenet"] + + state_dict = load_state_dict_from_url(model_url, progress=progress) + + model.load_state_dict(state_dict) + + if not original_aux_logits: + model.aux_logits = False + model.aux1 = None # type: ignore[assignment] + model.aux2 = None # type: ignore[assignment] + return model diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 9b28ea8450a..0fb45e103e3 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -14,14 +14,14 @@ class GoogLeNetWeights(Weights): - ImageNet1K_Community = WeightEntry( + ImageNet1K_TFV1 = WeightEntry( url="https://download.pytorch.org/models/googlenet-1378be20.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ "size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/TheCodez/examples/blob/inception/imagenet/README.md#googlenet", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet", "acc@1": 69.778, "acc@5": 89.530, }, @@ -31,7 +31,7 @@ class GoogLeNetWeights(Weights): def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: if "pretrained" in kwargs: warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = GoogLeNetWeights.ImageNet1K_Community if kwargs.pop("pretrained") else None + weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = GoogLeNetWeights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) diff --git a/torchvision/prototype/models/quantization/__init__.py b/torchvision/prototype/models/quantization/__init__.py index b792ca6ecf7..e82fed54a9c 100644 --- a/torchvision/prototype/models/quantization/__init__.py +++ b/torchvision/prototype/models/quantization/__init__.py @@ -1 +1,2 @@ +from .googlenet import * from .resnet import * diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py new file mode 100644 index 00000000000..2cf6527cccf --- /dev/null +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -0,0 +1,88 @@ +import warnings +from functools import partial +from typing import Any, Optional, Union + +from torchvision.transforms.functional import InterpolationMode + +from ....models.quantization.googlenet import ( + QuantizableGoogLeNet, + _replace_relu, + quantize_model, +) +from ...transforms.presets import ImageNetEval +from .._api import Weights, WeightEntry +from .._meta import _IMAGENET_CATEGORIES +from ..googlenet import GoogLeNetWeights + + +__all__ = [ + "QuantizableGoogLeNet", + "QuantizedGoogLeNetWeights", + "googlenet", +] + + +class QuantizedGoogLeNetWeights(Weights): + ImageNet1K_FBGEMM_TFV1 = WeightEntry( + url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": GoogLeNetWeights.ImageNet1K_TFV1, + "acc@1": 69.826, + "acc@5": 89.404, + }, + ) + + +def googlenet( + weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableGoogLeNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + if kwargs.pop("pretrained"): + weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1 + else: + weights = None + + if quantize: + weights = QuantizedGoogLeNetWeights.verify(weights) + else: + weights = GoogLeNetWeights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if original_aux_logits: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) + kwargs["aux_logits"] = True + kwargs["init_weights"] = False + kwargs["num_classes"] = len(weights.meta["categories"]) + if "backend" in weights.meta: + kwargs["backend"] = weights.meta["backend"] + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableGoogLeNet(**kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + if not original_aux_logits: + model.aux_logits = False + model.aux1 = None # type: ignore[assignment] + model.aux2 = None # type: ignore[assignment] + + return model diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 15a583f3f00..361bf3dc385 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -2,6 +2,8 @@ from functools import partial from typing import Any, List, Optional, Type, Union +from torchvision.transforms.functional import InterpolationMode + from ....models.quantization.resnet import ( QuantizableBasicBlock, QuantizableBottleneck, @@ -54,7 +56,9 @@ def _resnet( _common_meta = { "size": (224, 224), "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, "backend": "fbgemm", + "quantization": "ptq", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", } @@ -65,6 +69,7 @@ class QuantizedResNet18Weights(Weights): transforms=partial(ImageNetEval, crop_size=224), meta={ **_common_meta, + "unquantized": ResNet18Weights.ImageNet1K_RefV1, "acc@1": 69.494, "acc@5": 88.882, }, @@ -77,6 +82,7 @@ class QuantizedResNet50Weights(Weights): transforms=partial(ImageNetEval, crop_size=224), meta={ **_common_meta, + "unquantized": ResNet50Weights.ImageNet1K_RefV1, "acc@1": 75.920, "acc@5": 92.814, }, @@ -89,6 +95,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights): transforms=partial(ImageNetEval, crop_size=224), meta={ **_common_meta, + "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1, "acc@1": 78.986, "acc@5": 94.480, },