Skip to content

Commit d59ef15

Browse files
akihironittaBorda
authored andcommitted
Restore trainer.current_epoch after tuning (#7434)
* Add a test * Save and restore current_epoch * Update CHANGELOG * alphabetical order (cherry picked from commit 710b144)
1 parent 1ae191c commit d59ef15

File tree

4 files changed

+16
-5
lines changed

4 files changed

+16
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525
- Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362))
2626

2727

28+
- Fixed `Trainer.current_epoch` not getting restored after tuning ([#7434](https://github.com/PyTorchLightning/pytorch-lightning/pull/7434))
29+
30+
2831
## [1.3.0] - 2021-05-06
2932

3033
### Added

pytorch_lightning/tuner/lr_finder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def __lr_finder_dump_params(trainer, model):
288288
'logger': trainer.logger,
289289
'max_steps': trainer.max_steps,
290290
'checkpoint_callback': trainer.checkpoint_callback,
291+
'current_epoch': trainer.current_epoch,
291292
'configure_optimizers': model.configure_optimizers,
292293
}
293294

@@ -297,6 +298,7 @@ def __lr_finder_restore_params(trainer, model):
297298
trainer.logger = trainer.__dumped_params['logger']
298299
trainer.callbacks = trainer.__dumped_params['callbacks']
299300
trainer.max_steps = trainer.__dumped_params['max_steps']
301+
trainer.current_epoch = trainer.__dumped_params['current_epoch']
300302
model.configure_optimizers = trainer.__dumped_params['configure_optimizers']
301303
del trainer.__dumped_params
302304

tests/tuner/test_lr_finder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,13 @@ def test_trainer_reset_correctly(tmpdir):
7777
)
7878

7979
changed_attributes = [
80-
'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback'
80+
'accumulate_grad_batches',
81+
'auto_lr_find',
82+
'callbacks',
83+
'checkpoint_callback',
84+
'current_epoch',
85+
'logger',
86+
'max_steps',
8187
]
8288
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
8389
trainer.tuner.lr_find(model, num_training=5)

tests/tuner/test_scale_batch_size.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ def test_trainer_reset_correctly(tmpdir):
111111
)
112112

113113
changed_attributes = [
114-
'max_steps',
115-
'weights_summary',
116-
'logger',
117114
'callbacks',
118115
'checkpoint_callback',
119-
'limit_train_batches',
120116
'current_epoch',
117+
'limit_train_batches',
118+
'logger',
119+
'max_steps',
120+
'weights_summary',
121121
]
122122
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
123123
trainer.tuner.scale_batch_size(model, max_trials=5)

0 commit comments

Comments
 (0)