Skip to content

Commit 0a919db

Browse files
authored
Add registration mechanism for models (#6333)
* Model registration mechanism. * Add overwrite options to the dataset prototype registration mechanism. * Adding example models. * Fix module filtering * Fix linter * Fix docs * Make name optional if same as model builder * Apply updates from code-review. * fix minor bug * Adding getter for model weight enum * Support both strings and callables on get_model_weight. * linter fixes * Fixing mypy. * Renaming `get_model_weight` to `get_model_weights` * Registering all classification models. * Registering all video models. * Registering all detection models. * Registering all optical flow models. * Fixing mypy. * Registering all segmentation models. * Registering all quantization models. * Fixing linter * Registering all prototype depth perception models. * Adding tests and updating existing tests. * Fix linters * Fix tests. * Add beta annotation on docs. * Fix tests. * Apply changes from code-review. * Adding documentation. * Fix docs.
1 parent 6387051 commit 0a919db

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+374
-120
lines changed

docs/source/models.rst

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,46 @@ behavior, such as batch normalization. To switch between these modes, use
120120
# Set model to eval mode
121121
model.eval()
122122
123+
Model Registration Mechanism
124+
----------------------------
125+
126+
.. betastatus:: registration mechanism
127+
128+
As of v0.14, TorchVision offers a new model registration mechanism which allows retreaving models
129+
and weights by their names. Here are a few examples on how to use them:
130+
131+
.. code:: python
132+
133+
# List available models
134+
all_models = list_models()
135+
classification_models = list_models(module=torchvision.models)
136+
137+
# Initialize models
138+
m1 = get_model("mobilenet_v3_large", weights=None)
139+
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
140+
141+
# Fetch weights
142+
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
143+
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
144+
145+
weights_enum = get_model_weights("quantized_mobilenet_v3_large")
146+
assert weights_enum == MobileNet_V3_Large_QuantizedWeights
147+
148+
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
149+
assert weights_enum == weights_enum2
150+
151+
Here are the available public methods of the model registration mechanism:
152+
153+
.. currentmodule:: torchvision.models
154+
.. autosummary::
155+
:toctree: generated/
156+
:template: function.rst
157+
158+
get_model
159+
get_model_weights
160+
get_weight
161+
list_models
162+
123163
Using models from Hub
124164
---------------------
125165

test/test_backbone_utils.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,6 @@
1111
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
1212

1313

14-
def get_available_models():
15-
# TODO add a registration mechanism to torchvision.models
16-
return [
17-
k
18-
for k, v in models.__dict__.items()
19-
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
20-
]
21-
22-
2314
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
2415
def test_resnet_fpn_backbone(backbone_name):
2516
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
@@ -135,10 +126,10 @@ def _get_return_nodes(self, model):
135126
eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
136127
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
137128

138-
@pytest.mark.parametrize("model_name", get_available_models())
129+
@pytest.mark.parametrize("model_name", models.list_models(models))
139130
def test_build_fx_feature_extractor(self, model_name):
140131
set_rng_seed(0)
141-
model = models.__dict__[model_name](**self.model_defaults).eval()
132+
model = models.get_model(model_name, **self.model_defaults).eval()
142133
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
143134
# Check that it works with both a list and dict for return nodes
144135
self._create_feature_extractor(
@@ -172,9 +163,9 @@ def test_node_name_conventions(self):
172163
train_nodes, _ = get_graph_node_names(model)
173164
assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
174165

175-
@pytest.mark.parametrize("model_name", get_available_models())
166+
@pytest.mark.parametrize("model_name", models.list_models(models))
176167
def test_forward_backward(self, model_name):
177-
model = models.__dict__[model_name](**self.model_defaults).train()
168+
model = models.get_model(model_name, **self.model_defaults).train()
178169
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
179170
model = self._create_feature_extractor(
180171
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
@@ -211,10 +202,10 @@ def test_feature_extraction_methods_equivalence(self):
211202
for k in ilg_out.keys():
212203
assert ilg_out[k].equal(fgn_out[k])
213204

214-
@pytest.mark.parametrize("model_name", get_available_models())
205+
@pytest.mark.parametrize("model_name", models.list_models(models))
215206
def test_jit_forward_backward(self, model_name):
216207
set_rng_seed(0)
217-
model = models.__dict__[model_name](**self.model_defaults).train()
208+
model = models.get_model(model_name, **self.model_defaults).train()
218209
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
219210
model = self._create_feature_extractor(
220211
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes

test/test_extended_models.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import importlib
21
import os
32

43
import pytest
54
import test_models as TM
65
import torch
76
from torchvision import models
8-
from torchvision.models._api import Weights, WeightsEnum
7+
from torchvision.models._api import get_model_weights, Weights, WeightsEnum
98
from torchvision.models._utils import handle_legacy_interface
109

1110

@@ -15,23 +14,52 @@
1514
)
1615

1716

18-
def _get_parent_module(model_fn):
19-
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
20-
module = importlib.import_module(parent_module_name)
21-
return module
17+
@pytest.mark.parametrize(
18+
"name, model_class",
19+
[
20+
("resnet50", models.ResNet),
21+
("retinanet_resnet50_fpn_v2", models.detection.RetinaNet),
22+
("raft_large", models.optical_flow.RAFT),
23+
("quantized_resnet50", models.quantization.QuantizableResNet),
24+
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP),
25+
("mvit_v1_b", models.video.MViT),
26+
],
27+
)
28+
def test_get_model(name, model_class):
29+
assert isinstance(models.get_model(name), model_class)
30+
31+
32+
@pytest.mark.parametrize(
33+
"name, weight",
34+
[
35+
("resnet50", models.ResNet50_Weights),
36+
("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights),
37+
("raft_large", models.optical_flow.Raft_Large_Weights),
38+
("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights),
39+
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights),
40+
("mvit_v1_b", models.video.MViT_V1_B_Weights),
41+
],
42+
)
43+
def test_get_model_weights(name, weight):
44+
assert models.get_model_weights(name) == weight
2245

2346

24-
def _get_model_weights(model_fn):
25-
module = _get_parent_module(model_fn)
26-
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
27-
try:
28-
return next(
29-
v
47+
@pytest.mark.parametrize(
48+
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
49+
)
50+
def test_list_models(module):
51+
def get_models_from_module(module):
52+
return [
53+
v.__name__
3054
for k, v in module.__dict__.items()
31-
if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__
32-
)
33-
except StopIteration:
34-
return None
55+
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
56+
]
57+
58+
a = set(get_models_from_module(module))
59+
b = set(x.replace("quantized_", "") for x in models.list_models(module))
60+
61+
assert len(b) > 0
62+
assert a == b
3563

3664

3765
@pytest.mark.parametrize(
@@ -55,27 +83,27 @@ def test_get_weight(name, weight):
5583

5684
@pytest.mark.parametrize(
5785
"model_fn",
58-
TM.get_models_from_module(models)
59-
+ TM.get_models_from_module(models.detection)
60-
+ TM.get_models_from_module(models.quantization)
61-
+ TM.get_models_from_module(models.segmentation)
62-
+ TM.get_models_from_module(models.video)
63-
+ TM.get_models_from_module(models.optical_flow),
86+
TM.list_model_fns(models)
87+
+ TM.list_model_fns(models.detection)
88+
+ TM.list_model_fns(models.quantization)
89+
+ TM.list_model_fns(models.segmentation)
90+
+ TM.list_model_fns(models.video)
91+
+ TM.list_model_fns(models.optical_flow),
6492
)
6593
def test_naming_conventions(model_fn):
66-
weights_enum = _get_model_weights(model_fn)
94+
weights_enum = get_model_weights(model_fn)
6795
assert weights_enum is not None
6896
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
6997

7098

7199
@pytest.mark.parametrize(
72100
"model_fn",
73-
TM.get_models_from_module(models)
74-
+ TM.get_models_from_module(models.detection)
75-
+ TM.get_models_from_module(models.quantization)
76-
+ TM.get_models_from_module(models.segmentation)
77-
+ TM.get_models_from_module(models.video)
78-
+ TM.get_models_from_module(models.optical_flow),
101+
TM.list_model_fns(models)
102+
+ TM.list_model_fns(models.detection)
103+
+ TM.list_model_fns(models.quantization)
104+
+ TM.list_model_fns(models.segmentation)
105+
+ TM.list_model_fns(models.video)
106+
+ TM.list_model_fns(models.optical_flow),
79107
)
80108
@run_if_test_with_extended
81109
def test_schema_meta_validation(model_fn):
@@ -112,7 +140,7 @@ def test_schema_meta_validation(model_fn):
112140
module_name = model_fn.__module__.split(".")[-2]
113141
expected_fields = defaults["all"] | defaults[module_name]
114142

115-
weights_enum = _get_model_weights(model_fn)
143+
weights_enum = get_model_weights(model_fn)
116144
if len(weights_enum) == 0:
117145
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
118146

@@ -153,17 +181,17 @@ def test_schema_meta_validation(model_fn):
153181

154182
@pytest.mark.parametrize(
155183
"model_fn",
156-
TM.get_models_from_module(models)
157-
+ TM.get_models_from_module(models.detection)
158-
+ TM.get_models_from_module(models.quantization)
159-
+ TM.get_models_from_module(models.segmentation)
160-
+ TM.get_models_from_module(models.video)
161-
+ TM.get_models_from_module(models.optical_flow),
184+
TM.list_model_fns(models)
185+
+ TM.list_model_fns(models.detection)
186+
+ TM.list_model_fns(models.quantization)
187+
+ TM.list_model_fns(models.segmentation)
188+
+ TM.list_model_fns(models.video)
189+
+ TM.list_model_fns(models.optical_flow),
162190
)
163191
@run_if_test_with_extended
164192
def test_transforms_jit(model_fn):
165193
model_name = model_fn.__name__
166-
weights_enum = _get_model_weights(model_fn)
194+
weights_enum = get_model_weights(model_fn)
167195
if len(weights_enum) == 0:
168196
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
169197

test/test_models.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,15 @@
1616
from _utils_internal import get_relative_path
1717
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
1818
from torchvision import models
19+
from torchvision.models._api import find_model, list_models
20+
1921

2022
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
2123
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
2224

2325

24-
def get_models_from_module(module):
25-
# TODO add a registration mechanism to torchvision.models
26-
return [
27-
v
28-
for k, v in module.__dict__.items()
29-
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
30-
]
26+
def list_model_fns(module):
27+
return [find_model(name) for name in list_models(module)]
3128

3229

3330
@pytest.fixture
@@ -597,7 +594,7 @@ def test_vitc_models(model_fn, dev):
597594
test_classification_model(model_fn, dev)
598595

599596

600-
@pytest.mark.parametrize("model_fn", get_models_from_module(models))
597+
@pytest.mark.parametrize("model_fn", list_model_fns(models))
601598
@pytest.mark.parametrize("dev", cpu_and_gpu())
602599
def test_classification_model(model_fn, dev):
603600
set_rng_seed(0)
@@ -633,7 +630,7 @@ def test_classification_model(model_fn, dev):
633630
_check_input_backprop(model, x)
634631

635632

636-
@pytest.mark.parametrize("model_fn", get_models_from_module(models.segmentation))
633+
@pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation))
637634
@pytest.mark.parametrize("dev", cpu_and_gpu())
638635
def test_segmentation_model(model_fn, dev):
639636
set_rng_seed(0)
@@ -695,7 +692,7 @@ def check_out(out):
695692
_check_input_backprop(model, x)
696693

697694

698-
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
695+
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
699696
@pytest.mark.parametrize("dev", cpu_and_gpu())
700697
def test_detection_model(model_fn, dev):
701698
set_rng_seed(0)
@@ -793,7 +790,7 @@ def compute_mean_std(tensor):
793790
_check_input_backprop(model, model_input)
794791

795792

796-
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
793+
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
797794
def test_detection_model_validation(model_fn):
798795
set_rng_seed(0)
799796
model = model_fn(num_classes=50, weights=None, weights_backbone=None)
@@ -822,7 +819,7 @@ def test_detection_model_validation(model_fn):
822819
model(x, targets=targets)
823820

824821

825-
@pytest.mark.parametrize("model_fn", get_models_from_module(models.video))
822+
@pytest.mark.parametrize("model_fn", list_model_fns(models.video))
826823
@pytest.mark.parametrize("dev", cpu_and_gpu())
827824
def test_video_model(model_fn, dev):
828825
set_rng_seed(0)
@@ -868,7 +865,7 @@ def test_video_model(model_fn, dev):
868865
),
869866
reason="This Pytorch Build has not been built with fbgemm and qnnpack",
870867
)
871-
@pytest.mark.parametrize("model_fn", get_models_from_module(models.quantization))
868+
@pytest.mark.parametrize("model_fn", list_model_fns(models.quantization))
872869
def test_quantized_classification_model(model_fn):
873870
set_rng_seed(0)
874871
defaults = {
@@ -917,7 +914,7 @@ def test_quantized_classification_model(model_fn):
917914
torch.ao.quantization.convert(model, inplace=True)
918915

919916

920-
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
917+
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
921918
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
922919
model_name = model_fn.__name__
923920
max_trainable = _model_tests_values[model_name]["max_trainable"]
@@ -930,9 +927,9 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load
930927

931928

932929
@needs_cuda
933-
@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small))
930+
@pytest.mark.parametrize("model_fn", list_model_fns(models.optical_flow))
934931
@pytest.mark.parametrize("scripted", (False, True))
935-
def test_raft(model_builder, scripted):
932+
def test_raft(model_fn, scripted):
936933

937934
torch.manual_seed(0)
938935

@@ -942,7 +939,7 @@ def test_raft(model_builder, scripted):
942939
# reduced to 1)
943940
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)
944941

945-
model = model_builder(corr_block=corr_block).eval().to("cuda")
942+
model = model_fn(corr_block=corr_block).eval().to("cuda")
946943
if scripted:
947944
model = torch.jit.script(model)
948945

@@ -954,7 +951,7 @@ def test_raft(model_builder, scripted):
954951
flow_pred = preds[-1]
955952
# Tolerance is fairly high, but there are 2 * H * W outputs to check
956953
# The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
957-
_assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1)
954+
_assert_expected(flow_pred, name=model_fn.__name__, atol=1e-2, rtol=1)
958955

959956

960957
if __name__ == "__main__":

test/test_models_detection_negative_samples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def test_assign_targets_to_proposals(self):
9999
],
100100
)
101101
def test_forward_negative_sample_frcnn(self, name):
102-
model = torchvision.models.detection.__dict__[name](
103-
weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
102+
model = torchvision.models.get_model(
103+
name, weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
104104
)
105105

106106
images, targets = self._make_empty_sample()

0 commit comments

Comments
 (0)