diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index d89f326a4fc..c8a4218e65a 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -1,2 +1,56 @@ +import os + +import pytest +import torch +from common_utils import set_rng_seed, cpu_and_gpu +from test_models import _assert_expected, _model_params +from torchvision import models as original_models +from torchvision.prototype import models + + +def get_available_classification_models(): + return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + + +@pytest.mark.parametrize("model_name", get_available_classification_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") +def test_classification_model(model_name, dev): + set_rng_seed(0) + defaults = { + "num_classes": 50, + "input_shape": (1, 3, 224, 224), + } + kwargs = {**defaults, **_model_params.get(model_name, {})} + input_shape = kwargs.pop("input_shape") + model = models.__dict__[model_name](**kwargs) + model.eval().to(device=dev) + x = torch.rand(input_shape).to(device=dev) + out = model(x) + _assert_expected(out.cpu(), model_name, prec=0.1) + assert out.shape[-1] == 50 + + +@pytest.mark.parametrize("model_name", get_available_classification_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") +def test_old_vs_new_classification_factory(model_name, dev): + defaults = { + "pretrained": True, + "input_shape": (1, 3, 224, 224), + } + kwargs = {**defaults, **_model_params.get(model_name, {})} + input_shape = kwargs.pop("input_shape") + model_old = original_models.__dict__[model_name](**kwargs) + model_old.eval().to(device=dev) + x = torch.rand(input_shape).to(device=dev) + out_old = model_old(x) + # compare with new model builder parameterized in the old fashion way + model_new = models.__dict__[model_name](**kwargs) + model_new.eval().to(device=dev) + out_new = model_new(x) + torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0, check_dtype=False) + + def test_smoke(): import torchvision.prototype.models # noqa: F401 diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index aaa02d5d407..7f4158143fc 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -8,7 +8,27 @@ from ._meta import _IMAGENET_CATEGORIES -__all__ = ["ResNet", "ResNet50Weights", "resnet50"] +__all__ = [ + "ResNet", + "ResNet18Weights", + "ResNet34Weights", + "ResNet50Weights", + "ResNet101Weights", + "ResNet152Weights", + "ResNeXt50_32x4dWeights", + "ResNeXt101_32x8dWeights", + "WideResNet50_2Weights", + "WideResNet101_2Weights", + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "wide_resnet50_2", + "wide_resnet101_2", +] def _resnet( @@ -35,6 +55,32 @@ def _resnet( } +class ResNet18Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnet18-f37072fd.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 69.758, + "acc@5": 89.078, + }, + ) + + +class ResNet34Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnet34-b627a593.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 73.314, + "acc@5": 91.420, + }, + ) + + class ResNet50Weights(Weights): ImageNet1K_RefV1 = WeightEntry( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", @@ -58,6 +104,104 @@ class ResNet50Weights(Weights): ) +class ResNet101Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnet101-63fe2227.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 77.374, + "acc@5": 93.546, + }, + ) + + +class ResNet152Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnet152-394f9c45.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 78.312, + "acc@5": 94.046, + }, + ) + + +class ResNeXt50_32x4dWeights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 77.618, + "acc@5": 93.698, + }, + ) + + +class ResNeXt101_32x8dWeights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 79.312, + "acc@5": 94.526, + }, + ) + + +class WideResNet50_2Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 78.468, + "acc@5": 94.086, + }, + ) + + +class WideResNet101_2Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 78.848, + "acc@5": 94.284, + }, + ) + + +def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNet18Weights.verify(weights) + + return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) + + +def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNet34Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNet34Weights.verify(weights) + + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) + + def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if "pretrained" in kwargs: warnings.warn("The argument pretrained is deprecated, please use weights instead.") @@ -65,3 +209,65 @@ def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, * weights = ResNet50Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNet101Weights.verify(weights) + + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) + + +def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNet152Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNet152Weights.verify(weights) + + return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) + + +def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNeXt50_32x4dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNeXt50_32x4dWeights.verify(weights) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 4 + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNeXt101_32x8dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNeXt101_32x8dWeights.verify(weights) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 8 + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) + + +def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = WideResNet50_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = WideResNet50_2Weights.verify(weights) + kwargs["width_per_group"] = 64 * 2 + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = WideResNet101_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = WideResNet101_2Weights.verify(weights) + kwargs["width_per_group"] = 64 * 2 + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)