From 8509c6ee36c33dde71d21d710f6085c2163c6cb8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 16 May 2022 14:42:57 +0100 Subject: [PATCH 1/2] Prefixing `_get_enum_from_fn` with underscore --- torchvision/models/_api.py | 2 +- torchvision/models/detection/backbone_utils.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 05b35fe87f0..a202ed625b5 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -107,7 +107,7 @@ def get_weight(name: str) -> WeightsEnum: return weights_enum.from_str(value_name) -def get_enum_from_fn(fn: Callable) -> WeightsEnum: +def _get_enum_from_fn(fn: Callable) -> WeightsEnum: """ Internal method that gets the weight enum of a specific model builder method. Might be removed after the handle_legacy_interface is removed. diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 65fe45c4cbd..fbef524b99c 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -6,7 +6,7 @@ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from .. import mobilenet, resnet -from .._api import WeightsEnum, get_enum_from_fn +from .._api import WeightsEnum, _get_enum_from_fn from .._utils import IntermediateLayerGetter, handle_legacy_interface @@ -62,7 +62,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: @handle_legacy_interface( weights=( "pretrained", - lambda kwargs: get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), + lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), ), ) def resnet_fpn_backbone( @@ -177,7 +177,7 @@ def _validate_trainable_layers( @handle_legacy_interface( weights=( "pretrained", - lambda kwargs: get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), + lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), ), ) def mobilenet_backbone( From c3b7c6d628ea83c56da188dbba9a5416358764db Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 16 May 2022 14:43:35 +0100 Subject: [PATCH 2/2] Exposing `get_weight` to Torch Hub --- hubconf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hubconf.py b/hubconf.py index 28d2c5a5d01..a229ab07667 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,6 +1,7 @@ # Optional list of dependencies required by the package dependencies = ["torch"] +from torchvision.models import get_weight from torchvision.models.alexnet import alexnet from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161