From 6eafa048a9951d3308b1f3381c6e3cd0afaae388 Mon Sep 17 00:00:00 2001 From: donlapark Date: Thu, 28 Jul 2022 18:38:39 +0700 Subject: [PATCH 01/24] fixes typing errors in auto_restart.py --- pyproject.toml | 1 - .../utilities/auto_restart.py | 88 +++++++++++++------ src/pytorch_lightning/utilities/types.py | 14 ++- 3 files changed, 72 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 32cc6e8452d25..f4cc9a5a14ac8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,6 @@ module = [ "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", - "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", "pytorch_lightning.utilities.meta", ] diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 3877a1ab3944c..195aeac6cb59e 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sized from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, TYPE_CHECKING, TypedDict, Union from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler from torch.utils.data.dataloader import ( @@ -31,9 +32,27 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states -from pytorch_lightning.utilities.types import _Stateful +from pytorch_lightning.utilities.types import _IntStateful, _Stateful +if TYPE_CHECKING: + _BaseLoaderIter = _BaseDataLoaderIter +else: + _BaseLoaderIter = object + +class IteratorStateDict(TypedDict): + dataset_state: Dict[int, Any] + sampler_state: Dict[int, Any] + worker_id: int + num_workers: int + num_batches_fetched: int + name: Optional[str] + +class MergedIteratorStateDict(TypedDict): + state: dict + latest_worker_id: int + represent_map_dataset: Optional[bool] + class FastForwardSampler(Sampler): """This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations performed during an epoch. @@ -79,7 +98,7 @@ def __iter__(self) -> Iterator[Any]: self._counter = 0 return self - def __next__(self): + def __next__(self) -> Any: # the `state dict` was cached as workers were unavailable before. if self._cached_state_dict is not None: self._load_non_random_state(self._cached_state_dict) @@ -109,6 +128,7 @@ def __next__(self): raise StopIteration def __len__(self) -> int: + assert isinstance(self._sampler, Sized) return len(self._sampler) def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: @@ -161,7 +181,7 @@ class IteratorState: name: Optional[str] = None @classmethod - def from_state_dict(cls, state_dict) -> "IteratorState": + def from_state_dict(cls, state_dict: IteratorStateDict) -> "IteratorState": return cls(**state_dict) @@ -173,22 +193,24 @@ class MergedIteratorState: worker states in this merged iterator state. """ - state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict) + state: Dict = field(default_factory=dict) latest_worker_id: int = 0 represent_map_dataset: Optional[bool] = None def update(self, generator_name: Optional[str], new_state: IteratorState) -> None: # a map based dataset doesn't own a generator and therefore `generator_name` should be None. self.represent_map_dataset = generator_name is None - if self.represent_map_dataset: - state = self.state + latest_worker_id = new_state.worker_id + if generator_name is None: + self.state = cast(Dict[int, IteratorState], self.state) + self.state[latest_worker_id] = new_state else: + self.state = cast(Dict[str, Dict[int, IteratorState]], self.state) if generator_name not in self.state: self.state[generator_name] = {} state = self.state[generator_name] + state[latest_worker_id] = new_state - latest_worker_id = new_state.worker_id - state[latest_worker_id] = new_state self.latest_worker_id = latest_worker_id @property @@ -202,12 +224,14 @@ def dataset_states(self) -> Dict[int, Any]: return {k: self.state[k].dataset_state[k] for k in self.state.keys()} @classmethod - def from_state_dict(cls, state_dict) -> "MergedIteratorState": + def from_state_dict(cls, state_dict: MergedIteratorStateDict) -> "MergedIteratorState": if state_dict["represent_map_dataset"]: + state_dict["state"] = cast(Dict[int, IteratorState], state_dict["state"]) state_dict["state"] = { worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items() } else: + state_dict["state"] = cast(Dict[str, Dict[int, IteratorState]], state_dict["state"]) state_dict["state"] = { sampler_name: { worker_id: IteratorState.from_state_dict(state) for worker_id, state in worker_state.items() @@ -229,15 +253,15 @@ class CaptureMapDataset(Dataset): """ def __init__(self, dataset: Dataset) -> None: - self.dataset = dataset - self._cached_state_dict = None + self.dataset: Dataset = dataset + self._cached_state_dict: Optional[Dict[int, Any]] = None @property def worker_id(self) -> int: worker_info = get_worker_info() return worker_info.id if worker_info else 0 - def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: + def __getitem__(self, item: int) -> Tuple[Any, Dict[int, Dict]]: if self._cached_state_dict is not None: if self.worker_id in self._cached_state_dict: _set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"]) @@ -246,6 +270,7 @@ def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: return self.dataset[item] def __len__(self) -> int: + assert isinstance(self.dataset, Sized) return len(self.dataset) def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None: @@ -268,7 +293,7 @@ def __init__(self, dataset: IterableDataset) -> None: super().__init__() self.dataset = deepcopy(dataset) self.samplers: Optional[Dict[str, FastForwardSampler]] = None - self._state_dict: Optional[Dict[int, Any]] = None + self._state_dict: Optional[Dict[str, Any]] = None self._has_wrapped: bool = False @property @@ -276,9 +301,10 @@ def sampler(self) -> Sampler: return self.dataset.sampler def state_dict(self) -> Dict[str, Any]: + assert self.samplers is not None return {k: v.state_dict() for k, v in self.samplers.items()} - def load_state_dict(self, state_dict: Dict[int, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._state_dict = deepcopy(state_dict) def _wrap_generator_samplers(self) -> None: @@ -294,7 +320,7 @@ def _wrap_generator_samplers(self) -> None: # it will be wrapped into a `FastForwardSampler`. for (generator_attr_name, generator) in dataset_sampler_generators.items(): - if isinstance(generator, Sampler): + if not isinstance(generator, Sampler): continue # wrap the generator into a `FastForwardSampler` @@ -311,7 +337,7 @@ def _wrap_generator_samplers(self) -> None: self.reset_on_epoch() - def reset_on_epoch(self): + def reset_on_epoch(self) -> None: self._state_dict = None def __iter__(self) -> Iterator: @@ -371,8 +397,8 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str for _ in range(state_dict["previous_worker"] - 1): next(iter_dataloader._worker_queue_idx_cycle) - # we can finally call reset and apply prefecthing. - iter_dataloader._reset = iter_dataloader._original_reset + # we can finally call reset and apply prefetching. + iter_dataloader._reset = iter_dataloader._original_reset # type: ignore[assignment] iter_dataloader._reset(dataloader, first_iter=True) # return the iterator return iter_dataloader @@ -445,6 +471,7 @@ def wrapper() -> Any: ] elif isinstance(dataset, CaptureMapDataset): ff_sampler = _find_fast_forward_samplers(dl) + assert ff_sampler is not None state = [ IteratorState( num_workers=dl.num_workers, @@ -519,6 +546,7 @@ def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader, # reload sampler state ff_sampler = _find_fast_forward_samplers(dataloader) + assert ff_sampler is not None ff_sampler.load_state_dict(iterator_state.sampler_state) # reload dataset state @@ -575,6 +603,7 @@ def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dic for worker_id in state_dict["state"].keys() } + assert isinstance(dataloader.dataset, _IntStateful) dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers)) @@ -610,10 +639,10 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state} -class _StatefulDataLoaderIter: +class _StatefulDataLoaderIter(_BaseLoaderIter): """This mixin is used to make PyTorch DataLoaderIter stateful.""" - def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None: + def __accumulate_state(self, sampler_state: Dict[int, Any]) -> None: # store sampler state within a queue alongside its idx. self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1 self._sampler_state.append((sampler_state, self._sampler_state_idx)) @@ -621,7 +650,7 @@ def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None: def _store_sampler_state(self) -> None: """This function is used to extract the sampler states if any.""" sampler_state = { - k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset" + int(k): v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset" } self.__accumulate_state(sampler_state) @@ -630,12 +659,12 @@ def _next_index(self) -> Any: self._store_sampler_state() return indexes - def _prepare_loader(self, loader): + def _prepare_loader(self, loader: DataLoader) -> None: _add_capture_metadata_collate(loader) self._loader = loader self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher self.num_batches_fetched = 0 - self._sampler_state = [] + self._sampler_state: List[Tuple[Dict[int, Any], int]] = [] self._sampler_state_idx = 0 def __del__(self) -> None: @@ -680,7 +709,7 @@ def __init__(self, loader: DataLoader): super().__init__(loader) -def _get_iterator(self) -> "_BaseDataLoaderIter": +def _get_iterator(self: DataLoader) -> "_BaseDataLoaderIter": if not hasattr(self, "_lightning_fetcher"): raise MisconfigurationException( "A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader." @@ -699,7 +728,7 @@ def _patch_dataloader_get_iterators() -> None: return if not hasattr(DataLoader, "_ori_get_iterator"): DataLoader._ori_get_iterator = DataLoader._get_iterator - DataLoader._get_iterator = _get_iterator + DataLoader._get_iterator = _get_iterator # type: ignore[assignment] def _teardown_dataloader_get_iterators() -> None: @@ -707,7 +736,7 @@ def _teardown_dataloader_get_iterators() -> None: # cleanup the get_iterator replacement in case of Fault-tolerance. get_iterator = getattr(DataLoader, "_ori_get_iterator", None) if get_iterator: - DataLoader._get_iterator = get_iterator + DataLoader._get_iterator = get_iterator # type: ignore[assignment] del DataLoader._ori_get_iterator @@ -781,16 +810,17 @@ def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) - raise ValueError("Fault-tolerance supports only a single dataloader.") for dataloader in dl_loaders: + assert isinstance(dataloader, DataLoader) validator_fn = ( _validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset ) validator_fn(dataloader) -def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any: +def _collect_states_on_rank_zero_over_collection(state_dict: Dict, key: str = "state") -> Dict: """This utility collects the state across processes for a collection of state.""" - def fn(state: Dict): + def fn(state: Dict) -> Dict: if key in state: return _collect_states_on_rank_zero(state) return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()} diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index f6c14d366805f..9ec0450e838e3 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -25,7 +25,8 @@ from torch import Tensor from torch._C._distributed_c10d import ProcessGroup from torch.optim import Optimizer -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.dataloader import _BaseDataLoaderIter from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable @@ -98,6 +99,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... +@runtime_checkable +class _IntStateful(Protocol): + """This class is used to detect if an object is stateful, whose states are integers, using `isinstance(obj, _Stateful)`.""" + + def state_dict(self) -> Dict[int, Any]: + ... + + def load_state_dict(self, state_dict: Dict[int, Any]) -> None: + ... + + # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable From 193556481319a693f807c35ed42eebdc4e13b77d Mon Sep 17 00:00:00 2001 From: donlapark Date: Thu, 28 Jul 2022 18:41:09 +0700 Subject: [PATCH 02/24] fixes typing errors in auto_restart.py --- src/pytorch_lightning/utilities/types.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 9ec0450e838e3..099acb71035b3 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -25,8 +25,7 @@ from torch import Tensor from torch._C._distributed_c10d import ProcessGroup from torch.optim import Optimizer -from torch.utils.data import DataLoader, Dataset -from torch.utils.data.dataloader import _BaseDataLoaderIter +from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable From b0409b8af7ae040bab4a8be6fa7291660dacc001 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Jul 2022 11:58:25 +0000 Subject: [PATCH 03/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../utilities/auto_restart.py | 24 ++++++++++++++++--- src/pytorch_lightning/utilities/types.py | 3 ++- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 195aeac6cb59e..5b904cb193cf1 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -15,7 +15,21 @@ from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps -from typing import Any, Callable, cast, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, TYPE_CHECKING, TypedDict, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + Iterable, + Iterator, + List, + Optional, + Tuple, + TYPE_CHECKING, + TypedDict, + Union, +) from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler from torch.utils.data.dataloader import ( @@ -34,12 +48,12 @@ from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states from pytorch_lightning.utilities.types import _IntStateful, _Stateful - if TYPE_CHECKING: _BaseLoaderIter = _BaseDataLoaderIter else: _BaseLoaderIter = object + class IteratorStateDict(TypedDict): dataset_state: Dict[int, Any] sampler_state: Dict[int, Any] @@ -48,11 +62,13 @@ class IteratorStateDict(TypedDict): num_batches_fetched: int name: Optional[str] + class MergedIteratorStateDict(TypedDict): state: dict latest_worker_id: int represent_map_dataset: Optional[bool] + class FastForwardSampler(Sampler): """This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations performed during an epoch. @@ -650,7 +666,9 @@ def __accumulate_state(self, sampler_state: Dict[int, Any]) -> None: def _store_sampler_state(self) -> None: """This function is used to extract the sampler states if any.""" sampler_state = { - int(k): v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset" + int(k): v.state_dict() + for k, v in self._loader.__dict__.items() + if isinstance(v, _Stateful) and k != "dataset" } self.__accumulate_state(sampler_state) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 099acb71035b3..bcd5dd153b1f7 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -100,7 +100,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @runtime_checkable class _IntStateful(Protocol): - """This class is used to detect if an object is stateful, whose states are integers, using `isinstance(obj, _Stateful)`.""" + """This class is used to detect if an object is stateful, whose states are integers, using `isinstance(obj, + _Stateful)`.""" def state_dict(self) -> Dict[int, Any]: ... From 192f155e11dec472a6ed05b0292a0e21b5ad11f8 Mon Sep 17 00:00:00 2001 From: donlapark Date: Thu, 28 Jul 2022 18:59:02 +0700 Subject: [PATCH 04/24] minor fix --- src/pytorch_lightning/utilities/auto_restart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 195aeac6cb59e..d768186ba1381 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -15,7 +15,8 @@ from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps -from typing import Any, Callable, cast, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, TYPE_CHECKING, TypedDict, Union +from typing import Any, Callable, cast, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union +from typing_extensions import TypedDict from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler from torch.utils.data.dataloader import ( From 217dbd5d8b68faf1bfa70e1ccb995f896ababbb6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Jul 2022 12:03:23 +0000 Subject: [PATCH 05/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../utilities/auto_restart.py | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 93175c856821f..fa96185dfafc6 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -15,24 +15,7 @@ from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps - - -from typing import ( - Any, - Callable, - cast, - Dict, - Generator, - Iterable, - Iterator, - List, - Optional, - Tuple, - TYPE_CHECKING, - Union, -) -from typing_extensions import TypedDict - +from typing import Any, Callable, cast, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler from torch.utils.data.dataloader import ( @@ -42,6 +25,7 @@ DataLoader, IterableDataset, ) +from typing_extensions import TypedDict import pytorch_lightning as pl from pytorch_lightning.utilities.apply_func import apply_to_collection From dd6b168245e3b2d04bb964c56c6be9f0e33b94af Mon Sep 17 00:00:00 2001 From: donlapark Date: Thu, 28 Jul 2022 20:02:46 +0700 Subject: [PATCH 06/24] minor fix --- src/pytorch_lightning/utilities/auto_restart.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index fa96185dfafc6..1c10046fc3d1c 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -645,15 +645,15 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor class _StatefulDataLoaderIter(_BaseLoaderIter): """This mixin is used to make PyTorch DataLoaderIter stateful.""" - def __accumulate_state(self, sampler_state: Dict[int, Any]) -> None: + def __accumulate_state(self, sampler_state: Dict[Union[int, str], Any]) -> None: # store sampler state within a queue alongside its idx. - self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1 + self._sampler_state_idx: int = getattr(self, "_sampler_state_idx", 0) + 1 self._sampler_state.append((sampler_state, self._sampler_state_idx)) def _store_sampler_state(self) -> None: """This function is used to extract the sampler states if any.""" - sampler_state = { - int(k): v.state_dict() + sampler_state: Dict[Union[int, str], Any] = { + k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset" } @@ -669,7 +669,7 @@ def _prepare_loader(self, loader: DataLoader) -> None: self._loader = loader self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher self.num_batches_fetched = 0 - self._sampler_state: List[Tuple[Dict[int, Any], int]] = [] + self._sampler_state: List[Tuple[Dict[Union[int, str], Any], int]] = [] self._sampler_state_idx = 0 def __del__(self) -> None: @@ -687,6 +687,7 @@ def _next_data(self) -> Any: # there is no workers within the samplers worker_id = list(state.keys())[0] + sampler_state = cast(Dict[int, Any], sampler_state) state = [ IteratorState( num_workers=self._loader.num_workers, From afb9612077354d9134025a0b12a3f2cb22c501cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Jul 2022 13:04:26 +0000 Subject: [PATCH 07/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/auto_restart.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 1c10046fc3d1c..97854544d8ded 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -653,9 +653,7 @@ def __accumulate_state(self, sampler_state: Dict[Union[int, str], Any]) -> None: def _store_sampler_state(self) -> None: """This function is used to extract the sampler states if any.""" sampler_state: Dict[Union[int, str], Any] = { - k: v.state_dict() - for k, v in self._loader.__dict__.items() - if isinstance(v, _Stateful) and k != "dataset" + k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset" } self.__accumulate_state(sampler_state) From 8d4ef40cc8cd9b43b4a8be9dc605cd875274ac02 Mon Sep 17 00:00:00 2001 From: donlapark Date: Thu, 28 Jul 2022 20:18:04 +0700 Subject: [PATCH 08/24] minor fix --- src/pytorch_lightning/utilities/auto_restart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 97854544d8ded..d79ab73757ceb 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -323,10 +323,11 @@ def _wrap_generator_samplers(self) -> None: # it will be wrapped into a `FastForwardSampler`. for (generator_attr_name, generator) in dataset_sampler_generators.items(): - if not isinstance(generator, Sampler): + if isinstance(generator, Sampler): continue # wrap the generator into a `FastForwardSampler` + assert isinstance(generator, (Sampler, Generator)) sampler = FastForwardSampler(generator, attr_name=generator_attr_name) # if `CaptureIterableDataset` was available, the sampler should reload its own state. From 45d918d73cfbe4bc8390b3a49306f3d5d155d2fd Mon Sep 17 00:00:00 2001 From: donlapark Date: Thu, 28 Jul 2022 20:30:57 +0700 Subject: [PATCH 09/24] minor fix --- src/pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index d79ab73757ceb..80794b6efe15b 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -327,7 +327,7 @@ def _wrap_generator_samplers(self) -> None: continue # wrap the generator into a `FastForwardSampler` - assert isinstance(generator, (Sampler, Generator)) + assert isinstance(generator, Generator) sampler = FastForwardSampler(generator, attr_name=generator_attr_name) # if `CaptureIterableDataset` was available, the sampler should reload its own state. From 0141f6eabd0dcfd4bed533d7593d771ea111fdde Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Thu, 28 Jul 2022 21:53:45 +0700 Subject: [PATCH 10/24] Minor --- src/pytorch_lightning/utilities/auto_restart.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 80794b6efe15b..14bb7bbb4264e 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -327,6 +327,7 @@ def _wrap_generator_samplers(self) -> None: continue # wrap the generator into a `FastForwardSampler` + print(type(generator)) assert isinstance(generator, Generator) sampler = FastForwardSampler(generator, attr_name=generator_attr_name) From 03fc815c845faacb9d5ac3684f67a436e9051205 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Thu, 28 Jul 2022 22:21:10 +0700 Subject: [PATCH 11/24] Change `FastForwardSampler`'s sampler to `Iterator` --- src/pytorch_lightning/utilities/auto_restart.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 14bb7bbb4264e..2ec9bcec566eb 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -67,7 +67,7 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None: + def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False @@ -327,8 +327,6 @@ def _wrap_generator_samplers(self) -> None: continue # wrap the generator into a `FastForwardSampler` - print(type(generator)) - assert isinstance(generator, Generator) sampler = FastForwardSampler(generator, attr_name=generator_attr_name) # if `CaptureIterableDataset` was available, the sampler should reload its own state. From c1df7c21b82f8a093afbe75b2b7ea8398cbc9c56 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Fri, 29 Jul 2022 02:01:10 +0700 Subject: [PATCH 12/24] Update `pyproject.toml` --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f4cc9a5a14ac8..97200c46e757e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ module = [ "pytorch_lightning.profilers.pytorch", "pytorch_lightning.profilers.simple", "pytorch_lightning.strategies.ddp", - "pytorch_lightning.strategies.ddp_spawn", "pytorch_lightning.strategies.fully_sharded", "pytorch_lightning.strategies.ipu", "pytorch_lightning.strategies.sharded", From 0059ca57c9edbf9749fa5cceec809a2355bcc306 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 29 Aug 2022 01:25:14 +0700 Subject: [PATCH 13/24] Remove `_BaseLoaderIter` --- src/pytorch_lightning/utilities/auto_restart.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 2ec9bcec566eb..69eb2152b402e 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -35,10 +35,6 @@ from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states from pytorch_lightning.utilities.types import _IntStateful, _Stateful -if TYPE_CHECKING: - _BaseLoaderIter = _BaseDataLoaderIter -else: - _BaseLoaderIter = object class IteratorStateDict(TypedDict): From 1921203a37b4876e368499cab6a9e2301eb9cfaa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 28 Aug 2022 18:26:40 +0000 Subject: [PATCH 14/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/auto_restart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 69eb2152b402e..d22651bb1bbdc 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -36,7 +36,6 @@ from pytorch_lightning.utilities.types import _IntStateful, _Stateful - class IteratorStateDict(TypedDict): dataset_state: Dict[int, Any] sampler_state: Dict[int, Any] From f3ef4f8b4d5231442b0305997da718a2684040e0 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 29 Aug 2022 01:28:58 +0700 Subject: [PATCH 15/24] Remove `_BaseLoaderIter` --- src/pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index d22651bb1bbdc..d423886564726 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -637,7 +637,7 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state} -class _StatefulDataLoaderIter(_BaseLoaderIter): +class _StatefulDataLoaderIter(_BaseDataLoaderIter): """This mixin is used to make PyTorch DataLoaderIter stateful.""" def __accumulate_state(self, sampler_state: Dict[Union[int, str], Any]) -> None: From 02fff74026b80ce9dee8316ed8085f217bed52f5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 30 Aug 2022 01:24:11 +0530 Subject: [PATCH 16/24] self review --- .../utilities/auto_restart.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index d423886564726..d2dfae58a1672 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -15,7 +15,7 @@ from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps -from typing import Any, Callable, cast, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler from torch.utils.data.dataloader import ( @@ -46,7 +46,7 @@ class IteratorStateDict(TypedDict): class MergedIteratorStateDict(TypedDict): - state: dict + state: Dict[str, Any] latest_worker_id: int represent_map_dataset: Optional[bool] @@ -200,10 +200,8 @@ def update(self, generator_name: Optional[str], new_state: IteratorState) -> Non self.represent_map_dataset = generator_name is None latest_worker_id = new_state.worker_id if generator_name is None: - self.state = cast(Dict[int, IteratorState], self.state) self.state[latest_worker_id] = new_state else: - self.state = cast(Dict[str, Dict[int, IteratorState]], self.state) if generator_name not in self.state: self.state[generator_name] = {} state = self.state[generator_name] @@ -224,12 +222,10 @@ def dataset_states(self) -> Dict[int, Any]: @classmethod def from_state_dict(cls, state_dict: MergedIteratorStateDict) -> "MergedIteratorState": if state_dict["represent_map_dataset"]: - state_dict["state"] = cast(Dict[int, IteratorState], state_dict["state"]) state_dict["state"] = { worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items() } else: - state_dict["state"] = cast(Dict[str, Dict[int, IteratorState]], state_dict["state"]) state_dict["state"] = { sampler_name: { worker_id: IteratorState.from_state_dict(state) for worker_id, state in worker_state.items() @@ -640,15 +636,17 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor class _StatefulDataLoaderIter(_BaseDataLoaderIter): """This mixin is used to make PyTorch DataLoaderIter stateful.""" - def __accumulate_state(self, sampler_state: Dict[Union[int, str], Any]) -> None: + def __accumulate_state(self, sampler_state: Dict[int, Any]) -> None: # store sampler state within a queue alongside its idx. self._sampler_state_idx: int = getattr(self, "_sampler_state_idx", 0) + 1 self._sampler_state.append((sampler_state, self._sampler_state_idx)) def _store_sampler_state(self) -> None: """This function is used to extract the sampler states if any.""" - sampler_state: Dict[Union[int, str], Any] = { - k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset" + sampler_state: Dict[int, Any] = { + k: v.state_dict() # type: ignore[misc] + for k, v in self._loader.__dict__.items() + if isinstance(v, _Stateful) and k != "dataset" } self.__accumulate_state(sampler_state) @@ -662,7 +660,7 @@ def _prepare_loader(self, loader: DataLoader) -> None: self._loader = loader self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher self.num_batches_fetched = 0 - self._sampler_state: List[Tuple[Dict[Union[int, str], Any], int]] = [] + self._sampler_state: List[Tuple[Dict[int, Any], int]] = [] self._sampler_state_idx = 0 def __del__(self) -> None: @@ -680,7 +678,6 @@ def _next_data(self) -> Any: # there is no workers within the samplers worker_id = list(state.keys())[0] - sampler_state = cast(Dict[int, Any], sampler_state) state = [ IteratorState( num_workers=self._loader.num_workers, From 40c6aa4d820b5a94a583157429b7d92b2e1c29dd Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Tue, 30 Aug 2022 04:05:02 +0700 Subject: [PATCH 17/24] Corrected `_IntStateful`'s comment --- src/pytorch_lightning/utilities/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 5ea435f7915fe..28e1c7f9af4e6 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -104,7 +104,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @runtime_checkable class _IntStateful(Protocol): """This class is used to detect if an object is stateful, whose states are integers, using `isinstance(obj, - _Stateful)`.""" + _IntStateful)`.""" def state_dict(self) -> Dict[int, Any]: ... From 0fbc4625fd8ba55135fcd12cd1b6dd2e38bb070e Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 5 Sep 2022 20:33:30 +0700 Subject: [PATCH 18/24] Update src/pytorch_lightning/utilities/auto_restart.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index d2dfae58a1672..16879a1ec7495 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -36,7 +36,7 @@ from pytorch_lightning.utilities.types import _IntStateful, _Stateful -class IteratorStateDict(TypedDict): +class _IteratorStateDict(TypedDict): dataset_state: Dict[int, Any] sampler_state: Dict[int, Any] worker_id: int From bb8841a7958980e244169d5af2a24b39c4052d02 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 5 Sep 2022 20:33:48 +0700 Subject: [PATCH 19/24] Update src/pytorch_lightning/utilities/auto_restart.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 16879a1ec7495..a2c4bdfcd65a9 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -45,7 +45,7 @@ class _IteratorStateDict(TypedDict): name: Optional[str] -class MergedIteratorStateDict(TypedDict): +class _MergedIteratorStateDict(TypedDict): state: Dict[str, Any] latest_worker_id: int represent_map_dataset: Optional[bool] From cdaeef0e3772b71f9be7d0200dacb3ce6e2e2bac Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 5 Sep 2022 20:39:24 +0700 Subject: [PATCH 20/24] Remove `_IntStateful` and make `_Stateful` generic --- src/pytorch_lightning/utilities/auto_restart.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index a2c4bdfcd65a9..1b3ae92138002 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -33,7 +33,7 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states -from pytorch_lightning.utilities.types import _IntStateful, _Stateful +from pytorch_lightning.utilities.types import _Stateful class _IteratorStateDict(TypedDict): @@ -179,7 +179,7 @@ class IteratorState: name: Optional[str] = None @classmethod - def from_state_dict(cls, state_dict: IteratorStateDict) -> "IteratorState": + def from_state_dict(cls, state_dict: _IteratorStateDict) -> "IteratorState": return cls(**state_dict) @@ -220,7 +220,7 @@ def dataset_states(self) -> Dict[int, Any]: return {k: self.state[k].dataset_state[k] for k in self.state.keys()} @classmethod - def from_state_dict(cls, state_dict: MergedIteratorStateDict) -> "MergedIteratorState": + def from_state_dict(cls, state_dict: _MergedIteratorStateDict) -> "MergedIteratorState": if state_dict["represent_map_dataset"]: state_dict["state"] = { worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items() @@ -597,7 +597,7 @@ def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dic for worker_id in state_dict["state"].keys() } - assert isinstance(dataloader.dataset, _IntStateful) + assert isinstance(dataloader.dataset, _Stateful) dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers)) From 81a8d5c455832db3d6d6d99018868670f6d5e974 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 5 Sep 2022 20:42:07 +0700 Subject: [PATCH 21/24] Remove `_IntStateful` and make `_Stateful` generic --- src/pytorch_lightning/utilities/types.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 28e1c7f9af4e6..e5192982978f1 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -20,7 +20,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, TypeVar, Union import torch from torch import Tensor @@ -30,6 +30,8 @@ from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable +T = TypeVar('T') + _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] _METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] @@ -91,25 +93,13 @@ def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: @runtime_checkable -class _Stateful(Protocol): +class _Stateful(Protocol[T]): """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" - def state_dict(self) -> Dict[str, Any]: - ... - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - ... - - -@runtime_checkable -class _IntStateful(Protocol): - """This class is used to detect if an object is stateful, whose states are integers, using `isinstance(obj, - _IntStateful)`.""" - - def state_dict(self) -> Dict[int, Any]: + def state_dict(self) -> Dict[T, Any]: ... - def load_state_dict(self, state_dict: Dict[int, Any]) -> None: + def load_state_dict(self, state_dict: Dict[T, Any]) -> None: ... From c11d3e03ea8923f25d098478328ab60da6e9a325 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Sep 2022 13:44:45 +0000 Subject: [PATCH 22/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index e5192982978f1..c37afc8b39f6a 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -30,7 +30,7 @@ from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable -T = TypeVar('T') +T = TypeVar("T") _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] From 95a3760238f9103447ae0f57ef19deb9aba4f637 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 5 Sep 2022 20:58:28 +0700 Subject: [PATCH 23/24] Minor fix --- src/pytorch_lightning/utilities/auto_restart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 1b3ae92138002..e90dcc7172690 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -597,7 +597,6 @@ def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dic for worker_id in state_dict["state"].keys() } - assert isinstance(dataloader.dataset, _Stateful) dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers)) From 71d81da6d870b16b526297ee05d658affd44fe7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 5 Sep 2022 16:32:12 +0200 Subject: [PATCH 24/24] Fixes to dictkey --- src/pytorch_lightning/utilities/types.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index c37afc8b39f6a..c90657b34e868 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -30,8 +30,6 @@ from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable -T = TypeVar("T") - _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] _METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] @@ -92,21 +90,24 @@ def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: ... +_DictKey = TypeVar("_DictKey") + + @runtime_checkable -class _Stateful(Protocol[T]): +class _Stateful(Protocol[_DictKey]): """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" - def state_dict(self) -> Dict[T, Any]: + def state_dict(self) -> Dict[_DictKey, Any]: ... - def load_state_dict(self, state_dict: Dict[T, Any]) -> None: + def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: ... # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable -class _LRScheduler(_Stateful, Protocol): +class _LRScheduler(_Stateful[str], Protocol): optimizer: Optimizer base_lrs: List[float] @@ -120,7 +121,7 @@ def step(self, epoch: Optional[int] = None) -> None: # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable -class ReduceLROnPlateau(_Stateful, Protocol): +class ReduceLROnPlateau(_Stateful[str], Protocol): in_cooldown: bool optimizer: Optimizer