Skip to content

Commit 05b15e6

Browse files
kaushikb11rohitgr7pre-commit-ci[bot]
authored
Add strategy argument to Trainer (#8597)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 28fc8d2 commit 05b15e6

File tree

9 files changed

+323
-16
lines changed

9 files changed

+323
-16
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
184184
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
185185

186186

187+
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))
188+
189+
187190
### Changed
188191

189192
- Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867))

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
ipus,
9595
distributed_backend,
9696
accelerator,
97+
strategy: Optional[Union[str, TrainingTypePlugin]],
9798
gpus,
9899
gpu_ids,
99100
num_nodes,
@@ -111,12 +112,9 @@ def __init__(
111112
self._distrib_type = None
112113
self._accelerator_type = None
113114

114-
if distributed_backend is not None:
115-
rank_zero_deprecation(
116-
f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5."
117-
f" Use `Trainer(accelerator={distributed_backend})` instead."
118-
)
119-
distributed_backend = distributed_backend or accelerator
115+
self.strategy = strategy.lower() if isinstance(strategy, str) else strategy
116+
self.distributed_backend = distributed_backend or accelerator
117+
120118
self._init_deterministic(deterministic)
121119

122120
self.num_processes = num_processes
@@ -126,7 +124,6 @@ def __init__(
126124
self.parallel_device_ids = gpu_ids
127125
self.tpu_cores = tpu_cores
128126
self.ipus = ipus
129-
self.distributed_backend = distributed_backend
130127
self.num_nodes = num_nodes
131128
self.sync_batchnorm = sync_batchnorm
132129
self.benchmark = benchmark
@@ -151,16 +148,23 @@ def __init__(
151148

152149
self.plugins = plugins
153150

151+
self._handle_accelerator_and_distributed_backend(distributed_backend, accelerator)
152+
154153
self._validate_accelerator_and_devices()
155154

156155
self._warn_if_devices_flag_ignored()
157156

158157
self.select_accelerator_type()
159-
self.set_distributed_mode()
158+
159+
if self.strategy is not None:
160+
self._set_training_type_plugin()
161+
else:
162+
self.set_distributed_mode()
160163
self.configure_slurm_ddp()
161164

162165
self.handle_given_plugins()
163166
self.update_device_type_if_ipu_plugin()
167+
self.update_device_type_if_training_type_plugin_passed()
164168

165169
self._validate_accelerator_type()
166170
self._set_devices_if_none()
@@ -228,11 +232,11 @@ def select_accelerator_type(self) -> None:
228232
self._set_devices_to_cpu_num_processes()
229233
self._accelerator_type = DeviceType.CPU
230234

231-
if self.distributed_backend in ["auto"] + list(DeviceType):
235+
if self.distributed_backend in self.accelerator_types:
232236
self.distributed_backend = None
233237

234238
def _validate_accelerator_and_devices(self) -> None:
235-
if self.distributed_backend not in ["auto"] + list(DeviceType) and self.devices is not None:
239+
if self.distributed_backend not in self.accelerator_types and self.devices is not None:
236240
raise MisconfigurationException(
237241
f"You passed `devices={self.devices}` but haven't specified"
238242
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping,"
@@ -285,9 +289,56 @@ def _set_devices_if_none(self) -> None:
285289
elif self._accelerator_type == DeviceType.CPU:
286290
self.devices = self.num_processes
287291

292+
def _handle_accelerator_and_distributed_backend(
293+
self, distributed_backend: Optional[str], accelerator: Optional[Union[str, Accelerator]]
294+
) -> None:
295+
if distributed_backend is not None:
296+
rank_zero_deprecation(
297+
f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5."
298+
f" Use `Trainer(strategy={distributed_backend})` instead."
299+
)
300+
if self.strategy is not None:
301+
raise MisconfigurationException(
302+
f"You have passed `Trainer(strategy={self.strategy})` but have"
303+
f" also passed `Trainer(distributed_backend={distributed_backend})`."
304+
f"HINT: Use just `Trainer(strategy={self.strategy})` instead."
305+
)
306+
307+
if accelerator is not None and accelerator in list(DistributedType):
308+
rank_zero_deprecation(
309+
f"Passing {accelerator} `strategy` to the `accelerator` flag in Trainer has been deprecated"
310+
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator})` instead."
311+
)
312+
if self.strategy is not None:
313+
raise MisconfigurationException(
314+
f"You have passed `Trainer(strategy={self.strategy})` but have"
315+
f" also passed `Trainer(accelerator={accelerator})`."
316+
f"HINT: Use just `Trainer(strategy={self.strategy})` instead."
317+
)
318+
319+
def _set_training_type_plugin(self) -> None:
320+
if isinstance(self.strategy, str) and self.strategy in TrainingTypePluginsRegistry:
321+
self._training_type_plugin = TrainingTypePluginsRegistry.get(self.strategy)
322+
if isinstance(self.strategy, str):
323+
self.set_distributed_mode(self.strategy)
324+
elif isinstance(self.strategy, TrainingTypePlugin):
325+
self._training_type_plugin = self.strategy
326+
288327
def handle_given_plugins(self) -> None:
289328

290-
training_type = None
329+
for plug in self.plugins:
330+
if self.strategy is not None and self._is_plugin_training_type(plug):
331+
raise MisconfigurationException(
332+
f"You have passed `Trainer(strategy={self.strategy})`"
333+
f" and you can only specify one training type plugin, but you have passed {plug} as a plugin."
334+
)
335+
if self._is_plugin_training_type(plug):
336+
rank_zero_deprecation(
337+
f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated"
338+
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead."
339+
)
340+
341+
training_type = self._training_type_plugin or None
291342
checkpoint = None
292343
precision = None
293344
cluster_environment = None
@@ -350,6 +401,10 @@ def handle_given_plugins(self) -> None:
350401
self._checkpoint_io = checkpoint
351402
self._cluster_environment = cluster_environment or self.select_cluster_environment()
352403

404+
@property
405+
def accelerator_types(self) -> List[str]:
406+
return ["auto"] + list(DeviceType)
407+
353408
@property
354409
def precision_plugin(self) -> PrecisionPlugin:
355410
if self._precision_plugin is None:
@@ -540,9 +595,18 @@ def root_gpu(self) -> Optional[int]:
540595
else None
541596
)
542597

598+
@staticmethod
599+
def _is_plugin_training_type(plugin: Union[str, TrainingTypePlugin]) -> bool:
600+
if isinstance(plugin, str) and (plugin in TrainingTypePluginsRegistry or plugin in list(DistributedType)):
601+
return True
602+
return isinstance(plugin, TrainingTypePlugin)
603+
543604
@property
544605
def is_training_type_in_plugins(self) -> bool:
545-
return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins)
606+
return any(
607+
(isinstance(plug, str) and plug in TrainingTypePluginsRegistry) or isinstance(plug, TrainingTypePlugin)
608+
for plug in self.plugins
609+
)
546610

547611
def select_precision_plugin(self) -> PrecisionPlugin:
548612
# set precision type
@@ -875,6 +939,25 @@ def update_device_type_if_ipu_plugin(self) -> None:
875939
if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU:
876940
self._device_type = DeviceType.IPU
877941

942+
def update_device_type_if_training_type_plugin_passed(self) -> None:
943+
if isinstance(self.strategy, TrainingTypePlugin) or any(
944+
isinstance(plug, TrainingTypePlugin) for plug in self.plugins
945+
):
946+
if self._accelerator_type is not None:
947+
if self.use_ipu:
948+
self._device_type = DeviceType.IPU
949+
elif self.use_tpu:
950+
self._device_type = DeviceType.TPU
951+
elif self.use_gpu:
952+
self._device_type = DeviceType.GPU
953+
else:
954+
if self.has_ipu:
955+
self._device_type = DeviceType.IPU
956+
elif self.has_tpu:
957+
self._device_type = DeviceType.TPU
958+
elif self.has_gpu:
959+
self._device_type = DeviceType.GPU
960+
878961
def configure_slurm_ddp(self):
879962
# extract SLURM flag vars
880963
# whenever we have the correct number of tasks, we let slurm manage processes

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(
155155
flush_logs_every_n_steps: Optional[int] = None,
156156
log_every_n_steps: int = 50,
157157
accelerator: Optional[Union[str, Accelerator]] = None,
158+
strategy: Optional[Union[str, TrainingTypePlugin]] = None,
158159
sync_batchnorm: bool = False,
159160
precision: Union[int, str] = 32,
160161
enable_model_summary: bool = True,
@@ -354,6 +355,9 @@ def __init__(
354355
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
355356
training will start from the beginning of the next epoch.
356357
358+
strategy: Supports different training strategies with aliases
359+
as well custom training type plugins.
360+
357361
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
358362
359363
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
@@ -424,6 +428,7 @@ def __init__(
424428
ipus,
425429
distributed_backend,
426430
accelerator,
431+
strategy,
427432
gpus,
428433
gpu_ids,
429434
num_nodes,

tests/accelerators/test_accelerator_connector.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.accelerators.gpu import GPUAccelerator
2727
from pytorch_lightning.callbacks import Callback
2828
from pytorch_lightning.plugins import (
29+
DataParallelPlugin,
2930
DDP2Plugin,
3031
DDPPlugin,
3132
DDPShardedPlugin,
@@ -42,7 +43,7 @@
4243
SLURMEnvironment,
4344
TorchElasticEnvironment,
4445
)
45-
from pytorch_lightning.utilities import DistributedType
46+
from pytorch_lightning.utilities import DeviceType, DistributedType
4647
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4748
from tests.helpers.boring_model import BoringModel
4849
from tests.helpers.runif import RunIf
@@ -631,6 +632,78 @@ def test_accelerator_ddp_for_cpu(tmpdir):
631632
assert isinstance(trainer.training_type_plugin, DDPPlugin)
632633

633634

635+
def test_exception_when_strategy_used_with_distributed_backend():
636+
with pytest.raises(MisconfigurationException, match="but have also passed"):
637+
Trainer(distributed_backend="ddp_cpu", strategy="ddp_spawn")
638+
639+
640+
def test_exception_when_strategy_used_with_accelerator():
641+
with pytest.raises(MisconfigurationException, match="but have also passed"):
642+
Trainer(accelerator="ddp", strategy="ddp_spawn")
643+
644+
645+
def test_exception_when_strategy_used_with_plugins():
646+
with pytest.raises(MisconfigurationException, match="only specify one training type plugin, but you have passed"):
647+
Trainer(plugins="ddp_find_unused_parameters_false", strategy="ddp_spawn")
648+
649+
650+
@pytest.mark.parametrize(
651+
["strategy", "plugin"],
652+
[
653+
("ddp_spawn", DDPSpawnPlugin),
654+
("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin),
655+
("ddp", DDPPlugin),
656+
("ddp_find_unused_parameters_false", DDPPlugin),
657+
],
658+
)
659+
def test_strategy_choice_cpu_str(tmpdir, strategy, plugin):
660+
trainer = Trainer(strategy=strategy, accelerator="cpu", devices=2)
661+
assert isinstance(trainer.training_type_plugin, plugin)
662+
663+
664+
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
665+
def test_strategy_choice_cpu_plugin(tmpdir, plugin):
666+
trainer = Trainer(strategy=plugin(), accelerator="cpu", devices=2)
667+
assert isinstance(trainer.training_type_plugin, plugin)
668+
669+
670+
@RunIf(min_gpus=2)
671+
@pytest.mark.parametrize(
672+
["strategy", "plugin"],
673+
[
674+
("ddp_spawn", DDPSpawnPlugin),
675+
("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin),
676+
("ddp", DDPPlugin),
677+
("ddp_find_unused_parameters_false", DDPPlugin),
678+
("ddp2", DDP2Plugin),
679+
("dp", DataParallelPlugin),
680+
("ddp_sharded", DDPShardedPlugin),
681+
("ddp_sharded_spawn", DDPSpawnShardedPlugin),
682+
pytest.param("deepspeed", DeepSpeedPlugin, marks=RunIf(deepspeed=True)),
683+
],
684+
)
685+
def test_strategy_choice_gpu_str(tmpdir, strategy, plugin):
686+
trainer = Trainer(strategy=strategy, accelerator="gpu", devices=2)
687+
assert isinstance(trainer.training_type_plugin, plugin)
688+
689+
690+
@RunIf(min_gpus=2)
691+
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
692+
def test_strategy_choice_gpu_plugin(tmpdir, plugin):
693+
trainer = Trainer(strategy=plugin(), accelerator="gpu", devices=2)
694+
assert isinstance(trainer.training_type_plugin, plugin)
695+
696+
697+
@RunIf(min_gpus=2)
698+
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
699+
def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin):
700+
701+
trainer = Trainer(strategy=plugin(), gpus=2)
702+
assert isinstance(trainer.training_type_plugin, plugin)
703+
assert trainer._device_type == DeviceType.GPU
704+
assert isinstance(trainer.accelerator, GPUAccelerator)
705+
706+
634707
@pytest.mark.parametrize("precision", [1, 12, "invalid"])
635708
def test_validate_precision_type(tmpdir, precision):
636709

tests/accelerators/test_ipu.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin
2525
from pytorch_lightning.trainer.states import RunningStage
2626
from pytorch_lightning.trainer.supporters import CombinedLoader
27-
from pytorch_lightning.utilities import _IPU_AVAILABLE
27+
from pytorch_lightning.utilities import _IPU_AVAILABLE, DeviceType
2828
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2929
from tests.helpers.boring_model import BoringModel
3030
from tests.helpers.datamodules import ClassifDataModule
@@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir):
120120
@RunIf(ipu=True)
121121
def test_no_warning_plugin(tmpdir):
122122
with pytest.warns(None) as record:
123-
Trainer(default_root_dir=tmpdir, plugins=IPUPlugin(training_opts=poptorch.Options()))
123+
Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options()))
124124
assert len(record) == 0
125125

126126

@@ -528,3 +528,18 @@ def test_set_devices_if_none_ipu():
528528

529529
trainer = Trainer(accelerator="ipu", ipus=8)
530530
assert trainer.devices == 8
531+
532+
533+
@RunIf(ipu=True)
534+
def test_strategy_choice_ipu_plugin(tmpdir):
535+
trainer = Trainer(strategy=IPUPlugin(), accelerator="ipu", devices=8)
536+
assert isinstance(trainer.training_type_plugin, IPUPlugin)
537+
538+
539+
@RunIf(ipu=True)
540+
def test_device_type_when_training_plugin_ipu_passed(tmpdir):
541+
542+
trainer = Trainer(strategy=IPUPlugin(), ipus=8)
543+
assert isinstance(trainer.training_type_plugin, IPUPlugin)
544+
assert trainer._device_type == DeviceType.IPU
545+
assert isinstance(trainer.accelerator, IPUAccelerator)

tests/accelerators/test_tpu_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,19 @@ def test_ddp_cpu_not_supported_on_tpus():
227227
Trainer(accelerator="ddp_cpu")
228228

229229

230+
@RunIf(tpu=True)
231+
@pytest.mark.parametrize("strategy", ["tpu_spawn", "tpu_spawn_debug"])
232+
def test_strategy_choice_tpu_str(tmpdir, strategy):
233+
trainer = Trainer(strategy=strategy, accelerator="tpu", devices=8)
234+
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
235+
236+
237+
@RunIf(tpu=True)
238+
def test_strategy_choice_tpu_plugin(tmpdir):
239+
trainer = Trainer(strategy=TPUSpawnPlugin(), accelerator="tpu", devices=8)
240+
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
241+
242+
230243
@RunIf(tpu=True)
231244
def test_auto_parameters_tying_tpus(tmpdir):
232245

tests/deprecated_api/test_remove_1-7.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,16 @@ def test_v1_7_0_deprecate_parameter_validation():
343343
from pytorch_lightning.core.decorators import parameter_validation # noqa: F401
344344

345345

346+
def test_v1_7_0_passing_strategy_to_accelerator_trainer_flag():
347+
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
348+
Trainer(accelerator="ddp_spawn")
349+
350+
351+
def test_v1_7_0_passing_strategy_to_plugins_flag():
352+
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
353+
Trainer(plugins="ddp_spawn")
354+
355+
346356
def test_v1_7_0_weights_summary_trainer(tmpdir):
347357
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"):
348358
t = Trainer(weights_summary="full")

0 commit comments

Comments
 (0)