Skip to content

Commit e53c4e8

Browse files
krishnakalyan3rohitgr7Bordaawaelchli
authored
Fix mypy errors attributed to pytorch_lightning. strategies.sharded_spawn (#14102)
Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: awaelchli <[email protected]>
1 parent 31ecf9b commit e53c4e8

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ module = [
5858
"pytorch_lightning.profilers.base",
5959
"pytorch_lightning.profilers.pytorch",
6060
"pytorch_lightning.strategies.sharded",
61-
"pytorch_lightning.strategies.sharded_spawn",
6261
"pytorch_lightning.trainer.callback_hook",
6362
"pytorch_lightning.trainer.connectors.data_connector",
6463
"pytorch_lightning.trainer.supporters",

src/pytorch_lightning/overrides/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
7575
trainer = pl_module._trainer
7676

7777
if trainer is not None:
78+
assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
7879
if trainer.training:
7980
output = self.module.training_step(*inputs, **kwargs)
8081
# In manual_optimization, we need to prevent DDP reducer as

src/pytorch_lightning/strategies/sharded_spawn.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Dict, Generator, List, Optional, Tuple
15+
from typing import Any, Dict, Generator, List, Optional, Tuple
1616

1717
from torch import Tensor
1818
from torch.nn import Module
1919
from torch.optim import Optimizer
2020

2121
import pytorch_lightning as pl
22+
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
2223
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
2324
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
2425
from pytorch_lightning.trainer.states import TrainerFn
@@ -42,7 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
4243

4344
def configure_ddp(self) -> None:
4445
# set up optimizers after the wrapped module has been moved to the device
46+
assert self.lightning_module is not None
4547
self.setup_optimizers(self.lightning_module.trainer)
48+
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
4649
self.model, self.optimizers = self._setup_model_and_optimizers(
4750
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
4851
)
@@ -69,12 +72,13 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"
6972
return optimizers
7073

7174
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
72-
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
75+
assert self.lightning_module
76+
if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
7377
return optimizers
7478

7579
return self._reinit_optimizers_with_oss(optimizers)
7680

77-
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
81+
def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
7882
if isinstance(optimizer, OSS):
7983
optimizer.consolidate_state_dict()
8084
return self._optim_state_dict(optimizer)
@@ -93,7 +97,7 @@ def block_backward_sync(self) -> Generator:
9397
yield None
9498

9599
@rank_zero_only
96-
def _optim_state_dict(self, optimizer):
100+
def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
97101
"""
98102
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
99103
:meth:`consolidate_state_dict`.
@@ -112,7 +116,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
112116
def pre_backward(self, closure_loss: Tensor) -> None:
113117
pass
114118

115-
def post_training_step(self):
119+
def post_training_step(self) -> None:
116120
pass
117121

118122
@classmethod

0 commit comments

Comments
 (0)