Skip to content

Commit 74484ed

Browse files
ref: separate te from ddp (#3810)
* ref: separate te from ddp * ref: separate te from ddp * ref: separate te from ddp
1 parent a28528c commit 74484ed

File tree

3 files changed

+177
-1
lines changed

3 files changed

+177
-1
lines changed

pytorch_lightning/accelerators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from pytorch_lightning.accelerators.gpu_backend import GPUBackend
88
from pytorch_lightning.accelerators.tpu_backend import TPUBackend
99
from pytorch_lightning.accelerators.horovod_backend import HorovodBackend
10+
from pytorch_lightning.accelerators.ddp_torchelastic_backend import DDPTorchElasticBackend

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def select_accelerator(self):
145145
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='slurm_ddp')
146146

147147
elif use_torchelastic_ddp:
148-
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='torchelastic_ddp')
148+
accelerator_backend = accelerators.DDPTorchElasticBackend(self.trainer)
149149

150150
elif use_ddp_spawn:
151151
accelerator_backend = accelerators.DDPSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
import os
15+
import torch
16+
import torch.distributed as torch_distrib
17+
import torch.distributed as dist
18+
19+
from pytorch_lightning.accelerators.base_backend import Accelerator
20+
from pytorch_lightning import _logger as log
21+
from pytorch_lightning.utilities import AMPType
22+
from pytorch_lightning.utilities.distributed import rank_zero_only
23+
from pytorch_lightning.utilities.seed import seed_everything
24+
from pytorch_lightning.distributed.dist import LightningDistributed
25+
26+
27+
try:
28+
from hydra.utils import to_absolute_path, get_original_cwd
29+
from hydra.core.hydra_config import HydraConfig
30+
except ImportError:
31+
HYDRA_AVAILABLE = False
32+
else:
33+
HYDRA_AVAILABLE = True
34+
35+
36+
# -------------------------------------------
37+
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
38+
# TEMP CLASS WHILE WE DECOUPLE SLURM FROM DDP
39+
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
40+
# -------------------------------------------
41+
class DDPTorchElasticBackend(Accelerator):
42+
43+
def __init__(self, trainer):
44+
super().__init__(trainer)
45+
self.task_idx = None
46+
self._has_spawned_children = False
47+
self.dist = LightningDistributed()
48+
49+
def setup(self, model):
50+
self.trainer.model = model
51+
self.task_idx = int(os.environ['LOCAL_RANK'])
52+
53+
def train(self):
54+
model = self.trainer.model
55+
self.ddp_train(process_idx=self.task_idx, model=model)
56+
57+
def set_world_ranks(self, process_idx):
58+
self.trainer.local_rank = process_idx
59+
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
60+
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes
61+
62+
def model_to_device(self, model, process_idx):
63+
self.trainer.root_gpu = process_idx
64+
torch.cuda.set_device(self.trainer.root_gpu)
65+
model.cuda(self.trainer.root_gpu)
66+
67+
def get_device_ids(self):
68+
device_ids = [self.trainer.root_gpu]
69+
return device_ids
70+
71+
def training_step(self, args):
72+
if self.trainer.amp_backend == AMPType.NATIVE:
73+
with torch.cuda.amp.autocast():
74+
output = self.trainer.model(*args)
75+
else:
76+
output = self.trainer.model(*args)
77+
return output
78+
79+
def validation_step(self, args):
80+
output = self.training_step(args)
81+
return output
82+
83+
def test_step(self, args):
84+
output = self.training_step(args)
85+
return output
86+
87+
def barrier(self, name: str = None):
88+
if torch_distrib.is_initialized():
89+
torch_distrib.barrier()
90+
91+
def early_stopping_should_stop(self, pl_module):
92+
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
93+
dist.all_reduce(stop, op=dist.reduce_op.SUM)
94+
dist.barrier()
95+
should_stop = stop == self.trainer.world_size
96+
return should_stop
97+
98+
def broadcast(self, obj, src=0):
99+
return self.dist.broadcast(obj)
100+
101+
def ddp_train(self, process_idx, model):
102+
"""
103+
Entry point for ddp
104+
105+
Args:
106+
process_idx:
107+
mp_queue: multiprocessing queue
108+
model:
109+
110+
Returns:
111+
112+
"""
113+
# determine which process we are and world size
114+
self.set_world_ranks(process_idx)
115+
116+
# toggle prog bar
117+
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
118+
self.trainer.progress_bar_callback.disable()
119+
120+
# set warning rank
121+
rank_zero_only.rank = self.trainer.global_rank
122+
123+
# set up server using proc 0's ip address
124+
# try to init for 20 times at max in case ports are taken
125+
# where to store ip_table
126+
model.trainer = self.trainer
127+
model.init_ddp_connection(
128+
self.trainer.global_rank,
129+
self.trainer.world_size,
130+
self.trainer.is_slurm_managing_tasks
131+
)
132+
133+
# call setup after the ddp process has connected
134+
self.trainer.call_setup_hook(model)
135+
136+
# on world_size=0 let everyone know training is starting
137+
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
138+
log.info('-' * 100)
139+
log.info(f'distributed_backend={self.trainer.distributed_backend} (on SLURM)')
140+
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
141+
log.info('-' * 100)
142+
143+
# call sync_bn before .cuda(), configure_apex and configure_ddp
144+
if self.trainer.sync_batchnorm:
145+
model = model.configure_sync_batchnorm(model)
146+
147+
# move the model to the correct device
148+
self.model_to_device(model, process_idx)
149+
150+
# CHOOSE OPTIMIZER
151+
# allow for lr schedulers as well
152+
self.setup_optimizers(model)
153+
154+
# set model properties before going into wrapper
155+
self.trainer.model_connector.copy_trainer_model_properties(model)
156+
157+
# 16-bit
158+
model = self.trainer.precision_connector.connect(model)
159+
160+
# device ids change depending on the DDP setup
161+
device_ids = self.get_device_ids()
162+
163+
# allow user to configure ddp
164+
model = model.configure_ddp(model, device_ids)
165+
166+
# set up training routine
167+
self.trainer.train_loop.setup_training(model)
168+
169+
# train or test
170+
results = self.train_or_test()
171+
172+
# clean up memory
173+
torch.cuda.empty_cache()
174+
175+
return results

0 commit comments

Comments
 (0)