12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from contextlib import contextmanager
15
- from typing import Dict , Generator , List , Optional , Tuple
15
+ from typing import Any , Dict , Generator , List , Optional , Tuple
16
16
17
17
from torch import Tensor
18
18
from torch .nn import Module
19
19
from torch .optim import Optimizer
20
20
21
21
import pytorch_lightning as pl
22
+ from pytorch_lightning .overrides .base import _LightningPrecisionModuleWrapperBase
22
23
from pytorch_lightning .overrides .fairscale import _FAIRSCALE_AVAILABLE
23
24
from pytorch_lightning .strategies .ddp_spawn import DDPSpawnStrategy
24
25
from pytorch_lightning .trainer .states import TrainerFn
@@ -42,7 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
42
43
43
44
def configure_ddp (self ) -> None :
44
45
# set up optimizers after the wrapped module has been moved to the device
46
+ assert self .lightning_module is not None
45
47
self .setup_optimizers (self .lightning_module .trainer )
48
+ assert isinstance (self .model , (pl .LightningModule , _LightningPrecisionModuleWrapperBase ))
46
49
self .model , self .optimizers = self ._setup_model_and_optimizers (
47
50
model = LightningShardedDataParallel (self .model ), optimizers = self .optimizers
48
51
)
@@ -69,12 +72,13 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"
69
72
return optimizers
70
73
71
74
def _wrap_optimizers (self , optimizers : List [Optimizer ]) -> List ["OSS" ]:
72
- if self .model is not None and self .model .trainer .state .fn != TrainerFn .FITTING :
75
+ assert self .lightning_module
76
+ if self .model is not None and self .lightning_module .trainer .state .fn != TrainerFn .FITTING :
73
77
return optimizers
74
78
75
79
return self ._reinit_optimizers_with_oss (optimizers )
76
80
77
- def optimizer_state (self , optimizer : "OSS" ) -> Optional [ dict ]:
81
+ def optimizer_state (self , optimizer : "OSS" ) -> Dict [ str , Any ]:
78
82
if isinstance (optimizer , OSS ):
79
83
optimizer .consolidate_state_dict ()
80
84
return self ._optim_state_dict (optimizer )
@@ -93,7 +97,7 @@ def block_backward_sync(self) -> Generator:
93
97
yield None
94
98
95
99
@rank_zero_only
96
- def _optim_state_dict (self , optimizer ) :
100
+ def _optim_state_dict (self , optimizer : Optimizer ) -> Dict [ str , Any ] :
97
101
"""
98
102
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
99
103
:meth:`consolidate_state_dict`.
@@ -112,7 +116,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
112
116
def pre_backward (self , closure_loss : Tensor ) -> None :
113
117
pass
114
118
115
- def post_training_step (self ):
119
+ def post_training_step (self ) -> None :
116
120
pass
117
121
118
122
@classmethod
0 commit comments