From e48df51ae925e119cd98620573ca779cc92b8783 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 19 Oct 2021 15:08:14 +0100 Subject: [PATCH 01/17] adding Weights classes for Resnet classification models --- torchvision/prototype/models/resnet.py | 208 ++++++++++++++++++++++++- 1 file changed, 207 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index aaa02d5d407..fc6eb82764d 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -8,7 +8,27 @@ from ._meta import _IMAGENET_CATEGORIES -__all__ = ["ResNet", "ResNet50Weights", "resnet50"] +__all__ = [ + "ResNet", + "ResNet18Weights", + "ResNet34Weights", + "ResNet50Weights", + "ResNet101Weights", + "ResNet152Weights", + "ResNeXt50_32x4dWeights", + "ResNeXt101_32x8dWeights", + "WideResNet50_2Weights", + "WideResNet101_2Weights", + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "wide_resnet50_2", + "wide_resnet101_2", +] def _resnet( @@ -35,6 +55,32 @@ def _resnet( } +class ResNet18Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnet18-f37072fd.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 69.758, + "acc@5": 89.078, + }, + ) + + +class ResNet34Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnet34-b627a593.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 73.314, + "acc@5": 91.420, + }, + ) + + class ResNet50Weights(Weights): ImageNet1K_RefV1 = WeightEntry( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", @@ -58,6 +104,104 @@ class ResNet50Weights(Weights): ) +class ResNet101Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnet101-63fe2227.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 77.374, + "acc@5": 93.546, + }, + ) + + +class ResNet152Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnet152-394f9c45.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 78.312, + "acc@5": 94.046, + }, + ) + + +class ResNeXt50_32x4dWeights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 77.618, + "acc@5": 93.698, + }, + ) + + +class ResNeXt101_32x8dWeights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 79.312, + "acc@5": 94.526, + }, + ) + + +class WideResNet50_2Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 78.468, + "acc@5": 94.086, + }, + ) + + +class WideResNet101_2Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 78.848, + "acc@5": 94.284, + }, + ) + + +def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNet18Weights.verify(weights) + + return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) + + +def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNet34Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNet34Weights.verify(weights) + + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) + + def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if "pretrained" in kwargs: warnings.warn("The argument pretrained is deprecated, please use weights instead.") @@ -65,3 +209,65 @@ def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, * weights = ResNet50Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNet101Weights.verify(weights) + + return _resnet(BasicBlock, [3, 4, 23, 3], weights, progress, **kwargs) + + +def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNet152Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNet152Weights.verify(weights) + + return _resnet(BasicBlock, [3, 8, 36, 3], weights, progress, **kwargs) + + +def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNeXt50_32x4dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNeXt50_32x4dWeights.verify(weights) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 4 + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) + + +def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNeXt101_32x8dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = ResNeXt101_32x8dWeights.verify(weights) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 8 + return _resnet(BasicBlock, [3, 4, 23, 3], weights, progress, **kwargs) + + +def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = WideResNet50_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = WideResNet50_2Weights.verify(weights) + kwargs["width_per_group"] = 64 * 2 + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) + + +def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = WideResNet101_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + + weights = WideResNet101_2Weights.verify(weights) + kwargs["width_per_group"] = 64 * 2 + return _resnet(BasicBlock, [3, 4, 23, 3], weights, progress, **kwargs) From 73543815c35fc2c97a1f0dca8a8a44655711898e Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 19 Oct 2021 15:35:02 +0100 Subject: [PATCH 02/17] Replacing BasicBlock by Bottleneck in all but 3 model contructors --- torchvision/prototype/models/resnet.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index fc6eb82764d..7f4158143fc 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -218,7 +218,7 @@ def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, weights = ResNet101Weights.verify(weights) - return _resnet(BasicBlock, [3, 4, 23, 3], weights, progress, **kwargs) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @@ -228,7 +228,7 @@ def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, weights = ResNet152Weights.verify(weights) - return _resnet(BasicBlock, [3, 8, 36, 3], weights, progress, **kwargs) + return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @@ -239,7 +239,7 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: weights = ResNeXt50_32x4dWeights.verify(weights) kwargs["groups"] = 32 kwargs["width_per_group"] = 4 - return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @@ -250,7 +250,7 @@ def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress weights = ResNeXt101_32x8dWeights.verify(weights) kwargs["groups"] = 32 kwargs["width_per_group"] = 8 - return _resnet(BasicBlock, [3, 4, 23, 3], weights, progress, **kwargs) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @@ -260,7 +260,7 @@ def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: b weights = WideResNet50_2Weights.verify(weights) kwargs["width_per_group"] = 64 * 2 - return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @@ -270,4 +270,4 @@ def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: weights = WideResNet101_2Weights.verify(weights) kwargs["width_per_group"] = 64 * 2 - return _resnet(BasicBlock, [3, 4, 23, 3], weights, progress, **kwargs) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) From 9264a52d3e56531d91736adf7b80a7e597385325 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 19 Oct 2021 21:42:38 +0100 Subject: [PATCH 03/17] adding tests for prototype models --- .circleci/unittest/linux/scripts/run_test.sh | 1 + .../unittest/windows/scripts/run_test.sh | 1 + test/test_prototype_models.py | 79 +++++++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 test/test_prototype_models.py diff --git a/.circleci/unittest/linux/scripts/run_test.sh b/.circleci/unittest/linux/scripts/run_test.sh index 419b9eb562c..0049ec0a2a5 100755 --- a/.circleci/unittest/linux/scripts/run_test.sh +++ b/.circleci/unittest/linux/scripts/run_test.sh @@ -6,5 +6,6 @@ eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env export PYTORCH_TEST_WITH_SLOW='1' +export PYTORCH_TEST_WITH_PROTOTYP='0' python -m torch.utils.collect_env pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py diff --git a/.circleci/unittest/windows/scripts/run_test.sh b/.circleci/unittest/windows/scripts/run_test.sh index 58e4b0d7141..205fad2a37f 100644 --- a/.circleci/unittest/windows/scripts/run_test.sh +++ b/.circleci/unittest/windows/scripts/run_test.sh @@ -9,5 +9,6 @@ this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$this_dir/set_cuda_envs.sh" export PYTORCH_TEST_WITH_SLOW='1' +export PYTORCH_TEST_WITH_PROTOTYPE='0' python -m torch.utils.collect_env pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py new file mode 100644 index 00000000000..24b2e2fcfdf --- /dev/null +++ b/test/test_prototype_models.py @@ -0,0 +1,79 @@ +import functools +import io +import operator +import os +import traceback +import warnings +from collections import OrderedDict + +import pytest +import torch +import torch.fx +import torch.nn as nn +import torchvision +from _utils_internal import get_relative_path +from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda +from torchvision.prototype import models +from test_models import _assert_expected, _model_params + + +model_to_default_weights_mapping = { + "resnet18": ["ImageNet1K_RefV1", models.ResNet18Weights.ImageNet1K_RefV1], + "resnet34": ["ImageNet1K_RefV1", models.ResNet34Weights.ImageNet1K_RefV1], + "resnet50":["ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1], + "resnet101": ["ImageNet1K_RefV1", models.ResNet101Weights.ImageNet1K_RefV1], + "resnet152": ["ImageNet1K_RefV1", models.ResNet152Weights.ImageNet1K_RefV1], + "resnext50_32x4d": ["ImageNet1K_RefV1", models.ResNeXt50_32x4dWeights.ImageNet1K_RefV1], + "resnext101_32x8d": ["ImageNet1K_RefV1", models.ResNeXt101_32x8dWeights.ImageNet1K_RefV1], + "wide_resnet50_2": ["ImageNet1K_RefV1", models.WideResNet50_2Weights.ImageNet1K_RefV1], + "wide_resnet101_2": ["ImageNet1K_RefV1", models.WideResNet101_2Weights.ImageNet1K_RefV1], +} + +def get_available_classification_models(): + # TODO add a registration mechanism to torchvision.models + return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + +@pytest.mark.parametrize("model_name", get_available_classification_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "1") == "0", reason="Prototype code tests are disabled") +def test_classification_model(model_name, dev): + set_rng_seed(0) + defaults = { + "num_classes": 50, + "input_shape": (1, 3, 224, 224), + } + kwargs = {**defaults, **_model_params.get(model_name, {})} + input_shape = kwargs.pop("input_shape") + model = models.__dict__[model_name](**kwargs) + model.eval().to(device=dev) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) + out = model(x) + _assert_expected(out.cpu(), model_name, prec=0.1) + assert out.shape[-1] == 50 + + +@pytest.mark.parametrize("model_name", get_available_classification_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "1") == "0", reason="Prototype code tests are disabled") +def test_old_vs_new_classification_factory(model_name, dev): + set_rng_seed(0) + defaults = { + "num_classes": 50, + "pretrained": True, + } + input_shape = (1, 3, 224, 224) + kwargs = {**defaults, **_model_params.get(model_name, {})} + model_old = models.__dict__[model_name](**kwargs) + model_old.eval().to(device=dev) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) + out_old = model_old(x) + defaults.pop("pretrained") + for weights_val in model_to_default_weights_mapping[model_name]: + kwargs = {**defaults, **_model_params.get(model_name, {}), "weights": weights_val} + model_new = models.__dict__[model_name](**kwargs) + model_new.eval().to(device=dev) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + out_new = model_new(x) + torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0, check_dtype=False) From 17fa569e67e541b26399545ac0dd2ac5423852c5 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 19 Oct 2021 21:46:47 +0100 Subject: [PATCH 04/17] fixing typo in environment variable --- .circleci/unittest/linux/scripts/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/unittest/linux/scripts/run_test.sh b/.circleci/unittest/linux/scripts/run_test.sh index 0049ec0a2a5..851e08d52a0 100755 --- a/.circleci/unittest/linux/scripts/run_test.sh +++ b/.circleci/unittest/linux/scripts/run_test.sh @@ -6,6 +6,6 @@ eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env export PYTORCH_TEST_WITH_SLOW='1' -export PYTORCH_TEST_WITH_PROTOTYP='0' +export PYTORCH_TEST_WITH_PROTOTYPE='0' python -m torch.utils.collect_env pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py From 72ba4e3f21ea870133f7629737c137a0afa132f6 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 09:24:11 +0100 Subject: [PATCH 05/17] Update test/test_prototype_models.py Co-authored-by: Vasilis Vryniotis --- test/test_prototype_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 24b2e2fcfdf..583e16f3962 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -35,7 +35,7 @@ def get_available_classification_models(): @pytest.mark.parametrize("model_name", get_available_classification_models()) @pytest.mark.parametrize("dev", cpu_and_gpu()) -@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "1") == "0", reason="Prototype code tests are disabled") +@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "1", reason="Prototype code tests are disabled") def test_classification_model(model_name, dev): set_rng_seed(0) defaults = { From ae412e18ea9fa39b289c753332c9772aaed9c1b0 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 09:29:16 +0100 Subject: [PATCH 06/17] changing default value for PYTORCH_TEST_WITH_PROTOTYPE --- test/test_prototype_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 583e16f3962..a1147d90d83 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -55,7 +55,7 @@ def test_classification_model(model_name, dev): @pytest.mark.parametrize("model_name", get_available_classification_models()) @pytest.mark.parametrize("dev", cpu_and_gpu()) -@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "1") == "0", reason="Prototype code tests are disabled") +@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") def test_old_vs_new_classification_factory(model_name, dev): set_rng_seed(0) defaults = { From df6cd1fecfeb2249b7baa021f13a1f56f09dbb82 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 09:59:49 +0100 Subject: [PATCH 07/17] adding checks to compare outputs of the prototype vs old models --- test/test_prototype_models.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index a1147d90d83..5595313fde9 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -14,6 +14,7 @@ from _utils_internal import get_relative_path from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda from torchvision.prototype import models +from torchvision import models as original_models from test_models import _assert_expected, _model_params @@ -35,7 +36,7 @@ def get_available_classification_models(): @pytest.mark.parametrize("model_name", get_available_classification_models()) @pytest.mark.parametrize("dev", cpu_and_gpu()) -@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "1", reason="Prototype code tests are disabled") +@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") def test_classification_model(model_name, dev): set_rng_seed(0) defaults = { @@ -59,21 +60,26 @@ def test_classification_model(model_name, dev): def test_old_vs_new_classification_factory(model_name, dev): set_rng_seed(0) defaults = { - "num_classes": 50, + "num_classes": 1000, "pretrained": True, } input_shape = (1, 3, 224, 224) kwargs = {**defaults, **_model_params.get(model_name, {})} - model_old = models.__dict__[model_name](**kwargs) + model_old = original_models.__dict__[model_name](**kwargs) model_old.eval().to(device=dev) # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests x = torch.rand(input_shape).to(device=dev) out_old = model_old(x) + # comapre with new model builder parameterized in the old fashion way + model_new = models.__dict__[model_name](**kwargs) + model_new.eval().to(device=dev) + out_new = model_new(x) + torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0, check_dtype=False) + # compare with new model builder parameterized in the new way defaults.pop("pretrained") for weights_val in model_to_default_weights_mapping[model_name]: kwargs = {**defaults, **_model_params.get(model_name, {}), "weights": weights_val} model_new = models.__dict__[model_name](**kwargs) model_new.eval().to(device=dev) - # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests out_new = model_new(x) torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0, check_dtype=False) From 9e4200d23e16d027075ebc09026695e9fc4098af Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 10:35:00 +0100 Subject: [PATCH 08/17] refactoring prototype tests --- .circleci/unittest/linux/scripts/run_test.sh | 1 - .../unittest/windows/scripts/run_test.sh | 1 - test/test_prototype_models.py | 29 ++++--------------- 3 files changed, 5 insertions(+), 26 deletions(-) diff --git a/.circleci/unittest/linux/scripts/run_test.sh b/.circleci/unittest/linux/scripts/run_test.sh index 851e08d52a0..419b9eb562c 100755 --- a/.circleci/unittest/linux/scripts/run_test.sh +++ b/.circleci/unittest/linux/scripts/run_test.sh @@ -6,6 +6,5 @@ eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env export PYTORCH_TEST_WITH_SLOW='1' -export PYTORCH_TEST_WITH_PROTOTYPE='0' python -m torch.utils.collect_env pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py diff --git a/.circleci/unittest/windows/scripts/run_test.sh b/.circleci/unittest/windows/scripts/run_test.sh index 205fad2a37f..58e4b0d7141 100644 --- a/.circleci/unittest/windows/scripts/run_test.sh +++ b/.circleci/unittest/windows/scripts/run_test.sh @@ -9,6 +9,5 @@ this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$this_dir/set_cuda_envs.sh" export PYTORCH_TEST_WITH_SLOW='1' -export PYTORCH_TEST_WITH_PROTOTYPE='0' python -m torch.utils.collect_env pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 5595313fde9..183aedc32f3 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -13,27 +13,16 @@ import torchvision from _utils_internal import get_relative_path from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda -from torchvision.prototype import models -from torchvision import models as original_models from test_models import _assert_expected, _model_params +from torchvision import models as original_models +from torchvision.prototype import models -model_to_default_weights_mapping = { - "resnet18": ["ImageNet1K_RefV1", models.ResNet18Weights.ImageNet1K_RefV1], - "resnet34": ["ImageNet1K_RefV1", models.ResNet34Weights.ImageNet1K_RefV1], - "resnet50":["ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1], - "resnet101": ["ImageNet1K_RefV1", models.ResNet101Weights.ImageNet1K_RefV1], - "resnet152": ["ImageNet1K_RefV1", models.ResNet152Weights.ImageNet1K_RefV1], - "resnext50_32x4d": ["ImageNet1K_RefV1", models.ResNeXt50_32x4dWeights.ImageNet1K_RefV1], - "resnext101_32x8d": ["ImageNet1K_RefV1", models.ResNeXt101_32x8dWeights.ImageNet1K_RefV1], - "wide_resnet50_2": ["ImageNet1K_RefV1", models.WideResNet50_2Weights.ImageNet1K_RefV1], - "wide_resnet101_2": ["ImageNet1K_RefV1", models.WideResNet101_2Weights.ImageNet1K_RefV1], -} - def get_available_classification_models(): # TODO add a registration mechanism to torchvision.models return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + @pytest.mark.parametrize("model_name", get_available_classification_models()) @pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") @@ -60,11 +49,11 @@ def test_classification_model(model_name, dev): def test_old_vs_new_classification_factory(model_name, dev): set_rng_seed(0) defaults = { - "num_classes": 1000, "pretrained": True, + "input_shape": (1, 3, 224, 224), } - input_shape = (1, 3, 224, 224) kwargs = {**defaults, **_model_params.get(model_name, {})} + input_shape = kwargs.pop("input_shape") model_old = original_models.__dict__[model_name](**kwargs) model_old.eval().to(device=dev) # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests @@ -75,11 +64,3 @@ def test_old_vs_new_classification_factory(model_name, dev): model_new.eval().to(device=dev) out_new = model_new(x) torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0, check_dtype=False) - # compare with new model builder parameterized in the new way - defaults.pop("pretrained") - for weights_val in model_to_default_weights_mapping[model_name]: - kwargs = {**defaults, **_model_params.get(model_name, {}), "weights": weights_val} - model_new = models.__dict__[model_name](**kwargs) - model_new.eval().to(device=dev) - out_new = model_new(x) - torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0, check_dtype=False) From 0b49b703e05c6dbae8123eec217c7d61146bd3f2 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 10:42:43 +0100 Subject: [PATCH 09/17] removing unused imports --- test/test_prototype_models.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 183aedc32f3..5252c5f5cfc 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -1,18 +1,9 @@ -import functools -import io -import operator +# tests import os -import traceback -import warnings -from collections import OrderedDict - import pytest import torch import torch.fx -import torch.nn as nn -import torchvision -from _utils_internal import get_relative_path -from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda +from common_utils import set_rng_seed, cpu_and_gpu from test_models import _assert_expected, _model_params from torchvision import models as original_models from torchvision.prototype import models From b877d6b5e45be4ed52a92e7e7c89e7a36a7903dd Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 10:46:52 +0100 Subject: [PATCH 10/17] applying ufmt --- test/test_prototype_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 5252c5f5cfc..dd0b0b5a765 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -1,5 +1,6 @@ # tests import os + import pytest import torch import torch.fx From 069c83243a40095ec56177aa7a1207ee3fd6a2a7 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 10:51:11 +0100 Subject: [PATCH 11/17] Update test/test_prototype_models.py Co-authored-by: Vasilis Vryniotis --- test/test_prototype_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index dd0b0b5a765..8e84b16b7d1 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -11,7 +11,6 @@ def get_available_classification_models(): - # TODO add a registration mechanism to torchvision.models return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] From 20fa4dacaf5e026f9351b6adf1001689c921d761 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 10:51:18 +0100 Subject: [PATCH 12/17] Update test/test_prototype_models.py Co-authored-by: Vasilis Vryniotis --- test/test_prototype_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 8e84b16b7d1..afc3006aba8 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -27,7 +27,6 @@ def test_classification_model(model_name, dev): input_shape = kwargs.pop("input_shape") model = models.__dict__[model_name](**kwargs) model.eval().to(device=dev) - # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests x = torch.rand(input_shape).to(device=dev) out = model(x) _assert_expected(out.cpu(), model_name, prec=0.1) From d59497353dc5b7b3a34a9536c3b488b82c4d2bcb Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 10:51:25 +0100 Subject: [PATCH 13/17] Update test/test_prototype_models.py Co-authored-by: Vasilis Vryniotis --- test/test_prototype_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index afc3006aba8..f677563f5c7 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -46,7 +46,6 @@ def test_old_vs_new_classification_factory(model_name, dev): input_shape = kwargs.pop("input_shape") model_old = original_models.__dict__[model_name](**kwargs) model_old.eval().to(device=dev) - # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests x = torch.rand(input_shape).to(device=dev) out_old = model_old(x) # comapre with new model builder parameterized in the old fashion way From da969ce98e671835bc6159d63ae2133534861217 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 10:51:33 +0100 Subject: [PATCH 14/17] Update test/test_prototype_models.py Co-authored-by: Vasilis Vryniotis --- test/test_prototype_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index f677563f5c7..77528da484d 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -48,7 +48,7 @@ def test_old_vs_new_classification_factory(model_name, dev): model_old.eval().to(device=dev) x = torch.rand(input_shape).to(device=dev) out_old = model_old(x) - # comapre with new model builder parameterized in the old fashion way + # compare with new model builder parameterized in the old fashion way model_new = models.__dict__[model_name](**kwargs) model_new.eval().to(device=dev) out_new = model_new(x) From b737bf6227c390ea44fb322a2bd10bdf01785814 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 10:51:39 +0100 Subject: [PATCH 15/17] Update test/test_prototype_models.py Co-authored-by: Vasilis Vryniotis --- test/test_prototype_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 77528da484d..8f89f99c6ec 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -1,4 +1,3 @@ -# tests import os import pytest From 814aeb8c5f84897bd396e3089dad705e67cf92b4 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 10:51:50 +0100 Subject: [PATCH 16/17] Update test/test_prototype_models.py Co-authored-by: Vasilis Vryniotis --- test/test_prototype_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 8f89f99c6ec..960925f0170 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -36,7 +36,6 @@ def test_classification_model(model_name, dev): @pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") def test_old_vs_new_classification_factory(model_name, dev): - set_rng_seed(0) defaults = { "pretrained": True, "input_shape": (1, 3, 224, 224), From 344b84c9032f1894d428190bd05e02aa85b2f254 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 20 Oct 2021 13:19:02 +0100 Subject: [PATCH 17/17] Update test/test_prototype_models.py Co-authored-by: Vasilis Vryniotis --- test/test_prototype_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 3909d0b83ae..c8a4218e65a 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -2,7 +2,6 @@ import pytest import torch -import torch.fx from common_utils import set_rng_seed, cpu_and_gpu from test_models import _assert_expected, _model_params from torchvision import models as original_models