diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index ec7d59817ab..3a314c867ca 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -316,7 +316,7 @@ def shufflenet_v2_x1_0( return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1)) def shufflenet_v2_x1_5( *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: @@ -346,7 +346,7 @@ def shufflenet_v2_x1_5( return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1)) def shufflenet_v2_x2_0( *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: