7
7
from torchvision .transforms .functional import InterpolationMode
8
8
9
9
from ...models .densenet import DenseNet
10
- from ._api import Weights , WeightEntry
10
+ from ._api import WeightsEnum , Weights
11
11
from ._meta import _IMAGENET_CATEGORIES
12
12
from ._utils import _deprecated_param , _deprecated_positional , _ovewrite_named_param
13
13
14
14
15
15
__all__ = [
16
16
"DenseNet" ,
17
- "DenseNet121Weights " ,
18
- "DenseNet161Weights " ,
19
- "DenseNet169Weights " ,
20
- "DenseNet201Weights " ,
17
+ "DenseNet121_Weights " ,
18
+ "DenseNet161_Weights " ,
19
+ "DenseNet169_Weights " ,
20
+ "DenseNet201_Weights " ,
21
21
"densenet121" ,
22
22
"densenet161" ,
23
23
"densenet169" ,
24
24
"densenet201" ,
25
25
]
26
26
27
27
28
- def _load_state_dict (model : nn .Module , weights : Weights , progress : bool ) -> None :
28
+ def _load_state_dict (model : nn .Module , weights : WeightsEnum , progress : bool ) -> None :
29
29
# '.'s are no longer allowed in module names, but previous _DenseLayer
30
30
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
31
31
# They are also in the checkpoints in model_urls. This pattern is used
@@ -48,7 +48,7 @@ def _densenet(
48
48
growth_rate : int ,
49
49
block_config : Tuple [int , int , int , int ],
50
50
num_init_features : int ,
51
- weights : Optional [Weights ],
51
+ weights : Optional [WeightsEnum ],
52
52
progress : bool ,
53
53
** kwargs : Any ,
54
54
) -> DenseNet :
@@ -71,8 +71,8 @@ def _densenet(
71
71
}
72
72
73
73
74
- class DenseNet121Weights ( Weights ):
75
- ImageNet1K_Community = WeightEntry (
74
+ class DenseNet121_Weights ( WeightsEnum ):
75
+ ImageNet1K_V1 = Weights (
76
76
url = "https://download.pytorch.org/models/densenet121-a639ec97.pth" ,
77
77
transforms = partial (ImageNetEval , crop_size = 224 ),
78
78
meta = {
@@ -84,8 +84,8 @@ class DenseNet121Weights(Weights):
84
84
)
85
85
86
86
87
- class DenseNet161Weights ( Weights ):
88
- ImageNet1K_Community = WeightEntry (
87
+ class DenseNet161_Weights ( WeightsEnum ):
88
+ ImageNet1K_V1 = Weights (
89
89
url = "https://download.pytorch.org/models/densenet161-8d451a50.pth" ,
90
90
transforms = partial (ImageNetEval , crop_size = 224 ),
91
91
meta = {
@@ -97,8 +97,8 @@ class DenseNet161Weights(Weights):
97
97
)
98
98
99
99
100
- class DenseNet169Weights ( Weights ):
101
- ImageNet1K_Community = WeightEntry (
100
+ class DenseNet169_Weights ( WeightsEnum ):
101
+ ImageNet1K_V1 = Weights (
102
102
url = "https://download.pytorch.org/models/densenet169-b2777c0a.pth" ,
103
103
transforms = partial (ImageNetEval , crop_size = 224 ),
104
104
meta = {
@@ -110,8 +110,8 @@ class DenseNet169Weights(Weights):
110
110
)
111
111
112
112
113
- class DenseNet201Weights ( Weights ):
114
- ImageNet1K_Community = WeightEntry (
113
+ class DenseNet201_Weights ( WeightsEnum ):
114
+ ImageNet1K_V1 = Weights (
115
115
url = "https://download.pytorch.org/models/densenet201-c1103571.pth" ,
116
116
transforms = partial (ImageNetEval , crop_size = 224 ),
117
117
meta = {
@@ -123,41 +123,41 @@ class DenseNet201Weights(Weights):
123
123
)
124
124
125
125
126
- def densenet121 (weights : Optional [DenseNet121Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
126
+ def densenet121 (weights : Optional [DenseNet121_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
127
127
if type (weights ) == bool and weights :
128
128
_deprecated_positional (kwargs , "pretrained" , "weights" , True )
129
129
if "pretrained" in kwargs :
130
- weights = _deprecated_param (kwargs , "pretrained" , "weights" , DenseNet121Weights . ImageNet1K_Community )
131
- weights = DenseNet121Weights .verify (weights )
130
+ weights = _deprecated_param (kwargs , "pretrained" , "weights" , DenseNet121_Weights . ImageNet1K_V1 )
131
+ weights = DenseNet121_Weights .verify (weights )
132
132
133
133
return _densenet (32 , (6 , 12 , 24 , 16 ), 64 , weights , progress , ** kwargs )
134
134
135
135
136
- def densenet161 (weights : Optional [DenseNet161Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
136
+ def densenet161 (weights : Optional [DenseNet161_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
137
137
if type (weights ) == bool and weights :
138
138
_deprecated_positional (kwargs , "pretrained" , "weights" , True )
139
139
if "pretrained" in kwargs :
140
- weights = _deprecated_param (kwargs , "pretrained" , "weights" , DenseNet161Weights . ImageNet1K_Community )
141
- weights = DenseNet161Weights .verify (weights )
140
+ weights = _deprecated_param (kwargs , "pretrained" , "weights" , DenseNet161_Weights . ImageNet1K_V1 )
141
+ weights = DenseNet161_Weights .verify (weights )
142
142
143
143
return _densenet (48 , (6 , 12 , 36 , 24 ), 96 , weights , progress , ** kwargs )
144
144
145
145
146
- def densenet169 (weights : Optional [DenseNet169Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
146
+ def densenet169 (weights : Optional [DenseNet169_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
147
147
if type (weights ) == bool and weights :
148
148
_deprecated_positional (kwargs , "pretrained" , "weights" , True )
149
149
if "pretrained" in kwargs :
150
- weights = _deprecated_param (kwargs , "pretrained" , "weights" , DenseNet169Weights . ImageNet1K_Community )
151
- weights = DenseNet169Weights .verify (weights )
150
+ weights = _deprecated_param (kwargs , "pretrained" , "weights" , DenseNet169_Weights . ImageNet1K_V1 )
151
+ weights = DenseNet169_Weights .verify (weights )
152
152
153
153
return _densenet (32 , (6 , 12 , 32 , 32 ), 64 , weights , progress , ** kwargs )
154
154
155
155
156
- def densenet201 (weights : Optional [DenseNet201Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
156
+ def densenet201 (weights : Optional [DenseNet201_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
157
157
if type (weights ) == bool and weights :
158
158
_deprecated_positional (kwargs , "pretrained" , "weights" , True )
159
159
if "pretrained" in kwargs :
160
- weights = _deprecated_param (kwargs , "pretrained" , "weights" , DenseNet201Weights . ImageNet1K_Community )
161
- weights = DenseNet201Weights .verify (weights )
160
+ weights = _deprecated_param (kwargs , "pretrained" , "weights" , DenseNet201_Weights . ImageNet1K_V1 )
161
+ weights = DenseNet201_Weights .verify (weights )
162
162
163
163
return _densenet (32 , (6 , 12 , 48 , 32 ), 64 , weights , progress , ** kwargs )
0 commit comments