Skip to content

DeepSpeed Integration #5954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 68 commits into from
Feb 17, 2021
Merged

DeepSpeed Integration #5954

merged 68 commits into from
Feb 17, 2021

Conversation

SeanNaren
Copy link
Contributor

@SeanNaren SeanNaren commented Feb 13, 2021

What does this PR do?

Closes #817.

Allows users to enable DeepSpeed training type plugin. Requires some user constraints when training, as this library is built to be more research focused.

The API:

from pytorch_lightning import Trainer

model = MyModel()
trainer = Trainer(gpus=4, plugins='deepspeed', precision=16) # default enables ZeRO optimization/offload
trainer.fit(model)

Using config:

from pytorch_lightning import Trainer

model = MyModel()
trainer = Trainer(accelerator='deepspeed', gpus=4, deepspeed_config="/path/to/deepspeed_config.json", precision=16) # zero offload requires mixed precision
trainer.fit(model)

Or via config object:

from pytorch_lightning import Trainer

deepspeed_config = {
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 3e-5,
            "betas": [0.998, 0.999],
            "eps": 1e-5,
            "weight_decay": 1e-9,
        },
    },
    'scheduler': {
        "type": "WarmupLR",
        "params": {
            "last_batch_iteration": -1,
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-5,
            "warmup_num_steps": 100,
        }
    },
    "zero_optimization": {
        "stage": 2,
        "cpu_offload": True,
        "contiguous_gradients": True,
        "overlap_comm": True
    }
}

model = MyModel()
trainer = Trainer(accelerator='deepspeed', gpus=4, deepspeed_config=deepspeed_config, precision=16) # zero offload requires mixed precision
trainer.fit(model)

Limitations

  • The largest limitation is that currently we have to define the optimizer/scheduler within the configs for the DeepSpeed engine to initialise, hence initialisation of Optimisers/schedulers via configure_optimizers are ignored. This needs to be made clear within the README we now support configure optimizers with 1 optimizer/scheduler! and deepspeed config options :)
  • A limitation of the current lightning accelerator API means the precision plugin needs to contain logic for the loop even if precision is handled within the DeepSpeed plugin. This is temporary hopefully till we decide on where the precision logic should live.

Performance

  • I have tested across large transformer models, similar to this and this which showed really in depth breakdowns of DeepSpeed and have replicated similar results in a different training environment using the DeepSpeed plugin. Both also highlight the fin-nicking with parameters needed to get the most optimum performance, so I'll be adding some references to these great pieces of info + bring some of the information in our docs to highlight this!

Still need to Address

  • Currently we do not re-initialize the deepspeed engine when we do save/load. This means some of the state variables are not saved/reloaded for training. This means currently resume_from_checkpoint isn't supported, and a note has to be made in the docs for this.
  • Currently due to using the latest DeepSpeed release, AMD processors with 1-bit ADAM will segfault, with a fix already being worked on here. This is one of the key pieces to the memory reduction, so I'll be keeping an eye out on this, and adding info in the docs in the meantime. Also there are some nice helper functions to reduce the logging which have been merged into DeepSpeed master, but not been made into a release yet!
  • Ensure test function works, few issues where this crashes (and it might be due to an allocation of more memory)

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified
  • Check that target branch and milestone match!

@SeanNaren SeanNaren added the feature Is an improvement or enhancement label Feb 13, 2021
@SeanNaren SeanNaren added this to the 1.2 milestone Feb 13, 2021
@SeanNaren SeanNaren self-assigned this Feb 13, 2021
@SeanNaren
Copy link
Contributor Author

The other thing I forgot to mention is the requirement for mpi4py is needed till a new release of DeepSpeed is made, and may work out the box for CI, I'm not sure.

@codecov
Copy link

codecov bot commented Feb 13, 2021

Codecov Report

Merging #5954 (981e735) into master (e0bb33c) will decrease coverage by 0%.
The diff coverage is 97%.

@@          Coverage Diff           @@
##           master   #5954   +/-   ##
======================================
- Coverage      93%     93%   -0%     
======================================
  Files         160     160           
  Lines       11343   11371   +28     
======================================
+ Hits        10554   10557    +3     
- Misses        789     814   +25     

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome addition !

"Within the DeepSpeed config, do not set gradient_accumulation_steps "
"as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer."
)
self.config["train_micro_batch_size_per_gpu"] = self.lightning_module.train_dataloader().batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the model doesn't a train_dataloader as it will be attached by the datamodule ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how I'm testing now, using a datamodule. I think internally the function also caches the train_dataloader which means we don't create it twice

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works, if not batchsize but directly a batchsampler was provided to the loader, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic can actually be omitted as its only used for timer purposes it seems

Copy link
Contributor Author

@SeanNaren SeanNaren Feb 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic can actually be omitted as its only used for timer purposes it seems

This unfortunately cannot be omitted, there are some assertions internally that rely on this being set, even if it's just for throughput calculation.

I've added a comment to highlight that this default may be incorrect for certain uses that use a BatchSampler. I think for now this is acceptable considering that the DeepSpeed info messages that are printed are suppressed unless the user enables them.

To address this long term, we can make the change in the DeepSpeed repo to make this parameter optional for the DeepSpeedEngine

"Within the DeepSpeed config, do not set gradient_accumulation_steps "
"as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer."
)
self.config["train_micro_batch_size_per_gpu"] = self.lightning_module.train_dataloader().batch_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works, if not batchsize but directly a batchsampler was provided to the loader, right?

def batch_to(data):
return data.half()

def _move_float_tensors_to_half(self, batch: Any):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be a staticmethod.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, not a huge issue however but I think this function will eventually be useful for other accelerators.

precision = self.lightning_module.trainer.accelerator_backend.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)

if self.lightning_module.trainer.training:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart !

"""
Test to ensure that the plugin can be passed via a string with an environment variable.
"""
config_path = os.path.join(tmpdir, 'temp.json')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart !

"You have not specified an optimizer or scheduler within the DeepSpeed config."
"Using `configure_optimizers` to define optimizer and scheduler."
)
optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it !


def _initialize_deepspeed_train(self, model):
optimizer, lightning_scheduler, optimizer_frequencies = None, None, None
if "optimizer" not in self.config:
Copy link
Contributor

@tchaton tchaton Feb 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the user specify scheduler and not the optimizer (we make choose the one from the config by default) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they could!

I plan to do a few followup PRs to ease DeepSpeed integration in these cases, these are not super essential but very valid points :)

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work !

@mergify mergify bot removed the has conflicts label Feb 17, 2021
@SeanNaren SeanNaren enabled auto-merge (squash) February 17, 2021 18:49
Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what a beast of a plugin!

@tchaton tchaton added the _Will label Feb 17, 2021
@Borda
Copy link
Member

Borda commented Feb 17, 2021

@SeanNaren seems to be missing chlog

reduce_bucket_size: int = 2e8,
zero_allow_untested_optimizer: bool = True,
config: Optional[Union[Path, str, dict]] = None,
logging_level: int = logging.WARN,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would you need to separate logging level, shall be default the global level?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of messages out from DeepSpeed, this helps to surpress some of their logging messages, but the user can enable them should they wish!

distributed_backend = "deepspeed"
DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"

def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these default for most models or just very large?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, its set for large models, but it's going to be slow without some tuning.

Comment on lines +161 to +167
if os.path.exists(config):
with open(config) as f:
config = json.load(f)
else:
raise MisconfigurationException(
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if os.path.exists(config):
with open(config) as f:
config = json.load(f)
else:
raise MisconfigurationException(
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
)
if not os.path.isfile(config):
raise MisconfigurationException(
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
)
with open(config) as f:
config = json.load(f)

optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
self.lightning_module
)
if (len(optimizers) != 1) or len(schedulers) > 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (len(optimizers) != 1) or len(schedulers) > 1:
if len(optimizers) > 1 or len(schedulers) > 1:

# set optimizer for save/load, but deepspeed manages the specific optimizer logic
trainer = self.lightning_module.trainer
trainer.optimizers = [optimizer]
self.model = model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the diff between self.model and self._model bellow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I think this is an artifact that should be fixed in the ddp.py file as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, early on in the refactor we didn't have a setter yet, so we referred to _model and this seems to be a leftover :)

HorovodPlugin,
NativeMixedPrecisionPlugin,
Plugin,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to expose this one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add deepspeed support
8 participants