Skip to content

Fix mypy typing errors in pytorch_lightning/strategies/tpu_spawn.py #13813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
88766a6
Remove tpy_spawn in pyproject.toml
Jungwon-Lee Jul 23, 2022
972ffc2
Add assert before step
Jungwon-Lee Jul 23, 2022
35b2d72
fix mypy error in tpu-spawn.py
Jungwon-Lee Jul 23, 2022
bec228a
fix is_distribued check world_size
Jungwon-Lee Jul 25, 2022
a1d1c2a
check source.instance is DataLoader, List
Jungwon-Lee Jul 25, 2022
2efebe1
Update strategy.py suggestion
Jungwon-Lee Jul 27, 2022
f992165
Add checking self.accelerator is none
Jungwon-Lee Jul 27, 2022
dc3894a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2022
6848700
undo Step type check
awaelchli Jul 27, 2022
8b0c161
Merge branch 'master' into typing_tpu_spawn
awaelchli Jul 27, 2022
c4554a6
update validation of dataloaders
awaelchli Jul 27, 2022
92280e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2022
9e6dff4
Merge branch 'master' into typing_tpu_spawn
awaelchli Jul 27, 2022
446786a
update
awaelchli Jul 27, 2022
4727920
Merge branch 'master' into typing_tpu_spawn
carmocca Jul 28, 2022
929d624
Fix docstring
carmocca Jul 28, 2022
485998d
Merge branch 'master' into typing_tpu_spawn
Jul 29, 2022
bdfbcf0
make mypy happier
Jul 29, 2022
ffbd23c
update broadcast obj to TBroadcast
Jungwon-Lee Jul 30, 2022
99790c2
Merge branch 'Lightning-AI:master' into typing_tpu_spawn
Jungwon-Lee Jul 30, 2022
f525150
Merge branch 'Lightning-AI:master' into typing_tpu_spawn
Jungwon-Lee Aug 1, 2022
dfa50ef
Merge branch 'master' into typing_tpu_spawn
Jungwon-Lee Aug 2, 2022
3ee844e
make pre-commit not complain
Aug 2, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ module = [
"pytorch_lightning.strategies.ipu",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.strategies.tpu_spawn",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.data_connector",
Expand Down
48 changes: 30 additions & 18 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import io
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import torch
from torch import Tensor
Expand All @@ -29,15 +29,17 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.strategies.launchers.xla import _XLALauncher
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
Expand All @@ -58,7 +60,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[int]] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
debug: bool = False,
Expand All @@ -72,6 +74,7 @@ def __init__(
precision_plugin=precision_plugin,
start_method="fork",
)
self._checkpoint_io: Optional[CheckpointIO]
self.debug = debug
self._launched = False

Expand All @@ -95,17 +98,16 @@ def root_device(self) -> torch.device:
return xm.xla_device()

@staticmethod
def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None:
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

for dataloader in dataloaders:
def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None:
def check_has_len(dataloader: DataLoader) -> None:
if not has_len(dataloader):
raise MisconfigurationException(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
)

apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len)

@staticmethod
def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
"""Validate and fail fast if the dataloaders were passed directly to fit."""
Expand All @@ -118,32 +120,37 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
)
for source in sources:
if not source.is_module():
assert source.instance is not None
assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule))
TPUSpawnStrategy._validate_dataloader(source.instance)

def connect(self, model: "pl.LightningModule") -> None:
def connect(self, model: "pl.LightningModule") -> None: # type: ignore
TPUSpawnStrategy._validate_patched_dataloaders(model)
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
return super().connect(model)

def _configure_launcher(self):
def _configure_launcher(self) -> None:
self._launcher = _XLALauncher(self)

def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator
self.accelerator.setup(trainer)

if self.debug:
os.environ["PT_XLA_DEBUG"] = "1"

assert self.model
shared_params = find_shared_parameters(self.model)
self.model_to_device()
assert isinstance(self.model.module, Module)
set_shared_parameters(self.model.module, shared_params)
self.setup_precision_plugin()

if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

def _setup_model(self, model: Module) -> Module:
def _setup_model(self, model: Module) -> Module: # type: ignore
return model

@property
Expand All @@ -168,11 +175,11 @@ def configure_ddp(self) -> None:
def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)

def barrier(self, name: Optional[str] = None) -> None:
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
if self.is_distributed:
rendezvous(name)

def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not self.is_distributed:
return obj
buffer = io.BytesIO()
Expand All @@ -184,7 +191,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = torch.load(buffer)
return obj

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
def reduce(
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
if not isinstance(output, Tensor):
output = torch.tensor(output, device=self.root_device)

Expand All @@ -203,20 +212,23 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[

return output

def _worker_setup(self, process_idx: int):
def _worker_setup(self, process_idx: int) -> None:
self._launched = True
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
with self.precision_plugin.val_step_context():
return self.model(*args, **kwargs)

def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
with self.precision_plugin.test_step_context():
return self.model(*args, **kwargs)

def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.model is not None
with self.precision_plugin.predict_step_context():
return self.model(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def is_defined(self) -> bool:
return not self.is_module() or is_overridden(self.name, self.instance)

def is_module(self) -> bool:
"""Returns whether the the DataLoader source is a LightningModule or a LightningDataModule.
"""Returns whether the DataLoader source is a LightningModule or a LightningDataModule.

It does not check whether ``*_dataloader`` methods are actually overridden.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def apply_to_collection(
dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
*args: Any,
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
wrong_dtype: Optional[Union[type, Tuple[type, ...]]] = None,
include_none: bool = True,
**kwargs: Any,
) -> Any:
Expand Down