Skip to content

Commit f9bb7c6

Browse files
SeanNarenBorda
andauthored
DeepSpeed ZeRO Docs update (#6752)
* Added base docs * Add more information * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]>
1 parent 1302766 commit f9bb7c6

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed

docs/source/advanced/multi_gpu.rst

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,169 @@ For even more speed benefit, DeepSpeed offers an optimized CPU version of ADAM c
788788
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True), precision=16)
789789
trainer.fit(model)
790790
791+
DeepSpeed ZeRO Stage 3
792+
""""""""""""""""""""""
793+
794+
DeepSpeed ZeRO Stage 3 shards the optimizer states, gradients and the model parameters (also optionally activations). Sharding model parameters and activations comes with an increase in distributed communication, however allows you to scale your models massively from one GPU to multiple GPUs.
795+
**The DeepSpeed team report the ability to fine-tune models with over 40B parameters on a single GPU and over 2 Trillion parameters on 512 GPUs.** For more information we suggest checking the `DeepSpeed ZeRO-3 Offload documentation <https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html>`__.
796+
797+
We've ran benchmarks and give a simple example of how all these features in Lightning, which you can see at `minGPT <https://github.com/SeanNaren/minGPT/tree/stage3>`_.
798+
799+
Currently this functionality is only available on master and will be included in our next 1.3 Release Candidate and 1.3 release.
800+
801+
.. code-block:: python
802+
803+
pip install https://github.com/PyTorchLightning/pytorch-lightning/archive/refs/heads/master.zip
804+
805+
806+
To reach the highest memory efficiency or model size, you must:
807+
808+
1. Use the DeepSpeed Plugin with the stage 3 parameter
809+
2. Use CPU Offloading to offload weights to CPU, plus have a reasonable amount of CPU RAM to offload onto
810+
3. Use DeepSpeed Activation Checkpointing to shard activations
811+
812+
Below we describe how to enable all of these to see benefit. **With all these improvements we reached 45 Billion parameters training a GPT model on 8 GPUs with ~1TB of CPU RAM available**.
813+
814+
Also please have a look at our :ref:`deepspeed-zero-stage-3-tips` which contains a lot of helpful information when configuring your own models.
815+
816+
.. note::
817+
Currently we only support non-elastic checkpointing. This means saving the model across GPUs will save shards of the model on all processes, which will then require the same amount of GPUS to load.
818+
This additionally means for inference you must use the ``Trainer.test` or ``Trainer.predict`` functionality as described below, to ensure we set up the distributed environment correctly.
819+
820+
This limitation is actively being worked on and will be resolved in the near future.
821+
822+
.. code-block:: python
823+
824+
from pytorch_lightning import Trainer
825+
from pytorch_lightning.plugins import DeepSpeedPlugin
826+
from deepspeed.ops.adam import FusedAdam
827+
828+
class MyModel(pl.LightningModule):
829+
...
830+
def configure_optimizers(self):
831+
return FusedAdam(self.parameters())
832+
833+
model = MyModel()
834+
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(stage=3), precision=16)
835+
trainer.fit(model)
836+
837+
trainer.test()
838+
trainer.predict()
839+
840+
841+
Shard Model Instantly to Reduce Initialization Time/Memory
842+
""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
843+
844+
When instantiating really large models, it is sometimes necessary to shard the model layers instantly.
845+
846+
This is the case if layers may not fit on one single machines CPU or GPU memory, but would fit once sharded across multiple machines.
847+
We expose a hook that layers initialized within the hook will be sharded instantly on a per layer basis, allowing you to instantly shard models.
848+
849+
This reduces the time taken to initialize very large models, as well as ensure we do not run out of memory when instantiating larger models. For more information you can refer to the DeepSpeed docs for `Constructing Massive Models <https://deepspeed.readthedocs.io/en/latest/zero3.html>`_.
850+
851+
.. note::
852+
When using ``configure_sharded_model`` hook to shard models, note that ``LightningModule.load_from_checkpoint`` for loading saved checkpoints may not work. If you've trained on one GPU, you can manually instantiate the model and call the hook,
853+
however when using multiple GPUs, this will not work as ``LightningModule.load_from_checkpoint`` doesn't support sharded checkpoints.
854+
855+
We recommend using the ``Trainer`` and using ``Trainer.test`` or ``Trainer.predict`` for inference.
856+
857+
.. code-block:: python
858+
859+
from pytorch_lightning import Trainer
860+
from pytorch_lightning.plugins import DeepSpeedPlugin
861+
from deepspeed.ops.adam import FusedAdam
862+
863+
class MyModel(pl.LightningModule):
864+
...
865+
def configure_sharded_model(self):
866+
# Created within sharded model context, modules are instantly sharded across processes
867+
# as soon as they are made.
868+
self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
869+
870+
def configure_optimizers(self):
871+
return FusedAdam(self.parameters())
872+
873+
model = MyModel()
874+
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(stage=3), precision=16)
875+
trainer.fit(model)
876+
877+
trainer.test()
878+
trainer.predict()
879+
880+
881+
DeepSpeed ZeRO Stage 3 Offload
882+
""""""""""""""""""""""""""""""
883+
884+
DeepSpeed ZeRO Stage 3 Offloads optimizer state, gradients to the host CPU to reduce memory usage as ZeRO Stage 2 does, however additionally allows you to offload the parameters as well for even more memory saving.
885+
886+
.. code-block:: python
887+
888+
from pytorch_lightning import Trainer
889+
from pytorch_lightning.plugins import DeepSpeedPlugin
890+
891+
# Enable CPU Offloading
892+
model = MyModel()
893+
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(stage=3, cpu_offload=True), precision=16)
894+
trainer.fit(model)
895+
896+
# Enable CPU Offloading, and offload parameters as well to CPU when possible
897+
model = MyModel()
898+
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(stage=3, cpu_offload=True, cpu_offload_params=True), precision=16)
899+
trainer.fit(model)
900+
901+
902+
DeepSpeed Activation Checkpointing
903+
""""""""""""""""""""""""""""""""""
904+
905+
Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass.
906+
They are then re-computed for the backwards pass as needed.
907+
908+
This saves memory when training larger models however requires using a checkpoint function to run the module as shown below.
909+
910+
.. code-block:: python
911+
912+
from pytorch_lightning import Trainer
913+
from pytorch_lightning.plugins import DeepSpeedPlugin
914+
import deepspeed
915+
916+
917+
class MyModel(pl.LightningModule):
918+
...
919+
920+
def configure_sharded_model(self):
921+
self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
922+
923+
def forward(self, x):
924+
# Use the DeepSpeed checkpointing function instead of calling the module directly
925+
output = deepspeed.checkpointing.checkpoint(self.block, x)
926+
return output
927+
928+
929+
model = MyModel()
930+
trainer = Trainer(
931+
gpus=4,
932+
plugins=DeepSpeedPlugin(
933+
stage=3,
934+
cpu_offload=True, # Enable CPU Offloading
935+
partition_activations=True, # Optionally move activations to CPU if you have enough memory
936+
cpu_checkpointing=True # Optionally Partition activations across machines
937+
),
938+
precision=16
939+
)
940+
trainer.fit(model)
941+
942+
943+
.. _deepspeed-zero-stage-3-tips:
944+
945+
DeepSpeed ZeRO Stage 3 Tips
946+
"""""""""""""""""""""""""""
947+
948+
Here are some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lightning.
949+
950+
* If you're using Adam or AdamW, ensure to use FusedAdam or DeepSpeedCPUAdam (for CPU Offloading) rather than the default torch optimizers as they come with large speed benefits
951+
* Treat your GPU/CPU memory as one large pool. In some cases, you may not want to offload certain things (like activations) to provide even more space to offload model parameters
952+
* When offloading to the CPU, make sure to bump up the batch size as GPU memory will be freed
953+
791954

792955
Custom DeepSpeed Config
793956
"""""""""""""""""""""""

0 commit comments

Comments
 (0)