Skip to content

Commit 9d9cfab

Browse files
jdsgomesdatumbox
andauthored
add swin_s and swin_b variants and improved swin_t (#6048)
* add swin_s and swin_b variants * fix swin_b params * fix n parameters and acc numbers * adding missing acc numbers * apply ufmt * Updating `_docs` to reflect training recipe * Fix exted for swin_b Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 52a4480 commit 9d9cfab

File tree

5 files changed

+133
-9
lines changed

5 files changed

+133
-9
lines changed

docs/source/models/swin_transformer.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,5 @@ more details about this class.
2323
:template: function.rst
2424

2525
swin_t
26+
swin_s
27+
swin_b

references/classification/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,14 @@ and `--batch_size 64`.
228228
### SwinTransformer
229229
```
230230
torchrun --nproc_per_node=8 train.py\
231-
--model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\
232-
--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\
233-
--lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\
234-
--clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra
231+
--model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 224
235232
```
233+
Here `$MODEL` is one of `swin_t`, `swin_s` or `swin_b`.
236234
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
237235

238236

237+
238+
239239
### ShuffleNet V2
240240
```
241241
torchrun --nproc_per_node=8 train.py \
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.

torchvision/models/swin_transformer.py

Lines changed: 127 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
__all__ = [
1919
"SwinTransformer",
2020
"Swin_T_Weights",
21+
"Swin_S_Weights",
22+
"Swin_B_Weights",
2123
"swin_t",
24+
"swin_s",
25+
"swin_b",
2226
]
2327

2428

@@ -408,9 +412,9 @@ def _swin_transformer(
408412

409413
class Swin_T_Weights(WeightsEnum):
410414
IMAGENET1K_V1 = Weights(
411-
url="https://download.pytorch.org/models/swin_t-81486767.pth",
415+
url="https://download.pytorch.org/models/swin_t-4c37bd06.pth",
412416
transforms=partial(
413-
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
417+
ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
414418
),
415419
meta={
416420
**_COMMON_META,
@@ -419,11 +423,57 @@ class Swin_T_Weights(WeightsEnum):
419423
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
420424
"_metrics": {
421425
"ImageNet-1K": {
422-
"acc@1": 81.358,
423-
"acc@5": 95.526,
426+
"acc@1": 81.474,
427+
"acc@5": 95.776,
428+
}
429+
},
430+
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
431+
},
432+
)
433+
DEFAULT = IMAGENET1K_V1
434+
435+
436+
class Swin_S_Weights(WeightsEnum):
437+
IMAGENET1K_V1 = Weights(
438+
url="https://download.pytorch.org/models/swin_s-30134662.pth",
439+
transforms=partial(
440+
ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC
441+
),
442+
meta={
443+
**_COMMON_META,
444+
"num_params": 49606258,
445+
"min_size": (224, 224),
446+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
447+
"_metrics": {
448+
"ImageNet-1K": {
449+
"acc@1": 83.196,
450+
"acc@5": 96.360,
451+
}
452+
},
453+
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
454+
},
455+
)
456+
DEFAULT = IMAGENET1K_V1
457+
458+
459+
class Swin_B_Weights(WeightsEnum):
460+
IMAGENET1K_V1 = Weights(
461+
url="https://download.pytorch.org/models/swin_b-1f1feb5c.pth",
462+
transforms=partial(
463+
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
464+
),
465+
meta={
466+
**_COMMON_META,
467+
"num_params": 87768224,
468+
"min_size": (224, 224),
469+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
470+
"_metrics": {
471+
"ImageNet-1K": {
472+
"acc@1": 83.582,
473+
"acc@5": 96.640,
424474
}
425475
},
426-
"_docs": """These weights reproduce closely the results of the paper using its training recipe.""",
476+
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
427477
},
428478
)
429479
DEFAULT = IMAGENET1K_V1
@@ -463,3 +513,75 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
463513
progress=progress,
464514
**kwargs,
465515
)
516+
517+
518+
def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
519+
"""
520+
Constructs a swin_small architecture from
521+
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
522+
523+
Args:
524+
weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
525+
pretrained weights to use. See
526+
:class:`~torchvision.models.Swin_S_Weights` below for
527+
more details, and possible values. By default, no pre-trained
528+
weights are used.
529+
progress (bool, optional): If True, displays a progress bar of the
530+
download to stderr. Default is True.
531+
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
532+
base class. Please refer to the `source code
533+
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
534+
for more details about this class.
535+
536+
.. autoclass:: torchvision.models.Swin_S_Weights
537+
:members:
538+
"""
539+
weights = Swin_S_Weights.verify(weights)
540+
541+
return _swin_transformer(
542+
patch_size=4,
543+
embed_dim=96,
544+
depths=[2, 2, 18, 2],
545+
num_heads=[3, 6, 12, 24],
546+
window_size=7,
547+
stochastic_depth_prob=0.3,
548+
weights=weights,
549+
progress=progress,
550+
**kwargs,
551+
)
552+
553+
554+
def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
555+
"""
556+
Constructs a swin_base architecture from
557+
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
558+
559+
Args:
560+
weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
561+
pretrained weights to use. See
562+
:class:`~torchvision.models.Swin_B_Weights` below for
563+
more details, and possible values. By default, no pre-trained
564+
weights are used.
565+
progress (bool, optional): If True, displays a progress bar of the
566+
download to stderr. Default is True.
567+
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
568+
base class. Please refer to the `source code
569+
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
570+
for more details about this class.
571+
572+
.. autoclass:: torchvision.models.Swin_B_Weights
573+
:members:
574+
"""
575+
weights = Swin_B_Weights.verify(weights)
576+
577+
return _swin_transformer(
578+
patch_size=4,
579+
embed_dim=128,
580+
depths=[2, 2, 18, 2],
581+
num_heads=[4, 8, 16, 32],
582+
window_size=7,
583+
stochastic_depth_prob=0.5,
584+
weights=weights,
585+
progress=progress,
586+
**kwargs,
587+
)

0 commit comments

Comments
 (0)