@@ -35,10 +35,10 @@ def scale_batch_size(
35
35
init_val : int = 2 ,
36
36
max_trials : int = 25 ,
37
37
batch_arg_name : str = "batch_size" ,
38
- ):
38
+ ) -> Optional [ int ] :
39
39
if trainer .fast_dev_run :
40
40
rank_zero_warn ("Skipping batch size scaler since `fast_dev_run` is enabled." )
41
- return
41
+ return None
42
42
43
43
# Save initial model, that is loaded after batch size is found
44
44
ckpt_path = os .path .join (trainer .default_root_dir , f".scale_batch_size_{ uuid .uuid4 ()} .ckpt" )
@@ -141,7 +141,12 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any])
141
141
142
142
143
143
def _run_power_scaling (
144
- trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , new_size : int , batch_arg_name : str , max_trials : int , params
144
+ trainer : "pl.Trainer" ,
145
+ pl_module : "pl.LightningModule" ,
146
+ new_size : int ,
147
+ batch_arg_name : str ,
148
+ max_trials : int ,
149
+ params : Dict [str , Any ],
145
150
) -> int :
146
151
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
147
152
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
@@ -179,7 +184,12 @@ def _run_power_scaling(
179
184
180
185
181
186
def _run_binary_scaling (
182
- trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , new_size : int , batch_arg_name : str , max_trials : int , params
187
+ trainer : "pl.Trainer" ,
188
+ pl_module : "pl.LightningModule" ,
189
+ new_size : int ,
190
+ batch_arg_name : str ,
191
+ max_trials : int ,
192
+ params : Dict [str , Any ],
183
193
) -> int :
184
194
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is
185
195
encountered.
@@ -309,7 +319,7 @@ def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
309
319
reset_fn (pl_module )
310
320
311
321
312
- def _try_loop_run (trainer : "pl.Trainer" , params ) -> None :
322
+ def _try_loop_run (trainer : "pl.Trainer" , params : Dict [ str , Any ] ) -> None :
313
323
if trainer .state .fn == "fit" :
314
324
loop = trainer .fit_loop
315
325
else :
0 commit comments