@@ -94,6 +94,7 @@ def __init__(
9494 ipus ,
9595 distributed_backend ,
9696 accelerator ,
97+ strategy : Optional [Union [str , TrainingTypePlugin ]],
9798 gpus ,
9899 gpu_ids ,
99100 num_nodes ,
@@ -111,12 +112,9 @@ def __init__(
111112 self ._distrib_type = None
112113 self ._accelerator_type = None
113114
114- if distributed_backend is not None :
115- rank_zero_deprecation (
116- f"`Trainer(distributed_backend={ distributed_backend } )` has been deprecated and will be removed in v1.5."
117- f" Use `Trainer(accelerator={ distributed_backend } )` instead."
118- )
119- distributed_backend = distributed_backend or accelerator
115+ self .strategy = strategy .lower () if isinstance (strategy , str ) else strategy
116+ self .distributed_backend = distributed_backend or accelerator
117+
120118 self ._init_deterministic (deterministic )
121119
122120 self .num_processes = num_processes
@@ -126,7 +124,6 @@ def __init__(
126124 self .parallel_device_ids = gpu_ids
127125 self .tpu_cores = tpu_cores
128126 self .ipus = ipus
129- self .distributed_backend = distributed_backend
130127 self .num_nodes = num_nodes
131128 self .sync_batchnorm = sync_batchnorm
132129 self .benchmark = benchmark
@@ -151,16 +148,23 @@ def __init__(
151148
152149 self .plugins = plugins
153150
151+ self ._handle_accelerator_and_distributed_backend (distributed_backend , accelerator )
152+
154153 self ._validate_accelerator_and_devices ()
155154
156155 self ._warn_if_devices_flag_ignored ()
157156
158157 self .select_accelerator_type ()
159- self .set_distributed_mode ()
158+
159+ if self .strategy is not None :
160+ self ._set_training_type_plugin ()
161+ else :
162+ self .set_distributed_mode ()
160163 self .configure_slurm_ddp ()
161164
162165 self .handle_given_plugins ()
163166 self .update_device_type_if_ipu_plugin ()
167+ self .update_device_type_if_training_type_plugin_passed ()
164168
165169 self ._validate_accelerator_type ()
166170 self ._set_devices_if_none ()
@@ -228,11 +232,11 @@ def select_accelerator_type(self) -> None:
228232 self ._set_devices_to_cpu_num_processes ()
229233 self ._accelerator_type = DeviceType .CPU
230234
231- if self .distributed_backend in [ "auto" ] + list ( DeviceType ) :
235+ if self .distributed_backend in self . accelerator_types :
232236 self .distributed_backend = None
233237
234238 def _validate_accelerator_and_devices (self ) -> None :
235- if self .distributed_backend not in [ "auto" ] + list ( DeviceType ) and self .devices is not None :
239+ if self .distributed_backend not in self . accelerator_types and self .devices is not None :
236240 raise MisconfigurationException (
237241 f"You passed `devices={ self .devices } ` but haven't specified"
238242 " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping,"
@@ -285,9 +289,56 @@ def _set_devices_if_none(self) -> None:
285289 elif self ._accelerator_type == DeviceType .CPU :
286290 self .devices = self .num_processes
287291
292+ def _handle_accelerator_and_distributed_backend (
293+ self , distributed_backend : Optional [str ], accelerator : Optional [Union [str , Accelerator ]]
294+ ) -> None :
295+ if distributed_backend is not None :
296+ rank_zero_deprecation (
297+ f"`Trainer(distributed_backend={ distributed_backend } )` has been deprecated and will be removed in v1.5."
298+ f" Use `Trainer(strategy={ distributed_backend } )` instead."
299+ )
300+ if self .strategy is not None :
301+ raise MisconfigurationException (
302+ f"You have passed `Trainer(strategy={ self .strategy } )` but have"
303+ f" also passed `Trainer(distributed_backend={ distributed_backend } )`."
304+ f"HINT: Use just `Trainer(strategy={ self .strategy } )` instead."
305+ )
306+
307+ if accelerator is not None and accelerator in list (DistributedType ):
308+ rank_zero_deprecation (
309+ f"Passing { accelerator } `strategy` to the `accelerator` flag in Trainer has been deprecated"
310+ f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ accelerator } )` instead."
311+ )
312+ if self .strategy is not None :
313+ raise MisconfigurationException (
314+ f"You have passed `Trainer(strategy={ self .strategy } )` but have"
315+ f" also passed `Trainer(accelerator={ accelerator } )`."
316+ f"HINT: Use just `Trainer(strategy={ self .strategy } )` instead."
317+ )
318+
319+ def _set_training_type_plugin (self ) -> None :
320+ if isinstance (self .strategy , str ) and self .strategy in TrainingTypePluginsRegistry :
321+ self ._training_type_plugin = TrainingTypePluginsRegistry .get (self .strategy )
322+ if isinstance (self .strategy , str ):
323+ self .set_distributed_mode (self .strategy )
324+ elif isinstance (self .strategy , TrainingTypePlugin ):
325+ self ._training_type_plugin = self .strategy
326+
288327 def handle_given_plugins (self ) -> None :
289328
290- training_type = None
329+ for plug in self .plugins :
330+ if self .strategy is not None and self ._is_plugin_training_type (plug ):
331+ raise MisconfigurationException (
332+ f"You have passed `Trainer(strategy={ self .strategy } )`"
333+ f" and you can only specify one training type plugin, but you have passed { plug } as a plugin."
334+ )
335+ if self ._is_plugin_training_type (plug ):
336+ rank_zero_deprecation (
337+ f"Passing { plug } `strategy` to the `plugins` flag in Trainer has been deprecated"
338+ f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ plug } )` instead."
339+ )
340+
341+ training_type = self ._training_type_plugin or None
291342 checkpoint = None
292343 precision = None
293344 cluster_environment = None
@@ -350,6 +401,10 @@ def handle_given_plugins(self) -> None:
350401 self ._checkpoint_io = checkpoint
351402 self ._cluster_environment = cluster_environment or self .select_cluster_environment ()
352403
404+ @property
405+ def accelerator_types (self ) -> List [str ]:
406+ return ["auto" ] + list (DeviceType )
407+
353408 @property
354409 def precision_plugin (self ) -> PrecisionPlugin :
355410 if self ._precision_plugin is None :
@@ -540,9 +595,18 @@ def root_gpu(self) -> Optional[int]:
540595 else None
541596 )
542597
598+ @staticmethod
599+ def _is_plugin_training_type (plugin : Union [str , TrainingTypePlugin ]) -> bool :
600+ if isinstance (plugin , str ) and (plugin in TrainingTypePluginsRegistry or plugin in list (DistributedType )):
601+ return True
602+ return isinstance (plugin , TrainingTypePlugin )
603+
543604 @property
544605 def is_training_type_in_plugins (self ) -> bool :
545- return any (isinstance (plug , str ) and plug in TrainingTypePluginsRegistry for plug in self .plugins )
606+ return any (
607+ (isinstance (plug , str ) and plug in TrainingTypePluginsRegistry ) or isinstance (plug , TrainingTypePlugin )
608+ for plug in self .plugins
609+ )
546610
547611 def select_precision_plugin (self ) -> PrecisionPlugin :
548612 # set precision type
@@ -875,6 +939,25 @@ def update_device_type_if_ipu_plugin(self) -> None:
875939 if isinstance (self ._training_type_plugin , IPUPlugin ) and self ._device_type != DeviceType .IPU :
876940 self ._device_type = DeviceType .IPU
877941
942+ def update_device_type_if_training_type_plugin_passed (self ) -> None :
943+ if isinstance (self .strategy , TrainingTypePlugin ) or any (
944+ isinstance (plug , TrainingTypePlugin ) for plug in self .plugins
945+ ):
946+ if self ._accelerator_type is not None :
947+ if self .use_ipu :
948+ self ._device_type = DeviceType .IPU
949+ elif self .use_tpu :
950+ self ._device_type = DeviceType .TPU
951+ elif self .use_gpu :
952+ self ._device_type = DeviceType .GPU
953+ else :
954+ if self .has_ipu :
955+ self ._device_type = DeviceType .IPU
956+ elif self .has_tpu :
957+ self ._device_type = DeviceType .TPU
958+ elif self .has_gpu :
959+ self ._device_type = DeviceType .GPU
960+
878961 def configure_slurm_ddp (self ):
879962 # extract SLURM flag vars
880963 # whenever we have the correct number of tasks, we let slurm manage processes
0 commit comments