Skip to content

Commit e3767f8

Browse files
kazhangdatumbox
andauthored
add regnet_y_128gf factory function (#5176)
* add regnet_y_128gf * fix test * add expected test file * update regnet factory function, add to prototype as well * write torchscript to temp file instead bytesio in model test * docs * clear GPU memory * no_grad * nit Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent a8f2ded commit e3767f8

File tree

6 files changed

+42
-9
lines changed

6 files changed

+42
-9
lines changed

docs/source/models.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ You can construct a model with random weights by calling its constructor:
7676
regnet_y_8gf = models.regnet_y_8gf()
7777
regnet_y_16gf = models.regnet_y_16gf()
7878
regnet_y_32gf = models.regnet_y_32gf()
79+
regnet_y_128gf = models.regnet_y_128gf()
7980
regnet_x_400mf = models.regnet_x_400mf()
8081
regnet_x_800mf = models.regnet_x_800mf()
8182
regnet_x_1_6gf = models.regnet_x_1_6gf()
@@ -439,6 +440,7 @@ RegNet
439440
regnet_y_8gf
440441
regnet_y_16gf
441442
regnet_y_32gf
443+
regnet_y_128gf
442444
regnet_x_400mf
443445
regnet_x_800mf
444446
regnet_x_1_6gf

hubconf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
regnet_y_8gf,
2828
regnet_y_16gf,
2929
regnet_y_32gf,
30+
regnet_y_128gf,
3031
regnet_x_400mf,
3132
regnet_x_800mf,
3233
regnet_x_1_6gf,
Binary file not shown.

test/test_models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import contextlib
22
import functools
3-
import io
43
import operator
54
import os
65
import pkgutil
76
import sys
87
import traceback
98
import warnings
109
from collections import OrderedDict
10+
from tempfile import TemporaryDirectory
1111

1212
import pytest
1313
import torch
@@ -126,16 +126,16 @@ def assert_export_import_module(m, args):
126126

127127
def get_export_import_copy(m):
128128
"""Save and load a TorchScript model"""
129-
buffer = io.BytesIO()
130-
torch.jit.save(m, buffer)
131-
buffer.seek(0)
132-
imported = torch.jit.load(buffer)
129+
with TemporaryDirectory() as dir:
130+
path = os.path.join(dir, "script.pt")
131+
m.save(path)
132+
imported = torch.jit.load(path)
133133
return imported
134134

135135
m_import = get_export_import_copy(m)
136-
with freeze_rng_state():
136+
with torch.no_grad(), freeze_rng_state():
137137
results = m(*args)
138-
with freeze_rng_state():
138+
with torch.no_grad(), freeze_rng_state():
139139
results_from_imported = m_import(*args)
140140
tol = 3e-4
141141
torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol)
@@ -156,10 +156,10 @@ def get_export_import_copy(m):
156156

157157
sm = torch.jit.script(nn_module)
158158

159-
with freeze_rng_state():
159+
with torch.no_grad(), freeze_rng_state():
160160
eager_out = nn_module(*args)
161161

162-
with freeze_rng_state():
162+
with torch.no_grad(), freeze_rng_state():
163163
script_out = sm(*args)
164164
if unwrapper:
165165
script_out = unwrapper(script_out)

torchvision/models/regnet.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"regnet_y_8gf",
2727
"regnet_y_16gf",
2828
"regnet_y_32gf",
29+
"regnet_y_128gf",
2930
"regnet_x_400mf",
3031
"regnet_x_800mf",
3132
"regnet_x_1_6gf",
@@ -505,6 +506,18 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any
505506
return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs)
506507

507508

509+
def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
510+
"""
511+
Constructs a RegNetY_128GF architecture from
512+
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
513+
NOTE: Pretrained weights are not available for this model.
514+
"""
515+
params = BlockParams.from_init_params(
516+
depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
517+
)
518+
return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs)
519+
520+
508521
def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
509522
"""
510523
Constructs a RegNetX_400MF architecture from

torchvision/prototype/models/regnet.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"RegNet_Y_8GF_Weights",
2121
"RegNet_Y_16GF_Weights",
2222
"RegNet_Y_32GF_Weights",
23+
"RegNet_Y_128GF_Weights",
2324
"RegNet_X_400MF_Weights",
2425
"RegNet_X_800MF_Weights",
2526
"RegNet_X_1_6GF_Weights",
@@ -34,6 +35,7 @@
3435
"regnet_y_8gf",
3536
"regnet_y_16gf",
3637
"regnet_y_32gf",
38+
"regnet_y_128gf",
3739
"regnet_x_400mf",
3840
"regnet_x_800mf",
3941
"regnet_x_1_6gf",
@@ -253,6 +255,11 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
253255
default = ImageNet1K_V2
254256

255257

258+
class RegNet_Y_128GF_Weights(WeightsEnum):
259+
# weights are not available yet.
260+
pass
261+
262+
256263
class RegNet_X_400MF_Weights(WeightsEnum):
257264
ImageNet1K_V1 = Weights(
258265
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
@@ -501,6 +508,16 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress:
501508
return _regnet(params, weights, progress, **kwargs)
502509

503510

511+
@handle_legacy_interface(weights=("pretrained", None))
512+
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
513+
weights = RegNet_Y_128GF_Weights.verify(weights)
514+
515+
params = BlockParams.from_init_params(
516+
depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
517+
)
518+
return _regnet(params, weights, progress, **kwargs)
519+
520+
504521
@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.ImageNet1K_V1))
505522
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
506523
weights = RegNet_X_400MF_Weights.verify(weights)

0 commit comments

Comments
 (0)