Skip to content

Commit 20f37b8

Browse files
awaelchlipre-commit-ci[bot]SkafteNicki
authored
add warning when Trainer(log_every_n_steps) not well chosen (#7734)
* add warning * update changelog * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * logger check * add docstring for test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte <[email protected]>
1 parent 41be61c commit 20f37b8

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6565
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))
6666

6767

68+
- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))
69+
70+
6871
### Changed
6972

7073
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)

pytorch_lightning/trainer/data_loading.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class TrainerDataLoadingMixin(ABC):
5151
test_dataloaders: Optional[List[DataLoader]]
5252
num_test_batches: List[Union[int, float]]
5353
limit_train_batches: Union[int, float]
54+
log_every_n_steps: int
5455
overfit_batches: Union[int, float]
5556
distributed_sampler_kwargs: dict
5657
accelerator: Accelerator
@@ -302,6 +303,13 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
302303
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
303304
self.val_check_batch = max(1, self.val_check_batch)
304305

306+
if self.logger and self.num_training_batches < self.log_every_n_steps:
307+
rank_zero_warn(
308+
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
309+
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
310+
f" you want to see logs for the training epoch."
311+
)
312+
305313
def _reset_eval_dataloader(
306314
self,
307315
model: LightningModule,

tests/trainer/test_dataloaders.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,25 @@ def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch):
895895
trainer.fit(model, train_dataloader=dataloader)
896896

897897

898+
def test_warning_with_small_dataloader_and_logging_interval(tmpdir):
899+
""" Test that a warning message is shown if the dataloader length is too short for the chosen logging interval. """
900+
model = BoringModel()
901+
dataloader = DataLoader(RandomDataset(32, length=10))
902+
model.train_dataloader = lambda: dataloader
903+
904+
with pytest.warns(UserWarning, match=r"The number of training samples \(10\) is smaller than the logging interval"):
905+
trainer = Trainer(
906+
default_root_dir=tmpdir,
907+
max_epochs=1,
908+
log_every_n_steps=11,
909+
)
910+
trainer.fit(model)
911+
912+
with pytest.warns(UserWarning, match=r"The number of training samples \(1\) is smaller than the logging interval"):
913+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1)
914+
trainer.fit(model)
915+
916+
898917
def test_warning_with_iterable_dataset_and_len(tmpdir):
899918
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
900919
model = BoringModel()

0 commit comments

Comments
 (0)