Skip to content

Commit 381600d

Browse files
donlaparkotajrohitgr7carmocca
authored
fixes typing errors in auto_restart.py (#13904)
Co-authored-by: otaj <[email protected]> Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent e5395de commit 381600d

File tree

3 files changed

+63
-36
lines changed

3 files changed

+63
-36
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ module = [
5656
"pytorch_lightning.trainer.supporters",
5757
"pytorch_lightning.trainer.trainer",
5858
"pytorch_lightning.tuner.batch_size_scaling",
59-
"pytorch_lightning.utilities.auto_restart",
6059
"pytorch_lightning.utilities.data",
6160
]
6261
ignore_errors = "True"

src/pytorch_lightning/utilities/auto_restart.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from collections.abc import Sized
1415
from copy import deepcopy
1516
from dataclasses import dataclass, field
1617
from functools import partial, wraps
@@ -24,6 +25,7 @@
2425
DataLoader,
2526
IterableDataset,
2627
)
28+
from typing_extensions import TypedDict
2729

2830
import pytorch_lightning as pl
2931
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -34,6 +36,21 @@
3436
from pytorch_lightning.utilities.types import _Stateful
3537

3638

39+
class _IteratorStateDict(TypedDict):
40+
dataset_state: Dict[int, Any]
41+
sampler_state: Dict[int, Any]
42+
worker_id: int
43+
num_workers: int
44+
num_batches_fetched: int
45+
name: Optional[str]
46+
47+
48+
class _MergedIteratorStateDict(TypedDict):
49+
state: Dict[str, Any]
50+
latest_worker_id: int
51+
represent_map_dataset: Optional[bool]
52+
53+
3754
class FastForwardSampler(Sampler):
3855
"""This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations
3956
performed during an epoch.
@@ -45,7 +62,7 @@ class FastForwardSampler(Sampler):
4562
samples seen in the last iterations (for the current worker).
4663
"""
4764

48-
def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None:
65+
def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None:
4966
super().__init__(data_source=None)
5067
self._sampler = sampler
5168
self.restarting: bool = False
@@ -79,7 +96,7 @@ def __iter__(self) -> Iterator[Any]:
7996
self._counter = 0
8097
return self
8198

82-
def __next__(self):
99+
def __next__(self) -> Any:
83100
# the `state dict` was cached as workers were unavailable before.
84101
if self._cached_state_dict is not None:
85102
self._load_non_random_state(self._cached_state_dict)
@@ -109,6 +126,7 @@ def __next__(self):
109126
raise StopIteration
110127

111128
def __len__(self) -> int:
129+
assert isinstance(self._sampler, Sized)
112130
return len(self._sampler)
113131

114132
def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]:
@@ -161,7 +179,7 @@ class IteratorState:
161179
name: Optional[str] = None
162180

163181
@classmethod
164-
def from_state_dict(cls, state_dict) -> "IteratorState":
182+
def from_state_dict(cls, state_dict: _IteratorStateDict) -> "IteratorState":
165183
return cls(**state_dict)
166184

167185

@@ -173,22 +191,22 @@ class MergedIteratorState:
173191
worker states in this merged iterator state.
174192
"""
175193

176-
state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict)
194+
state: Dict = field(default_factory=dict)
177195
latest_worker_id: int = 0
178196
represent_map_dataset: Optional[bool] = None
179197

180198
def update(self, generator_name: Optional[str], new_state: IteratorState) -> None:
181199
# a map based dataset doesn't own a generator and therefore `generator_name` should be None.
182200
self.represent_map_dataset = generator_name is None
183-
if self.represent_map_dataset:
184-
state = self.state
201+
latest_worker_id = new_state.worker_id
202+
if generator_name is None:
203+
self.state[latest_worker_id] = new_state
185204
else:
186205
if generator_name not in self.state:
187206
self.state[generator_name] = {}
188207
state = self.state[generator_name]
208+
state[latest_worker_id] = new_state
189209

190-
latest_worker_id = new_state.worker_id
191-
state[latest_worker_id] = new_state
192210
self.latest_worker_id = latest_worker_id
193211

194212
@property
@@ -202,7 +220,7 @@ def dataset_states(self) -> Dict[int, Any]:
202220
return {k: self.state[k].dataset_state[k] for k in self.state.keys()}
203221

204222
@classmethod
205-
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
223+
def from_state_dict(cls, state_dict: _MergedIteratorStateDict) -> "MergedIteratorState":
206224
if state_dict["represent_map_dataset"]:
207225
state_dict["state"] = {
208226
worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items()
@@ -229,15 +247,15 @@ class CaptureMapDataset(Dataset):
229247
"""
230248

231249
def __init__(self, dataset: Dataset) -> None:
232-
self.dataset = dataset
233-
self._cached_state_dict = None
250+
self.dataset: Dataset = dataset
251+
self._cached_state_dict: Optional[Dict[int, Any]] = None
234252

235253
@property
236254
def worker_id(self) -> int:
237255
worker_info = get_worker_info()
238256
return worker_info.id if worker_info else 0
239257

240-
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
258+
def __getitem__(self, item: int) -> Tuple[Any, Dict[int, Dict]]:
241259
if self._cached_state_dict is not None:
242260
if self.worker_id in self._cached_state_dict:
243261
_set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
@@ -246,6 +264,7 @@ def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
246264
return self.dataset[item]
247265

248266
def __len__(self) -> int:
267+
assert isinstance(self.dataset, Sized)
249268
return len(self.dataset)
250269

251270
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
@@ -268,17 +287,18 @@ def __init__(self, dataset: IterableDataset) -> None:
268287
super().__init__()
269288
self.dataset = deepcopy(dataset)
270289
self.samplers: Optional[Dict[str, FastForwardSampler]] = None
271-
self._state_dict: Optional[Dict[int, Any]] = None
290+
self._state_dict: Optional[Dict[str, Any]] = None
272291
self._has_wrapped: bool = False
273292

274293
@property
275294
def sampler(self) -> Sampler:
276295
return self.dataset.sampler
277296

278297
def state_dict(self) -> Dict[str, Any]:
298+
assert self.samplers is not None
279299
return {k: v.state_dict() for k, v in self.samplers.items()}
280300

281-
def load_state_dict(self, state_dict: Dict[int, Any]) -> None:
301+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
282302
self._state_dict = deepcopy(state_dict)
283303

284304
def _wrap_generator_samplers(self) -> None:
@@ -311,7 +331,7 @@ def _wrap_generator_samplers(self) -> None:
311331

312332
self.reset_on_epoch()
313333

314-
def reset_on_epoch(self):
334+
def reset_on_epoch(self) -> None:
315335
self._state_dict = None
316336

317337
def __iter__(self) -> Iterator:
@@ -371,8 +391,8 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
371391
for _ in range(state_dict["previous_worker"] - 1):
372392
next(iter_dataloader._worker_queue_idx_cycle)
373393

374-
# we can finally call reset and apply prefecthing.
375-
iter_dataloader._reset = iter_dataloader._original_reset
394+
# we can finally call reset and apply prefetching.
395+
iter_dataloader._reset = iter_dataloader._original_reset # type: ignore[assignment]
376396
iter_dataloader._reset(dataloader, first_iter=True)
377397
# return the iterator
378398
return iter_dataloader
@@ -445,6 +465,7 @@ def wrapper() -> Any:
445465
]
446466
elif isinstance(dataset, CaptureMapDataset):
447467
ff_sampler = _find_fast_forward_samplers(dl)
468+
assert ff_sampler is not None
448469
state = [
449470
IteratorState(
450471
num_workers=dl.num_workers,
@@ -519,6 +540,7 @@ def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader,
519540

520541
# reload sampler state
521542
ff_sampler = _find_fast_forward_samplers(dataloader)
543+
assert ff_sampler is not None
522544
ff_sampler.load_state_dict(iterator_state.sampler_state)
523545

524546
# reload dataset state
@@ -610,18 +632,20 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor
610632
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}
611633

612634

613-
class _StatefulDataLoaderIter:
635+
class _StatefulDataLoaderIter(_BaseDataLoaderIter):
614636
"""This mixin is used to make PyTorch DataLoaderIter stateful."""
615637

616-
def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None:
638+
def __accumulate_state(self, sampler_state: Dict[int, Any]) -> None:
617639
# store sampler state within a queue alongside its idx.
618-
self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1
640+
self._sampler_state_idx: int = getattr(self, "_sampler_state_idx", 0) + 1
619641
self._sampler_state.append((sampler_state, self._sampler_state_idx))
620642

621643
def _store_sampler_state(self) -> None:
622644
"""This function is used to extract the sampler states if any."""
623-
sampler_state = {
624-
k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset"
645+
sampler_state: Dict[int, Any] = {
646+
k: v.state_dict() # type: ignore[misc]
647+
for k, v in self._loader.__dict__.items()
648+
if isinstance(v, _Stateful) and k != "dataset"
625649
}
626650
self.__accumulate_state(sampler_state)
627651

@@ -630,12 +654,12 @@ def _next_index(self) -> Any:
630654
self._store_sampler_state()
631655
return indexes
632656

633-
def _prepare_loader(self, loader):
657+
def _prepare_loader(self, loader: DataLoader) -> None:
634658
_add_capture_metadata_collate(loader)
635659
self._loader = loader
636660
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
637661
self.num_batches_fetched = 0
638-
self._sampler_state = []
662+
self._sampler_state: List[Tuple[Dict[int, Any], int]] = []
639663
self._sampler_state_idx = 0
640664

641665
def __del__(self) -> None:
@@ -680,7 +704,7 @@ def __init__(self, loader: DataLoader):
680704
super().__init__(loader)
681705

682706

683-
def _get_iterator(self) -> "_BaseDataLoaderIter":
707+
def _get_iterator(self: DataLoader) -> "_BaseDataLoaderIter":
684708
if not hasattr(self, "_lightning_fetcher"):
685709
raise MisconfigurationException(
686710
"A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader."
@@ -699,15 +723,15 @@ def _patch_dataloader_get_iterators() -> None:
699723
return
700724
if not hasattr(DataLoader, "_ori_get_iterator"):
701725
DataLoader._ori_get_iterator = DataLoader._get_iterator
702-
DataLoader._get_iterator = _get_iterator
726+
DataLoader._get_iterator = _get_iterator # type: ignore[assignment]
703727

704728

705729
def _teardown_dataloader_get_iterators() -> None:
706730
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
707731
# cleanup the get_iterator replacement in case of Fault-tolerance.
708732
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
709733
if get_iterator:
710-
DataLoader._get_iterator = get_iterator
734+
DataLoader._get_iterator = get_iterator # type: ignore[assignment]
711735
del DataLoader._ori_get_iterator
712736

713737

@@ -781,16 +805,17 @@ def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -
781805
raise ValueError("Fault-tolerance supports only a single dataloader.")
782806

783807
for dataloader in dl_loaders:
808+
assert isinstance(dataloader, DataLoader)
784809
validator_fn = (
785810
_validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset
786811
)
787812
validator_fn(dataloader)
788813

789814

790-
def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
815+
def _collect_states_on_rank_zero_over_collection(state_dict: Dict, key: str = "state") -> Dict:
791816
"""This utility collects the state across processes for a collection of state."""
792817

793-
def fn(state: Dict):
818+
def fn(state: Dict) -> Dict:
794819
if key in state:
795820
return _collect_states_on_rank_zero(state)
796821
return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()}

src/pytorch_lightning/utilities/types.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from contextlib import contextmanager
2121
from dataclasses import dataclass
2222
from pathlib import Path
23-
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union
23+
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, TypeVar, Union
2424

2525
import torch
2626
from torch import Tensor
@@ -90,21 +90,24 @@ def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
9090
...
9191

9292

93+
_DictKey = TypeVar("_DictKey")
94+
95+
9396
@runtime_checkable
94-
class _Stateful(Protocol):
97+
class _Stateful(Protocol[_DictKey]):
9598
"""This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`."""
9699

97-
def state_dict(self) -> Dict[str, Any]:
100+
def state_dict(self) -> Dict[_DictKey, Any]:
98101
...
99102

100-
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
103+
def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None:
101104
...
102105

103106

104107
# Inferred from `torch.optim.lr_scheduler.pyi`
105108
# Missing attributes were added to improve typing
106109
@runtime_checkable
107-
class _LRScheduler(_Stateful, Protocol):
110+
class _LRScheduler(_Stateful[str], Protocol):
108111
optimizer: Optimizer
109112
base_lrs: List[float]
110113

@@ -118,7 +121,7 @@ def step(self, epoch: Optional[int] = None) -> None:
118121
# Inferred from `torch.optim.lr_scheduler.pyi`
119122
# Missing attributes were added to improve typing
120123
@runtime_checkable
121-
class ReduceLROnPlateau(_Stateful, Protocol):
124+
class ReduceLROnPlateau(_Stateful[str], Protocol):
122125
in_cooldown: bool
123126
optimizer: Optimizer
124127

0 commit comments

Comments
 (0)