Skip to content

Commit 94e567e

Browse files
jxtngxrohitgr7otaj
authored
Fix mypy errors attributed to pytorch_lightning.trainer.connectors.data_connector.py (#13806)
Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: otaj <[email protected]>
1 parent e2221a0 commit 94e567e

File tree

7 files changed

+37
-30
lines changed

7 files changed

+37
-30
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ module = [
5555
"pytorch_lightning.profilers.pytorch",
5656
"pytorch_lightning.strategies.sharded",
5757
"pytorch_lightning.trainer.callback_hook",
58-
"pytorch_lightning.trainer.connectors.data_connector",
5958
"pytorch_lightning.trainer.supporters",
6059
"pytorch_lightning.trainer.trainer",
6160
"pytorch_lightning.tuner.batch_size_scaling",

src/pytorch_lightning/core/datamodule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from torch.utils.data import DataLoader, Dataset, IterableDataset
2020

21+
import pytorch_lightning as pl
2122
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
2223
from pytorch_lightning.core.mixins import HyperparametersMixin
2324
from pytorch_lightning.core.saving import _load_from_checkpoint
@@ -62,7 +63,7 @@ def teardown(self):
6263
def __init__(self) -> None:
6364
super().__init__()
6465
# Pointer to the trainer object
65-
self.trainer = None
66+
self.trainer: Optional["pl.Trainer"] = None
6667

6768
@classmethod
6869
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:

src/pytorch_lightning/core/module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
105105
self._use_amp: bool = False
106106

107107
# the precision used
108-
self.precision: int = 32
108+
self.precision: Union[int, str] = 32
109109

110110
# optionally can be set by user
111111
self._example_input_array = None
@@ -294,6 +294,7 @@ def loggers(self) -> List[Logger]:
294294
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
295295
if self._trainer:
296296
datahook_selector = self._trainer._data_connector._datahook_selector
297+
assert datahook_selector is not None
297298
obj = datahook_selector.get_instance(hook_name)
298299
if isinstance(obj, self.__class__):
299300
trainer_method = self._trainer._call_lightning_module_hook

src/pytorch_lightning/trainer/configuration_validator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
4646
elif trainer.state.fn == TrainerFn.PREDICTING:
4747
__verify_eval_loop_configuration(trainer, model, "predict")
4848

49-
__verify_batch_transfer_support(trainer, model)
49+
__verify_batch_transfer_support(trainer)
5050
_check_deprecated_callback_hooks(trainer)
5151
# TODO: Delete _check_on_hpc_hooks in v1.8
5252
_check_on_hpc_hooks(model)
@@ -149,10 +149,12 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning
149149
raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.")
150150

151151

152-
def __verify_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
152+
def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None:
153153
"""Raise Misconfiguration exception since these hooks are not supported in DP mode."""
154154
batch_transfer_hooks = ("transfer_batch_to_device", "on_after_batch_transfer")
155155
datahook_selector = trainer._data_connector._datahook_selector
156+
assert datahook_selector is not None
157+
156158
for hook in batch_transfer_hooks:
157159
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
158160
if isinstance(trainer.strategy, DataParallelStrategy) and (

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import multiprocessing
1515
import os
1616
from dataclasses import dataclass, field
17-
from typing import Any, Collection, List, Optional, Tuple, Union
17+
from typing import Any, Iterable, List, Optional, Tuple, Union
1818
from weakref import proxy
1919

2020
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
@@ -55,7 +55,7 @@ def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_
5555
self._test_dataloader_source = _DataLoaderSource(None, "")
5656
self._predict_dataloader_source = _DataLoaderSource(None, "")
5757

58-
self._datahook_selector = _DataHookSelector(None, None)
58+
self._datahook_selector: Optional[_DataHookSelector] = None
5959

6060
@property
6161
def _should_reload_train_dl(self) -> bool:
@@ -230,7 +230,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
230230
category=PossibleUserWarning,
231231
)
232232

233-
def _requires_distributed_sampler(self, dataloader) -> bool:
233+
def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
234234
return (
235235
self.trainer._accelerator_connector.replace_sampler_ddp
236236
and self.trainer._accelerator_connector.is_distributed
@@ -292,14 +292,18 @@ def _prepare_dataloader(
292292

293293
return dataloader
294294

295-
def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler:
295+
def _resolve_sampler(
296+
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
297+
) -> Union[Sampler, Iterable]:
296298
if self._requires_distributed_sampler(dataloader):
299+
distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs
300+
assert distributed_sampler_kwargs is not None
297301
sampler = self._get_distributed_sampler(
298302
dataloader,
299303
shuffle,
300304
mode=mode,
301305
overfit_batches=self.trainer.overfit_batches,
302-
**self.trainer.distributed_sampler_kwargs,
306+
**distributed_sampler_kwargs,
303307
)
304308

305309
# update docs too once this is resolved
@@ -357,7 +361,7 @@ def _reset_eval_dataloader(
357361
dataloaders = self._resolve_overfit_batches(dataloaders, mode)
358362

359363
if not isinstance(dataloaders, list):
360-
dataloaders = [dataloaders]
364+
dataloaders = [dataloaders] # type: ignore[assignment]
361365

362366
if any(dl is None for dl in dataloaders):
363367
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
@@ -426,7 +430,7 @@ def _reset_eval_dataloader(
426430

427431
return loader_num_batches, dataloaders
428432

429-
def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[DataLoader]]:
433+
def _request_dataloader(self, stage: RunningStage) -> TRAIN_DATALOADERS:
430434
"""Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage.
431435
432436
Returns:
@@ -447,10 +451,12 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
447451
return dataloader
448452

449453
@staticmethod
450-
def _resolve_overfit_batches(dataloaders: Collection[DataLoader], mode: RunningStage) -> Collection[DataLoader]:
454+
def _resolve_overfit_batches(
455+
dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], mode: RunningStage
456+
) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
451457
all_have_sequential_sampler = True
452458

453-
def resolve_has_no_sequential_sampler(dataloader: DataLoader):
459+
def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None:
454460
nonlocal all_have_sequential_sampler
455461
all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
456462
dataloader.sampler, SequentialSampler
@@ -460,19 +466,23 @@ def resolve_has_no_sequential_sampler(dataloader: DataLoader):
460466

461467
if not all_have_sequential_sampler:
462468
rank_zero_warn(
463-
"You requested to overfit but enabled training dataloader shuffling."
469+
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
464470
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
465471
)
466472

467473
def replace_sampler(dataloader: DataLoader) -> DataLoader:
468-
return _update_dataloader(dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode)
474+
return _update_dataloader(
475+
dataloader,
476+
sampler=SequentialSampler(dataloader.dataset), # type: ignore[arg-type]
477+
mode=mode,
478+
)
469479

470480
dataloaders = apply_to_collection(dataloaders, DataLoader, replace_sampler)
471481

472482
return dataloaders
473483

474484
@staticmethod
475-
def _check_eval_shuffling(dataloader, mode):
485+
def _check_eval_shuffling(dataloader: DataLoader, mode: RunningStage) -> None:
476486
# limit this warning only for samplers assigned automatically when shuffle is set
477487
if _is_dataloader_shuffled(dataloader):
478488
rank_zero_warn(
@@ -506,18 +516,14 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
506516
507517
If the source is a module, the method with the corresponding :attr:`name` gets called.
508518
"""
509-
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
510-
511-
if not self.name:
512-
return self.instance
513-
514-
if isinstance(self.instance, LightningModule):
519+
if isinstance(self.instance, pl.LightningModule):
515520
return self.instance.trainer._call_lightning_module_hook(self.name, pl_module=self.instance)
516521

517-
if isinstance(self.instance, LightningDataModule):
522+
if isinstance(self.instance, pl.LightningDataModule):
518523
method = getattr(self.instance, self.name)
519524
return method()
520525

526+
assert self.instance is not None
521527
return self.instance
522528

523529
def is_defined(self) -> bool:
@@ -532,9 +538,7 @@ def is_module(self) -> bool:
532538
533539
It does not check whether ``*_dataloader`` methods are actually overridden.
534540
"""
535-
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
536-
537-
return isinstance(self.instance, (LightningModule, LightningDataModule))
541+
return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule))
538542

539543

540544
@dataclass
@@ -553,7 +557,7 @@ class _DataHookSelector:
553557

554558
model: "pl.LightningModule"
555559
datamodule: Optional["pl.LightningDataModule"]
556-
_valid_hooks: Tuple[str] = field(
560+
_valid_hooks: Tuple[str, ...] = field(
557561
default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
558562
)
559563

src/pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2234,7 +2234,7 @@ def is_global_zero(self) -> bool:
22342234
return self.strategy.is_global_zero
22352235

22362236
@property
2237-
def distributed_sampler_kwargs(self) -> Optional[dict]:
2237+
def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]:
22382238
if isinstance(self.strategy, ParallelStrategy):
22392239
return self.strategy.distributed_sampler_kwargs
22402240

tests/tests_pytorch/trainer/flags/test_overfit_batches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def val_dataloader(self):
6666
model = TestModel()
6767
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)
6868

69-
with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"):
69+
with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"):
7070
trainer.fit(model)
7171

7272
assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)

0 commit comments

Comments
 (0)