11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ from collections .abc import Sized
14
15
from copy import deepcopy
15
16
from dataclasses import dataclass , field
16
17
from functools import partial , wraps
24
25
DataLoader ,
25
26
IterableDataset ,
26
27
)
28
+ from typing_extensions import TypedDict
27
29
28
30
import pytorch_lightning as pl
29
31
from pytorch_lightning .utilities .apply_func import apply_to_collection
34
36
from pytorch_lightning .utilities .types import _Stateful
35
37
36
38
39
+ class _IteratorStateDict (TypedDict ):
40
+ dataset_state : Dict [int , Any ]
41
+ sampler_state : Dict [int , Any ]
42
+ worker_id : int
43
+ num_workers : int
44
+ num_batches_fetched : int
45
+ name : Optional [str ]
46
+
47
+
48
+ class _MergedIteratorStateDict (TypedDict ):
49
+ state : Dict [str , Any ]
50
+ latest_worker_id : int
51
+ represent_map_dataset : Optional [bool ]
52
+
53
+
37
54
class FastForwardSampler (Sampler ):
38
55
"""This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations
39
56
performed during an epoch.
@@ -45,7 +62,7 @@ class FastForwardSampler(Sampler):
45
62
samples seen in the last iterations (for the current worker).
46
63
"""
47
64
48
- def __init__ (self , sampler : Union [ Sampler , Generator ] , attr_name : Optional [str ] = None ) -> None :
65
+ def __init__ (self , sampler : Iterator , attr_name : Optional [str ] = None ) -> None :
49
66
super ().__init__ (data_source = None )
50
67
self ._sampler = sampler
51
68
self .restarting : bool = False
@@ -79,7 +96,7 @@ def __iter__(self) -> Iterator[Any]:
79
96
self ._counter = 0
80
97
return self
81
98
82
- def __next__ (self ):
99
+ def __next__ (self ) -> Any :
83
100
# the `state dict` was cached as workers were unavailable before.
84
101
if self ._cached_state_dict is not None :
85
102
self ._load_non_random_state (self ._cached_state_dict )
@@ -109,6 +126,7 @@ def __next__(self):
109
126
raise StopIteration
110
127
111
128
def __len__ (self ) -> int :
129
+ assert isinstance (self ._sampler , Sized )
112
130
return len (self ._sampler )
113
131
114
132
def state_dict (self , num_batches_processed : Optional [int ] = None ) -> Dict [int , Dict [str , int ]]:
@@ -161,7 +179,7 @@ class IteratorState:
161
179
name : Optional [str ] = None
162
180
163
181
@classmethod
164
- def from_state_dict (cls , state_dict ) -> "IteratorState" :
182
+ def from_state_dict (cls , state_dict : _IteratorStateDict ) -> "IteratorState" :
165
183
return cls (** state_dict )
166
184
167
185
@@ -173,22 +191,22 @@ class MergedIteratorState:
173
191
worker states in this merged iterator state.
174
192
"""
175
193
176
- state : Union [ Dict [ Union [ int , str ], Union [ Dict [ str , IteratorState ], IteratorState ]]] = field (default_factory = dict )
194
+ state : Dict = field (default_factory = dict )
177
195
latest_worker_id : int = 0
178
196
represent_map_dataset : Optional [bool ] = None
179
197
180
198
def update (self , generator_name : Optional [str ], new_state : IteratorState ) -> None :
181
199
# a map based dataset doesn't own a generator and therefore `generator_name` should be None.
182
200
self .represent_map_dataset = generator_name is None
183
- if self .represent_map_dataset :
184
- state = self .state
201
+ latest_worker_id = new_state .worker_id
202
+ if generator_name is None :
203
+ self .state [latest_worker_id ] = new_state
185
204
else :
186
205
if generator_name not in self .state :
187
206
self .state [generator_name ] = {}
188
207
state = self .state [generator_name ]
208
+ state [latest_worker_id ] = new_state
189
209
190
- latest_worker_id = new_state .worker_id
191
- state [latest_worker_id ] = new_state
192
210
self .latest_worker_id = latest_worker_id
193
211
194
212
@property
@@ -202,7 +220,7 @@ def dataset_states(self) -> Dict[int, Any]:
202
220
return {k : self .state [k ].dataset_state [k ] for k in self .state .keys ()}
203
221
204
222
@classmethod
205
- def from_state_dict (cls , state_dict ) -> "MergedIteratorState" :
223
+ def from_state_dict (cls , state_dict : _MergedIteratorStateDict ) -> "MergedIteratorState" :
206
224
if state_dict ["represent_map_dataset" ]:
207
225
state_dict ["state" ] = {
208
226
worker_id : IteratorState .from_state_dict (state ) for worker_id , state in state_dict ["state" ].items ()
@@ -229,15 +247,15 @@ class CaptureMapDataset(Dataset):
229
247
"""
230
248
231
249
def __init__ (self , dataset : Dataset ) -> None :
232
- self .dataset = dataset
233
- self ._cached_state_dict = None
250
+ self .dataset : Dataset = dataset
251
+ self ._cached_state_dict : Optional [ Dict [ int , Any ]] = None
234
252
235
253
@property
236
254
def worker_id (self ) -> int :
237
255
worker_info = get_worker_info ()
238
256
return worker_info .id if worker_info else 0
239
257
240
- def __getitem__ (self , item ) -> Tuple [Any , Dict [int , Dict ]]:
258
+ def __getitem__ (self , item : int ) -> Tuple [Any , Dict [int , Dict ]]:
241
259
if self ._cached_state_dict is not None :
242
260
if self .worker_id in self ._cached_state_dict :
243
261
_set_rng_states (self ._cached_state_dict [self .worker_id ]["rng_states" ])
@@ -246,6 +264,7 @@ def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
246
264
return self .dataset [item ]
247
265
248
266
def __len__ (self ) -> int :
267
+ assert isinstance (self .dataset , Sized )
249
268
return len (self .dataset )
250
269
251
270
def load_state_dict (self , state_dict : Dict [int , Any ], latest_worker_id : int , num_workers : int ) -> None :
@@ -268,17 +287,18 @@ def __init__(self, dataset: IterableDataset) -> None:
268
287
super ().__init__ ()
269
288
self .dataset = deepcopy (dataset )
270
289
self .samplers : Optional [Dict [str , FastForwardSampler ]] = None
271
- self ._state_dict : Optional [Dict [int , Any ]] = None
290
+ self ._state_dict : Optional [Dict [str , Any ]] = None
272
291
self ._has_wrapped : bool = False
273
292
274
293
@property
275
294
def sampler (self ) -> Sampler :
276
295
return self .dataset .sampler
277
296
278
297
def state_dict (self ) -> Dict [str , Any ]:
298
+ assert self .samplers is not None
279
299
return {k : v .state_dict () for k , v in self .samplers .items ()}
280
300
281
- def load_state_dict (self , state_dict : Dict [int , Any ]) -> None :
301
+ def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
282
302
self ._state_dict = deepcopy (state_dict )
283
303
284
304
def _wrap_generator_samplers (self ) -> None :
@@ -311,7 +331,7 @@ def _wrap_generator_samplers(self) -> None:
311
331
312
332
self .reset_on_epoch ()
313
333
314
- def reset_on_epoch (self ):
334
+ def reset_on_epoch (self ) -> None :
315
335
self ._state_dict = None
316
336
317
337
def __iter__ (self ) -> Iterator :
@@ -371,8 +391,8 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
371
391
for _ in range (state_dict ["previous_worker" ] - 1 ):
372
392
next (iter_dataloader ._worker_queue_idx_cycle )
373
393
374
- # we can finally call reset and apply prefecthing .
375
- iter_dataloader ._reset = iter_dataloader ._original_reset
394
+ # we can finally call reset and apply prefetching .
395
+ iter_dataloader ._reset = iter_dataloader ._original_reset # type: ignore[assignment]
376
396
iter_dataloader ._reset (dataloader , first_iter = True )
377
397
# return the iterator
378
398
return iter_dataloader
@@ -445,6 +465,7 @@ def wrapper() -> Any:
445
465
]
446
466
elif isinstance (dataset , CaptureMapDataset ):
447
467
ff_sampler = _find_fast_forward_samplers (dl )
468
+ assert ff_sampler is not None
448
469
state = [
449
470
IteratorState (
450
471
num_workers = dl .num_workers ,
@@ -519,6 +540,7 @@ def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader,
519
540
520
541
# reload sampler state
521
542
ff_sampler = _find_fast_forward_samplers (dataloader )
543
+ assert ff_sampler is not None
522
544
ff_sampler .load_state_dict (iterator_state .sampler_state )
523
545
524
546
# reload dataset state
@@ -610,18 +632,20 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor
610
632
return {new_id : state [old_id ] for old_id , new_id in old_to_new_worker_id_map if old_id in state }
611
633
612
634
613
- class _StatefulDataLoaderIter :
635
+ class _StatefulDataLoaderIter ( _BaseDataLoaderIter ) :
614
636
"""This mixin is used to make PyTorch DataLoaderIter stateful."""
615
637
616
- def __accumulate_state (self , sampler_state : Dict [str , Any ]) -> None :
638
+ def __accumulate_state (self , sampler_state : Dict [int , Any ]) -> None :
617
639
# store sampler state within a queue alongside its idx.
618
- self ._sampler_state_idx = getattr (self , "_sampler_state_idx" , 0 ) + 1
640
+ self ._sampler_state_idx : int = getattr (self , "_sampler_state_idx" , 0 ) + 1
619
641
self ._sampler_state .append ((sampler_state , self ._sampler_state_idx ))
620
642
621
643
def _store_sampler_state (self ) -> None :
622
644
"""This function is used to extract the sampler states if any."""
623
- sampler_state = {
624
- k : v .state_dict () for k , v in self ._loader .__dict__ .items () if isinstance (v , _Stateful ) and k != "dataset"
645
+ sampler_state : Dict [int , Any ] = {
646
+ k : v .state_dict () # type: ignore[misc]
647
+ for k , v in self ._loader .__dict__ .items ()
648
+ if isinstance (v , _Stateful ) and k != "dataset"
625
649
}
626
650
self .__accumulate_state (sampler_state )
627
651
@@ -630,12 +654,12 @@ def _next_index(self) -> Any:
630
654
self ._store_sampler_state ()
631
655
return indexes
632
656
633
- def _prepare_loader (self , loader ) :
657
+ def _prepare_loader (self , loader : DataLoader ) -> None :
634
658
_add_capture_metadata_collate (loader )
635
659
self ._loader = loader
636
660
self ._data_fetcher : "pl.utilities.fetching.AbstractDataFetcher" = loader ._lightning_fetcher
637
661
self .num_batches_fetched = 0
638
- self ._sampler_state = []
662
+ self ._sampler_state : List [ Tuple [ Dict [ int , Any ], int ]] = []
639
663
self ._sampler_state_idx = 0
640
664
641
665
def __del__ (self ) -> None :
@@ -680,7 +704,7 @@ def __init__(self, loader: DataLoader):
680
704
super ().__init__ (loader )
681
705
682
706
683
- def _get_iterator (self ) -> "_BaseDataLoaderIter" :
707
+ def _get_iterator (self : DataLoader ) -> "_BaseDataLoaderIter" :
684
708
if not hasattr (self , "_lightning_fetcher" ):
685
709
raise MisconfigurationException (
686
710
"A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader."
@@ -699,15 +723,15 @@ def _patch_dataloader_get_iterators() -> None:
699
723
return
700
724
if not hasattr (DataLoader , "_ori_get_iterator" ):
701
725
DataLoader ._ori_get_iterator = DataLoader ._get_iterator
702
- DataLoader ._get_iterator = _get_iterator
726
+ DataLoader ._get_iterator = _get_iterator # type: ignore[assignment]
703
727
704
728
705
729
def _teardown_dataloader_get_iterators () -> None :
706
730
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
707
731
# cleanup the get_iterator replacement in case of Fault-tolerance.
708
732
get_iterator = getattr (DataLoader , "_ori_get_iterator" , None )
709
733
if get_iterator :
710
- DataLoader ._get_iterator = get_iterator
734
+ DataLoader ._get_iterator = get_iterator # type: ignore[assignment]
711
735
del DataLoader ._ori_get_iterator
712
736
713
737
@@ -781,16 +805,17 @@ def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -
781
805
raise ValueError ("Fault-tolerance supports only a single dataloader." )
782
806
783
807
for dataloader in dl_loaders :
808
+ assert isinstance (dataloader , DataLoader )
784
809
validator_fn = (
785
810
_validate_iterable_dataset if isinstance (dataloader .dataset , IterableDataset ) else _validate_map_dataset
786
811
)
787
812
validator_fn (dataloader )
788
813
789
814
790
- def _collect_states_on_rank_zero_over_collection (state_dict : Any , key : str = "state" ) -> Any :
815
+ def _collect_states_on_rank_zero_over_collection (state_dict : Dict , key : str = "state" ) -> Dict :
791
816
"""This utility collects the state across processes for a collection of state."""
792
817
793
- def fn (state : Dict ):
818
+ def fn (state : Dict ) -> Dict :
794
819
if key in state :
795
820
return _collect_states_on_rank_zero (state )
796
821
return {k : apply_to_collection (v , Dict , fn ) for k , v in state .items ()}
0 commit comments