Skip to content

Commit 290a3f0

Browse files
justusschocklexierule
authored andcommitted
Add warning to trainstep output (#7779)
* Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update training_loop.py * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: Ethan Harris <[email protected]> * Update test_training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: ananthsub <[email protected]> * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_training_loop.py * Update training_loop.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * escape regex Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: ananthsub <[email protected]> (cherry picked from commit 6a0d503)
1 parent 55f52a1 commit 290a3f0

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

pytorch_lightning/trainer/training_loop.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from contextlib import contextmanager, suppress
1616
from copy import copy, deepcopy
17-
from typing import Any, Dict, List, Optional, Union
17+
from typing import Any, Dict, List, Mapping, Optional, Union
1818

1919
import numpy as np
2020
import torch
@@ -265,6 +265,16 @@ def _check_training_step_output(self, training_step_output):
265265
if training_step_output.grad_fn is None:
266266
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
267267
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
268+
elif self.trainer.lightning_module.automatic_optimization:
269+
if not any((
270+
isinstance(training_step_output, torch.Tensor),
271+
(isinstance(training_step_output, Mapping)
272+
and 'loss' in training_step_output), training_step_output is None
273+
)):
274+
raise MisconfigurationException(
275+
"In automatic optimization, `training_step` must either return a Tensor, "
276+
"a dict with key 'loss' or None (where the step will be skipped)."
277+
)
268278

269279
def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
270280
# give the PL module a result for logging

tests/trainer/loops/test_training_loop.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
15+
1416
import pytest
1517
import torch
1618

1719
from pytorch_lightning import seed_everything, Trainer
20+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1821
from tests.helpers import BoringModel
1922

2023

@@ -222,3 +225,56 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
222225
else:
223226
assert trainer.batch_idx == batch_idx_
224227
assert trainer.global_step == batch_idx_ * max_epochs
228+
229+
230+
def test_should_stop_mid_epoch(tmpdir):
231+
"""Test that training correctly stops mid epoch and that validation is still called at the right time"""
232+
233+
class TestModel(BoringModel):
234+
235+
def __init__(self):
236+
super().__init__()
237+
self.validation_called_at = None
238+
239+
def training_step(self, batch, batch_idx):
240+
if batch_idx == 4:
241+
self.trainer.should_stop = True
242+
return super().training_step(batch, batch_idx)
243+
244+
def validation_step(self, *args):
245+
self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step)
246+
return super().validation_step(*args)
247+
248+
model = TestModel()
249+
trainer = Trainer(
250+
default_root_dir=tmpdir,
251+
max_epochs=1,
252+
limit_train_batches=10,
253+
limit_val_batches=1,
254+
)
255+
trainer.fit(model)
256+
257+
assert trainer.current_epoch == 0
258+
assert trainer.global_step == 5
259+
assert model.validation_called_at == (0, 4)
260+
261+
262+
@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )])
263+
def test_warning_invalid_trainstep_output(tmpdir, output):
264+
265+
class TestModel(BoringModel):
266+
267+
def training_step(self, batch, batch_idx):
268+
return output
269+
270+
model = TestModel()
271+
272+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
273+
with pytest.raises(
274+
MisconfigurationException,
275+
match=re.escape(
276+
"In automatic optimization, `training_step` must either return a Tensor, "
277+
"a dict with key 'loss' or None (where the step will be skipped)."
278+
)
279+
):
280+
trainer.fit(model)

0 commit comments

Comments
 (0)