From 166aacb8781d45384a5dec342079a29adf7c8513 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 21 Jan 2022 18:30:05 +0000 Subject: [PATCH 01/10] Refactor model builder --- torchvision/prototype/models/convnext.py | 30 ++++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 788dcbc2cd1..64601e53ef0 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -177,6 +177,24 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) +def _convnext( + block_setting: List[CNBlockConfig], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> ConvNeXt: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + class ConvNeXt_Tiny_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", @@ -200,7 +218,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1)) def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - r"""ConvNeXt model architecture from the + r"""ConvNeXt Tiny model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: @@ -209,9 +227,6 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: """ weights = ConvNeXt_Tiny_Weights.verify(weights) - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), @@ -219,9 +234,4 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) - model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) From e0c56b5f1f2b8ccd01ba48ae9321914edd37ad5b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 21 Jan 2022 18:45:04 +0000 Subject: [PATCH 02/10] Add 3 more convnext variants. --- .../ModelTester.test_convnext_base_expect.pkl | Bin 0 -> 939 bytes ...ModelTester.test_convnext_large_expect.pkl | Bin 0 -> 939 bytes ...ModelTester.test_convnext_small_expect.pkl | Bin 0 -> 939 bytes torchvision/prototype/models/convnext.py | 70 +++++++++++++++++- 4 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 test/expect/ModelTester.test_convnext_base_expect.pkl create mode 100644 test/expect/ModelTester.test_convnext_large_expect.pkl create mode 100644 test/expect/ModelTester.test_convnext_small_expect.pkl diff --git a/test/expect/ModelTester.test_convnext_base_expect.pkl b/test/expect/ModelTester.test_convnext_base_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..09148743b100e015b5978ce84f588dc38d3691da GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5@$*GF1x&q_v|*t9JgKHzj5C|saXE>#vJ;n^QUWhy7W* z_gvgYyP69p?WQm9u~(|svHvSEZQlv*6?;p$`S*t=7TL|_%(Bx=X0iX(m%s0qNT0p= zoTvK^FH+clLgDY;TdDK+A8}!}Tfp_*E^ft!eI5&$>;$ei*_>6Gu|MhcihU`%Ec;(J z&e%Vdzjxp3-d+1%_q6Pu(DRE!8azRVRQU1;V&Pi=V+79$ICoMa95aaC0&zGPsaJ4#G6%0@-|V zUV11K&{hx*@MZ*2@HB}WhXNo86o8&W(RCyHi4R5RD zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5{JX{&pzpeEA69L#qBpFPO)e7bFmj$H(~#aK9>E@Dktuj$a-z3d9BA@W{Smr zu?5l&Kde66K3ipI@3oi5?uX2*y{a)k_UC-{x8L>Ka-Z?VhW$UfC)qdL&D-;XYs>yW z%8m96$vgMISQu-6X#V>B5*7mX3b`-#oY|tif8D$V_Bm!Tb}J9|?A!OP$=+pW>wdPm z=5|3lv+bAUJhy!lws3z=%742P2@Lk-Y=7;SncA(Jey1X#q1TdsP7;2j35f0CXwS%03?9|&{HV7Ze&04q3C=C?(c~y%Ind!t_GJ zAi$fAO$Vw-j#(G39F&+r07h?za2Y0nJqhwI8z^ructRC`GC_bhD;r3R83;k@A!-5T C@A|L+ literal 0 HcmV?d00001 diff --git a/test/expect/ModelTester.test_convnext_small_expect.pkl b/test/expect/ModelTester.test_convnext_small_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f5bf3b800bfbaa5664602fc7d5973b1bff55e005 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=U&gsQte6GxtBQmDs=KNYmaz^P_glO6~Vg;We_qe)*~0jf+0}{9a1gF}cm$ zpL=+&UChRv`^uN5+Hd@?V87+xwEe7h{`=+DDerHaCc95-exiMa>O;E)d-V2hv3Xo&7?_2lnUB*k{|#D`5A1 z##FoY7gpLWP(NfRY;(lUXS%z++&a5`etjbQn-(s!>r2|U|IA7^dxj-UyBi8pty_|L z_Ih3Gvx|C>Z3hZ1{VVmoCx9UZ!ni|=pTQa)T4kw4#lTo_b229~xR62)!ZhXr*?e(c zdMFdnRuB&GW&~02G>IIC0w4(#fSy9pbtC(U4@KuIAP-r$z5%*kWLNQ{=#>Dv5T+Lz z1_9n|Y&uXya?HAL<)Fk20x)_zgv&4q>`9P!*+6-N!4s+glnDa7S=m5h%s>cI4^ayM D-pKlk literal 0 HcmV?d00001 diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 64601e53ef0..650d5e73217 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -15,7 +15,17 @@ from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"] +__all__ = [ + "ConvNeXt", + "ConvNeXt_Tiny_Weights", + "ConvNeXt_Small_Weights", + "ConvNeXt_Base_Weights", + "ConvNeXt_Large_Weights", + "convnext_tiny", + "convnext_small", + "convnext_base", + "convnext_large", +] class LayerNorm2d(nn.LayerNorm): @@ -216,6 +226,18 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): default = ImageNet1K_V1 +class ConvNeXt_Small_Weights(WeightsEnum): + pass + + +class ConvNeXt_Base_Weights(WeightsEnum): + pass + + +class ConvNeXt_Large_Weights(WeightsEnum): + pass + + @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1)) def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Tiny model architecture from the @@ -235,3 +257,49 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", None)) +def convnext_small( + *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: + weights = ConvNeXt_Small_Weights.verify(weights) + + block_setting = [ + CNBlockConfig(96, 192, 3), + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 27), + CNBlockConfig(768, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", None)) +def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: + weights = ConvNeXt_Base_Weights.verify(weights) + + block_setting = [ + CNBlockConfig(128, 256, 3), + CNBlockConfig(256, 512, 3), + CNBlockConfig(512, 1024, 27), + CNBlockConfig(1024, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", None)) +def convnext_large( + *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: + weights = ConvNeXt_Large_Weights.verify(weights) + + block_setting = [ + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 3), + CNBlockConfig(768, 1536, 27), + CNBlockConfig(1536, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) From 1a0352498ad9749941eff4314027a9384870a9ae Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Jan 2022 14:28:55 +0000 Subject: [PATCH 03/10] Adding weights for convnext_small. --- docs/source/models.rst | 1 + references/classification/README.md | 5 ++-- torchvision/prototype/models/convnext.py | 35 ++++++++++++++++-------- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 82eb3170e78..794d2df2530 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -249,6 +249,7 @@ vit_b_32 75.912 92.466 vit_l_16 79.662 94.638 vit_l_32 76.972 93.070 convnext_tiny (prototype) 82.520 96.146 +convnext_small (prototype) 83.616 96.650 ================================ ============= ============= diff --git a/references/classification/README.md b/references/classification/README.md index 0fb27eac7cc..e75336f23ca 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -201,11 +201,12 @@ and `--batch_size 64`. ### ConvNeXt ``` torchrun --nproc_per_node=8 train.py\ ---model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \ +--model $MODEL --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \ --lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \ --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \ ---train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4 +--train-crop-size 176 --model-ema --val-resize-size 232 --ra-sampler --ra-reps 4 ``` +Here `$MODEL` is one of `convnext_tiny`, `convnext_small`, `convnext_base` and `convnext_large`. Note that each variant had its `--val-resize-size` optimized in a post-training step, see their `Weights` entry for their exact value. Note that the above command corresponds to training on a single node with 8 GPUs. For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index f9429ce8a40..b368c94cb27 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -205,29 +205,42 @@ def _convnext( return model +_COMMON_META = { + "task": "image_classification", + "architecture": "ConvNeXt", + "publication_year": 2022, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", +} + + class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=236), meta={ - "task": "image_classification", - "architecture": "ConvNeXt", - "publication_year": 2022, "num_params": 28589128, - "size": (224, 224), - "min_size": (32, 32), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", "acc@1": 82.520, "acc@5": 96.146, }, ) DEFAULT = IMAGENET1K_V1 - + class ConvNeXt_Small_Weights(WeightsEnum): - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth", + transforms=partial(ImageNetEval, crop_size=224, resize_size=230), + meta={ + "num_params": 28589128, + "acc@1": 83.616, + "acc@5": 96.650, + }, + ) + DEFAULT = IMAGENET1K_V1 class ConvNeXt_Base_Weights(WeightsEnum): @@ -259,7 +272,7 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) def convnext_small( *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: From 5c5b1a98263bbe6102b137beaa08a2792346a05c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Jan 2022 14:44:45 +0000 Subject: [PATCH 04/10] Fix minor bug. --- torchvision/prototype/models/convnext.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index b368c94cb27..cd6c452b014 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -222,6 +222,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=236), meta={ + **_COMMON_META, "num_params": 28589128, "acc@1": 82.520, "acc@5": 96.146, @@ -235,6 +236,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=230), meta={ + **_COMMON_META, "num_params": 28589128, "acc@1": 83.616, "acc@5": 96.650, From 7652819e30d2efb3d21d7f0a02b509abce065310 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Jan 2022 15:13:28 +0000 Subject: [PATCH 05/10] Fix number of parameters for small model. --- torchvision/prototype/models/convnext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index cd6c452b014..f9a2b584339 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -237,7 +237,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): transforms=partial(ImageNetEval, crop_size=224, resize_size=230), meta={ **_COMMON_META, - "num_params": 28589128, + "num_params": 50223688, "acc@1": 83.616, "acc@5": 96.650, }, From db63f45dc82e6105390c9c90d52405109404085b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 28 Jan 2022 12:20:18 +0000 Subject: [PATCH 06/10] Adding weights for the base variant. --- docs/source/models.rst | 1 + torchvision/prototype/models/convnext.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 794d2df2530..7b6507c83ee 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -250,6 +250,7 @@ vit_l_16 79.662 94.638 vit_l_32 76.972 93.070 convnext_tiny (prototype) 82.520 96.146 convnext_small (prototype) 83.616 96.650 +convnext_base (prototype) 84.062 96.870 ================================ ============= ============= diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index f9a2b584339..720d592cc04 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -246,7 +246,17 @@ class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum): - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth", + transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 88591464, + "acc@1": 84.062, + "acc@5": 96.870, + }, + ) + DEFAULT = IMAGENET1K_V1 class ConvNeXt_Large_Weights(WeightsEnum): @@ -290,7 +300,7 @@ def convnext_small( return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: weights = ConvNeXt_Base_Weights.verify(weights) From 30a6b6d2b63514fc069e1be28b466aa27579d9f9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 31 Jan 2022 13:17:44 +0000 Subject: [PATCH 07/10] Adding weights for the large variant. --- docs/source/models.rst | 1 + torchvision/prototype/models/convnext.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 7b6507c83ee..a37748b0a3b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -251,6 +251,7 @@ vit_l_32 76.972 93.070 convnext_tiny (prototype) 82.520 96.146 convnext_small (prototype) 83.616 96.650 convnext_base (prototype) 84.062 96.870 +convnext_large (prototype) 84.414 96.976 ================================ ============= ============= diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 720d592cc04..2aba8701268 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -260,7 +260,17 @@ class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum): - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth", + transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 197767336, + "acc@1": 84.414, + "acc@5": 96.976, + }, + ) + DEFAULT = IMAGENET1K_V1 @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) @@ -314,7 +324,7 @@ def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) def convnext_large( *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: From cafa02dccebf5e1fbb749cea3cfeab6190e2f463 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 31 Jan 2022 18:20:33 +0000 Subject: [PATCH 08/10] Simplify LayerNorm2d implementation. --- torchvision/prototype/models/convnext.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 2aba8701268..5dd1762727a 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -29,17 +29,11 @@ class LayerNorm2d(nn.LayerNorm): - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.channels_last = kwargs.pop("channels_last", False) - super().__init__(*args, **kwargs) - def forward(self, x: Tensor) -> Tensor: # TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 - if not self.channels_last: - x = x.permute(0, 2, 3, 1) + x = x.permute(0, 2, 3, 1) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - if not self.channels_last: - x = x.permute(0, 3, 1, 2) + x = x.permute(0, 3, 1, 2) return x From 290440b0560f8bea65f8824f9e3a1016df88463e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 31 Jan 2022 19:22:58 +0000 Subject: [PATCH 09/10] Optimize speed of CNBlock. --- torchvision/prototype/models/convnext.py | 52 +++++++++++++----------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 5dd1762727a..72a0a338852 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -37,29 +37,35 @@ def forward(self, x: Tensor) -> Tensor: return x +class Permute(nn.Module): + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.permute(x, self.dims) + + class CNBlock(nn.Module): def __init__( - self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module] + self, + dim, + layer_scale: float, + stochastic_depth_prob: float, + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.block = nn.Sequential( - ConvNormActivation( - dim, - dim, - kernel_size=7, - groups=dim, - norm_layer=norm_layer, - activation_layer=None, - bias=True, - ), - ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None), - ConvNormActivation( - 4 * dim, - dim, - kernel_size=1, - norm_layer=None, - activation_layer=None, - ), + nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True), + Permute([0, 2, 3, 1]), + norm_layer(dim), + nn.Linear(in_features=dim, out_features=4 * dim, bias=True), + nn.GELU(), + nn.Linear(in_features=4 * dim, out_features=dim, bias=True), + Permute([0, 3, 1, 2]), ) self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") @@ -142,7 +148,7 @@ def __init__( for _ in range(cnf.num_layers): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) - stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) + stage.append(block(cnf.input_channels, layer_scale, sd_prob)) stage_block_id += 1 layers.append(nn.Sequential(*stage)) if cnf.out_channels is not None: @@ -213,7 +219,7 @@ def _convnext( class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", + url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", # TODO: repackage transforms=partial(ImageNetEval, crop_size=224, resize_size=236), meta={ **_COMMON_META, @@ -227,7 +233,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth", + url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth", # TODO: repackage transforms=partial(ImageNetEval, crop_size=224, resize_size=230), meta={ **_COMMON_META, @@ -241,7 +247,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth", + url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth", # TODO: repackage transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, @@ -255,7 +261,7 @@ class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth", + url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth", # TODO: repackage transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, From fd5c99d2e13f4deeb011ad4a85af8c524b9aa44f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Feb 2022 09:56:44 +0000 Subject: [PATCH 10/10] Repackage weights. --- torchvision/prototype/models/convnext.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 72a0a338852..f8f91307ed1 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -30,7 +30,6 @@ class LayerNorm2d(nn.LayerNorm): def forward(self, x: Tensor) -> Tensor: - # TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 x = x.permute(0, 2, 3, 1) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) @@ -219,7 +218,7 @@ def _convnext( class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", # TODO: repackage + url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=236), meta={ **_COMMON_META, @@ -233,7 +232,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth", # TODO: repackage + url="https://download.pytorch.org/models/convnext_small-0c510722.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=230), meta={ **_COMMON_META, @@ -247,7 +246,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth", # TODO: repackage + url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, @@ -261,7 +260,7 @@ class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth", # TODO: repackage + url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META,