Skip to content

Commit 1c175d8

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Cleanup namings of Multi-weights classes and enums (#5003)
Summary: * Rename classes Weights => WeightsEnum and WeightEntry => Weights. * Make enum values follow the naming convention `_V1`, `_V2` etc * Cleanup the Enum class naming conventions. * Add a test to check naming conventions. Reviewed By: NicolasHug Differential Revision: D32759196 fbshipit-source-id: 5348a432dc439cad21fbb1db507b21edcbeb7ece
1 parent ff7d301 commit 1c175d8

33 files changed

+730
-716
lines changed

test/test_prototype_models.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ def _get_original_model(model_fn):
1818
return module.__dict__[model_fn.__name__]
1919

2020

21+
def _get_parent_module(model_fn):
22+
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
23+
module = importlib.import_module(parent_module_name)
24+
return module
25+
26+
2127
def _build_model(fn, **kwargs):
2228
try:
2329
model = fn(**kwargs)
@@ -29,27 +35,42 @@ def _build_model(fn, **kwargs):
2935
return model.eval()
3036

3137

32-
def get_models_with_module_names(module):
33-
module_name = module.__name__.split(".")[-1]
34-
return [(fn, module_name) for fn in TM.get_models_from_module(module)]
35-
36-
3738
@pytest.mark.parametrize(
3839
"model_fn, name, weight",
3940
[
40-
(models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1),
41-
(models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2),
41+
(models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
42+
(models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2),
43+
(
44+
models.quantization.resnet50,
45+
"default",
46+
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
47+
),
4248
(
4349
models.quantization.resnet50,
44-
"ImageNet1K_FBGEMM_RefV1",
45-
models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1,
50+
"ImageNet1K_FBGEMM_V1",
51+
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
4652
),
4753
],
4854
)
4955
def test_get_weight(model_fn, name, weight):
5056
assert models._api.get_weight(model_fn, name) == weight
5157

5258

59+
@pytest.mark.parametrize(
60+
"model_fn",
61+
TM.get_models_from_module(models)
62+
+ TM.get_models_from_module(models.detection)
63+
+ TM.get_models_from_module(models.quantization)
64+
+ TM.get_models_from_module(models.segmentation)
65+
+ TM.get_models_from_module(models.video),
66+
)
67+
def test_naming_conventions(model_fn):
68+
model_name = model_fn.__name__
69+
module = _get_parent_module(model_fn)
70+
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
71+
assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name))
72+
73+
5374
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
5475
@pytest.mark.parametrize("dev", cpu_and_gpu())
5576
@run_if_test_with_prototype
@@ -85,16 +106,16 @@ def test_video_model(model_fn, dev):
85106

86107

87108
@pytest.mark.parametrize(
88-
"model_fn, module_name",
89-
get_models_with_module_names(models)
90-
+ get_models_with_module_names(models.detection)
91-
+ get_models_with_module_names(models.quantization)
92-
+ get_models_with_module_names(models.segmentation)
93-
+ get_models_with_module_names(models.video),
109+
"model_fn",
110+
TM.get_models_from_module(models)
111+
+ TM.get_models_from_module(models.detection)
112+
+ TM.get_models_from_module(models.quantization)
113+
+ TM.get_models_from_module(models.segmentation)
114+
+ TM.get_models_from_module(models.video),
94115
)
95116
@pytest.mark.parametrize("dev", cpu_and_gpu())
96117
@run_if_test_with_prototype
97-
def test_old_vs_new_factory(model_fn, module_name, dev):
118+
def test_old_vs_new_factory(model_fn, dev):
98119
defaults = {
99120
"models": {
100121
"input_shape": (1, 3, 224, 224),
@@ -114,6 +135,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
114135
},
115136
}
116137
model_name = model_fn.__name__
138+
module_name = model_fn.__module__.split(".")[-2]
117139
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
118140
input_shape = kwargs.pop("input_shape")
119141
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models

torchvision/prototype/models/_api.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from ..._internally_replaced_utils import load_state_dict_from_url
88

99

10-
__all__ = ["Weights", "WeightEntry", "get_weight"]
10+
__all__ = ["WeightsEnum", "Weights", "get_weight"]
1111

1212

1313
@dataclass
14-
class WeightEntry:
14+
class Weights:
1515
"""
1616
This class is used to group important attributes associated with the pre-trained weights.
1717
@@ -33,17 +33,17 @@ class WeightEntry:
3333
default: bool
3434

3535

36-
class Weights(Enum):
36+
class WeightsEnum(Enum):
3737
"""
3838
This class is the parent class of all model weights. Each model building method receives an optional `weights`
3939
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
40-
`WeightEntry`.
40+
`Weights`.
4141
4242
Args:
43-
value (WeightEntry): The data class entry with the weight information.
43+
value (Weights): The data class entry with the weight information.
4444
"""
4545

46-
def __init__(self, value: WeightEntry):
46+
def __init__(self, value: Weights):
4747
self._value_ = value
4848

4949
@classmethod
@@ -58,7 +58,7 @@ def verify(cls, obj: Any) -> Any:
5858
return obj
5959

6060
@classmethod
61-
def from_str(cls, value: str) -> "Weights":
61+
def from_str(cls, value: str) -> "WeightsEnum":
6262
for v in cls:
6363
if v._name_ == value or (value == "default" and v.default):
6464
return v
@@ -71,14 +71,14 @@ def __repr__(self):
7171
return f"{self.__class__.__name__}.{self._name_}"
7272

7373
def __getattr__(self, name):
74-
# Be able to fetch WeightEntry attributes directly
75-
for f in fields(WeightEntry):
74+
# Be able to fetch Weights attributes directly
75+
for f in fields(Weights):
7676
if f.name == name:
7777
return object.__getattribute__(self.value, name)
7878
return super().__getattr__(name)
7979

8080

81-
def get_weight(fn: Callable, weight_name: str) -> Weights:
81+
def get_weight(fn: Callable, weight_name: str) -> WeightsEnum:
8282
"""
8383
Gets the weight enum of a specific model builder method and weight name combination.
8484
@@ -87,32 +87,32 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
8787
weight_name (str): The name of the weight enum entry of the specific model.
8888
8989
Returns:
90-
Weights: The requested weight enum.
90+
WeightsEnum: The requested weight enum.
9191
"""
9292
sig = signature(fn)
9393
if "weights" not in sig.parameters:
9494
raise ValueError("The method is missing the 'weights' parameter.")
9595

9696
ann = signature(fn).parameters["weights"].annotation
97-
weights_class = None
98-
if isinstance(ann, type) and issubclass(ann, Weights):
99-
weights_class = ann
97+
weights_enum = None
98+
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
99+
weights_enum = ann
100100
else:
101101
# handle cases like Union[Optional, T]
102102
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
103103
for t in ann.__args__: # type: ignore[union-attr]
104-
if isinstance(t, type) and issubclass(t, Weights):
104+
if isinstance(t, type) and issubclass(t, WeightsEnum):
105105
# ensure the name exists. handles builders with multiple types of weights like in quantization
106106
try:
107107
t.from_str(weight_name)
108108
except ValueError:
109109
continue
110-
weights_class = t
110+
weights_enum = t
111111
break
112112

113-
if weights_class is None:
113+
if weights_enum is None:
114114
raise ValueError(
115115
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
116116
)
117117

118-
return weights_class.from_str(weight_name)
118+
return weights_enum.from_str(weight_name)

torchvision/prototype/models/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
22
from typing import Any, Dict, Optional, TypeVar
33

4-
from ._api import Weights
4+
from ._api import WeightsEnum
55

66

7-
W = TypeVar("W", bound=Weights)
7+
W = TypeVar("W", bound=WeightsEnum)
88
V = TypeVar("V")
99

1010

torchvision/prototype/models/alexnet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ...models.alexnet import AlexNet
8-
from ._api import Weights, WeightEntry
8+
from ._api import WeightsEnum, Weights
99
from ._meta import _IMAGENET_CATEGORIES
1010
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
1111

1212

13-
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
13+
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
1414

1515

16-
class AlexNetWeights(Weights):
17-
ImageNet1K_RefV1 = WeightEntry(
16+
class AlexNet_Weights(WeightsEnum):
17+
ImageNet1K_V1 = Weights(
1818
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
1919
transforms=partial(ImageNetEval, crop_size=224),
2020
meta={
@@ -29,12 +29,12 @@ class AlexNetWeights(Weights):
2929
)
3030

3131

32-
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
32+
def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
3333
if type(weights) == bool and weights:
3434
_deprecated_positional(kwargs, "pretrained", "weights", True)
3535
if "pretrained" in kwargs:
36-
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1)
37-
weights = AlexNetWeights.verify(weights)
36+
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1)
37+
weights = AlexNet_Weights.verify(weights)
3838

3939
if weights is not None:
4040
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

torchvision/prototype/models/densenet.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,25 @@
77
from torchvision.transforms.functional import InterpolationMode
88

99
from ...models.densenet import DenseNet
10-
from ._api import Weights, WeightEntry
10+
from ._api import WeightsEnum, Weights
1111
from ._meta import _IMAGENET_CATEGORIES
1212
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
1313

1414

1515
__all__ = [
1616
"DenseNet",
17-
"DenseNet121Weights",
18-
"DenseNet161Weights",
19-
"DenseNet169Weights",
20-
"DenseNet201Weights",
17+
"DenseNet121_Weights",
18+
"DenseNet161_Weights",
19+
"DenseNet169_Weights",
20+
"DenseNet201_Weights",
2121
"densenet121",
2222
"densenet161",
2323
"densenet169",
2424
"densenet201",
2525
]
2626

2727

28-
def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None:
28+
def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
2929
# '.'s are no longer allowed in module names, but previous _DenseLayer
3030
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
3131
# They are also in the checkpoints in model_urls. This pattern is used
@@ -48,7 +48,7 @@ def _densenet(
4848
growth_rate: int,
4949
block_config: Tuple[int, int, int, int],
5050
num_init_features: int,
51-
weights: Optional[Weights],
51+
weights: Optional[WeightsEnum],
5252
progress: bool,
5353
**kwargs: Any,
5454
) -> DenseNet:
@@ -71,8 +71,8 @@ def _densenet(
7171
}
7272

7373

74-
class DenseNet121Weights(Weights):
75-
ImageNet1K_Community = WeightEntry(
74+
class DenseNet121_Weights(WeightsEnum):
75+
ImageNet1K_V1 = Weights(
7676
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
7777
transforms=partial(ImageNetEval, crop_size=224),
7878
meta={
@@ -84,8 +84,8 @@ class DenseNet121Weights(Weights):
8484
)
8585

8686

87-
class DenseNet161Weights(Weights):
88-
ImageNet1K_Community = WeightEntry(
87+
class DenseNet161_Weights(WeightsEnum):
88+
ImageNet1K_V1 = Weights(
8989
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
9090
transforms=partial(ImageNetEval, crop_size=224),
9191
meta={
@@ -97,8 +97,8 @@ class DenseNet161Weights(Weights):
9797
)
9898

9999

100-
class DenseNet169Weights(Weights):
101-
ImageNet1K_Community = WeightEntry(
100+
class DenseNet169_Weights(WeightsEnum):
101+
ImageNet1K_V1 = Weights(
102102
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
103103
transforms=partial(ImageNetEval, crop_size=224),
104104
meta={
@@ -110,8 +110,8 @@ class DenseNet169Weights(Weights):
110110
)
111111

112112

113-
class DenseNet201Weights(Weights):
114-
ImageNet1K_Community = WeightEntry(
113+
class DenseNet201_Weights(WeightsEnum):
114+
ImageNet1K_V1 = Weights(
115115
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
116116
transforms=partial(ImageNetEval, crop_size=224),
117117
meta={
@@ -123,41 +123,41 @@ class DenseNet201Weights(Weights):
123123
)
124124

125125

126-
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
126+
def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
127127
if type(weights) == bool and weights:
128128
_deprecated_positional(kwargs, "pretrained", "weights", True)
129129
if "pretrained" in kwargs:
130-
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_Community)
131-
weights = DenseNet121Weights.verify(weights)
130+
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1)
131+
weights = DenseNet121_Weights.verify(weights)
132132

133133
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
134134

135135

136-
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
136+
def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
137137
if type(weights) == bool and weights:
138138
_deprecated_positional(kwargs, "pretrained", "weights", True)
139139
if "pretrained" in kwargs:
140-
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_Community)
141-
weights = DenseNet161Weights.verify(weights)
140+
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1)
141+
weights = DenseNet161_Weights.verify(weights)
142142

143143
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
144144

145145

146-
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
146+
def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
147147
if type(weights) == bool and weights:
148148
_deprecated_positional(kwargs, "pretrained", "weights", True)
149149
if "pretrained" in kwargs:
150-
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_Community)
151-
weights = DenseNet169Weights.verify(weights)
150+
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1)
151+
weights = DenseNet169_Weights.verify(weights)
152152

153153
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
154154

155155

156-
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
156+
def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
157157
if type(weights) == bool and weights:
158158
_deprecated_positional(kwargs, "pretrained", "weights", True)
159159
if "pretrained" in kwargs:
160-
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_Community)
161-
weights = DenseNet201Weights.verify(weights)
160+
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1)
161+
weights = DenseNet201_Weights.verify(weights)
162162

163163
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)

0 commit comments

Comments
 (0)