diff --git a/docs/source/models.rst b/docs/source/models.rst index 9c750908b06..62c104cf927 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -76,6 +76,7 @@ You can construct a model with random weights by calling its constructor: regnet_y_8gf = models.regnet_y_8gf() regnet_y_16gf = models.regnet_y_16gf() regnet_y_32gf = models.regnet_y_32gf() + regnet_y_128gf = models.regnet_y_128gf() regnet_x_400mf = models.regnet_x_400mf() regnet_x_800mf = models.regnet_x_800mf() regnet_x_1_6gf = models.regnet_x_1_6gf() @@ -439,6 +440,7 @@ RegNet regnet_y_8gf regnet_y_16gf regnet_y_32gf + regnet_y_128gf regnet_x_400mf regnet_x_800mf regnet_x_1_6gf diff --git a/hubconf.py b/hubconf.py index 81b15ff9ff1..2b2eeb1c166 100644 --- a/hubconf.py +++ b/hubconf.py @@ -27,6 +27,7 @@ regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, + regnet_y_128gf, regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, diff --git a/test/expect/ModelTester.test_regnet_y_128gf_expect.pkl b/test/expect/ModelTester.test_regnet_y_128gf_expect.pkl new file mode 100644 index 00000000000..4f6037929cc Binary files /dev/null and b/test/expect/ModelTester.test_regnet_y_128gf_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index 2e0ed783849..f4f1828d8af 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,6 +1,5 @@ import contextlib import functools -import io import operator import os import pkgutil @@ -8,6 +7,7 @@ import traceback import warnings from collections import OrderedDict +from tempfile import TemporaryDirectory import pytest import torch @@ -126,16 +126,16 @@ def assert_export_import_module(m, args): def get_export_import_copy(m): """Save and load a TorchScript model""" - buffer = io.BytesIO() - torch.jit.save(m, buffer) - buffer.seek(0) - imported = torch.jit.load(buffer) + with TemporaryDirectory() as dir: + path = os.path.join(dir, "script.pt") + m.save(path) + imported = torch.jit.load(path) return imported m_import = get_export_import_copy(m) - with freeze_rng_state(): + with torch.no_grad(), freeze_rng_state(): results = m(*args) - with freeze_rng_state(): + with torch.no_grad(), freeze_rng_state(): results_from_imported = m_import(*args) tol = 3e-4 torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) @@ -156,10 +156,10 @@ def get_export_import_copy(m): sm = torch.jit.script(nn_module) - with freeze_rng_state(): + with torch.no_grad(), freeze_rng_state(): eager_out = nn_module(*args) - with freeze_rng_state(): + with torch.no_grad(), freeze_rng_state(): script_out = sm(*args) if unwrapper: script_out = unwrapper(script_out) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 85f53751dd0..1066ade43f4 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -26,6 +26,7 @@ "regnet_y_8gf", "regnet_y_16gf", "regnet_y_32gf", + "regnet_y_128gf", "regnet_x_400mf", "regnet_x_800mf", "regnet_x_1_6gf", @@ -505,6 +506,18 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) +def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_128GF architecture from + `"Designing Network Design Spaces" `_. + NOTE: Pretrained weights are not available for this model. + """ + params = BlockParams.from_init_params( + depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs + ) + return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs) + + def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_400MF architecture from diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index db1e86fdcab..c23a971bb5f 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -20,6 +20,7 @@ "RegNet_Y_8GF_Weights", "RegNet_Y_16GF_Weights", "RegNet_Y_32GF_Weights", + "RegNet_Y_128GF_Weights", "RegNet_X_400MF_Weights", "RegNet_X_800MF_Weights", "RegNet_X_1_6GF_Weights", @@ -34,6 +35,7 @@ "regnet_y_8gf", "regnet_y_16gf", "regnet_y_32gf", + "regnet_y_128gf", "regnet_x_400mf", "regnet_x_800mf", "regnet_x_1_6gf", @@ -253,6 +255,11 @@ class RegNet_Y_32GF_Weights(WeightsEnum): default = ImageNet1K_V2 +class RegNet_Y_128GF_Weights(WeightsEnum): + # weights are not available yet. + pass + + class RegNet_X_400MF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", @@ -501,6 +508,16 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: return _regnet(params, weights, progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", None)) +def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + weights = RegNet_Y_128GF_Weights.verify(weights) + + params = BlockParams.from_init_params( + depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs + ) + return _regnet(params, weights, progress, **kwargs) + + @handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.ImageNet1K_V1)) def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_X_400MF_Weights.verify(weights)