Skip to content

Commit d377d0e

Browse files
ar90nawaelchliotajBorda
authored
Fix type hints of tuner/batch_size_scaling.py (#13518)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: otaj <[email protected]> Co-authored-by: Jirka <[email protected]>
1 parent 136d573 commit d377d0e

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ module = [
5959
"pytorch_lightning.callbacks.progress.rich_progress",
6060
"pytorch_lightning.trainer.trainer",
6161
"pytorch_lightning.trainer.connectors.checkpoint_connector",
62-
"pytorch_lightning.tuner.batch_size_scaling",
6362
"lightning_app.api.http_methods",
6463
"lightning_app.api.request_types",
6564
"lightning_app.cli.app-template.app",

src/pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
class BatchSizeFinder(Callback):
3232
SUPPORTED_MODES = ("power", "binsearch")
3333

34+
optimal_batch_size: Optional[int]
35+
3436
def __init__(
3537
self,
3638
mode: str = "power",

src/pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def scale_batch_size(
3535
init_val: int = 2,
3636
max_trials: int = 25,
3737
batch_arg_name: str = "batch_size",
38-
):
38+
) -> Optional[int]:
3939
if trainer.fast_dev_run:
4040
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
41-
return
41+
return None
4242

4343
# Save initial model, that is loaded after batch size is found
4444
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
@@ -141,7 +141,12 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any])
141141

142142

143143
def _run_power_scaling(
144-
trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params
144+
trainer: "pl.Trainer",
145+
pl_module: "pl.LightningModule",
146+
new_size: int,
147+
batch_arg_name: str,
148+
max_trials: int,
149+
params: Dict[str, Any],
145150
) -> int:
146151
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
147152
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
@@ -179,7 +184,12 @@ def _run_power_scaling(
179184

180185

181186
def _run_binary_scaling(
182-
trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params
187+
trainer: "pl.Trainer",
188+
pl_module: "pl.LightningModule",
189+
new_size: int,
190+
batch_arg_name: str,
191+
max_trials: int,
192+
params: Dict[str, Any],
183193
) -> int:
184194
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is
185195
encountered.
@@ -309,7 +319,7 @@ def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
309319
reset_fn(pl_module)
310320

311321

312-
def _try_loop_run(trainer: "pl.Trainer", params) -> None:
322+
def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
313323
if trainer.state.fn == "fit":
314324
loop = trainer.fit_loop
315325
else:

0 commit comments

Comments
 (0)