@@ -112,7 +112,8 @@ def __init__(
112112 self ._accelerator_type = None
113113
114114 self .strategy = strategy .lower () if isinstance (strategy , str ) else strategy
115- self .accelerator = accelerator
115+ # TODO: Rename this to something else once all the distributed flags are moved to strategy
116+ self .distributed_backend = accelerator
116117
117118 self ._init_deterministic (deterministic )
118119
@@ -202,7 +203,7 @@ def _init_deterministic(self, deterministic: bool) -> None:
202203 os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":4096:8"
203204
204205 def select_accelerator_type (self ) -> None :
205- if self .accelerator == "auto" :
206+ if self .distributed_backend == "auto" :
206207 if self .has_tpu :
207208 self ._accelerator_type = DeviceType .TPU
208209 elif self .has_ipu :
@@ -212,34 +213,34 @@ def select_accelerator_type(self) -> None:
212213 else :
213214 self ._set_devices_to_cpu_num_processes ()
214215 self ._accelerator_type = DeviceType .CPU
215- elif self .accelerator == DeviceType .TPU :
216+ elif self .distributed_backend == DeviceType .TPU :
216217 if not self .has_tpu :
217218 msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`"
218219 raise MisconfigurationException (f"You passed `accelerator='tpu'`, but { msg } ." )
219220 self ._accelerator_type = DeviceType .TPU
220- elif self .accelerator == DeviceType .IPU :
221+ elif self .distributed_backend == DeviceType .IPU :
221222 if not self .has_ipu :
222223 msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`"
223224 raise MisconfigurationException (f"You passed `accelerator='ipu'`, but { msg } ." )
224225 self ._accelerator_type = DeviceType .IPU
225- elif self .accelerator == DeviceType .GPU :
226+ elif self .distributed_backend == DeviceType .GPU :
226227 if not self .has_gpu :
227228 msg = "you didn't pass `gpus` to `Trainer`" if torch .cuda .is_available () else "GPUs are not available"
228229 raise MisconfigurationException (f"You passed `accelerator='gpu'`, but { msg } ." )
229230 self ._accelerator_type = DeviceType .GPU
230- elif self .accelerator == DeviceType .CPU :
231+ elif self .distributed_backend == DeviceType .CPU :
231232 self ._set_devices_to_cpu_num_processes ()
232233 self ._accelerator_type = DeviceType .CPU
233234
234- if self .accelerator in self .accelerator_types :
235- self .accelerator = None
235+ if self .distributed_backend in self .accelerator_types :
236+ self .distributed_backend = None
236237
237238 def _validate_accelerator_and_devices (self ) -> None :
238- if self .accelerator not in self .accelerator_types and self .devices is not None :
239+ if self .distributed_backend not in self .accelerator_types and self .devices is not None :
239240 raise MisconfigurationException (
240241 f"You passed `devices={ self .devices } ` but haven't specified"
241242 " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping,"
242- f" got `accelerator={ self .accelerator !r} `."
243+ f" got `accelerator={ self .distributed_backend !r} `."
243244 )
244245
245246 def _validate_accelerator_type (self ) -> None :
@@ -255,16 +256,16 @@ def _warn_if_devices_flag_ignored(self) -> None:
255256 if self .devices is None :
256257 return
257258 devices_warning = f"The flag `devices={ self .devices } ` will be ignored, as you have set"
258- if self .accelerator in ("auto" , DeviceType .TPU ):
259+ if self .distributed_backend in ("auto" , DeviceType .TPU ):
259260 if self .tpu_cores is not None :
260261 rank_zero_warn (f"{ devices_warning } `tpu_cores={ self .tpu_cores } `" )
261- elif self .accelerator in ("auto" , DeviceType .IPU ):
262+ elif self .distributed_backend in ("auto" , DeviceType .IPU ):
262263 if self .ipus is not None :
263264 rank_zero_warn (f"{ devices_warning } `ipus={ self .ipus } `" )
264- elif self .accelerator in ("auto" , DeviceType .GPU ):
265+ elif self .distributed_backend in ("auto" , DeviceType .GPU ):
265266 if self .gpus is not None :
266267 rank_zero_warn (f"{ devices_warning } `gpus={ self .gpus } `" )
267- elif self .accelerator in ("auto" , DeviceType .CPU ):
268+ elif self .distributed_backend in ("auto" , DeviceType .CPU ):
268269 if self .num_processes != 1 :
269270 rank_zero_warn (f"{ devices_warning } `num_processes={ self .num_processes } `" )
270271
@@ -281,15 +282,15 @@ def _set_devices_if_none(self) -> None:
281282 self .devices = self .num_processes
282283
283284 def _handle_accelerator_and_strategy (self ) -> None :
284- if self .accelerator is not None and self .accelerator in list (DistributedType ):
285+ if self .distributed_backend is not None and self .distributed_backend in list (DistributedType ):
285286 rank_zero_deprecation (
286- f"Passing `Trainer(accelerator={ self .accelerator !r} )` has been deprecated"
287- f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ self .accelerator !r} )` instead."
287+ f"Passing `Trainer(accelerator={ self .distributed_backend !r} )` has been deprecated"
288+ f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ self .distributed_backend !r} )` instead."
288289 )
289290 if self .strategy is not None :
290291 raise MisconfigurationException (
291292 f"You have passed `Trainer(strategy={ self .strategy !r} )` but have"
292- f" also passed `Trainer(accelerator={ self .accelerator !r} )`."
293+ f" also passed `Trainer(accelerator={ self .distributed_backend !r} )`."
293294 f" HINT: Use just `Trainer(strategy={ self .strategy !r} )` instead."
294295 )
295296
@@ -635,8 +636,11 @@ def select_precision_plugin(self) -> PrecisionPlugin:
635636 return ApexMixedPrecisionPlugin (self .amp_level )
636637
637638 def select_training_type_plugin (self ) -> TrainingTypePlugin :
638- if isinstance (self .accelerator , Accelerator ) and self .accelerator .training_type_plugin is not None :
639- plugin = self .accelerator .training_type_plugin
639+ if (
640+ isinstance (self .distributed_backend , Accelerator )
641+ and self .distributed_backend .training_type_plugin is not None
642+ ):
643+ plugin = self .distributed_backend .training_type_plugin
640644 elif self .use_ddp2 :
641645 plugin = DDP2Plugin (parallel_devices = self .parallel_devices , cluster_environment = self .cluster_environment )
642646 elif self .use_ddp and self .use_deepspeed :
@@ -718,15 +722,15 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra
718722 return training_type
719723
720724 def select_accelerator (self ) -> Accelerator :
721- if isinstance (self .accelerator , Accelerator ):
725+ if isinstance (self .distributed_backend , Accelerator ):
722726 # custom accelerator from user
723727 if self ._precision_plugin is not None or self ._training_type_plugin is not None :
724728 # plugins also specified by user
725729 rank_zero_warn (
726730 "Specified `Precision` and `TrainingType` plugins will be ignored,"
727731 " since an `Accelerator` instance was provided."
728732 )
729- return self .accelerator
733+ return self .distributed_backend
730734
731735 if self .use_gpu :
732736 acc_cls = GPUAccelerator
@@ -766,32 +770,32 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
766770 return
767771
768772 if strategy is not None and strategy in TrainingTypePluginsRegistry :
769- self .accelerator = TrainingTypePluginsRegistry [strategy ]["distributed_backend" ]
773+ self .distributed_backend = TrainingTypePluginsRegistry [strategy ]["distributed_backend" ]
770774 elif strategy is not None :
771- self .accelerator = strategy
775+ self .distributed_backend = strategy
772776
773- if isinstance (self .accelerator , Accelerator ):
777+ if isinstance (self .distributed_backend , Accelerator ):
774778 return
775779
776780 is_cpu_accelerator_type = self ._accelerator_type and self ._accelerator_type == DeviceType .CPU
777- _use_cpu = is_cpu_accelerator_type or self .accelerator and "cpu" in self .accelerator
781+ _use_cpu = is_cpu_accelerator_type or self .distributed_backend and "cpu" in self .distributed_backend
778782
779- if self .accelerator is None :
783+ if self .distributed_backend is None :
780784 if self .has_horovodrun ():
781785 self ._set_horovod_backend ()
782786 elif self .num_gpus == 0 and self .num_nodes > 1 :
783787 self ._distrib_type = DistributedType .DDP
784788 elif self .num_gpus == 0 and self .num_processes > 1 :
785- self .accelerator = DistributedType .DDP_SPAWN
789+ self .distributed_backend = DistributedType .DDP_SPAWN
786790 elif self .num_gpus > 1 and not _use_cpu :
787791 rank_zero_warn (
788792 "You requested multiple GPUs but did not specify a backend, e.g."
789793 ' `Trainer(strategy="dp"|"ddp"|"ddp2")`. Setting `strategy="ddp_spawn"` for you.'
790794 )
791- self .accelerator = DistributedType .DDP_SPAWN
795+ self .distributed_backend = DistributedType .DDP_SPAWN
792796
793797 # special case with DDP on CPUs
794- if self .accelerator == DistributedType .DDP_CPU :
798+ if self .distributed_backend == DistributedType .DDP_CPU :
795799 if _TPU_AVAILABLE :
796800 raise MisconfigurationException (
797801 "`accelerator='ddp_cpu'` is not supported on TPU machines. "
@@ -816,8 +820,8 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
816820 self ._distrib_type = DistributedType .TPU_SPAWN
817821 elif self .has_ipu and not _use_cpu :
818822 self ._device_type = DeviceType .IPU
819- elif self .accelerator and self ._distrib_type is None :
820- self ._distrib_type = DistributedType (self .accelerator )
823+ elif self .distributed_backend and self ._distrib_type is None :
824+ self ._distrib_type = DistributedType (self .distributed_backend )
821825
822826 if self .num_gpus > 0 and not _use_cpu :
823827 self ._device_type = DeviceType .GPU
@@ -850,7 +854,7 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
850854 self .num_processes = self .num_nodes
851855
852856 # Horovod is an extra case...
853- if self .accelerator == DistributedType .HOROVOD :
857+ if self .distributed_backend == DistributedType .HOROVOD :
854858 self ._set_horovod_backend ()
855859
856860 using_valid_distributed = self .use_ddp or self .use_ddp2
0 commit comments