Skip to content

Commit c72b284

Browse files
authored
Expose on Hub the public methods of the registration API (#6364)
* Expose on Hub the public methods of the registration API * Limit methods and update docs.
1 parent 8446983 commit c72b284

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

docs/source/models.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,15 @@ Most pre-trained models can be accessed directly via PyTorch Hub without having
176176
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
177177
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
178178
179+
You can also retrieve all the available weights of a specific model via PyTorch Hub by doing:
180+
181+
.. code:: python
182+
183+
import torch
184+
185+
weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
186+
print([weight for weight in weight_enum])
187+
179188
The only exception to the above are the detection models included on
180189
:mod:`torchvision.models.detection`. These models require TorchVision
181190
to be installed because they depend on custom C++ operators.

hubconf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Optional list of dependencies required by the package
22
dependencies = ["torch"]
33

4-
from torchvision.models import get_weight
4+
from torchvision.models import get_model_weights, get_weight
55
from torchvision.models.alexnet import alexnet
66
from torchvision.models.convnext import convnext_base, convnext_large, convnext_small, convnext_tiny
77
from torchvision.models.densenet import densenet121, densenet161, densenet169, densenet201

torchvision/models/_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def get_weight(name: str) -> WeightsEnum:
115115
W = TypeVar("W", bound=WeightsEnum)
116116

117117

118-
def get_model_weights(model: Union[Callable, str]) -> W:
118+
def get_model_weights(name: Union[Callable, str]) -> W:
119119
"""
120120
Retuns the weights enum class associated to the given model.
121121
@@ -127,8 +127,7 @@ def get_model_weights(model: Union[Callable, str]) -> W:
127127
Returns:
128128
weights_enum (W): The weights enum class associated with the model.
129129
"""
130-
if isinstance(model, str):
131-
model = find_model(model)
130+
model = find_model(name) if isinstance(name, str) else name
132131
return cast(W, _get_enum_from_fn(model))
133132

134133

0 commit comments

Comments
 (0)