18
18
__all__ = [
19
19
"SwinTransformer" ,
20
20
"Swin_T_Weights" ,
21
+ "Swin_S_Weights" ,
22
+ "Swin_B_Weights" ,
21
23
"swin_t" ,
24
+ "swin_s" ,
25
+ "swin_b" ,
22
26
]
23
27
24
28
@@ -408,9 +412,9 @@ def _swin_transformer(
408
412
409
413
class Swin_T_Weights (WeightsEnum ):
410
414
IMAGENET1K_V1 = Weights (
411
- url = "https://download.pytorch.org/models/swin_t-81486767 .pth" ,
415
+ url = "https://download.pytorch.org/models/swin_t-4c37bd06 .pth" ,
412
416
transforms = partial (
413
- ImageClassification , crop_size = 224 , resize_size = 238 , interpolation = InterpolationMode .BICUBIC
417
+ ImageClassification , crop_size = 224 , resize_size = 232 , interpolation = InterpolationMode .BICUBIC
414
418
),
415
419
meta = {
416
420
** _COMMON_META ,
@@ -419,11 +423,57 @@ class Swin_T_Weights(WeightsEnum):
419
423
"recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer" ,
420
424
"_metrics" : {
421
425
"ImageNet-1K" : {
422
- "acc@1" : 81.358 ,
423
- "acc@5" : 95.526 ,
426
+ "acc@1" : 81.474 ,
427
+ "acc@5" : 95.776 ,
428
+ }
429
+ },
430
+ "_docs" : """These weights reproduce closely the results of the paper using a similar training recipe.""" ,
431
+ },
432
+ )
433
+ DEFAULT = IMAGENET1K_V1
434
+
435
+
436
+ class Swin_S_Weights (WeightsEnum ):
437
+ IMAGENET1K_V1 = Weights (
438
+ url = "https://download.pytorch.org/models/swin_s-30134662.pth" ,
439
+ transforms = partial (
440
+ ImageClassification , crop_size = 224 , resize_size = 246 , interpolation = InterpolationMode .BICUBIC
441
+ ),
442
+ meta = {
443
+ ** _COMMON_META ,
444
+ "num_params" : 49606258 ,
445
+ "min_size" : (224 , 224 ),
446
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer" ,
447
+ "_metrics" : {
448
+ "ImageNet-1K" : {
449
+ "acc@1" : 83.196 ,
450
+ "acc@5" : 96.360 ,
451
+ }
452
+ },
453
+ "_docs" : """These weights reproduce closely the results of the paper using a similar training recipe.""" ,
454
+ },
455
+ )
456
+ DEFAULT = IMAGENET1K_V1
457
+
458
+
459
+ class Swin_B_Weights (WeightsEnum ):
460
+ IMAGENET1K_V1 = Weights (
461
+ url = "https://download.pytorch.org/models/swin_b-1f1feb5c.pth" ,
462
+ transforms = partial (
463
+ ImageClassification , crop_size = 224 , resize_size = 238 , interpolation = InterpolationMode .BICUBIC
464
+ ),
465
+ meta = {
466
+ ** _COMMON_META ,
467
+ "num_params" : 87768224 ,
468
+ "min_size" : (224 , 224 ),
469
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer" ,
470
+ "_metrics" : {
471
+ "ImageNet-1K" : {
472
+ "acc@1" : 83.582 ,
473
+ "acc@5" : 96.640 ,
424
474
}
425
475
},
426
- "_docs" : """These weights reproduce closely the results of the paper using its training recipe.""" ,
476
+ "_docs" : """These weights reproduce closely the results of the paper using a similar training recipe.""" ,
427
477
},
428
478
)
429
479
DEFAULT = IMAGENET1K_V1
@@ -463,3 +513,75 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
463
513
progress = progress ,
464
514
** kwargs ,
465
515
)
516
+
517
+
518
+ def swin_s (* , weights : Optional [Swin_S_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> SwinTransformer :
519
+ """
520
+ Constructs a swin_small architecture from
521
+ `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
522
+
523
+ Args:
524
+ weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
525
+ pretrained weights to use. See
526
+ :class:`~torchvision.models.Swin_S_Weights` below for
527
+ more details, and possible values. By default, no pre-trained
528
+ weights are used.
529
+ progress (bool, optional): If True, displays a progress bar of the
530
+ download to stderr. Default is True.
531
+ **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
532
+ base class. Please refer to the `source code
533
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
534
+ for more details about this class.
535
+
536
+ .. autoclass:: torchvision.models.Swin_S_Weights
537
+ :members:
538
+ """
539
+ weights = Swin_S_Weights .verify (weights )
540
+
541
+ return _swin_transformer (
542
+ patch_size = 4 ,
543
+ embed_dim = 96 ,
544
+ depths = [2 , 2 , 18 , 2 ],
545
+ num_heads = [3 , 6 , 12 , 24 ],
546
+ window_size = 7 ,
547
+ stochastic_depth_prob = 0.3 ,
548
+ weights = weights ,
549
+ progress = progress ,
550
+ ** kwargs ,
551
+ )
552
+
553
+
554
+ def swin_b (* , weights : Optional [Swin_B_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> SwinTransformer :
555
+ """
556
+ Constructs a swin_base architecture from
557
+ `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
558
+
559
+ Args:
560
+ weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
561
+ pretrained weights to use. See
562
+ :class:`~torchvision.models.Swin_B_Weights` below for
563
+ more details, and possible values. By default, no pre-trained
564
+ weights are used.
565
+ progress (bool, optional): If True, displays a progress bar of the
566
+ download to stderr. Default is True.
567
+ **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
568
+ base class. Please refer to the `source code
569
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
570
+ for more details about this class.
571
+
572
+ .. autoclass:: torchvision.models.Swin_B_Weights
573
+ :members:
574
+ """
575
+ weights = Swin_B_Weights .verify (weights )
576
+
577
+ return _swin_transformer (
578
+ patch_size = 4 ,
579
+ embed_dim = 128 ,
580
+ depths = [2 , 2 , 18 , 2 ],
581
+ num_heads = [4 , 8 , 16 , 32 ],
582
+ window_size = 7 ,
583
+ stochastic_depth_prob = 0.5 ,
584
+ weights = weights ,
585
+ progress = progress ,
586
+ ** kwargs ,
587
+ )
0 commit comments