Skip to content

fix mypy typing errors in pytorch_lightning.utilities.data.py #13901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ warn_no_return = "False"
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.data",
"lightning_lite.utilities.data",
"pytorch_lightning.tuner.batch_size_scaling"
]
ignore_errors = "True"
33 changes: 17 additions & 16 deletions src/lightning_lite/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union

from lightning_utilities.core.inheritance import get_all_subclasses
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler

from lightning_lite.utilities.enums import LightningEnum
from lightning_lite.utilities.exceptions import MisconfigurationException
Expand All @@ -33,7 +33,8 @@ class _WrapAttrTag(LightningEnum):
SET = "set"
DEL = "del"

def __call__(self, *args):
def __call__(self, *args: Any) -> None:
fn: Union[Callable[[object, str], None], Callable[[object, str, Any], None]]
if self == self.SET:
fn = setattr
else:
Expand All @@ -45,20 +46,20 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset)


def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
def has_len(dataloader: Union[DataLoader, Iterable, Dataset]) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader."""
try:
# try getting the length
if len(dataloader) == 0:
if len(dataloader) == 0: # type: ignore [arg-type]
rank_zero_warn(
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
)
has_len = True
except (TypeError, NotImplementedError):
has_len = False

if has_len and has_iterable_dataset(dataloader):
if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader):
rank_zero_warn(
"Your `IterableDataset` has `__len__` defined."
" In combination with multi-process data loading (when num_workers > 1),"
Expand All @@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]

def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader,
sampler: Optional[Sampler],
sampler: Union[Sampler, Iterable],
disallow_batch_sampler: bool = False,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
if not isinstance(dataloader, DataLoader):
Expand All @@ -99,7 +100,7 @@ def _get_dataloader_init_args_and_kwargs(
arg_names = ()

# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc]
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
Expand Down Expand Up @@ -141,36 +142,36 @@ def _get_dataloader_init_args_and_kwargs(
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
sorted_required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args)
raise MisconfigurationException(
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
"`*_dataloader` hook of your module, we will do this for you."
f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` "
"inside a `*_dataloader` hook of your module, we will do this for you."
f" Otherwise, define {missing_args_message} inside your `__init__`."
)

if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
sorted_missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise TypeError(
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
"add the `__init__` arguments or allow passing `**kwargs`"
f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` "
"class, add the `__init__` arguments or allow passing `**kwargs`"
)

return dl_args, dl_kwargs


def _dataloader_init_kwargs_resolve_sampler(
dataloader: DataLoader,
sampler: Optional[Sampler],
sampler: Union[Sampler, Iterable],
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]:
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
Expand Down Expand Up @@ -334,7 +335,7 @@ def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""

@functools.wraps(method)
def wrapper(obj: Any, *args: Any):
def wrapper(obj: Any, *args: Any) -> None:
# First, let's find out if we're the first in inheritance chain calling the patched method.
name, *_ = args
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method"))
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _convert_to_poptorch_loader(
return dataloader

dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type]
dataloader, sampler, mode, self.replication_factor > 1
)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = _reinstantiate_wrapped_cls(
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class FastForwardSampler(Sampler):
samples seen in the last iterations (for the current worker).
"""

def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None:
def __init__(self, sampler: Union[Sampler, Iterable], attr_name: Optional[str] = None) -> None:
super().__init__(data_source=None)
self._sampler = sampler
self.restarting: bool = False
Expand Down
44 changes: 16 additions & 28 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)

import pytorch_lightning as pl
from lightning_lite.utilities import LightningEnum
from lightning_lite.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args
from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterable_dataset
from lightning_lite.utilities.data import has_len as new_has_len
Expand All @@ -41,24 +40,13 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn

BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]
# might be supported in later releases, see https://github.com/python/mypy/pull/13297
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc]

warning_cache = WarningCache()


class _WrapAttrTag(LightningEnum):
SET = "set"
DEL = "del"

def __call__(self, *args):
if self == self.SET:
fn = setattr
else:
fn = delattr
return fn(*args)


def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]:
if isinstance(batch, Tensor):
if batch.ndim == 0:
yield 1
Expand Down Expand Up @@ -109,7 +97,7 @@ def extract_batch_size(batch: BType) -> int:

def has_len_all_ranks(
dataloader: DataLoader,
strategy: "pl.Strategy",
strategy: "pl.strategies.Strategy",
model: Union["pl.LightningModule", "pl.LightningDataModule"],
) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
Expand Down Expand Up @@ -151,14 +139,14 @@ def has_len_all_ranks(
return has_len


def get_len(dataloader: DataLoader) -> Union[int, float]:
def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]:
"""Return the length of the given DataLoader.

If ``__len__`` method is not implemented, return float('inf').
"""

if new_has_len(dataloader):
return len(dataloader)
return len(dataloader) # type: ignore [arg-type]

return float("inf")

Expand All @@ -173,7 +161,7 @@ def _update_dataloader(

def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader,
sampler: Optional[Sampler],
sampler: Union[Sampler, Iterable],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
Expand All @@ -197,7 +185,7 @@ def _get_dataloader_init_args_and_kwargs(
arg_names = ()

# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc]
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
Expand Down Expand Up @@ -239,28 +227,28 @@ def _get_dataloader_init_args_and_kwargs(
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
sorted_required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args)
raise MisconfigurationException(
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
"`*_dataloader` hook of your module, we will do this for you."
f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` "
"inside a `*_dataloader` hook of your module, we will do this for you."
f" Otherwise, define {missing_args_message} inside your `__init__`."
)

if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
sorted_missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
"add the `__init__` arguments or allow passing `**kwargs`"
f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` "
"class, add the `__init__` arguments or allow passing `**kwargs`"
)

if _FaultTolerantMode.detect_current_mode().is_automatic:
Expand All @@ -273,7 +261,7 @@ def _get_dataloader_init_args_and_kwargs(

def _dataloader_init_kwargs_resolve_sampler(
dataloader: DataLoader,
sampler: Optional[Sampler],
sampler: Union[Sampler, Iterable],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]:
Expand Down