2424from pytorch_lightning .utilities import AMPType
2525from pytorch_lightning .utilities .exceptions import MisconfigurationException
2626from tests .base import EvalModelTemplate
27- from tests .helpers import BoringDataModule , BoringModel
27+ from tests .helpers import BoringDataModule , BoringModel , RandomDataset
2828from tests .helpers .datamodules import MNISTDataModule
2929from tests .helpers .runif import RunIf
3030
3131
3232class BatchSizeDataModule (BoringDataModule ):
3333
34- def __init__ (self , batch_size = None ):
34+ def __init__ (self , batch_size ):
3535 super ().__init__ ()
3636 if batch_size is not None :
3737 self .batch_size = batch_size
@@ -42,22 +42,23 @@ def train_dataloader(self):
4242
4343class BatchSizeModel (BoringModel ):
4444
45- def __init__ (self , batch_size = None ):
45+ def __init__ (self , batch_size ):
4646 super ().__init__ ()
4747 if batch_size is not None :
4848 self .batch_size = batch_size
4949
50+ def train_dataloader (self ):
51+ return DataLoader (RandomDataset (32 , 64 ), batch_size = getattr (self , "batch_size" , 1 ))
5052
51- @pytest .mark .parametrize (
52- "model,datamodule" , [
53- (BatchSizeModel (2 ), None ),
54- (BatchSizeModel (2 ), BatchSizeDataModule (2 )),
55- (BatchSizeModel (2 ), BatchSizeDataModule (None )),
56- (BatchSizeModel (None ), BatchSizeDataModule (2 )),
57- (BatchSizeModel (16 ), BatchSizeDataModule (16 )),
58- ]
59- )
60- def test_scale_batch_size_method_with_model_or_datamodule (tmpdir , model , datamodule ):
53+
54+ @pytest .mark .parametrize (["model_bs" , "dm_bs" ], [
55+ (2 , - 1 ),
56+ (2 , 2 ),
57+ (2 , None ),
58+ (None , 2 ),
59+ (16 , 16 ),
60+ ])
61+ def test_scale_batch_size_method_with_model_or_datamodule (tmpdir , model_bs , dm_bs ):
6162 """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
6263 trainer = Trainer (
6364 default_root_dir = tmpdir ,
@@ -66,14 +67,21 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod
6667 max_epochs = 1 ,
6768 )
6869 tuner = Tuner (trainer )
69- new_batch_size = tuner .scale_batch_size (
70- model = model , mode = "binsearch" , init_val = 4 , max_trials = 2 , datamodule = datamodule
71- )
70+
71+ model = BatchSizeModel (model_bs )
72+ datamodule = BatchSizeDataModule (dm_bs ) if dm_bs != - 1 else None
73+
74+ new_batch_size = tuner .scale_batch_size (model , mode = "binsearch" , init_val = 4 , max_trials = 2 , datamodule = datamodule )
7275 assert new_batch_size == 16
73- if hasattr (model , "batch_size" ):
74- assert model .batch_size == 16
75- if datamodule is not None and hasattr (datamodule , "batch_size" ):
76- assert datamodule .batch_size == 16
76+
77+ if model_bs is not None :
78+ assert model .batch_size == new_batch_size
79+ if dm_bs == - 1 :
80+ # datamodule batch size takes precedence
81+ assert trainer .train_dataloader .loaders .batch_size == new_batch_size
82+ if dm_bs not in (- 1 , None ):
83+ assert datamodule .batch_size == new_batch_size
84+ assert trainer .train_dataloader .loaders .batch_size == new_batch_size
7785
7886
7987def test_model_reset_correctly (tmpdir ):
0 commit comments