24
24
from pytorch_lightning .utilities import _TORCH_GREATER_EQUAL_1_6
25
25
from pytorch_lightning .utilities .exceptions import MisconfigurationException
26
26
from tests .helpers import BoringModel , RandomDataset
27
+ from tests .helpers .runif import RunIf
27
28
28
29
if _TORCH_GREATER_EQUAL_1_6 :
29
30
from pytorch_lightning .callbacks import StochasticWeightAveraging
31
+ from torch .optim .swa_utils import SWALR
30
32
31
33
class SwaTestModel (BoringModel ):
32
34
33
- def __init__ (self , batchnorm : bool = True ):
35
+ def __init__ (self , batchnorm : bool = True , interval : str = "epoch" ):
34
36
super ().__init__ ()
35
37
layers = [nn .Linear (32 , 32 )]
36
38
if batchnorm :
37
39
layers .append (nn .BatchNorm1d (32 ))
38
40
layers += [nn .ReLU (), nn .Linear (32 , 2 )]
39
41
self .layer = nn .Sequential (* layers )
42
+ self .interval = interval
40
43
41
44
def training_step (self , batch , batch_idx ):
42
45
output = self .forward (batch )
@@ -46,6 +49,14 @@ def training_step(self, batch, batch_idx):
46
49
def train_dataloader (self ):
47
50
return DataLoader (RandomDataset (32 , 64 ), batch_size = 2 )
48
51
52
+ def configure_optimizers (self ):
53
+ optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
54
+ return {
55
+ "optimizer" : optimizer ,
56
+ "scheduler" : torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 ),
57
+ "interval" : self .interval ,
58
+ }
59
+
49
60
class SwaTestCallback (StochasticWeightAveraging ):
50
61
update_parameters_calls : int = 0
51
62
transfer_weights_calls : int = 0
@@ -61,6 +72,10 @@ def transfer_weights(self, *args, **kwargs):
61
72
def on_train_epoch_start (self , trainer , * args ):
62
73
super ().on_train_epoch_start (trainer , * args )
63
74
assert trainer .train_loop ._skip_backward == (trainer .current_epoch > self .swa_end )
75
+ if self .swa_start <= trainer .current_epoch :
76
+ assert isinstance (trainer .lr_schedulers [0 ]["scheduler" ], SWALR )
77
+ assert trainer .lr_schedulers [0 ]["interval" ] == "epoch"
78
+ assert trainer .lr_schedulers [0 ]["frequency" ] == 1
64
79
65
80
def on_train_epoch_end (self , trainer , * args ):
66
81
super ().on_train_epoch_end (trainer , * args )
@@ -89,8 +104,8 @@ def on_train_end(self, trainer, pl_module):
89
104
90
105
91
106
@mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
92
- def train_with_swa (tmpdir , batchnorm = True , accelerator = None , gpus = None , num_processes = 1 ):
93
- model = SwaTestModel (batchnorm = batchnorm )
107
+ def train_with_swa (tmpdir , batchnorm = True , accelerator = None , gpus = None , num_processes = 1 , interval = "epoch" ):
108
+ model = SwaTestModel (batchnorm = batchnorm , interval = interval )
94
109
swa_start = 2
95
110
max_epochs = 5
96
111
swa_callback = SwaTestCallback (swa_epoch_start = swa_start , swa_lrs = 0.1 )
@@ -147,7 +162,13 @@ def test_swa_callback(tmpdir, batchnorm):
147
162
train_with_swa (tmpdir , batchnorm = batchnorm )
148
163
149
164
150
- @pytest .mark .skipif (not _TORCH_GREATER_EQUAL_1_6 , reason = "SWA available from PyTorch 1.6.0" )
165
+ @RunIf (min_torch = "1.6.0" )
166
+ @pytest .mark .parametrize ("interval" , ("epoch" , "step" ))
167
+ def test_swa_callback_scheduler_step (tmpdir , interval : bool ):
168
+ train_with_swa (tmpdir , interval = interval )
169
+
170
+
171
+ @RunIf (min_torch = "1.6.0" )
151
172
def test_swa_raises ():
152
173
with pytest .raises (MisconfigurationException , match = ">0 integer or a float between 0 and 1" ):
153
174
StochasticWeightAveraging (swa_epoch_start = 0 , swa_lrs = 0.1 )
0 commit comments