Skip to content

Adding multiweight support to Quantized GoogLeNet #4848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note
that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch
normalization and thus are trained with the default parameters.

### GoogLeNet

The weights of the GoogLeNet model are ported from the original paper rather than trained from scratch.

### Inception V3

The weights of the Inception V3 model are ported from the original paper rather than trained from scratch.
Expand Down
1 change: 1 addition & 0 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
},
"quantization": {
"input_shape": (1, 3, 224, 224),
"quantize": True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to actually run the tests on the quantized weights.

},
"segmentation": {
"input_shape": (1, 3, 520, 520),
Expand Down
125 changes: 62 additions & 63 deletions torchvision/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,69 +19,6 @@
}


def googlenet(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableGoogLeNet":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just moving this builder below to the class definition to use proper typing.


r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.

Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False

model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)

if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]

if pretrained:
if quantize:
model_url = quant_model_urls["googlenet_" + backend]
else:
model_url = model_urls["googlenet"]

state_dict = load_state_dict_from_url(model_url, progress=progress)

model.load_state_dict(state_dict)

if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model


class QuantizableBasicConv2d(BasicConv2d):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -164,3 +101,65 @@ def fuse_model(self) -> None:
for m in self.modules():
if type(m) is QuantizableBasicConv2d:
m.fuse_model()


def googlenet(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableGoogLeNet:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy-pasted, no need for review.

r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.

Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False

model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)

if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]

if pretrained:
if quantize:
model_url = quant_model_urls["googlenet_" + backend]
else:
model_url = model_urls["googlenet"]

state_dict = load_state_dict_from_url(model_url, progress=progress)

model.load_state_dict(state_dict)

if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
6 changes: 3 additions & 3 deletions torchvision/prototype/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@


class GoogLeNetWeights(Weights):
ImageNet1K_Community = WeightEntry(
ImageNet1K_TFV1 = WeightEntry(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was previously incorrectly tagged as community contribution while in reality the weights were ported from TF. Proof: #678 (comment)

url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/TheCodez/examples/blob/inception/imagenet/README.md#googlenet",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
"acc@1": 69.778,
"acc@5": 89.530,
},
Expand All @@ -31,7 +31,7 @@ class GoogLeNetWeights(Weights):
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = GoogLeNetWeights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.verify(weights)

original_aux_logits = kwargs.get("aux_logits", False)
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .googlenet import *
from .resnet import *
88 changes: 88 additions & 0 deletions torchvision/prototype/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import warnings
from functools import partial
from typing import Any, Optional, Union

from torchvision.transforms.functional import InterpolationMode

from ....models.quantization.googlenet import (
QuantizableGoogLeNet,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..googlenet import GoogLeNetWeights


__all__ = [
"QuantizableGoogLeNet",
"QuantizedGoogLeNetWeights",
"googlenet",
]


class QuantizedGoogLeNetWeights(Weights):
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": GoogLeNetWeights.ImageNet1K_TFV1,
"acc@1": 69.826,
"acc@5": 89.404,
},
)


def googlenet(
weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableGoogLeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
else:
weights = None

if quantize:
weights = QuantizedGoogLeNetWeights.verify(weights)
else:
weights = GoogLeNetWeights.verify(weights)

original_aux_logits = kwargs.get("aux_logits", False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We simplify similarly to the unquantized builder.

if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if original_aux_logits:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "fbgemm")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest is similar to the quantized implementation for resnet.


model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]

return model
7 changes: 7 additions & 0 deletions torchvision/prototype/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import partial
from typing import Any, List, Optional, Type, Union

from torchvision.transforms.functional import InterpolationMode

from ....models.quantization.resnet import (
QuantizableBasicBlock,
QuantizableBottleneck,
Expand Down Expand Up @@ -54,7 +56,9 @@ def _resnet(
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some additional meta-data such as interpolation, type of quantization and a reference to the unquantized enum of the the weights to all previous models.

"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}

Expand All @@ -65,6 +69,7 @@ class QuantizedResNet18Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"unquantized": ResNet18Weights.ImageNet1K_RefV1,
"acc@1": 69.494,
"acc@5": 88.882,
},
Expand All @@ -77,6 +82,7 @@ class QuantizedResNet50Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"unquantized": ResNet50Weights.ImageNet1K_RefV1,
"acc@1": 75.920,
"acc@5": 92.814,
},
Expand All @@ -89,6 +95,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1,
"acc@1": 78.986,
"acc@5": 94.480,
},
Expand Down