Skip to content

Commit f35dda8

Browse files
SeanNarenBorda
authored andcommitted
[Fix] Move init dist connection into the setup function (#6506)
* Move connection setup into the setup function. Call setup hook after we set up the accelerator * Added CHANGELOG.md * fix setup order in callback test * fix input arguments in test * Mock distributed function, remove protection to turn into training type hook * Remove import * Add missing mock, ensure custom plugin does not create children process * Skip test on windows * Update deepspeed to init connection in setup * Do not initialize distributed module * Move DeepSpeed tests to special tests since dist communication is being set up * Special the test to see if this fixes CI * Delete accelerator connector test to see if its causing build to fail * Delete deepspeed test * Revert "Delete accelerator connector test to see if its causing build to fail" This reverts commit edde60b * Revert "Delete deepspeed test" This reverts commit 9d317429 * Reverse hook * Reverse setup hooks to debug again * Add todo so i know where i left off * For single device move in pre_dispatch after setup function * Add additional model to device hook if any additional parameters have been set * See if we can enable deepspeed tests * Revert "See if we can enable deepspeed tests" This reverts commit b5450de * See if this hook approach works * Introduce new granular hooks * Remove import, fix tpu spawn by moving the function to setup * Added missing special test Co-authored-by: Adrian Wälchli <[email protected]> (cherry picked from commit 4e9b453)
1 parent caebaea commit f35dda8

File tree

14 files changed

+171
-95
lines changed

14 files changed

+171
-95
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
119119
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
120120

121121

122+
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))
123+
124+
122125
## [1.2.4] - 2021-03-16
123126

124127
### Changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Iterable, Optional, Union
14+
from typing import Any, Callable, Iterable, Optional, Union, Sequence
1515

1616
import torch
1717
from torch.optim import Optimizer
@@ -53,21 +53,20 @@ def __init__(
5353
self.precision_plugin = precision_plugin
5454
self.training_type_plugin = training_type_plugin
5555

56-
self.optimizers = None
57-
self.lr_schedulers = None
58-
self.optimizer_frequencies = None
56+
self.optimizers: Sequence = []
57+
self.lr_schedulers: Sequence = []
58+
self.optimizer_frequencies: Sequence = []
5959

6060
def setup(self, trainer, model: LightningModule) -> None:
6161
"""
62-
Connects the plugins to the training process, creates optimizers
63-
62+
Setup plugins for the trainer fit and creates optimizers.
6463
Args:
65-
trainer: the trainer instance to connect to
66-
model: the model to train
64+
trainer: the trainer instance
65+
model: the LightningModule
6766
"""
68-
self.connect_training_type_plugin(self.training_type_plugin, model)
67+
self.setup_training_type_plugin(self.training_type_plugin, model)
6968
self.setup_optimizers(trainer)
70-
self.connect_precision_plugin(self.precision_plugin)
69+
self.setup_precision_plugin(self.precision_plugin)
7170

7271
def start_training(self, trainer):
7372
self.training_type_plugin.start_training(trainer)
@@ -319,11 +318,8 @@ def setup_optimizers(self, trainer):
319318
self.lr_schedulers = lr_schedulers
320319
self.optimizer_frequencies = optimizer_frequencies
321320

322-
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
323-
"""Attaches the training type plugin to the accelerator.
324-
Also transfers ownership of the model to this plugin
325-
326-
"""
321+
def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
322+
"""Attaches the training type plugin to the accelerator."""
327323
plugin.connect(model)
328324

329325
def connect_precision_plugin(self, plugin: PrecisionPlugin):

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ def distributed_sampler_kwargs(self):
8686
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
8787
return distributed_sampler_kwargs
8888

89-
def setup(self, model):
90-
self._model = model
91-
89+
def setup_environment(self):
9290
# start the other scripts
9391
# TODO: refactor and let generic cluster env hold the information about who spawns the processes
9492
if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
@@ -97,6 +95,8 @@ def setup(self, model):
9795
# set the task idx
9896
self.task_idx = self.cluster_environment.local_rank()
9997

98+
self.setup_distributed()
99+
100100
def _call_children_scripts(self):
101101

102102
# bookkeeping of spawned processes
@@ -171,6 +171,34 @@ def _call_children_scripts(self):
171171
delay = np.random.uniform(1, 5, 1)[0]
172172
sleep(delay)
173173

174+
def setup_distributed(self):
175+
# TODO: check if needed
176+
seed = os.environ.get("PL_GLOBAL_SEED")
177+
if seed is not None:
178+
seed_everything(int(seed))
179+
180+
# determine which process we are and world size
181+
self.set_world_ranks()
182+
183+
# set warning rank
184+
rank_zero_only.rank = self.global_rank
185+
186+
# set up server using proc 0's ip address
187+
# try to init for 20 times at max in case ports are taken
188+
# where to store ip_table
189+
self.init_ddp_connection(self.global_rank, self.world_size)
190+
191+
# on world_size=0 let everyone know training is starting
192+
if self.is_global_zero and not torch.distributed.is_initialized():
193+
log.info("-" * 100)
194+
log.info(f"distributed_backend={self.distributed_backend}")
195+
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
196+
log.info("-" * 100)
197+
198+
# set the ranks and devices
199+
self.dist.rank = self.global_rank
200+
self.dist.device = self.root_device
201+
174202
def _check_can_spawn_children(self):
175203
if self._has_spawned_children:
176204
raise RuntimeError(
@@ -226,37 +254,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
226254
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
227255

228256
def pre_dispatch(self):
229-
# TODO: check if needed
230-
seed = os.environ.get("PL_GLOBAL_SEED")
231-
if seed is not None:
232-
seed_everything(int(seed))
233-
234-
# determine which process we are and world size
235-
self.set_world_ranks()
236-
237-
# set warning rank
238-
rank_zero_only.rank = self.global_rank
239-
240-
# set up server using proc 0's ip address
241-
# try to init for 20 times at max in case ports are taken
242-
# where to store ip_table
243-
self.init_ddp_connection(self.global_rank, self.world_size)
244-
245-
# TODO: we moved it to the trainer.fit after calling pre_dispatch
246-
# ... need to double check that it is the correct place
247-
# self.trainer.call_setup_hook(self.model)
248-
249-
# on world_size=0 let everyone know training is starting
250-
if self.is_global_zero and not torch.distributed.is_initialized():
251-
log.info("-" * 100)
252-
log.info(f"distributed_backend={self.distributed_backend}")
253-
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
254-
log.info("-" * 100)
255-
256-
# set the ranks and devices
257-
self.dist.rank = self.global_rank
258-
self.dist.device = self.root_device
259-
260257
if self.sync_batchnorm:
261258
self.model = self.configure_sync_batchnorm(self.model)
262259

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,7 @@ def _load_config(self, config):
192192
return config
193193

194194
def pre_dispatch(self):
195-
self.set_world_ranks()
196-
self.init_ddp_connection(self.global_rank, self.world_size)
197-
198195
self.init_deepspeed()
199-
200-
# set warning rank
201-
rank_zero_only.rank = self.global_rank
202-
203-
# set the ranks and devices
204-
self.dist.rank = self.global_rank
205-
self.dist.device = self.root_device
206196
self.barrier()
207197

208198
def init_deepspeed(self):

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,6 @@ def on_gpu(self):
6060
def lightning_module(self):
6161
return unwrap_lightning_module(self._model)
6262

63-
@abstractmethod
64-
def setup(self, model):
65-
raise NotImplementedError
66-
67-
def connect(self, model, *args, **kwargs):
68-
self.setup(model)
69-
return self.model
70-
7163
@property
7264
def is_global_zero(self) -> bool:
7365
return self.global_rank == 0

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ def model_to_device(self) -> None:
6464

6565
self._model.to(self.root_device)
6666

67-
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
68-
self._model = model
67+
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
6968
self.model_to_device()
7069
return self.model
7170

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,8 @@ def __init__(self, device: Union[torch.device, int]):
2626
def on_tpu(self) -> bool:
2727
return True
2828

29-
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
30-
self._model = model
31-
self.model_to_device()
32-
return self._model
33-
3429
def model_to_device(self) -> None:
35-
self._model.to(self.root_device)
30+
self.model.to(self.root_device)
3631

3732
def pre_dispatch(self) -> None:
3833
if isinstance(self.device, int):

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ def __init__(
3939
self.tpu_local_core_rank = 0
4040
self.start_method = None
4141

42-
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
42+
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
4343
self.create_mp_queue()
44-
self._model = model
45-
return self._model
44+
return self.model
4645

4746
def create_mp_queue(self):
4847
self.start_method = 'fork'

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,10 @@ def fit(
455455
# ----------------------------
456456
# SET UP TRAINING
457457
# ----------------------------
458-
self.call_setup_hook(model)
459458
self.call_hook("on_before_accelerator_backend_setup", model)
459+
self.accelerator.connect(model)
460+
self.accelerator.setup_environment()
461+
self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
460462
self.accelerator.setup(self, model) # note: this sets up self.lightning_module
461463

462464
# ----------------------------

tests/accelerators/test_accelerator_connector.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
9595
"SLURM_LOCALID": "10"
9696
}
9797
)
98-
def test_accelerator_choice_ddp_slurm():
98+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
99+
def test_accelerator_choice_ddp_slurm(setup_distributed_mock):
99100

100101
class CB(Callback):
101102

@@ -133,7 +134,8 @@ def on_fit_start(self, trainer, pl_module):
133134
}
134135
)
135136
@mock.patch('torch.cuda.device_count', return_value=2)
136-
def test_accelerator_choice_ddp2_slurm(device_count_mock):
137+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
138+
def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):
137139

138140
class CB(Callback):
139141

@@ -162,7 +164,8 @@ def on_fit_start(self, trainer, pl_module):
162164
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU")
163165
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"})
164166
@mock.patch('torch.cuda.device_count', return_value=2)
165-
def test_accelerator_choice_ddp_te(device_count_mock):
167+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
168+
def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock):
166169

167170
class CB(Callback):
168171

@@ -190,7 +193,8 @@ def on_fit_start(self, trainer, pl_module):
190193
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU")
191194
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"})
192195
@mock.patch('torch.cuda.device_count', return_value=2)
193-
def test_accelerator_choice_ddp2_te(device_count_mock):
196+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
197+
def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock):
194198

195199
class CB(Callback):
196200

@@ -221,7 +225,8 @@ def on_fit_start(self, trainer, pl_module):
221225
"NODE_RANK": "0",
222226
})
223227
@mock.patch('torch.cuda.device_count', return_value=0)
224-
def test_accelerator_choice_ddp_cpu_te(device_count_mock):
228+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
229+
def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock):
225230

226231
class CB(Callback):
227232

@@ -256,7 +261,8 @@ def on_fit_start(self, trainer, pl_module):
256261
}
257262
)
258263
@mock.patch('torch.cuda.device_count', return_value=0)
259-
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock):
264+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
265+
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
260266

261267
class CB(Callback):
262268

@@ -291,7 +297,8 @@ def on_fit_start(self, trainer, pl_module):
291297
}
292298
)
293299
@mock.patch('torch.cuda.device_count', return_value=0)
294-
def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock):
300+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
301+
def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock, setup_distributed_mock):
295302
"""
296303
Test that we choose the custom cluster even when SLURM or TE flags are around
297304
"""
@@ -301,6 +308,9 @@ class CustomCluster(ClusterEnvironment):
301308
def master_address(self):
302309
return 'asdf'
303310

311+
def creates_children(self) -> bool:
312+
return True
313+
304314
class CB(Callback):
305315

306316
def on_fit_start(self, trainer, pl_module):
@@ -333,7 +343,8 @@ def on_fit_start(self, trainer, pl_module):
333343
}
334344
)
335345
@mock.patch('torch.cuda.device_count', return_value=0)
336-
def test_custom_accelerator(device_count_mock):
346+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
347+
def test_custom_accelerator(device_count_mock, setup_distributed_mock):
337348

338349
class Accel(Accelerator):
339350
pass
@@ -368,7 +379,8 @@ class TrainTypePlugin(SingleDevicePlugin):
368379
}
369380
)
370381
@mock.patch('torch.cuda.device_count', return_value=0)
371-
def test_dist_backend_accelerator_mapping(device_count_mock):
382+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
383+
def test_dist_backend_accelerator_mapping(device_count_mock, setup_distributed_mock):
372384

373385
class CB(Callback):
374386

0 commit comments

Comments
 (0)