21
21
from typing import Any , Callable , Dict , Generator , Iterable , Optional , Tuple , Type , Union
22
22
23
23
from lightning_utilities .core .inheritance import get_all_subclasses
24
- from torch .utils .data import BatchSampler , DataLoader , IterableDataset , Sampler
24
+ from torch .utils .data import BatchSampler , DataLoader , Dataset , IterableDataset , Sampler
25
25
26
26
from lightning_lite .utilities .enums import LightningEnum
27
27
from lightning_lite .utilities .exceptions import MisconfigurationException
@@ -34,6 +34,7 @@ class _WrapAttrTag(LightningEnum):
34
34
DEL = "del"
35
35
36
36
def __call__ (self , * args : Any ) -> None :
37
+ fn : Union [Callable [[object , str ], None ], Callable [[object , str , Any ], None ]]
37
38
if self == self .SET :
38
39
fn = setattr
39
40
else :
@@ -45,20 +46,20 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
45
46
return hasattr (dataloader , "dataset" ) and isinstance (dataloader .dataset , IterableDataset )
46
47
47
48
48
- def has_len (dataloader : Union [DataLoader , Iterable ]) -> bool :
49
+ def has_len (dataloader : Union [DataLoader , Iterable , Dataset ]) -> bool :
49
50
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
50
51
infinite dataloader."""
51
52
try :
52
53
# try getting the length
53
- if len (dataloader ) == 0 :
54
+ if len (dataloader ) == 0 : # type: ignore [arg-type]
54
55
rank_zero_warn (
55
56
f"`{ dataloader .__class__ .__name__ } ` returned 0 length. Please make sure this was your intention."
56
57
)
57
58
has_len = True
58
59
except (TypeError , NotImplementedError ):
59
60
has_len = False
60
61
61
- if has_len and has_iterable_dataset (dataloader ):
62
+ if has_len and isinstance ( dataloader , DataLoader ) and has_iterable_dataset (dataloader ):
62
63
rank_zero_warn (
63
64
"Your `IterableDataset` has `__len__` defined."
64
65
" In combination with multi-process data loading (when num_workers > 1),"
@@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]
76
77
77
78
def _get_dataloader_init_args_and_kwargs (
78
79
dataloader : DataLoader ,
79
- sampler : Optional [Sampler ],
80
+ sampler : Union [Sampler , Iterable ],
80
81
disallow_batch_sampler : bool = False ,
81
82
) -> Tuple [Tuple [Any ], Dict [str , Any ]]:
82
83
if not isinstance (dataloader , DataLoader ):
@@ -170,7 +171,7 @@ def _get_dataloader_init_args_and_kwargs(
170
171
171
172
def _dataloader_init_kwargs_resolve_sampler (
172
173
dataloader : DataLoader ,
173
- sampler : Optional [Sampler ],
174
+ sampler : Union [Sampler , Iterable ],
174
175
disallow_batch_sampler : bool = False ,
175
176
) -> Dict [str , Any ]:
176
177
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
0 commit comments