Skip to content

Commit 9942f3e

Browse files
Jeff YangBorda
Jeff Yang
andauthored
Fix on_train_batch_start hook to end epoch early (#3700)
* init * add test * changelog and docs * fix test * Apply suggestion from code review Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 3ab730e commit 9942f3e

File tree

4 files changed

+28
-4
lines changed

4 files changed

+28
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5959

6060
### Fixed
6161

62+
- Fixed `on_train_batch_start` hook to end epoch early ([#3700](https://github.com/PyTorchLightning/pytorch-lightning/pull/3700))
63+
6264
- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917))
6365

6466
- Fixed RMSLE metric ([#3188](https://github.com/PyTorchLightning/pytorch-lightning/pull/3188))

docs/source/early_stopping.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Early stopping
1010

1111
Stopping an epoch early
1212
-----------------------
13-
You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.lightning.LightningModule.on_batch_start` to return ``-1`` when some condition is met.
13+
You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start` to return ``-1`` when some condition is met.
1414

1515
If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire run.
1616

pytorch_lightning/trainer/training_loop.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,10 @@ def run_training_epoch(self):
515515
# ------------------------------------
516516
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
517517

518+
# when returning -1 from train_step, we end epoch early
519+
if batch_output.signal == -1:
520+
break
521+
518522
# only track outputs when user implements training_epoch_end
519523
# otherwise we will build up unnecessary memory
520524
epoch_end_outputs = self.process_train_step_outputs(
@@ -527,9 +531,6 @@ def run_training_epoch(self):
527531
# TODO: add outputs to batches
528532
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
529533

530-
# when returning -1 from train_step, we end epoch early
531-
self.trainer.should_stop = batch_output.signal == -1
532-
533534
# -----------------------------------------
534535
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
535536
# -----------------------------------------

tests/models/test_hooks.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,24 @@ def transfer_batch_to_device(self, data, device):
108108
expected = torch.device('cuda', 0)
109109
assert model.hook_called
110110
assert batch_gpu.samples.device == batch_gpu.targets.device == expected
111+
112+
113+
@pytest.mark.parametrize(
114+
'max_epochs,batch_idx_',
115+
[(2, 5), (3, 8), (4, 12)]
116+
)
117+
def test_on_train_batch_start_hook(max_epochs, batch_idx_):
118+
class CurrentModel(EvalModelTemplate):
119+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
120+
if batch_idx == batch_idx_:
121+
return -1
122+
123+
model = CurrentModel()
124+
trainer = Trainer(max_epochs=max_epochs)
125+
trainer.fit(model)
126+
if batch_idx_ > len(model.val_dataloader()) - 1:
127+
assert trainer.batch_idx == len(model.val_dataloader()) - 1
128+
assert trainer.global_step == len(model.val_dataloader()) * max_epochs
129+
else:
130+
assert trainer.batch_idx == batch_idx_
131+
assert trainer.global_step == (batch_idx_ + 1) * max_epochs

0 commit comments

Comments
 (0)