-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding multiweight support to Quantized GoogLeNet #4848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e1baa2e
cd01389
0e6d7d0
851ddcb
efa445d
0a3ca96
1a150d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,69 +19,6 @@ | |
} | ||
|
||
|
||
def googlenet( | ||
pretrained: bool = False, | ||
progress: bool = True, | ||
quantize: bool = False, | ||
**kwargs: Any, | ||
) -> "QuantizableGoogLeNet": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just moving this builder below to the class definition to use proper typing. |
||
|
||
r"""GoogLeNet (Inception v1) model architecture from | ||
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_. | ||
|
||
Note that quantize = True returns a quantized model with 8 bit | ||
weights. Quantized models only support inference and run on CPUs. | ||
GPU inference is not yet supported | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
quantize (bool): If True, return a quantized version of the model | ||
aux_logits (bool): If True, adds two auxiliary branches that can improve training. | ||
Default: *False* when pretrained is True otherwise *True* | ||
transform_input (bool): If True, preprocesses the input according to the method with which it | ||
was trained on ImageNet. Default: *False* | ||
""" | ||
if pretrained: | ||
if "transform_input" not in kwargs: | ||
kwargs["transform_input"] = True | ||
if "aux_logits" not in kwargs: | ||
kwargs["aux_logits"] = False | ||
if kwargs["aux_logits"]: | ||
warnings.warn( | ||
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" | ||
) | ||
original_aux_logits = kwargs["aux_logits"] | ||
kwargs["aux_logits"] = True | ||
kwargs["init_weights"] = False | ||
|
||
model = QuantizableGoogLeNet(**kwargs) | ||
_replace_relu(model) | ||
|
||
if quantize: | ||
# TODO use pretrained as a string to specify the backend | ||
backend = "fbgemm" | ||
quantize_model(model, backend) | ||
else: | ||
assert pretrained in [True, False] | ||
|
||
if pretrained: | ||
if quantize: | ||
model_url = quant_model_urls["googlenet_" + backend] | ||
else: | ||
model_url = model_urls["googlenet"] | ||
|
||
state_dict = load_state_dict_from_url(model_url, progress=progress) | ||
|
||
model.load_state_dict(state_dict) | ||
|
||
if not original_aux_logits: | ||
model.aux_logits = False | ||
model.aux1 = None # type: ignore[assignment] | ||
model.aux2 = None # type: ignore[assignment] | ||
return model | ||
|
||
|
||
class QuantizableBasicConv2d(BasicConv2d): | ||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
|
@@ -164,3 +101,65 @@ def fuse_model(self) -> None: | |
for m in self.modules(): | ||
if type(m) is QuantizableBasicConv2d: | ||
m.fuse_model() | ||
|
||
|
||
def googlenet( | ||
pretrained: bool = False, | ||
progress: bool = True, | ||
quantize: bool = False, | ||
**kwargs: Any, | ||
) -> QuantizableGoogLeNet: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copy-pasted, no need for review. |
||
r"""GoogLeNet (Inception v1) model architecture from | ||
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_. | ||
|
||
Note that quantize = True returns a quantized model with 8 bit | ||
weights. Quantized models only support inference and run on CPUs. | ||
GPU inference is not yet supported | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
quantize (bool): If True, return a quantized version of the model | ||
aux_logits (bool): If True, adds two auxiliary branches that can improve training. | ||
Default: *False* when pretrained is True otherwise *True* | ||
transform_input (bool): If True, preprocesses the input according to the method with which it | ||
was trained on ImageNet. Default: *False* | ||
""" | ||
if pretrained: | ||
if "transform_input" not in kwargs: | ||
kwargs["transform_input"] = True | ||
if "aux_logits" not in kwargs: | ||
kwargs["aux_logits"] = False | ||
if kwargs["aux_logits"]: | ||
warnings.warn( | ||
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" | ||
) | ||
original_aux_logits = kwargs["aux_logits"] | ||
kwargs["aux_logits"] = True | ||
kwargs["init_weights"] = False | ||
|
||
model = QuantizableGoogLeNet(**kwargs) | ||
_replace_relu(model) | ||
|
||
if quantize: | ||
# TODO use pretrained as a string to specify the backend | ||
backend = "fbgemm" | ||
quantize_model(model, backend) | ||
else: | ||
assert pretrained in [True, False] | ||
|
||
if pretrained: | ||
if quantize: | ||
model_url = quant_model_urls["googlenet_" + backend] | ||
else: | ||
model_url = model_urls["googlenet"] | ||
|
||
state_dict = load_state_dict_from_url(model_url, progress=progress) | ||
|
||
model.load_state_dict(state_dict) | ||
|
||
if not original_aux_logits: | ||
model.aux_logits = False | ||
model.aux1 = None # type: ignore[assignment] | ||
model.aux2 = None # type: ignore[assignment] | ||
return model |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,14 +14,14 @@ | |
|
||
|
||
class GoogLeNetWeights(Weights): | ||
ImageNet1K_Community = WeightEntry( | ||
ImageNet1K_TFV1 = WeightEntry( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was previously incorrectly tagged as community contribution while in reality the weights were ported from TF. Proof: #678 (comment) |
||
url="https://download.pytorch.org/models/googlenet-1378be20.pth", | ||
transforms=partial(ImageNetEval, crop_size=224), | ||
meta={ | ||
"size": (224, 224), | ||
"categories": _IMAGENET_CATEGORIES, | ||
"interpolation": InterpolationMode.BILINEAR, | ||
"recipe": "https://github.com/TheCodez/examples/blob/inception/imagenet/README.md#googlenet", | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet", | ||
"acc@1": 69.778, | ||
"acc@5": 89.530, | ||
}, | ||
|
@@ -31,7 +31,7 @@ class GoogLeNetWeights(Weights): | |
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: | ||
if "pretrained" in kwargs: | ||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
weights = GoogLeNetWeights.ImageNet1K_Community if kwargs.pop("pretrained") else None | ||
weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None | ||
weights = GoogLeNetWeights.verify(weights) | ||
|
||
original_aux_logits = kwargs.get("aux_logits", False) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .googlenet import * | ||
from .resnet import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import warnings | ||
from functools import partial | ||
from typing import Any, Optional, Union | ||
|
||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
from ....models.quantization.googlenet import ( | ||
QuantizableGoogLeNet, | ||
_replace_relu, | ||
quantize_model, | ||
) | ||
from ...transforms.presets import ImageNetEval | ||
from .._api import Weights, WeightEntry | ||
from .._meta import _IMAGENET_CATEGORIES | ||
from ..googlenet import GoogLeNetWeights | ||
|
||
|
||
__all__ = [ | ||
"QuantizableGoogLeNet", | ||
"QuantizedGoogLeNetWeights", | ||
"googlenet", | ||
] | ||
|
||
|
||
class QuantizedGoogLeNetWeights(Weights): | ||
ImageNet1K_FBGEMM_TFV1 = WeightEntry( | ||
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", | ||
transforms=partial(ImageNetEval, crop_size=224), | ||
meta={ | ||
"size": (224, 224), | ||
"categories": _IMAGENET_CATEGORIES, | ||
"interpolation": InterpolationMode.BILINEAR, | ||
"backend": "fbgemm", | ||
"quantization": "ptq", | ||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", | ||
"unquantized": GoogLeNetWeights.ImageNet1K_TFV1, | ||
"acc@1": 69.826, | ||
"acc@5": 89.404, | ||
}, | ||
) | ||
|
||
|
||
def googlenet( | ||
weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None, | ||
progress: bool = True, | ||
quantize: bool = False, | ||
**kwargs: Any, | ||
) -> QuantizableGoogLeNet: | ||
if "pretrained" in kwargs: | ||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
if kwargs.pop("pretrained"): | ||
weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1 | ||
else: | ||
weights = None | ||
|
||
if quantize: | ||
weights = QuantizedGoogLeNetWeights.verify(weights) | ||
else: | ||
weights = GoogLeNetWeights.verify(weights) | ||
|
||
original_aux_logits = kwargs.get("aux_logits", False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We simplify similarly to the unquantized builder. |
||
if weights is not None: | ||
if "transform_input" not in kwargs: | ||
kwargs["transform_input"] = True | ||
if original_aux_logits: | ||
warnings.warn( | ||
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" | ||
) | ||
kwargs["aux_logits"] = True | ||
kwargs["init_weights"] = False | ||
kwargs["num_classes"] = len(weights.meta["categories"]) | ||
if "backend" in weights.meta: | ||
kwargs["backend"] = weights.meta["backend"] | ||
backend = kwargs.pop("backend", "fbgemm") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rest is similar to the quantized implementation for resnet. |
||
|
||
model = QuantizableGoogLeNet(**kwargs) | ||
_replace_relu(model) | ||
if quantize: | ||
quantize_model(model, backend) | ||
|
||
if weights is not None: | ||
model.load_state_dict(weights.state_dict(progress=progress)) | ||
if not original_aux_logits: | ||
model.aux_logits = False | ||
model.aux1 = None # type: ignore[assignment] | ||
model.aux2 = None # type: ignore[assignment] | ||
|
||
return model |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
from functools import partial | ||
from typing import Any, List, Optional, Type, Union | ||
|
||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
from ....models.quantization.resnet import ( | ||
QuantizableBasicBlock, | ||
QuantizableBottleneck, | ||
|
@@ -54,7 +56,9 @@ def _resnet( | |
_common_meta = { | ||
"size": (224, 224), | ||
"categories": _IMAGENET_CATEGORIES, | ||
"interpolation": InterpolationMode.BILINEAR, | ||
"backend": "fbgemm", | ||
"quantization": "ptq", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding some additional meta-data such as interpolation, type of quantization and a reference to the unquantized enum of the the weights to all previous models. |
||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", | ||
} | ||
|
||
|
@@ -65,6 +69,7 @@ class QuantizedResNet18Weights(Weights): | |
transforms=partial(ImageNetEval, crop_size=224), | ||
meta={ | ||
**_common_meta, | ||
"unquantized": ResNet18Weights.ImageNet1K_RefV1, | ||
"acc@1": 69.494, | ||
"acc@5": 88.882, | ||
}, | ||
|
@@ -77,6 +82,7 @@ class QuantizedResNet50Weights(Weights): | |
transforms=partial(ImageNetEval, crop_size=224), | ||
meta={ | ||
**_common_meta, | ||
"unquantized": ResNet50Weights.ImageNet1K_RefV1, | ||
"acc@1": 75.920, | ||
"acc@5": 92.814, | ||
}, | ||
|
@@ -89,6 +95,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights): | |
transforms=partial(ImageNetEval, crop_size=224), | ||
meta={ | ||
**_common_meta, | ||
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1, | ||
"acc@1": 78.986, | ||
"acc@5": 94.480, | ||
}, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed to actually run the tests on the quantized weights.