Skip to content

Commit 8ba6304

Browse files
authored
Increment the total batch idx before the accumulation early exit (#7692)
* Increment the total batch idx before the accumulation early exit * Update CHANGELOG
1 parent fe1c4ca commit 8ba6304

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
127127
- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677))
128128

129129

130+
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))
131+
132+
130133
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
131134

132135

pytorch_lightning/trainer/training_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ def run_training_epoch(self):
529529
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
530530
self.trainer.checkpoint_connector.has_trained = True
531531

532+
self.total_batch_idx += 1
533+
532534
# max steps reached, end training
533535
if (
534536
self.max_steps is not None and self.max_steps <= self.global_step + 1
@@ -542,8 +544,6 @@ def run_training_epoch(self):
542544
if self.trainer.should_stop:
543545
break
544546

545-
self.total_batch_idx += 1
546-
547547
# stop epoch if we limited the number of training batches
548548
if self._num_training_batches_reached(is_last_batch):
549549
break

tests/tuner/test_lr_finder.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -197,31 +197,24 @@ def test_datamodule_parameter(tmpdir):
197197

198198

199199
def test_accumulation_and_early_stopping(tmpdir):
200-
""" Test that early stopping of learning rate finder works, and that
201-
accumulation also works for this feature """
200+
""" Test that early stopping of learning rate finder works, and that accumulation also works for this feature """
202201

203-
hparams = EvalModelTemplate.get_default_hparams()
204-
model = EvalModelTemplate(**hparams)
202+
class TestModel(BoringModel):
205203

206-
before_lr = hparams.get('learning_rate')
207-
# logger file to get meta
204+
def __init__(self):
205+
super().__init__()
206+
self.lr = 1e-3
207+
208+
model = TestModel()
208209
trainer = Trainer(
209210
default_root_dir=tmpdir,
210211
accumulate_grad_batches=2,
211212
)
212-
213213
lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None)
214-
after_lr = lrfinder.suggestion()
215214

216-
expected_num_lrs = 100
217-
expected_batch_idx = 200 - 1
218-
219-
assert before_lr != after_lr, \
220-
'Learning rate was not altered after running learning rate finder'
221-
assert len(lrfinder.results['lr']) == expected_num_lrs, \
222-
'Early stopping for learning rate finder did not work'
223-
assert lrfinder._total_batch_idx == expected_batch_idx, \
224-
'Accumulation parameter did not work'
215+
assert lrfinder.suggestion() != 1e-3
216+
assert len(lrfinder.results['lr']) == 100
217+
assert lrfinder._total_batch_idx == 200
225218

226219

227220
def test_suggestion_parameters_work(tmpdir):

0 commit comments

Comments
 (0)