|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License |
| 14 | +import os |
| 15 | +import torch |
| 16 | +import torch.distributed as torch_distrib |
| 17 | +import torch.distributed as dist |
| 18 | + |
| 19 | +from pytorch_lightning.accelerators.base_backend import Accelerator |
| 20 | +from pytorch_lightning import _logger as log |
| 21 | +from pytorch_lightning.utilities import AMPType |
| 22 | +from pytorch_lightning.utilities.distributed import rank_zero_only |
| 23 | +from pytorch_lightning.utilities.seed import seed_everything |
| 24 | +from pytorch_lightning.distributed.dist import LightningDistributed |
| 25 | + |
| 26 | + |
| 27 | +try: |
| 28 | + from hydra.utils import to_absolute_path, get_original_cwd |
| 29 | + from hydra.core.hydra_config import HydraConfig |
| 30 | +except ImportError: |
| 31 | + HYDRA_AVAILABLE = False |
| 32 | +else: |
| 33 | + HYDRA_AVAILABLE = True |
| 34 | + |
| 35 | + |
| 36 | +# ------------------------------------------- |
| 37 | +# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!! |
| 38 | +# TEMP CLASS WHILE WE DECOUPLE SLURM FROM DDP |
| 39 | +# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!! |
| 40 | +# ------------------------------------------- |
| 41 | +class DDPTorchElasticBackend(Accelerator): |
| 42 | + |
| 43 | + def __init__(self, trainer): |
| 44 | + super().__init__(trainer) |
| 45 | + self.task_idx = None |
| 46 | + self._has_spawned_children = False |
| 47 | + self.dist = LightningDistributed() |
| 48 | + |
| 49 | + def setup(self, model): |
| 50 | + self.trainer.model = model |
| 51 | + self.task_idx = int(os.environ['LOCAL_RANK']) |
| 52 | + |
| 53 | + def train(self): |
| 54 | + model = self.trainer.model |
| 55 | + self.ddp_train(process_idx=self.task_idx, model=model) |
| 56 | + |
| 57 | + def set_world_ranks(self, process_idx): |
| 58 | + self.trainer.local_rank = process_idx |
| 59 | + self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx |
| 60 | + self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes |
| 61 | + |
| 62 | + def model_to_device(self, model, process_idx): |
| 63 | + self.trainer.root_gpu = process_idx |
| 64 | + torch.cuda.set_device(self.trainer.root_gpu) |
| 65 | + model.cuda(self.trainer.root_gpu) |
| 66 | + |
| 67 | + def get_device_ids(self): |
| 68 | + device_ids = [self.trainer.root_gpu] |
| 69 | + return device_ids |
| 70 | + |
| 71 | + def training_step(self, args): |
| 72 | + if self.trainer.amp_backend == AMPType.NATIVE: |
| 73 | + with torch.cuda.amp.autocast(): |
| 74 | + output = self.trainer.model(*args) |
| 75 | + else: |
| 76 | + output = self.trainer.model(*args) |
| 77 | + return output |
| 78 | + |
| 79 | + def validation_step(self, args): |
| 80 | + output = self.training_step(args) |
| 81 | + return output |
| 82 | + |
| 83 | + def test_step(self, args): |
| 84 | + output = self.training_step(args) |
| 85 | + return output |
| 86 | + |
| 87 | + def barrier(self, name: str = None): |
| 88 | + if torch_distrib.is_initialized(): |
| 89 | + torch_distrib.barrier() |
| 90 | + |
| 91 | + def early_stopping_should_stop(self, pl_module): |
| 92 | + stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) |
| 93 | + dist.all_reduce(stop, op=dist.reduce_op.SUM) |
| 94 | + dist.barrier() |
| 95 | + should_stop = stop == self.trainer.world_size |
| 96 | + return should_stop |
| 97 | + |
| 98 | + def broadcast(self, obj, src=0): |
| 99 | + return self.dist.broadcast(obj) |
| 100 | + |
| 101 | + def ddp_train(self, process_idx, model): |
| 102 | + """ |
| 103 | + Entry point for ddp |
| 104 | +
|
| 105 | + Args: |
| 106 | + process_idx: |
| 107 | + mp_queue: multiprocessing queue |
| 108 | + model: |
| 109 | +
|
| 110 | + Returns: |
| 111 | +
|
| 112 | + """ |
| 113 | + # determine which process we are and world size |
| 114 | + self.set_world_ranks(process_idx) |
| 115 | + |
| 116 | + # toggle prog bar |
| 117 | + if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None: |
| 118 | + self.trainer.progress_bar_callback.disable() |
| 119 | + |
| 120 | + # set warning rank |
| 121 | + rank_zero_only.rank = self.trainer.global_rank |
| 122 | + |
| 123 | + # set up server using proc 0's ip address |
| 124 | + # try to init for 20 times at max in case ports are taken |
| 125 | + # where to store ip_table |
| 126 | + model.trainer = self.trainer |
| 127 | + model.init_ddp_connection( |
| 128 | + self.trainer.global_rank, |
| 129 | + self.trainer.world_size, |
| 130 | + self.trainer.is_slurm_managing_tasks |
| 131 | + ) |
| 132 | + |
| 133 | + # call setup after the ddp process has connected |
| 134 | + self.trainer.call_setup_hook(model) |
| 135 | + |
| 136 | + # on world_size=0 let everyone know training is starting |
| 137 | + if self.trainer.is_global_zero and not torch.distributed.is_initialized(): |
| 138 | + log.info('-' * 100) |
| 139 | + log.info(f'distributed_backend={self.trainer.distributed_backend} (on SLURM)') |
| 140 | + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') |
| 141 | + log.info('-' * 100) |
| 142 | + |
| 143 | + # call sync_bn before .cuda(), configure_apex and configure_ddp |
| 144 | + if self.trainer.sync_batchnorm: |
| 145 | + model = model.configure_sync_batchnorm(model) |
| 146 | + |
| 147 | + # move the model to the correct device |
| 148 | + self.model_to_device(model, process_idx) |
| 149 | + |
| 150 | + # CHOOSE OPTIMIZER |
| 151 | + # allow for lr schedulers as well |
| 152 | + self.setup_optimizers(model) |
| 153 | + |
| 154 | + # set model properties before going into wrapper |
| 155 | + self.trainer.model_connector.copy_trainer_model_properties(model) |
| 156 | + |
| 157 | + # 16-bit |
| 158 | + model = self.trainer.precision_connector.connect(model) |
| 159 | + |
| 160 | + # device ids change depending on the DDP setup |
| 161 | + device_ids = self.get_device_ids() |
| 162 | + |
| 163 | + # allow user to configure ddp |
| 164 | + model = model.configure_ddp(model, device_ids) |
| 165 | + |
| 166 | + # set up training routine |
| 167 | + self.trainer.train_loop.setup_training(model) |
| 168 | + |
| 169 | + # train or test |
| 170 | + results = self.train_or_test() |
| 171 | + |
| 172 | + # clean up memory |
| 173 | + torch.cuda.empty_cache() |
| 174 | + |
| 175 | + return results |
0 commit comments