Skip to content

Commit d88b88c

Browse files
committed
Add test. Improve comment
1 parent 2e03886 commit d88b88c

File tree

2 files changed

+30
-22
lines changed

2 files changed

+30
-22
lines changed

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _run_power_scaling(
161161
raise # some other error not memory related
162162

163163
if changed:
164-
# set train dataloader to None so it is reset
164+
# Force the train dataloader to reset as the batch size has changed
165165
trainer.train_dataloader = None
166166
else:
167167
break
@@ -196,7 +196,7 @@ def _run_binsearch_scaling(
196196
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
197197

198198
if changed:
199-
# set train dataloader to None so it is reset
199+
# Force the train dataloader to reset as the batch size has changed
200200
trainer.train_dataloader = None
201201
else:
202202
break

tests/tuner/test_scale_batch_size.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
from pytorch_lightning.utilities import AMPType
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626
from tests.base import EvalModelTemplate
27-
from tests.helpers import BoringDataModule, BoringModel
27+
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
2828
from tests.helpers.datamodules import MNISTDataModule
2929
from tests.helpers.runif import RunIf
3030

3131

3232
class BatchSizeDataModule(BoringDataModule):
3333

34-
def __init__(self, batch_size=None):
34+
def __init__(self, batch_size):
3535
super().__init__()
3636
if batch_size is not None:
3737
self.batch_size = batch_size
@@ -42,22 +42,23 @@ def train_dataloader(self):
4242

4343
class BatchSizeModel(BoringModel):
4444

45-
def __init__(self, batch_size=None):
45+
def __init__(self, batch_size):
4646
super().__init__()
4747
if batch_size is not None:
4848
self.batch_size = batch_size
4949

50+
def train_dataloader(self):
51+
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
5052

51-
@pytest.mark.parametrize(
52-
"model,datamodule", [
53-
(BatchSizeModel(2), None),
54-
(BatchSizeModel(2), BatchSizeDataModule(2)),
55-
(BatchSizeModel(2), BatchSizeDataModule(None)),
56-
(BatchSizeModel(None), BatchSizeDataModule(2)),
57-
(BatchSizeModel(16), BatchSizeDataModule(16)),
58-
]
59-
)
60-
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
53+
54+
@pytest.mark.parametrize(["model_bs", "dm_bs"], [
55+
(2, -1),
56+
(2, 2),
57+
(2, None),
58+
(None, 2),
59+
(16, 16),
60+
])
61+
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs):
6162
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
6263
trainer = Trainer(
6364
default_root_dir=tmpdir,
@@ -66,14 +67,21 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod
6667
max_epochs=1,
6768
)
6869
tuner = Tuner(trainer)
69-
new_batch_size = tuner.scale_batch_size(
70-
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
71-
)
70+
71+
model = BatchSizeModel(model_bs)
72+
datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None
73+
74+
new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
7275
assert new_batch_size == 16
73-
if hasattr(model, "batch_size"):
74-
assert model.batch_size == 16
75-
if datamodule is not None and hasattr(datamodule, "batch_size"):
76-
assert datamodule.batch_size == 16
76+
77+
if model_bs is not None:
78+
assert model.batch_size == new_batch_size
79+
if dm_bs == -1:
80+
# datamodule batch size takes precedence
81+
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
82+
if dm_bs not in (-1, None):
83+
assert datamodule.batch_size == new_batch_size
84+
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
7785

7886

7987
def test_model_reset_correctly(tmpdir):

0 commit comments

Comments
 (0)