Skip to content

fixes mypy errors in trainer/supporters.py #14633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.data",
Expand Down
68 changes: 38 additions & 30 deletions src/pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Sized
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Union

Expand Down Expand Up @@ -53,23 +54,24 @@ class TensorRunningAccum:

def __init__(self, window_length: int):
self.window_length = window_length
self.memory = None
self.current_idx: int = 0
self.last_idx: Optional[int] = None
self.rotated: bool = False
self.reset(window_length)

def reset(self, window_length: Optional[int] = None) -> None:
"""Empty the accumulator."""
if window_length is None:
window_length = self.window_length
self.__init__(window_length)
if window_length is not None:
self.window_length = window_length
self.memory: Optional[torch.Tensor] = None
self.current_idx: int = 0
self.last_idx: Optional[int] = None
self.rotated: bool = False

def last(self):
def last(self) -> Optional[torch.Tensor]:
"""Get the last added element."""
if self.last_idx is not None:
assert isinstance(self.memory, torch.Tensor)
return self.memory[self.last_idx].float()

def append(self, x):
def append(self, x: torch.Tensor) -> None:
"""Add an element to the accumulator."""
if self.memory is None:
# tradeoff memory for speed by keeping the memory on device
Expand All @@ -88,20 +90,21 @@ def append(self, x):
if self.current_idx == 0:
self.rotated = True

def mean(self):
def mean(self) -> Optional[torch.Tensor]:
"""Get mean value from stored elements."""
return self._agg_memory("mean")

def max(self):
def max(self) -> Optional[torch.Tensor]:
"""Get maximal value from stored elements."""
return self._agg_memory("max")

def min(self):
def min(self) -> Optional[torch.Tensor]:
"""Get minimal value from stored elements."""
return self._agg_memory("min")

def _agg_memory(self, how: str):
def _agg_memory(self, how: str) -> Optional[torch.Tensor]:
if self.last_idx is not None:
assert isinstance(self.memory, torch.Tensor)
if self.rotated:
return getattr(self.memory.float(), how)()
return getattr(self.memory[: self.current_idx].float(), how)()
Expand Down Expand Up @@ -139,7 +142,7 @@ def done(self) -> bool:
class CycleIterator:
"""Iterator for restarting a dataloader if it runs out of samples."""

def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycleIteratorState = None):
def __init__(self, loader: Any, length: Optional[Union[int, float]] = None, state: SharedCycleIteratorState = None):
"""
Args:
loader: the loader to restart for cyclic (and optionally infinite) sampling
Expand Down Expand Up @@ -184,6 +187,8 @@ def __next__(self) -> Any:
Raises:
StopIteration: if more then :attr:`length` batches have been returned
"""
assert isinstance(self._loader_iter, Iterator)

# Note: if self.length is `inf`, then the iterator will never stop
if self.counter >= self.__len__() or self.state.done:
raise StopIteration
Expand Down Expand Up @@ -257,13 +262,13 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union
Returns:
length: the length of `CombinedDataset`
"""
if mode not in CombinedDataset.COMPUTE_FUNCS.keys():
if mode not in self.COMPUTE_FUNCS.keys():
raise MisconfigurationException(f"Invalid Mode: {mode}")

# extract the lengths
all_lengths = self._get_len_recursive(datasets)

compute_func = CombinedDataset.COMPUTE_FUNCS[mode]
compute_func = self.COMPUTE_FUNCS[mode]

if isinstance(all_lengths, (int, float)):
length = all_lengths
Expand All @@ -272,8 +277,9 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union

return length

def _get_len_recursive(self, data) -> int:
def _get_len_recursive(self, data: Any) -> Union[int, float, List, Dict]:
if isinstance(data, Dataset):
assert isinstance(data, Sized)
return len(data)

if isinstance(data, (float, int)):
Expand All @@ -290,13 +296,13 @@ def _get_len_recursive(self, data) -> int:
return self._get_len(data)

@staticmethod
def _get_len(dataset) -> int:
def _get_len(dataset: Any) -> Union[int, float]:
try:
return len(dataset)
except (TypeError, NotImplementedError):
return float("inf")

def __len__(self) -> int:
def __len__(self) -> Union[int, float]:
"""Return the minimum length of the datasets."""
return self._calc_num_data(self.datasets, self.mode)

Expand Down Expand Up @@ -348,8 +354,8 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
if self.mode == "max_size_cycle":
self._wrap_loaders_max_size_cycle()

self._loaders_iter_state_dict = None
self._iterator = None # assigned in __iter__
self._loaders_iter_state_dict: Optional[Dict] = None
self._iterator: Optional[Iterator] = None # assigned in __iter__

@staticmethod
def _state_dict_fn(iterator: Optional[Iterator], has_completed: int) -> Dict:
Expand Down Expand Up @@ -384,7 +390,7 @@ def state_dict(self, has_completed: bool = False) -> Dict:
has_completed=has_completed,
)

def load_state_dict(self, state_dict) -> None:
def load_state_dict(self, state_dict: Dict) -> None:
# store the samplers state.
# They would be reloaded once the `CombinedIterator` as been created
# and the workers are created.
Expand Down Expand Up @@ -482,18 +488,18 @@ def __iter__(self) -> Any:

# prevent `NotImplementedError` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def __getstate__patch__(*_):
def __getstate__patch__(*_: Any) -> Dict:
return {}

_BaseDataLoaderIter.__getstate__ = __getstate__patch__
_BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore[assignment]
iterator = CombinedLoaderIterator(self.loaders)
# handle fault tolerant restart logic.
self.on_restart(iterator)
self._iterator = iterator
return iterator

@staticmethod
def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]:
def _calc_num_batches(loaders: Any, mode: str = "min_size") -> Union[int, float]:
"""Compute the length (aka the number of batches) of `CombinedLoader`.

Args:
Expand All @@ -509,16 +515,16 @@ def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]:
return all_lengths
return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min)

def __len__(self) -> int:
def __len__(self) -> Union[int, float]:
return self._calc_num_batches(self.loaders, mode=self.mode)

@staticmethod
def _shutdown_workers_and_reset_iterator(dataloader) -> None:
def _shutdown_workers_and_reset_iterator(dataloader: DataLoader) -> None:
if hasattr(dataloader, "_iterator") and isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter):
dataloader._iterator._shutdown_workers()
dataloader._iterator = None

def reset(self):
def reset(self) -> None:
if self._iterator:
self._iterator._loader_iters = None
if self.loaders is not None:
Expand All @@ -535,7 +541,7 @@ def __init__(self, loaders: Any):
loaders: the loaders to sample from. Can be all kind of collection
"""
self.loaders = loaders
self._loader_iters = None
self._loader_iters: Any = None

@property
def loader_iters(self) -> Any:
Expand Down Expand Up @@ -584,7 +590,9 @@ def create_loader_iters(
return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))


def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable):
def _nested_calc_num_data(
data: Union[Mapping, Sequence], compute_func: Callable[[List[Union[int, float]]], Union[int, float]]
) -> Union[int, float]:

if isinstance(data, (float, int)):
return data
Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ def on_train_batch_end(
if self.progress_bar:
self.progress_bar.update()

current_loss = trainer.fit_loop.running_loss.last().item()
loss_tensor = trainer.fit_loop.running_loss.last()
assert loss_tensor is not None
current_loss = loss_tensor.item()
current_step = trainer.global_step

# Avg loss (loss with momentum) + smoothing
Expand Down