Skip to content

Commit b2bcad1

Browse files
authored
Fix tuner.scale_batch_size not finding batch size attribute when using datamodule (#5968)
1 parent 680e83a commit b2bcad1

File tree

4 files changed

+80
-3
lines changed

4 files changed

+80
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
178178
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
179179

180180

181+
- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968))
182+
183+
181184
## [1.2.1] - 2021-02-23
182185

183186
### Fixed

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def on_train_start(self):
105105
# provide rank to profiler
106106
self.trainer.profile_connector.on_train_start(self.trainer)
107107

108-
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
108+
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
109109
# clean hparams
110110
if hasattr(model, "hparams"):
111111
parsing.clean_namespace(model.hparams)

pytorch_lightning/tuner/tuning.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,20 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size):
3333
self.trainer.auto_lr_find = auto_lr_find
3434
self.trainer.auto_scale_batch_size = auto_scale_batch_size
3535

36-
def tune(self, model, train_dataloader, val_dataloaders, datamodule):
36+
def setup_trainer(
37+
self,
38+
model: LightningModule,
39+
train_dataloader: Optional[DataLoader] = None,
40+
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
41+
datamodule: LightningDataModule = None,
42+
):
43+
self.trainer.model_connector.copy_trainer_model_properties(model)
3744
# setup data, etc...
3845
self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
39-
4046
# hook
4147
self.trainer.data_connector.prepare_data(model)
4248

49+
def tune(self, model, train_dataloader, val_dataloaders, datamodule):
4350
# Run auto batch size scaling
4451
if self.trainer.auto_scale_batch_size:
4552
if isinstance(self.trainer.auto_scale_batch_size, bool):
@@ -104,6 +111,7 @@ def scale_batch_size(
104111
or datamodule.
105112
106113
"""
114+
self.setup_trainer(model, **fit_kwargs)
107115
return scale_batch_size(
108116
self.trainer,
109117
model,
@@ -128,6 +136,7 @@ def lr_find(
128136
datamodule: Optional[LightningDataModule] = None,
129137
update_attr: bool = False,
130138
):
139+
self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule)
131140
return lr_find(
132141
self.trainer,
133142
model,

tests/tuner/test_scale_batch_size.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
from torch.utils.data import DataLoader
16+
17+
from pytorch_lightning import Trainer
18+
from pytorch_lightning.tuner.tuning import Tuner
19+
from tests.helpers import BoringDataModule, BoringModel
20+
21+
22+
class BatchSizeDataModule(BoringDataModule):
23+
24+
def __init__(self, batch_size=None):
25+
super().__init__()
26+
if batch_size is not None:
27+
self.batch_size = batch_size
28+
29+
def train_dataloader(self):
30+
return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1))
31+
32+
33+
class BatchSizeModel(BoringModel):
34+
35+
def __init__(self, batch_size=None):
36+
super().__init__()
37+
if batch_size is not None:
38+
self.batch_size = batch_size
39+
40+
41+
@pytest.mark.parametrize(
42+
"model,datamodule", [
43+
(BatchSizeModel(2), None),
44+
(BatchSizeModel(2), BatchSizeDataModule(2)),
45+
(BatchSizeModel(2), BatchSizeDataModule(None)),
46+
(BatchSizeModel(None), BatchSizeDataModule(2)),
47+
]
48+
)
49+
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
50+
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
51+
trainer = Trainer(
52+
default_root_dir=tmpdir,
53+
limit_train_batches=1,
54+
limit_val_batches=0,
55+
max_epochs=1,
56+
)
57+
tuner = Tuner(trainer)
58+
new_batch_size = tuner.scale_batch_size(
59+
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
60+
)
61+
assert new_batch_size == 16
62+
if hasattr(model, "batch_size"):
63+
assert model.batch_size == 16
64+
if datamodule is not None and hasattr(datamodule, "batch_size"):
65+
assert datamodule.batch_size == 16

0 commit comments

Comments
 (0)