Skip to content

Commit 14f0fa0

Browse files
committed
Fixing imports
1 parent 76f017c commit 14f0fa0

38 files changed

+121
-141
lines changed

test/test_prototype_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import test_models as TM
66
import torch
77
from common_utils import cpu_and_gpu, needs_cuda
8-
from torchvision.prototype import models
98
from torchvision.models._api import WeightsEnum, Weights
9+
from torchvision.prototype import models
1010
from torchvision.prototype.models._utils import handle_legacy_interface
1111

1212
run_if_test_with_prototype = pytest.mark.skipif(

torchvision/_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import enum
2-
from typing import TypeVar, Type
2+
from typing import Sequence, TypeVar, Type
33

44
T = TypeVar("T", bound=enum.Enum)
55

@@ -18,3 +18,15 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
1818

1919
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
2020
pass
21+
22+
23+
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
24+
if not seq:
25+
return ""
26+
if len(seq) == 1:
27+
return f"'{seq[0]}'"
28+
29+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
30+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
31+
32+
return head + tail

torchvision/models/_utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import functools
2+
import inspect
23
import warnings
34
from collections import OrderedDict
45
from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union
56

67
from torch import nn
7-
from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw
88

9+
from .._utils import sequence_to_str
910
from ._api import WeightsEnum
1011

1112

@@ -88,6 +89,57 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) ->
8889
return new_v
8990

9091

92+
def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
93+
"""Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
94+
95+
For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
96+
97+
.. code::
98+
99+
def old_fn(foo, bar, baz=None):
100+
...
101+
102+
def new_fn(foo, *, bar, baz=None):
103+
...
104+
105+
Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
106+
and at the same time warn the user of the deprecation, this decorator can be used:
107+
108+
.. code::
109+
110+
@kwonly_to_pos_or_kw
111+
def new_fn(foo, *, bar, baz=None):
112+
...
113+
114+
new_fn("foo", "bar, "baz")
115+
"""
116+
params = inspect.signature(fn).parameters
117+
118+
try:
119+
keyword_only_start_idx = next(
120+
idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
121+
)
122+
except StopIteration:
123+
raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
124+
125+
keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
126+
127+
@functools.wraps(fn)
128+
def wrapper(*args: Any, **kwargs: Any) -> D:
129+
args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
130+
if keyword_only_args:
131+
keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
132+
warnings.warn(
133+
f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
134+
f"parameter(s) is deprecated. Please use keyword parameter(s) instead."
135+
)
136+
kwargs.update(keyword_only_kwargs)
137+
138+
return fn(*args, **kwargs)
139+
140+
return wrapper
141+
142+
91143
W = TypeVar("W", bound=WeightsEnum)
92144
M = TypeVar("M", bound=nn.Module)
93145
V = TypeVar("V")

torchvision/models/alexnet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import torch
55
import torch.nn as nn
66

7+
from ..transforms import ImageClassificationEval, InterpolationMode
78
from ..utils import _log_api_usage_once
89
from ._api import WeightsEnum, Weights
910
from ._meta import _IMAGENET_CATEGORIES
1011
from ._utils import handle_legacy_interface, _ovewrite_named_param
1112

12-
from ..transforms import ImageClassificationEval, InterpolationMode
13-
1413

1514
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
1615

@@ -94,4 +93,4 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
9493
if weights is not None:
9594
model.load_state_dict(weights.get_state_dict(progress=progress))
9695

97-
return model
96+
return model

torchvision/models/convnext.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77

88
from ..ops.misc import Conv2dNormActivation
99
from ..ops.stochastic_depth import StochasticDepth
10-
from ..utils import _log_api_usage_once
11-
1210
from ..transforms import ImageClassificationEval, InterpolationMode
13-
11+
from ..utils import _log_api_usage_once
1412
from ._api import WeightsEnum, Weights
1513
from ._meta import _IMAGENET_CATEGORIES
1614
from ._utils import handle_legacy_interface, _ovewrite_named_param

torchvision/models/densenet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
2-
from functools import partial
32
from collections import OrderedDict
3+
from functools import partial
44
from typing import Any, List, Optional, Tuple
55

66
import torch
@@ -11,7 +11,6 @@
1111

1212
from ..transforms import ImageClassificationEval, InterpolationMode
1313
from ..utils import _log_api_usage_once
14-
1514
from ._api import WeightsEnum, Weights
1615
from ._meta import _IMAGENET_CATEGORIES
1716
from ._utils import handle_legacy_interface, _ovewrite_named_param
@@ -277,6 +276,7 @@ def _densenet(
277276
"recipe": "https://github.com/pytorch/vision/pull/116",
278277
}
279278

279+
280280
class DenseNet121_Weights(WeightsEnum):
281281
IMAGENET1K_V1 = Weights(
282282
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
@@ -398,4 +398,4 @@ def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool
398398
"""
399399
weights = DenseNet201_Weights.verify(weights)
400400

401-
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
401+
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)

torchvision/models/efficientnet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
1313
from ..transforms import ImageClassificationEval, InterpolationMode
1414
from ..utils import _log_api_usage_once
15-
1615
from ._api import WeightsEnum, Weights
1716
from ._meta import _IMAGENET_CATEGORIES
1817
from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible

torchvision/models/googlenet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import warnings
2-
from functools import partial
32
from collections import namedtuple
3+
from functools import partial
44
from typing import Optional, Tuple, List, Callable, Any
55

66
import torch
77
import torch.nn as nn
88
import torch.nn.functional as F
99
from torch import Tensor
1010

11-
from ..utils import _log_api_usage_once
1211
from ..transforms import ImageClassificationEval, InterpolationMode
13-
12+
from ..utils import _log_api_usage_once
1413
from ._api import WeightsEnum, Weights
1514
from ._meta import _IMAGENET_CATEGORIES
1615
from ._utils import handle_legacy_interface, _ovewrite_named_param
@@ -333,4 +332,4 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T
333332
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
334333
)
335334

336-
return model
335+
return model

torchvision/models/inception.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
import torch.nn.functional as F
88
from torch import nn, Tensor
99

10-
from ..utils import _log_api_usage_once
11-
12-
1310
from ..transforms import ImageClassificationEval, InterpolationMode
11+
from ..utils import _log_api_usage_once
1412
from ._api import WeightsEnum, Weights
1513
from ._meta import _IMAGENET_CATEGORIES
1614
from ._utils import handle_legacy_interface, _ovewrite_named_param
@@ -465,4 +463,4 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo
465463
model.aux_logits = False
466464
model.AuxLogits = None
467465

468-
return model
466+
return model

torchvision/models/mnasnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from functools import partial
21
import warnings
2+
from functools import partial
33
from typing import Any, Dict, List, Optional
44

55
import torch

0 commit comments

Comments
 (0)