Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6eafa04
fixes typing errors in auto_restart.py
donlap Jul 28, 2022
1935564
fixes typing errors in auto_restart.py
donlap Jul 28, 2022
b0409b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2022
192f155
minor fix
donlap Jul 28, 2022
074f26f
minor fix
donlap Jul 28, 2022
217dbd5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2022
dd6b168
minor fix
donlap Jul 28, 2022
afb9612
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2022
8d4ef40
minor fix
donlap Jul 28, 2022
45d918d
minor fix
donlap Jul 28, 2022
0141f6e
Minor
donlapark Jul 28, 2022
03fc815
Change `FastForwardSampler`'s sampler to `Iterator`
donlapark Jul 28, 2022
c1df7c2
Update `pyproject.toml`
donlapark Jul 28, 2022
da5fc47
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Aug 6, 2022
e888eb6
Merge branch 'master' into fixes_mypy_auto_restart_py
otaj Aug 26, 2022
f93b88c
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Aug 26, 2022
0059ca5
Remove `_BaseLoaderIter`
donlapark Aug 28, 2022
1921203
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2022
f3ef4f8
Remove `_BaseLoaderIter`
donlapark Aug 28, 2022
e5844dd
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Aug 29, 2022
e81d961
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Aug 29, 2022
e754c86
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Aug 29, 2022
02fff74
self review
rohitgr7 Aug 29, 2022
40c6aa4
Corrected `_IntStateful`'s comment
donlapark Aug 29, 2022
2d5f082
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Aug 29, 2022
90e9617
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Aug 30, 2022
582287c
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Aug 31, 2022
920541d
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Sep 1, 2022
3355013
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Sep 2, 2022
0779a77
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Sep 3, 2022
c98917f
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Sep 4, 2022
0fbc462
Update src/pytorch_lightning/utilities/auto_restart.py
donlapark Sep 5, 2022
bb8841a
Update src/pytorch_lightning/utilities/auto_restart.py
donlapark Sep 5, 2022
cdaeef0
Remove `_IntStateful` and make `_Stateful` generic
donlapark Sep 5, 2022
81a8d5c
Remove `_IntStateful` and make `_Stateful` generic
donlapark Sep 5, 2022
789e92e
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Sep 5, 2022
c11d3e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2022
95a3760
Minor fix
donlapark Sep 5, 2022
a1d61b8
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Sep 5, 2022
71d81da
Fixes to dictkey
carmocca Sep 5, 2022
e7cb707
Merge branch 'master' into fixes_mypy_auto_restart_py
donlapark Sep 5, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ module = [
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
]
ignore_errors = "True"
83 changes: 54 additions & 29 deletions src/pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sized
from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial, wraps
Expand All @@ -24,6 +25,7 @@
DataLoader,
IterableDataset,
)
from typing_extensions import TypedDict

import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -34,6 +36,21 @@
from pytorch_lightning.utilities.types import _Stateful


class _IteratorStateDict(TypedDict):
dataset_state: Dict[int, Any]
sampler_state: Dict[int, Any]
worker_id: int
num_workers: int
num_batches_fetched: int
name: Optional[str]


class _MergedIteratorStateDict(TypedDict):
state: Dict[str, Any]
latest_worker_id: int
represent_map_dataset: Optional[bool]


class FastForwardSampler(Sampler):
"""This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations
performed during an epoch.
Expand All @@ -45,7 +62,7 @@ class FastForwardSampler(Sampler):
samples seen in the last iterations (for the current worker).
"""

def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None:
def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None:
super().__init__(data_source=None)
self._sampler = sampler
self.restarting: bool = False
Expand Down Expand Up @@ -79,7 +96,7 @@ def __iter__(self) -> Iterator[Any]:
self._counter = 0
return self

def __next__(self):
def __next__(self) -> Any:
# the `state dict` was cached as workers were unavailable before.
if self._cached_state_dict is not None:
self._load_non_random_state(self._cached_state_dict)
Expand Down Expand Up @@ -109,6 +126,7 @@ def __next__(self):
raise StopIteration

def __len__(self) -> int:
assert isinstance(self._sampler, Sized)
return len(self._sampler)

def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]:
Expand Down Expand Up @@ -161,7 +179,7 @@ class IteratorState:
name: Optional[str] = None

@classmethod
def from_state_dict(cls, state_dict) -> "IteratorState":
def from_state_dict(cls, state_dict: _IteratorStateDict) -> "IteratorState":
return cls(**state_dict)


Expand All @@ -173,22 +191,22 @@ class MergedIteratorState:
worker states in this merged iterator state.
"""

state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict)
state: Dict = field(default_factory=dict)
latest_worker_id: int = 0
represent_map_dataset: Optional[bool] = None

def update(self, generator_name: Optional[str], new_state: IteratorState) -> None:
# a map based dataset doesn't own a generator and therefore `generator_name` should be None.
self.represent_map_dataset = generator_name is None
if self.represent_map_dataset:
state = self.state
latest_worker_id = new_state.worker_id
if generator_name is None:
self.state[latest_worker_id] = new_state
else:
if generator_name not in self.state:
self.state[generator_name] = {}
state = self.state[generator_name]
state[latest_worker_id] = new_state

latest_worker_id = new_state.worker_id
state[latest_worker_id] = new_state
self.latest_worker_id = latest_worker_id

@property
Expand All @@ -202,7 +220,7 @@ def dataset_states(self) -> Dict[int, Any]:
return {k: self.state[k].dataset_state[k] for k in self.state.keys()}

@classmethod
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
def from_state_dict(cls, state_dict: _MergedIteratorStateDict) -> "MergedIteratorState":
if state_dict["represent_map_dataset"]:
state_dict["state"] = {
worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items()
Expand All @@ -229,15 +247,15 @@ class CaptureMapDataset(Dataset):
"""

def __init__(self, dataset: Dataset) -> None:
self.dataset = dataset
self._cached_state_dict = None
self.dataset: Dataset = dataset
self._cached_state_dict: Optional[Dict[int, Any]] = None

@property
def worker_id(self) -> int:
worker_info = get_worker_info()
return worker_info.id if worker_info else 0

def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
def __getitem__(self, item: int) -> Tuple[Any, Dict[int, Dict]]:
if self._cached_state_dict is not None:
if self.worker_id in self._cached_state_dict:
_set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
Expand All @@ -246,6 +264,7 @@ def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
return self.dataset[item]

def __len__(self) -> int:
assert isinstance(self.dataset, Sized)
return len(self.dataset)

def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
Expand All @@ -268,17 +287,18 @@ def __init__(self, dataset: IterableDataset) -> None:
super().__init__()
self.dataset = deepcopy(dataset)
self.samplers: Optional[Dict[str, FastForwardSampler]] = None
self._state_dict: Optional[Dict[int, Any]] = None
self._state_dict: Optional[Dict[str, Any]] = None
self._has_wrapped: bool = False

@property
def sampler(self) -> Sampler:
return self.dataset.sampler

def state_dict(self) -> Dict[str, Any]:
assert self.samplers is not None
return {k: v.state_dict() for k, v in self.samplers.items()}

def load_state_dict(self, state_dict: Dict[int, Any]) -> None:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._state_dict = deepcopy(state_dict)

def _wrap_generator_samplers(self) -> None:
Expand Down Expand Up @@ -311,7 +331,7 @@ def _wrap_generator_samplers(self) -> None:

self.reset_on_epoch()

def reset_on_epoch(self):
def reset_on_epoch(self) -> None:
self._state_dict = None

def __iter__(self) -> Iterator:
Expand Down Expand Up @@ -371,8 +391,8 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
for _ in range(state_dict["previous_worker"] - 1):
next(iter_dataloader._worker_queue_idx_cycle)

# we can finally call reset and apply prefecthing.
iter_dataloader._reset = iter_dataloader._original_reset
# we can finally call reset and apply prefetching.
iter_dataloader._reset = iter_dataloader._original_reset # type: ignore[assignment]
iter_dataloader._reset(dataloader, first_iter=True)
# return the iterator
return iter_dataloader
Expand Down Expand Up @@ -445,6 +465,7 @@ def wrapper() -> Any:
]
elif isinstance(dataset, CaptureMapDataset):
ff_sampler = _find_fast_forward_samplers(dl)
assert ff_sampler is not None
state = [
IteratorState(
num_workers=dl.num_workers,
Expand Down Expand Up @@ -519,6 +540,7 @@ def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader,

# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
assert ff_sampler is not None
ff_sampler.load_state_dict(iterator_state.sampler_state)

# reload dataset state
Expand Down Expand Up @@ -610,18 +632,20 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}


class _StatefulDataLoaderIter:
class _StatefulDataLoaderIter(_BaseDataLoaderIter):
"""This mixin is used to make PyTorch DataLoaderIter stateful."""

def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None:
def __accumulate_state(self, sampler_state: Dict[int, Any]) -> None:
# store sampler state within a queue alongside its idx.
self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1
self._sampler_state_idx: int = getattr(self, "_sampler_state_idx", 0) + 1
self._sampler_state.append((sampler_state, self._sampler_state_idx))

def _store_sampler_state(self) -> None:
"""This function is used to extract the sampler states if any."""
sampler_state = {
k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset"
sampler_state: Dict[int, Any] = {
k: v.state_dict() # type: ignore[misc]
for k, v in self._loader.__dict__.items()
if isinstance(v, _Stateful) and k != "dataset"
}
self.__accumulate_state(sampler_state)

Expand All @@ -630,12 +654,12 @@ def _next_index(self) -> Any:
self._store_sampler_state()
return indexes

def _prepare_loader(self, loader):
def _prepare_loader(self, loader: DataLoader) -> None:
_add_capture_metadata_collate(loader)
self._loader = loader
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
self.num_batches_fetched = 0
self._sampler_state = []
self._sampler_state: List[Tuple[Dict[int, Any], int]] = []
self._sampler_state_idx = 0

def __del__(self) -> None:
Expand Down Expand Up @@ -680,7 +704,7 @@ def __init__(self, loader: DataLoader):
super().__init__(loader)


def _get_iterator(self) -> "_BaseDataLoaderIter":
def _get_iterator(self: DataLoader) -> "_BaseDataLoaderIter":
if not hasattr(self, "_lightning_fetcher"):
raise MisconfigurationException(
"A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader."
Expand All @@ -699,15 +723,15 @@ def _patch_dataloader_get_iterators() -> None:
return
if not hasattr(DataLoader, "_ori_get_iterator"):
DataLoader._ori_get_iterator = DataLoader._get_iterator
DataLoader._get_iterator = _get_iterator
DataLoader._get_iterator = _get_iterator # type: ignore[assignment]


def _teardown_dataloader_get_iterators() -> None:
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
# cleanup the get_iterator replacement in case of Fault-tolerance.
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
if get_iterator:
DataLoader._get_iterator = get_iterator
DataLoader._get_iterator = get_iterator # type: ignore[assignment]
del DataLoader._ori_get_iterator


Expand Down Expand Up @@ -781,16 +805,17 @@ def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -
raise ValueError("Fault-tolerance supports only a single dataloader.")

for dataloader in dl_loaders:
assert isinstance(dataloader, DataLoader)
validator_fn = (
_validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset
)
validator_fn(dataloader)


def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
def _collect_states_on_rank_zero_over_collection(state_dict: Dict, key: str = "state") -> Dict:
"""This utility collects the state across processes for a collection of state."""

def fn(state: Dict):
def fn(state: Dict) -> Dict:
if key in state:
return _collect_states_on_rank_zero(state)
return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()}
Expand Down
15 changes: 9 additions & 6 deletions src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, TypeVar, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -90,21 +90,24 @@ def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
...


_DictKey = TypeVar("_DictKey")


@runtime_checkable
class _Stateful(Protocol):
class _Stateful(Protocol[_DictKey]):
"""This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`."""

def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> Dict[_DictKey, Any]:
...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None:
...


# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
class _LRScheduler(_Stateful, Protocol):
class _LRScheduler(_Stateful[str], Protocol):
optimizer: Optimizer
base_lrs: List[float]

Expand All @@ -118,7 +121,7 @@ def step(self, epoch: Optional[int] = None) -> None:
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
class ReduceLROnPlateau(_Stateful, Protocol):
class ReduceLROnPlateau(_Stateful[str], Protocol):
in_cooldown: bool
optimizer: Optimizer

Expand Down