-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Supporting Adding DDP Communication Hooks #6736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
89f284d
80cfbff
536c132
f172101
bf70e43
ea74906
a9aae99
70fe5da
0d23d75
ca6f98b
c5053da
9d4a2b8
7635b4f
d64f90c
dcdcd29
8651d54
15f4b9e
250d0aa
6c095b2
8222dc9
3a9fde9
7a369f4
b4a0b9e
5cf1db1
0ce7e05
fe9736d
c314ef6
c3feda0
c759477
7a8e540
ab8b849
4e67db2
67b6188
1e41d5b
6833b87
f856d31
14a0a1b
8998469
a17947b
91a945a
78c6925
443f223
f8d0603
f06285f
b607ebd
6cc9dfa
b12a16b
25ccb82
35d49bc
dc5c55c
fb184b2
bf44378
d529985
b8105be
e32a11d
2275b45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,9 +17,15 @@ | |
import warnings | ||
from functools import partial, wraps | ||
from typing import Any, Optional, Union | ||
from pytorch_lightning.utilities.imports import ( | ||
_TORCH_GREATER_EQUAL_1_8, | ||
_TORCH_GREATER_EQUAL_1_9, | ||
) | ||
|
||
import torch | ||
|
||
from torch.nn.parallel.distributed import DistributedDataParallel | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
if torch.distributed.is_available(): | ||
|
@@ -198,3 +204,107 @@ def all_gather_ddp_if_available( | |
with torch.no_grad(): | ||
return AllGatherGrad.apply(tensor, group) | ||
return tensor | ||
|
||
|
||
def register_ddp_comm_hook( | ||
model: DistributedDataParallel, | ||
ddp_comm_state: Optional[object] = None, | ||
ddp_comm_hook: Optional[callable] = None, | ||
ddp_comm_wrapper: Optional[callable] = None, | ||
) -> None: | ||
""" | ||
Function to register communication hook for DDP model | ||
https://pytorch.org/docs/master/ddp_comm_hooks.html | ||
|
||
Args: | ||
model: | ||
DDP model | ||
ddp_comm_state: | ||
state is passed to the hook and can be used to maintain | ||
and update any state information that users would like to | ||
maintain as part of the training process. Examples: error | ||
feedback in gradient compression, peers to communicate with | ||
next in GossipGrad etc. | ||
ddp_comm_hook: | ||
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future | ||
|
||
This callable function is called once the bucket is ready. The | ||
hook can perform whatever processing is needed and return | ||
a Future indicating completion of any async work (ex: allreduce). | ||
If the hook doesn't perform any communication, it can also | ||
just return a completed Future. The Future should hold the | ||
new value of grad bucket's tensors. Once a bucket is ready, | ||
c10d reducer would call this hook and use the tensors returned | ||
by the Future and copy grads to individual parameters. | ||
ddp_comm_wrapper: | ||
communication hook wraper to support a communication hook such | ||
as FP16 compression as wrapper, which could be combined with | ||
ddp_comm_hook | ||
|
||
.. warning :: | ||
DDP communication hook needs pytorch version at least 1.8.0 | ||
|
||
.. warning :: | ||
DDP communication wrapper needs pytorch version at least 1.9.0 | ||
|
||
Example: | ||
|
||
from torch.distributed.algorithms.ddp_comm_hooks import ( | ||
default_hooks as default, | ||
powerSGD_hook as powerSGD, | ||
) | ||
|
||
# fp16_compress_hook for compress gradients | ||
register_ddp_comm_hook( | ||
model=ddp_model, | ||
ddp_comm_hook=default.fp16_compress_hook, | ||
) | ||
|
||
# powerSGD_hook | ||
register_ddp_comm_hook( | ||
model=ddp_model, | ||
ddp_comm_state=powerSGD.PowerSGDState( | ||
process_group=None, | ||
matrix_approximation_rank=1, | ||
start_powerSGD_iter=5000, | ||
), | ||
ddp_comm_hook=powerSGD.powerSGD_hook, | ||
) | ||
|
||
# fp16_compress_wrapper combined with other communication hook | ||
register_ddp_comm_hook( | ||
model=ddp_model, | ||
ddp_comm_state=powerSGD.PowerSGDState( | ||
process_group=None, | ||
matrix_approximation_rank=1, | ||
start_powerSGD_iter=5000, | ||
), | ||
ddp_comm_hook=powerSGD.powerSGD_hook, | ||
ddp_comm_wrapper=default.fp16_compress_wrapper, | ||
) | ||
""" | ||
if not _TORCH_GREATER_EQUAL_1_8: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically it's also available in 1.7.0 right? But protected with an underscore. Do we want to include it or were important improvements done from 1.7 to 1.8? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. encountered import issue when I tried to import also, power SGD is introduced later |
||
rank_zero_warn( | ||
"Not registering DDP comm hook. " | ||
"To use communication hooks, please use pytorch>=1.8.0." | ||
) | ||
return | ||
if ddp_comm_hook is None: | ||
return | ||
if ddp_comm_wrapper is not None: | ||
if not _TORCH_GREATER_EQUAL_1_9: | ||
rank_zero_warn( | ||
"Not applying DDP comm wrapper. " | ||
"To use communication wrapper, please use pytorch>=1.9.0." | ||
) | ||
else: | ||
rank_zero_info( | ||
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." | ||
) | ||
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) | ||
|
||
rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") | ||
model.register_comm_hook( | ||
state=ddp_comm_state, | ||
hook=ddp_comm_hook, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.