Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))


- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))

Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def _run_power_scaling(
else:
raise # some other error not memory related

if not changed:
if changed:
# set train dataloader to None so it is reset
trainer.train_dataloader = None
else:
break
return new_size

Expand Down Expand Up @@ -192,7 +195,10 @@ def _run_binsearch_scaling(
else:
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')

if not changed:
if changed:
# set train dataloader to None so it is reset
trainer.train_dataloader = None
else:
break

except RuntimeError as exception:
Expand Down
1 change: 1 addition & 0 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, batch_size=None):
(BatchSizeModel(2), BatchSizeDataModule(2)),
(BatchSizeModel(2), BatchSizeDataModule(None)),
(BatchSizeModel(None), BatchSizeDataModule(2)),
(BatchSizeModel(16), BatchSizeDataModule(16)),
]
)
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
Expand Down