Skip to content

Commit a31ed67

Browse files
author
otaj
committed
finish typing
1 parent 1b9c987 commit a31ed67

File tree

4 files changed

+15
-26
lines changed

4 files changed

+15
-26
lines changed

src/lightning_lite/utilities/data.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union
2222

2323
from lightning_utilities.core.inheritance import get_all_subclasses
24-
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
24+
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler
2525

2626
from lightning_lite.utilities.enums import LightningEnum
2727
from lightning_lite.utilities.exceptions import MisconfigurationException
@@ -34,6 +34,7 @@ class _WrapAttrTag(LightningEnum):
3434
DEL = "del"
3535

3636
def __call__(self, *args: Any) -> None:
37+
fn: Union[Callable[[object, str], None], Callable[[object, str, Any], None]]
3738
if self == self.SET:
3839
fn = setattr
3940
else:
@@ -45,20 +46,20 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
4546
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset)
4647

4748

48-
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
49+
def has_len(dataloader: Union[DataLoader, Iterable, Dataset]) -> bool:
4950
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
5051
infinite dataloader."""
5152
try:
5253
# try getting the length
53-
if len(dataloader) == 0:
54+
if len(dataloader) == 0: # type: ignore [arg-type]
5455
rank_zero_warn(
5556
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
5657
)
5758
has_len = True
5859
except (TypeError, NotImplementedError):
5960
has_len = False
6061

61-
if has_len and has_iterable_dataset(dataloader):
62+
if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader):
6263
rank_zero_warn(
6364
"Your `IterableDataset` has `__len__` defined."
6465
" In combination with multi-process data loading (when num_workers > 1),"
@@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]
7677

7778
def _get_dataloader_init_args_and_kwargs(
7879
dataloader: DataLoader,
79-
sampler: Optional[Sampler],
80+
sampler: Union[Sampler, Iterable],
8081
disallow_batch_sampler: bool = False,
8182
) -> Tuple[Tuple[Any], Dict[str, Any]]:
8283
if not isinstance(dataloader, DataLoader):
@@ -170,7 +171,7 @@ def _get_dataloader_init_args_and_kwargs(
170171

171172
def _dataloader_init_kwargs_resolve_sampler(
172173
dataloader: DataLoader,
173-
sampler: Optional[Sampler],
174+
sampler: Union[Sampler, Iterable],
174175
disallow_batch_sampler: bool = False,
175176
) -> Dict[str, Any]:
176177
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its

src/pytorch_lightning/strategies/ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def _convert_to_poptorch_loader(
245245
return dataloader
246246

247247
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
248-
dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type]
248+
dataloader, sampler, mode, self.replication_factor > 1
249249
)
250250
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
251251
dataloader = _reinstantiate_wrapped_cls(

src/pytorch_lightning/utilities/auto_restart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class FastForwardSampler(Sampler):
6262
samples seen in the last iterations (for the current worker).
6363
"""
6464

65-
def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None:
65+
def __init__(self, sampler: Union[Sampler, Iterable], attr_name: Optional[str] = None) -> None:
6666
super().__init__(data_source=None)
6767
self._sampler = sampler
6868
self.restarting: bool = False

src/pytorch_lightning/utilities/data.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
)
3131

3232
import pytorch_lightning as pl
33-
from lightning_lite.utilities import LightningEnum
3433
from lightning_lite.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args
3534
from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterable_dataset
3635
from lightning_lite.utilities.data import has_len as new_has_len
@@ -41,23 +40,12 @@
4140
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4241
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
4342

44-
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]
43+
# might be supported in later releases, see https://github.com/python/mypy/pull/13297
44+
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc]
4545

4646
warning_cache = WarningCache()
4747

4848

49-
class _WrapAttrTag(LightningEnum):
50-
SET = "set"
51-
DEL = "del"
52-
53-
def __call__(self, *args: Any) -> None:
54-
if self == self.SET:
55-
fn = setattr
56-
else:
57-
fn = delattr
58-
return fn(*args)
59-
60-
6149
def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]:
6250
if isinstance(batch, Tensor):
6351
if batch.ndim == 0:
@@ -109,7 +97,7 @@ def extract_batch_size(batch: BType) -> int:
10997

11098
def has_len_all_ranks(
11199
dataloader: DataLoader,
112-
strategy: "pl.Strategy",
100+
strategy: "pl.strategies.Strategy",
113101
model: Union["pl.LightningModule", "pl.LightningDataModule"],
114102
) -> bool:
115103
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
@@ -158,7 +146,7 @@ def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]:
158146
"""
159147

160148
if new_has_len(dataloader):
161-
return len(dataloader)
149+
return len(dataloader) # type: ignore [arg-type]
162150

163151
return float("inf")
164152

@@ -173,7 +161,7 @@ def _update_dataloader(
173161

174162
def _get_dataloader_init_args_and_kwargs(
175163
dataloader: DataLoader,
176-
sampler: Optional[Sampler],
164+
sampler: Union[Sampler, Iterable],
177165
mode: Optional[RunningStage] = None,
178166
disallow_batch_sampler: bool = False,
179167
) -> Tuple[Tuple[Any], Dict[str, Any]]:
@@ -273,7 +261,7 @@ def _get_dataloader_init_args_and_kwargs(
273261

274262
def _dataloader_init_kwargs_resolve_sampler(
275263
dataloader: DataLoader,
276-
sampler: Optional[Sampler],
264+
sampler: Union[Sampler, Iterable],
277265
mode: Optional[RunningStage] = None,
278266
disallow_batch_sampler: bool = False,
279267
) -> Dict[str, Any]:

0 commit comments

Comments
 (0)