Skip to content

Commit ac4eaab

Browse files
authored
Merge branch 'main' into add_flow_datasets
2 parents 3b8ba30 + f5af07b commit ac4eaab

File tree

6 files changed

+144
-52
lines changed

6 files changed

+144
-52
lines changed

docs/source/models.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,8 @@ Model Acc@1 Acc@5
479479
================================ ============= =============
480480
MobileNet V2 71.658 90.150
481481
MobileNet V3 Large 73.004 90.858
482-
ShuffleNet V2 68.360 87.582
482+
ShuffleNet V2 x1.0 68.360 87.582
483+
ShuffleNet V2 x0.5 57.972 79.780
483484
ResNet 18 69.494 88.882
484485
ResNet 50 75.920 92.814
485486
ResNext 101 32x8d 78.986 94.480

references/classification/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ For all post training quantized models, the settings are:
168168
```
169169
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL'
170170
```
171-
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d` and `shufflenet_v2_x1_0`.
171+
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d`, `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0`.
172172

173173
### QAT MobileNetV2
174174

test/test_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,7 +1364,7 @@ def test_to_grayscale():
13641364
@pytest.mark.parametrize("p", (0, 1))
13651365
def test_random_apply(p, seed):
13661366
torch.manual_seed(seed)
1367-
random_apply_transform = transforms.RandomApply([transforms.RandomRotation((1, 45))], p=p)
1367+
random_apply_transform = transforms.RandomApply([transforms.RandomRotation((45, 50))], p=p)
13681368
img = transforms.ToPILImage()(torch.rand(3, 30, 40))
13691369
out = random_apply_transform(img)
13701370
if p == 0:
@@ -1384,7 +1384,7 @@ def test_random_choice(proba_passthrough, seed):
13841384
random_choice_transform = transforms.RandomChoice(
13851385
[
13861386
lambda x: x, # passthrough
1387-
transforms.RandomRotation((1, 45)),
1387+
transforms.RandomRotation((45, 50)),
13881388
],
13891389
p=[proba_passthrough, 1 - proba_passthrough],
13901390
)

torchvision/models/quantization/shufflenetv2.py

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

33
import torch
44
import torch.nn as nn
@@ -12,15 +12,11 @@
1212
"QuantizableShuffleNetV2",
1313
"shufflenet_v2_x0_5",
1414
"shufflenet_v2_x1_0",
15-
"shufflenet_v2_x1_5",
16-
"shufflenet_v2_x2_0",
1715
]
1816

1917
quant_model_urls = {
20-
"shufflenetv2_x0.5_fbgemm": None,
18+
"shufflenetv2_x0.5_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
2119
"shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
22-
"shufflenetv2_x1.5_fbgemm": None,
23-
"shufflenetv2_x2.0_fbgemm": None,
2420
}
2521

2622

@@ -96,6 +92,7 @@ def _shufflenetv2(
9692
assert pretrained in [True, False]
9793

9894
if pretrained:
95+
model_url: Optional[str] = None
9996
if quantize:
10097
model_url = quant_model_urls[arch + "_" + backend]
10198
else:
@@ -147,45 +144,3 @@ def shufflenet_v2_x1_0(
147144
return _shufflenetv2(
148145
"shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs
149146
)
150-
151-
152-
def shufflenet_v2_x1_5(
153-
pretrained: bool = False,
154-
progress: bool = True,
155-
quantize: bool = False,
156-
**kwargs: Any,
157-
) -> QuantizableShuffleNetV2:
158-
"""
159-
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
160-
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
161-
<https://arxiv.org/abs/1807.11164>`_.
162-
163-
Args:
164-
pretrained (bool): If True, returns a model pre-trained on ImageNet
165-
progress (bool): If True, displays a progress bar of the download to stderr
166-
quantize (bool): If True, return a quantized version of the model
167-
"""
168-
return _shufflenetv2(
169-
"shufflenetv2_x1.5", pretrained, progress, quantize, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs
170-
)
171-
172-
173-
def shufflenet_v2_x2_0(
174-
pretrained: bool = False,
175-
progress: bool = True,
176-
quantize: bool = False,
177-
**kwargs: Any,
178-
) -> QuantizableShuffleNetV2:
179-
"""
180-
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
181-
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
182-
<https://arxiv.org/abs/1807.11164>`_.
183-
184-
Args:
185-
pretrained (bool): If True, returns a model pre-trained on ImageNet
186-
progress (bool): If True, displays a progress bar of the download to stderr
187-
quantize (bool): If True, return a quantized version of the model
188-
"""
189-
return _shufflenetv2(
190-
"shufflenetv2_x2.0", pretrained, progress, quantize, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs
191-
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .googlenet import *
22
from .inception import *
33
from .resnet import *
4+
from .shufflenetv2 import *
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, List, Optional, Union
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ....models.quantization.shufflenetv2 import (
8+
QuantizableShuffleNetV2,
9+
_replace_relu,
10+
quantize_model,
11+
)
12+
from ...transforms.presets import ImageNetEval
13+
from .._api import Weights, WeightEntry
14+
from .._meta import _IMAGENET_CATEGORIES
15+
from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights
16+
17+
18+
__all__ = [
19+
"QuantizableShuffleNetV2",
20+
"QuantizedShuffleNetV2_x0_5Weights",
21+
"QuantizedShuffleNetV2_x1_0Weights",
22+
"shufflenet_v2_x0_5",
23+
"shufflenet_v2_x1_0",
24+
]
25+
26+
27+
def _shufflenetv2(
28+
stages_repeats: List[int],
29+
stages_out_channels: List[int],
30+
weights: Optional[Weights],
31+
progress: bool,
32+
quantize: bool,
33+
**kwargs: Any,
34+
) -> QuantizableShuffleNetV2:
35+
if weights is not None:
36+
kwargs["num_classes"] = len(weights.meta["categories"])
37+
if "backend" in weights.meta:
38+
kwargs["backend"] = weights.meta["backend"]
39+
backend = kwargs.pop("backend", "fbgemm")
40+
41+
model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
42+
_replace_relu(model)
43+
if quantize:
44+
quantize_model(model, backend)
45+
46+
if weights is not None:
47+
model.load_state_dict(weights.state_dict(progress=progress))
48+
49+
return model
50+
51+
52+
_common_meta = {
53+
"size": (224, 224),
54+
"categories": _IMAGENET_CATEGORIES,
55+
"interpolation": InterpolationMode.BILINEAR,
56+
"backend": "fbgemm",
57+
"quantization": "ptq",
58+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
59+
}
60+
61+
62+
class QuantizedShuffleNetV2_x0_5Weights(Weights):
63+
ImageNet1K_FBGEMM_Community = WeightEntry(
64+
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
65+
transforms=partial(ImageNetEval, crop_size=224),
66+
meta={
67+
**_common_meta,
68+
"unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_Community,
69+
"acc@1": 57.972,
70+
"acc@5": 79.780,
71+
},
72+
)
73+
74+
75+
class QuantizedShuffleNetV2_x1_0Weights(Weights):
76+
ImageNet1K_FBGEMM_Community = WeightEntry(
77+
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
78+
transforms=partial(ImageNetEval, crop_size=224),
79+
meta={
80+
**_common_meta,
81+
"unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_Community,
82+
"acc@1": 68.360,
83+
"acc@5": 87.582,
84+
},
85+
)
86+
87+
88+
def shufflenet_v2_x0_5(
89+
weights: Optional[Union[QuantizedShuffleNetV2_x0_5Weights, ShuffleNetV2_x0_5Weights]] = None,
90+
progress: bool = True,
91+
quantize: bool = False,
92+
**kwargs: Any,
93+
) -> QuantizableShuffleNetV2:
94+
if "pretrained" in kwargs:
95+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
96+
if kwargs.pop("pretrained"):
97+
weights = (
98+
QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community
99+
if quantize
100+
else ShuffleNetV2_x0_5Weights.ImageNet1K_Community
101+
)
102+
else:
103+
weights = None
104+
105+
if quantize:
106+
weights = QuantizedShuffleNetV2_x0_5Weights.verify(weights)
107+
else:
108+
weights = ShuffleNetV2_x0_5Weights.verify(weights)
109+
110+
return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs)
111+
112+
113+
def shufflenet_v2_x1_0(
114+
weights: Optional[Union[QuantizedShuffleNetV2_x1_0Weights, ShuffleNetV2_x1_0Weights]] = None,
115+
progress: bool = True,
116+
quantize: bool = False,
117+
**kwargs: Any,
118+
) -> QuantizableShuffleNetV2:
119+
if "pretrained" in kwargs:
120+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
121+
if kwargs.pop("pretrained"):
122+
weights = (
123+
QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community
124+
if quantize
125+
else ShuffleNetV2_x1_0Weights.ImageNet1K_Community
126+
)
127+
else:
128+
weights = None
129+
130+
if quantize:
131+
weights = QuantizedShuffleNetV2_x1_0Weights.verify(weights)
132+
else:
133+
weights = ShuffleNetV2_x1_0Weights.verify(weights)
134+
135+
return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs)

0 commit comments

Comments
 (0)