@@ -92,6 +92,7 @@ def __init__(
9292 tpu_cores ,
9393 ipus ,
9494 accelerator ,
95+ strategy : Optional [Union [str , TrainingTypePlugin ]],
9596 gpus ,
9697 gpu_ids ,
9798 num_nodes ,
@@ -109,14 +110,25 @@ def __init__(
109110 self ._distrib_type = None
110111 self ._accelerator_type = None
111112
113+ < << << << HEAD
114+ == == == =
115+ self .strategy = strategy .lower () if isinstance (strategy , str ) else strategy
116+ self .distributed_backend = distributed_backend or accelerator
117+
118+ self ._init_deterministic (deterministic )
119+
120+ >> >> >> > 05 b15e63f (Add `strategy` argument to Trainer (#8597))
112121 self .num_processes = num_processes
113122 self .devices = devices
114123 # `gpus` is the input passed to the Trainer, whereas `gpu_ids` is a list of parsed gpu ids.
115124 self .gpus = gpus
116125 self .parallel_device_ids = gpu_ids
117126 self .tpu_cores = tpu_cores
118127 self .ipus = ipus
128+ << < << << HEAD
119129 self .accelerator = accelerator
130+ == == == =
131+ >> > >> > > 05 b15e63f (Add `strategy` argument to Trainer (#8597))
120132 self .num_nodes = num_nodes
121133 self .sync_batchnorm = sync_batchnorm
122134 self .benchmark = benchmark
@@ -141,16 +153,23 @@ def __init__(
141153
142154 self .plugins = plugins
143155
156+ self ._handle_accelerator_and_distributed_backend (distributed_backend , accelerator )
157+
144158 self ._validate_accelerator_and_devices ()
145159
146160 self ._warn_if_devices_flag_ignored ()
147161
148162 self .select_accelerator_type ()
149- self .set_distributed_mode ()
163+
164+ if self .strategy is not None :
165+ self ._set_training_type_plugin ()
166+ else :
167+ self .set_distributed_mode ()
150168 self .configure_slurm_ddp ()
151169
152170 self .handle_given_plugins ()
153171 self .update_device_type_if_ipu_plugin ()
172+ self .update_device_type_if_training_type_plugin_passed ()
154173
155174 self ._validate_accelerator_type ()
156175 self ._set_devices_if_none ()
@@ -275,9 +294,56 @@ def _set_devices_if_none(self) -> None:
275294 elif self ._accelerator_type == DeviceType .CPU :
276295 self .devices = self .num_processes
277296
297+ def _handle_accelerator_and_distributed_backend (
298+ self , distributed_backend : Optional [str ], accelerator : Optional [Union [str , Accelerator ]]
299+ ) - > None :
300+ if distributed_backend is not None :
301+ rank_zero_deprecation (
302+ f"`Trainer(distributed_backend={ distributed_backend } )` has been deprecated and will be removed in v1.5."
303+ f" Use `Trainer(strategy={ distributed_backend } )` instead."
304+ )
305+ if self .strategy is not None :
306+ raise MisconfigurationException (
307+ f"You have passed `Trainer(strategy={ self .strategy } )` but have"
308+ f" also passed `Trainer(distributed_backend={ distributed_backend } )`."
309+ f"HINT: Use just `Trainer(strategy={ self .strategy } )` instead."
310+ )
311+
312+ if accelerator is not None and accelerator in list (DistributedType ):
313+ rank_zero_deprecation (
314+ f"Passing { accelerator } `strategy` to the `accelerator` flag in Trainer has been deprecated"
315+ f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ accelerator } )` instead."
316+ )
317+ if self .strategy is not None :
318+ raise MisconfigurationException (
319+ f"You have passed `Trainer(strategy={ self .strategy } )` but have"
320+ f" also passed `Trainer(accelerator={ accelerator } )`."
321+ f"HINT: Use just `Trainer(strategy={ self .strategy } )` instead."
322+ )
323+
324+ def _set_training_type_plugin (self ) - > None :
325+ if isinstance (self .strategy , str ) and self .strategy in TrainingTypePluginsRegistry :
326+ self ._training_type_plugin = TrainingTypePluginsRegistry .get (self .strategy )
327+ if isinstance (self .strategy , str ):
328+ self .set_distributed_mode (self .strategy )
329+ elif isinstance (self .strategy , TrainingTypePlugin ):
330+ self ._training_type_plugin = self .strategy
331+
278332 def handle_given_plugins (self ) - > None :
279333
280- training_type = None
334+ for plug in self .plugins :
335+ if self .strategy is not None and self ._is_plugin_training_type (plug ):
336+ raise MisconfigurationException (
337+ f"You have passed `Trainer(strategy={ self .strategy } )`"
338+ f" and you can only specify one training type plugin, but you have passed { plug } as a plugin."
339+ )
340+ if self ._is_plugin_training_type (plug ):
341+ rank_zero_deprecation (
342+ f"Passing { plug } `strategy` to the `plugins` flag in Trainer has been deprecated"
343+ f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ plug } )` instead."
344+ )
345+
346+ training_type = self ._training_type_plugin or None
281347 checkpoint = None
282348 precision = None
283349 cluster_environment = None
@@ -340,6 +406,10 @@ def handle_given_plugins(self) -> None:
340406 self ._checkpoint_io = checkpoint
341407 self ._cluster_environment = cluster_environment or self .select_cluster_environment ()
342408
409+ @property
410+ def accelerator_types (self ) - > List [str ]:
411+ return ["auto "] + list (DeviceType )
412+
343413 @property
344414 def precision_plugin (self ) -> PrecisionPlugin :
345415 if self ._precision_plugin is None :
@@ -530,9 +600,18 @@ def root_gpu(self) -> Optional[int]:
530600 else None
531601 )
532602
603+ @staticmethod
604+ def _is_plugin_training_type (plugin : Union [str , TrainingTypePlugin ]) - > bool :
605+ if isinstance (plugin , str ) and (plugin in TrainingTypePluginsRegistry or plugin in list (DistributedType )):
606+ return True
607+ return isinstance (plugin , TrainingTypePlugin )
608+
533609 @property
534610 def is_training_type_in_plugins (self ) - > bool :
535- return any (isinstance (plug , str ) and plug in TrainingTypePluginsRegistry for plug in self .plugins )
611+ return any (
612+ (isinstance (plug , str ) and plug in TrainingTypePluginsRegistry ) or isinstance (plug , TrainingTypePlugin )
613+ for plug in self .plugins
614+ )
536615
537616 def select_precision_plugin (self ) - > PrecisionPlugin :
538617 # set precision type
@@ -862,6 +941,25 @@ def update_device_type_if_ipu_plugin(self) -> None:
862941 if isinstance (self ._training_type_plugin , IPUPlugin ) and self ._device_type != DeviceType .IPU :
863942 self ._device_type = DeviceType .IPU
864943
944+ def update_device_type_if_training_type_plugin_passed (self ) - > None :
945+ if isinstance (self .strategy , TrainingTypePlugin ) or any (
946+ isinstance (plug , TrainingTypePlugin ) for plug in self .plugins
947+ ):
948+ if self ._accelerator_type is not None :
949+ if self .use_ipu :
950+ self ._device_type = DeviceType .IPU
951+ elif self .use_tpu :
952+ self ._device_type = DeviceType .TPU
953+ elif self .use_gpu :
954+ self ._device_type = DeviceType .GPU
955+ else :
956+ if self .has_ipu :
957+ self ._device_type = DeviceType .IPU
958+ elif self .has_tpu :
959+ self ._device_type = DeviceType .TPU
960+ elif self .has_gpu :
961+ self ._device_type = DeviceType .GPU
962+
865963 def configure_slurm_ddp (self ):
866964 # extract SLURM flag vars
867965 # whenever we have the correct number of tasks, we let slurm manage processes
0 commit comments