Skip to content

Commit 290440b

Browse files
committed
Optimize speed of CNBlock.
1 parent 2bbb112 commit 290440b

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

torchvision/prototype/models/convnext.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,35 @@ def forward(self, x: Tensor) -> Tensor:
3737
return x
3838

3939

40+
class Permute(nn.Module):
41+
def __init__(self, dims: List[int]):
42+
super().__init__()
43+
self.dims = dims
44+
45+
def forward(self, x):
46+
return torch.permute(x, self.dims)
47+
48+
4049
class CNBlock(nn.Module):
4150
def __init__(
42-
self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]
51+
self,
52+
dim,
53+
layer_scale: float,
54+
stochastic_depth_prob: float,
55+
norm_layer: Optional[Callable[..., nn.Module]] = None,
4356
) -> None:
4457
super().__init__()
58+
if norm_layer is None:
59+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
60+
4561
self.block = nn.Sequential(
46-
ConvNormActivation(
47-
dim,
48-
dim,
49-
kernel_size=7,
50-
groups=dim,
51-
norm_layer=norm_layer,
52-
activation_layer=None,
53-
bias=True,
54-
),
55-
ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None),
56-
ConvNormActivation(
57-
4 * dim,
58-
dim,
59-
kernel_size=1,
60-
norm_layer=None,
61-
activation_layer=None,
62-
),
62+
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
63+
Permute([0, 2, 3, 1]),
64+
norm_layer(dim),
65+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
66+
nn.GELU(),
67+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
68+
Permute([0, 3, 1, 2]),
6369
)
6470
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
6571
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
@@ -142,7 +148,7 @@ def __init__(
142148
for _ in range(cnf.num_layers):
143149
# adjust stochastic depth probability based on the depth of the stage block
144150
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
145-
stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer))
151+
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
146152
stage_block_id += 1
147153
layers.append(nn.Sequential(*stage))
148154
if cnf.out_channels is not None:
@@ -213,7 +219,7 @@ def _convnext(
213219

214220
class ConvNeXt_Tiny_Weights(WeightsEnum):
215221
IMAGENET1K_V1 = Weights(
216-
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth",
222+
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", # TODO: repackage
217223
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
218224
meta={
219225
**_COMMON_META,
@@ -227,7 +233,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
227233

228234
class ConvNeXt_Small_Weights(WeightsEnum):
229235
IMAGENET1K_V1 = Weights(
230-
url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth",
236+
url="https://download.pytorch.org/models/convnext_small-9aa23d28.pth", # TODO: repackage
231237
transforms=partial(ImageNetEval, crop_size=224, resize_size=230),
232238
meta={
233239
**_COMMON_META,
@@ -241,7 +247,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):
241247

242248
class ConvNeXt_Base_Weights(WeightsEnum):
243249
IMAGENET1K_V1 = Weights(
244-
url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth",
250+
url="https://download.pytorch.org/models/convnext_base-3b9f985d.pth", # TODO: repackage
245251
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
246252
meta={
247253
**_COMMON_META,
@@ -255,7 +261,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):
255261

256262
class ConvNeXt_Large_Weights(WeightsEnum):
257263
IMAGENET1K_V1 = Weights(
258-
url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth",
264+
url="https://download.pytorch.org/models/convnext_large-d73f62ac.pth", # TODO: repackage
259265
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
260266
meta={
261267
**_COMMON_META,

0 commit comments

Comments
 (0)