diff --git a/pyproject.toml b/pyproject.toml index 166447dd655f6..5a8f632481127 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", - "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.data", diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py index e183bdcc644d6..454143416f735 100644 --- a/src/pytorch_lightning/trainer/supporters.py +++ b/src/pytorch_lightning/trainer/supporters.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sized from dataclasses import asdict, dataclass, field from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Union @@ -53,23 +54,24 @@ class TensorRunningAccum: def __init__(self, window_length: int): self.window_length = window_length - self.memory = None - self.current_idx: int = 0 - self.last_idx: Optional[int] = None - self.rotated: bool = False + self.reset(window_length) def reset(self, window_length: Optional[int] = None) -> None: """Empty the accumulator.""" - if window_length is None: - window_length = self.window_length - self.__init__(window_length) + if window_length is not None: + self.window_length = window_length + self.memory: Optional[torch.Tensor] = None + self.current_idx: int = 0 + self.last_idx: Optional[int] = None + self.rotated: bool = False - def last(self): + def last(self) -> Optional[torch.Tensor]: """Get the last added element.""" if self.last_idx is not None: + assert isinstance(self.memory, torch.Tensor) return self.memory[self.last_idx].float() - def append(self, x): + def append(self, x: torch.Tensor) -> None: """Add an element to the accumulator.""" if self.memory is None: # tradeoff memory for speed by keeping the memory on device @@ -88,20 +90,21 @@ def append(self, x): if self.current_idx == 0: self.rotated = True - def mean(self): + def mean(self) -> Optional[torch.Tensor]: """Get mean value from stored elements.""" return self._agg_memory("mean") - def max(self): + def max(self) -> Optional[torch.Tensor]: """Get maximal value from stored elements.""" return self._agg_memory("max") - def min(self): + def min(self) -> Optional[torch.Tensor]: """Get minimal value from stored elements.""" return self._agg_memory("min") - def _agg_memory(self, how: str): + def _agg_memory(self, how: str) -> Optional[torch.Tensor]: if self.last_idx is not None: + assert isinstance(self.memory, torch.Tensor) if self.rotated: return getattr(self.memory.float(), how)() return getattr(self.memory[: self.current_idx].float(), how)() @@ -139,7 +142,7 @@ def done(self) -> bool: class CycleIterator: """Iterator for restarting a dataloader if it runs out of samples.""" - def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycleIteratorState = None): + def __init__(self, loader: Any, length: Optional[Union[int, float]] = None, state: SharedCycleIteratorState = None): """ Args: loader: the loader to restart for cyclic (and optionally infinite) sampling @@ -184,6 +187,8 @@ def __next__(self) -> Any: Raises: StopIteration: if more then :attr:`length` batches have been returned """ + assert isinstance(self._loader_iter, Iterator) + # Note: if self.length is `inf`, then the iterator will never stop if self.counter >= self.__len__() or self.state.done: raise StopIteration @@ -257,13 +262,13 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union Returns: length: the length of `CombinedDataset` """ - if mode not in CombinedDataset.COMPUTE_FUNCS.keys(): + if mode not in self.COMPUTE_FUNCS.keys(): raise MisconfigurationException(f"Invalid Mode: {mode}") # extract the lengths all_lengths = self._get_len_recursive(datasets) - compute_func = CombinedDataset.COMPUTE_FUNCS[mode] + compute_func = self.COMPUTE_FUNCS[mode] if isinstance(all_lengths, (int, float)): length = all_lengths @@ -272,8 +277,9 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union return length - def _get_len_recursive(self, data) -> int: + def _get_len_recursive(self, data: Any) -> Union[int, float, List, Dict]: if isinstance(data, Dataset): + assert isinstance(data, Sized) return len(data) if isinstance(data, (float, int)): @@ -290,13 +296,13 @@ def _get_len_recursive(self, data) -> int: return self._get_len(data) @staticmethod - def _get_len(dataset) -> int: + def _get_len(dataset: Any) -> Union[int, float]: try: return len(dataset) except (TypeError, NotImplementedError): return float("inf") - def __len__(self) -> int: + def __len__(self) -> Union[int, float]: """Return the minimum length of the datasets.""" return self._calc_num_data(self.datasets, self.mode) @@ -348,8 +354,8 @@ def __init__(self, loaders: Any, mode: str = "min_size"): if self.mode == "max_size_cycle": self._wrap_loaders_max_size_cycle() - self._loaders_iter_state_dict = None - self._iterator = None # assigned in __iter__ + self._loaders_iter_state_dict: Optional[Dict] = None + self._iterator: Optional[Iterator] = None # assigned in __iter__ @staticmethod def _state_dict_fn(iterator: Optional[Iterator], has_completed: int) -> Dict: @@ -384,7 +390,7 @@ def state_dict(self, has_completed: bool = False) -> Dict: has_completed=has_completed, ) - def load_state_dict(self, state_dict) -> None: + def load_state_dict(self, state_dict: Dict) -> None: # store the samplers state. # They would be reloaded once the `CombinedIterator` as been created # and the workers are created. @@ -482,10 +488,10 @@ def __iter__(self) -> Any: # prevent `NotImplementedError` from PyTorch: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541 - def __getstate__patch__(*_): + def __getstate__patch__(*_: Any) -> Dict: return {} - _BaseDataLoaderIter.__getstate__ = __getstate__patch__ + _BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore[assignment] iterator = CombinedLoaderIterator(self.loaders) # handle fault tolerant restart logic. self.on_restart(iterator) @@ -493,7 +499,7 @@ def __getstate__patch__(*_): return iterator @staticmethod - def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]: + def _calc_num_batches(loaders: Any, mode: str = "min_size") -> Union[int, float]: """Compute the length (aka the number of batches) of `CombinedLoader`. Args: @@ -509,16 +515,16 @@ def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]: return all_lengths return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min) - def __len__(self) -> int: + def __len__(self) -> Union[int, float]: return self._calc_num_batches(self.loaders, mode=self.mode) @staticmethod - def _shutdown_workers_and_reset_iterator(dataloader) -> None: + def _shutdown_workers_and_reset_iterator(dataloader: DataLoader) -> None: if hasattr(dataloader, "_iterator") and isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): dataloader._iterator._shutdown_workers() dataloader._iterator = None - def reset(self): + def reset(self) -> None: if self._iterator: self._iterator._loader_iters = None if self.loaders is not None: @@ -535,7 +541,7 @@ def __init__(self, loaders: Any): loaders: the loaders to sample from. Can be all kind of collection """ self.loaders = loaders - self._loader_iters = None + self._loader_iters: Any = None @property def loader_iters(self) -> Any: @@ -584,7 +590,9 @@ def create_loader_iters( return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping)) -def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable): +def _nested_calc_num_data( + data: Union[Mapping, Sequence], compute_func: Callable[[List[Union[int, float]]], Union[int, float]] +) -> Union[int, float]: if isinstance(data, (float, int)): return data diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 6b6d5771a4751..d8e5e6fc4a79b 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -356,7 +356,9 @@ def on_train_batch_end( if self.progress_bar: self.progress_bar.update() - current_loss = trainer.fit_loop.running_loss.last().item() + loss_tensor = trainer.fit_loop.running_loss.last() + assert loss_tensor is not None + current_loss = loss_tensor.item() current_step = trainer.global_step # Avg loss (loss with momentum) + smoothing