Skip to content

Commit ccc0a92

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Add shufflenetv2 1.5 and 2.0 weights (#5906)
Summary: * Add shufflenetv2 1.5 and 2.0 weights * Update recipe * Add to docs * Use resize_size=232 for eval and update the result * Add quantized shufflenetv2 large * Update docs and readme * Format with ufmt * Add to hubconf.py * Update readme for classification reference * Fix reference classification readme * Fix typo on readme * Update reference/classification/readme Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095677 fbshipit-source-id: 74a575c6272df397852dba325f9c1b1e5a1c0231
1 parent 3156a3e commit ccc0a92

File tree

5 files changed

+163
-6
lines changed

5 files changed

+163
-6
lines changed

docs/source/models.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ Densenet-201 76.896 93.370
176176
Densenet-161 77.138 93.560
177177
Inception v3 77.294 93.450
178178
GoogleNet 69.778 89.530
179-
ShuffleNet V2 x1.0 69.362 88.316
180179
ShuffleNet V2 x0.5 60.552 81.746
180+
ShuffleNet V2 x1.0 69.362 88.316
181+
ShuffleNet V2 x1.5 72.996 91.086
182+
ShuffleNet V2 x2.0 76.230 93.006
181183
MobileNet V2 71.878 90.286
182184
MobileNet V3 Large 74.042 91.340
183185
MobileNet V3 Small 67.668 87.402
@@ -499,8 +501,10 @@ Model Acc@1 Acc@5
499501
================================ ============= =============
500502
MobileNet V2 71.658 90.150
501503
MobileNet V3 Large 73.004 90.858
502-
ShuffleNet V2 x1.0 68.360 87.582
503504
ShuffleNet V2 x0.5 57.972 79.780
505+
ShuffleNet V2 x1.0 68.360 87.582
506+
ShuffleNet V2 x1.5 72.052 90.700
507+
ShuffleNet V2 x2.0 75.354 92.488
504508
ResNet 18 69.494 88.882
505509
ResNet 50 75.920 92.814
506510
ResNext 101 32x8d 78.986 94.480

hubconf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@
5959
deeplabv3_mobilenet_v3_large,
6060
lraspp_mobilenet_v3_large,
6161
)
62-
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
62+
from torchvision.models.shufflenetv2 import (
63+
shufflenet_v2_x0_5,
64+
shufflenet_v2_x1_0,
65+
shufflenet_v2_x1_5,
66+
shufflenet_v2_x2_0,
67+
)
6368
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
6469
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
6570
from torchvision.models.vision_transformer import (

references/classification/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,20 @@ torchrun --nproc_per_node=8 train.py\
236236
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
237237

238238

239+
### ShuffleNet V2
240+
```
241+
torchrun --nproc_per_node=8 train.py \
242+
--batch-size=128 \
243+
--lr=0.5 --lr-scheduler=cosineannealinglr --lr-warmup-epochs=5 --lr-warmup-method=linear \
244+
--auto-augment=ta_wide --epochs=600 --random-erase=0.1 --weight-decay=0.00002 \
245+
--norm-weight-decay=0.0 --label-smoothing=0.1 --mixup-alpha=0.2 --cutmix-alpha=1.0 \
246+
--train-crop-size=176 --model-ema --val-resize-size=232 --ra-sampler --ra-reps=4
247+
```
248+
Here `$MODEL` is either `shufflenet_v2_x1_5` or `shufflenet_v2_x2_0`.
249+
250+
The models `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0` were contributed by the community. See [PR-849](https://github.com/pytorch/vision/pull/849#issuecomment-483391686) for details.
251+
252+
239253
## Mixed precision training
240254
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp).
241255

@@ -263,6 +277,21 @@ python train_quantization.py --device='cpu' --post-training-quantize --backend='
263277
```
264278
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d`, `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0`.
265279

280+
### Quantized ShuffleNet V2
281+
282+
Here are commands that we use to quantized the `shufflenet_v2_x1_5` and `shufflenet_v2_x2_0` models.
283+
```
284+
# For shufflenet_v2_x1_5
285+
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \
286+
--model=shufflenet_v2_x1_5 --weights="ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1" \
287+
--train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/
288+
289+
# For shufflenet_v2_x2_0
290+
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \
291+
--model=shufflenet_v2_x2_0 --weights="ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1" \
292+
--train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/
293+
```
294+
266295
### QAT MobileNetV2
267296

268297
For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:

torchvision/models/quantization/shufflenetv2.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,25 @@
1010
from .._api import WeightsEnum, Weights
1111
from .._meta import _IMAGENET_CATEGORIES
1212
from .._utils import handle_legacy_interface, _ovewrite_named_param
13-
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
13+
from ..shufflenetv2 import (
14+
ShuffleNet_V2_X0_5_Weights,
15+
ShuffleNet_V2_X1_0_Weights,
16+
ShuffleNet_V2_X1_5_Weights,
17+
ShuffleNet_V2_X2_0_Weights,
18+
)
1419
from .utils import _fuse_modules, _replace_relu, quantize_model
1520

1621

1722
__all__ = [
1823
"QuantizableShuffleNetV2",
1924
"ShuffleNet_V2_X0_5_QuantizedWeights",
2025
"ShuffleNet_V2_X1_0_QuantizedWeights",
26+
"ShuffleNet_V2_X1_5_QuantizedWeights",
27+
"ShuffleNet_V2_X2_0_QuantizedWeights",
2128
"shufflenet_v2_x0_5",
2229
"shufflenet_v2_x1_0",
30+
"shufflenet_v2_x1_5",
31+
"shufflenet_v2_x2_0",
2332
]
2433

2534

@@ -143,6 +152,42 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
143152
DEFAULT = IMAGENET1K_FBGEMM_V1
144153

145154

155+
class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum):
156+
IMAGENET1K_FBGEMM_V1 = Weights(
157+
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_5_fbgemm-d7401f05.pth",
158+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
159+
meta={
160+
**_COMMON_META,
161+
"recipe": "https://github.com/pytorch/vision/pull/5906",
162+
"num_params": 3503624,
163+
"unquantized": ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1,
164+
"metrics": {
165+
"acc@1": 72.052,
166+
"acc@5": 90.700,
167+
},
168+
},
169+
)
170+
DEFAULT = IMAGENET1K_FBGEMM_V1
171+
172+
173+
class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum):
174+
IMAGENET1K_FBGEMM_V1 = Weights(
175+
url="https://download.pytorch.org/models/quantized/shufflenetv2_x2_0_fbgemm-5cac526c.pth",
176+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
177+
meta={
178+
**_COMMON_META,
179+
"recipe": "https://github.com/pytorch/vision/pull/5906",
180+
"num_params": 7393996,
181+
"unquantized": ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1,
182+
"metrics": {
183+
"acc@1": 75.354,
184+
"acc@5": 92.488,
185+
},
186+
},
187+
)
188+
DEFAULT = IMAGENET1K_FBGEMM_V1
189+
190+
146191
@handle_legacy_interface(
147192
weights=(
148193
"pretrained",
@@ -205,3 +250,51 @@ def shufflenet_v2_x1_0(
205250
return _shufflenetv2(
206251
[4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
207252
)
253+
254+
255+
def shufflenet_v2_x1_5(
256+
*,
257+
weights: Optional[Union[ShuffleNet_V2_X1_5_QuantizedWeights, ShuffleNet_V2_X1_5_Weights]] = None,
258+
progress: bool = True,
259+
quantize: bool = False,
260+
**kwargs: Any,
261+
) -> QuantizableShuffleNetV2:
262+
"""
263+
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
264+
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
265+
<https://arxiv.org/abs/1807.11164>`_.
266+
267+
Args:
268+
weights (ShuffleNet_V2_X1_5_QuantizedWeights or ShuffleNet_V2_X1_5_Weights, optional): The pretrained
269+
weights for the model
270+
progress (bool): If True, displays a progress bar of the download to stderr
271+
quantize (bool): If True, return a quantized version of the model
272+
"""
273+
weights = (ShuffleNet_V2_X1_5_QuantizedWeights if quantize else ShuffleNet_V2_X1_5_Weights).verify(weights)
274+
return _shufflenetv2(
275+
[4, 8, 4], [24, 176, 352, 704, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
276+
)
277+
278+
279+
def shufflenet_v2_x2_0(
280+
*,
281+
weights: Optional[Union[ShuffleNet_V2_X2_0_QuantizedWeights, ShuffleNet_V2_X2_0_Weights]] = None,
282+
progress: bool = True,
283+
quantize: bool = False,
284+
**kwargs: Any,
285+
) -> QuantizableShuffleNetV2:
286+
"""
287+
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
288+
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
289+
<https://arxiv.org/abs/1807.11164>`_.
290+
291+
Args:
292+
weights (ShuffleNet_V2_X2_0_QuantizedWeights or ShuffleNet_V2_X2_0_Weights, optional): The pretrained
293+
weights for the model
294+
progress (bool): If True, displays a progress bar of the download to stderr
295+
quantize (bool): If True, return a quantized version of the model
296+
"""
297+
weights = (ShuffleNet_V2_X2_0_QuantizedWeights if quantize else ShuffleNet_V2_X2_0_Weights).verify(weights)
298+
return _shufflenetv2(
299+
[4, 8, 4], [24, 244, 488, 976, 2048], weights=weights, progress=progress, quantize=quantize, **kwargs
300+
)

torchvision/models/shufflenetv2.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,37 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
223223

224224

225225
class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
226-
pass
226+
IMAGENET1K_V1 = Weights(
227+
url="https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth",
228+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
229+
meta={
230+
**_COMMON_META,
231+
"recipe": "https://github.com/pytorch/vision/pull/5906",
232+
"num_params": 3503624,
233+
"metrics": {
234+
"acc@1": 72.996,
235+
"acc@5": 91.086,
236+
},
237+
},
238+
)
239+
DEFAULT = IMAGENET1K_V1
227240

228241

229242
class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
230-
pass
243+
IMAGENET1K_V1 = Weights(
244+
url="https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth",
245+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
246+
meta={
247+
**_COMMON_META,
248+
"recipe": "https://github.com/pytorch/vision/pull/5906",
249+
"num_params": 7393996,
250+
"metrics": {
251+
"acc@1": 76.230,
252+
"acc@5": 93.006,
253+
},
254+
},
255+
)
256+
DEFAULT = IMAGENET1K_V1
231257

232258

233259
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))

0 commit comments

Comments
 (0)