-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
After the accelerator refactor, we've added some nicety around the TrainingTypePlugin, which now represents all logic that sits on top of an accelerator. However our API seems to be based on the old paradigm where the accelerator/distributed backend were tied, and has bit us already #6089
What I suggest is we do the following:
# Correct
Trainer(gpus=1, accelerator='ddp') # this is fine. should be set to ddp
Trainer(gpus=1, plugins='ddp_sharded', accelerator='gpu')
Trainer(gpus=1, plugins='ddp_sharded') # implicit. assume GPU for ddp_sharded as it is the only supported accelerator
# Incorrect
Trainer(gpus=1, accelerator='ddp_sharded') # throw misconfiguration exception
Trainer(gpus=1, plugins='ddp_sharded', accelerator='ddp') # this should throw a deprecation warning
# The Future
# Plugins can be TrainingTypePlugins, PrecisionPlugins, or whatever the user has provided.
# For precision, we don't do `Trainer(plugins='16bit_precision')` but instead `Trainer(precision=16)`
# so why do we use the plugins flag for the training type?
Trainer(training_type='ddp', accelerator='gpu')
Trainer(training_type='ddp_spawn', accelerator='cpu')
# so Trainer(plugins=...) is left for cluster_environments, SLURM, ...Incompatible Plugin/Accelerator
This is something I haven't hashed out but let's say the user does:
Trainer(training_type='ddp_sharded', accelerator='tpu')This currently is not compatible. We should throw an exception of some sort, but this delves into whitelisting/blacklisting support plugins which could get a bit unwieldy. An alternative is to just assume that the user knows enough about compatibility already.
Backwards compatibility
We allow the TrainingTypePlugin to be specified via the accelerator trainer flag but throw a warning that this will be deprecated in the future.
cc @ananthsub @carmocca @awaelchli @justusschock @tchaton @Borda @kaushikb11 @williamFalcon