Skip to content

Commit f9fccdf

Browse files
Move training_output validation to after train_step_end (#7868)
* move validation to after aggregation * changelog * add test for training_step_end * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3427cb7 commit f9fccdf

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
177177

178178
### Fixed
179179

180+
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
181+
180182
- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))
181183

182184
- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))

pytorch_lightning/trainer/training_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,10 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
296296

297297
self.trainer.logger_connector.cache_logged_metrics()
298298

299-
self._check_training_step_output(training_step_output)
300-
301299
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
302300

301+
self._check_training_step_output(training_step_output)
302+
303303
training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
304304
training_step_output, split_batch
305305
)

tests/trainer/loops/test_training_loop.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,12 @@ def validation_step(self, *args):
150150
@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )])
151151
def test_warning_invalid_trainstep_output(tmpdir, output):
152152

153-
class TestModel(BoringModel):
153+
class InvalidTrainStepModel(BoringModel):
154154

155155
def training_step(self, batch, batch_idx):
156156
return output
157157

158-
model = TestModel()
158+
model = InvalidTrainStepModel()
159159

160160
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
161161
with pytest.raises(
@@ -166,3 +166,22 @@ def training_step(self, batch, batch_idx):
166166
)
167167
):
168168
trainer.fit(model)
169+
170+
171+
def test_warning_valid_train_step_end(tmpdir):
172+
173+
class ValidTrainStepEndModel(BoringModel):
174+
175+
def training_step(self, batch, batch_idx):
176+
output = self(batch)
177+
return {'output': output, 'batch': batch}
178+
179+
def training_step_end(self, outputs):
180+
loss = self.loss(outputs['batch'], outputs['output'])
181+
return loss
182+
183+
# No error is raised
184+
model = ValidTrainStepEndModel()
185+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
186+
187+
trainer.fit(model)

0 commit comments

Comments
 (0)