Skip to content

Commit 9310325

Browse files
committed
Refactoring weight info.
1 parent aa82cf1 commit 9310325

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

torchvision/models/efficientnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
4343
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
4444
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
45+
# Temporary TF weights
46+
"efficientnet_v2_s": "https://download.pytorch.org/models/efficientnet_v2_s-tmp.pth",
4547
}
4648

4749

@@ -176,7 +178,6 @@ def __init__(
176178
cnf: FusedMBConvConfig,
177179
stochastic_depth_prob: float,
178180
norm_layer: Callable[..., nn.Module],
179-
**kwargs: Any,
180181
) -> None:
181182
super().__init__()
182183

torchvision/prototype/models/efficientnet.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,30 @@ def _efficientnet(
5757
return model
5858

5959

60-
_COMMON_META_V1 = {
60+
_COMMON_META = {
6161
"task": "image_classification",
62-
"architecture": "EfficientNet",
63-
"publication_year": 2019,
64-
"min_size": (1, 1),
6562
"categories": _IMAGENET_CATEGORIES,
6663
"interpolation": InterpolationMode.BICUBIC,
6764
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
6865
}
6966

7067

68+
_COMMON_META_V1 = {
69+
**_COMMON_META,
70+
"architecture": "EfficientNet",
71+
"publication_year": 2019,
72+
"min_size": (1, 1),
73+
}
74+
75+
76+
_COMMON_META_V2 = {
77+
**_COMMON_META,
78+
"architecture": "EfficientNetV2",
79+
"publication_year": 2021,
80+
"min_size": (33, 33),
81+
}
82+
83+
7184
class EfficientNet_B0_Weights(WeightsEnum):
7285
IMAGENET1K_V1 = Weights(
7386
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
@@ -202,7 +215,25 @@ class EfficientNet_B7_Weights(WeightsEnum):
202215

203216

204217
class EfficientNet_V2_S_Weights(WeightsEnum):
205-
pass
218+
IMAGENET1K_V1 = Weights(
219+
url="https://download.pytorch.org/models/efficientnet_v2_s-tmp.pth",
220+
transforms=partial(
221+
ImageNetEval,
222+
crop_size=384,
223+
resize_size=384,
224+
interpolation=InterpolationMode.BICUBIC,
225+
mean=(0.5, 0.5, 0.5),
226+
std=(0.5, 0.5, 0.5),
227+
),
228+
meta={
229+
**_COMMON_META_V2,
230+
"num_params": 21458488,
231+
"size": (384, 384),
232+
"acc@1": 83.152,
233+
"acc@5": 96.400,
234+
},
235+
)
236+
DEFAULT = IMAGENET1K_V1
206237

207238

208239
class EfficientNet_V2_M_Weights(WeightsEnum):
@@ -317,7 +348,7 @@ def efficientnet_b7(
317348
)
318349

319350

320-
@handle_legacy_interface(weights=("pretrained", None))
351+
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
321352
def efficientnet_v2_s(
322353
*, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
323354
) -> EfficientNet:

0 commit comments

Comments
 (0)