Skip to content

Commit 9c6979a

Browse files
Yi Wangfacebook-github-bot
authored andcommitted
[Gradient Compression] Error feedback for PowerSGD (still need to fix the key in error_dict) (#48670)
Summary: Pull Request resolved: #48670 Support an optional error feedback for PowerSGD -- storing the difference (i.e., the local error caused by compression) between the input gradient (adjusted by the existing error) and the gradient after decompression, and reinserting it at the next iteration. Still need to add an index field to GradBucket as the key of error_dict. This is because the current key, input tensor of the bucket, can change across steps, as the buckets may be rebuilt in forward pass in order to save peak memory usage. This is halfway of error feedback. Plan to add the new index field in a separate PR. Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202 ghstack-source-id: 117636492 Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl Reviewed By: rohan-varma Differential Revision: D25240290 fbshipit-source-id: 5b6e11e711caccfb8984ac2767dd107dbf4c9b3b
1 parent 463e5d2 commit 9c6979a

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

torch/distributed/algorithms/ddp_comm_hooks/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ def _ddp_comm_hook_wrapper(comm_hook, model, state):
1616

1717

1818
def _powerSGD_comm_hook_wrapper(
19-
comm_hook, model, state, matrix_approximation_rank, random_seed=0
19+
comm_hook,
20+
model,
21+
state,
22+
matrix_approximation_rank,
23+
use_error_feedback=True,
24+
random_seed=0,
2025
):
2126
"""
2227
To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
@@ -25,6 +30,7 @@ def _powerSGD_comm_hook_wrapper(
2530
powerSGD_state = powerSGD.PowerSGDState(
2631
process_group=state,
2732
matrix_approximation_rank=matrix_approximation_rank,
33+
use_error_feedback=use_error_feedback,
2834
random_seed=random_seed,
2935
)
3036
model.register_comm_hook(powerSGD_state, comm_hook)

torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,44 @@ def _orthogonalize(matrix, epsilon=1e-8):
3030

3131

3232
class PowerSGDState(object):
33-
__slots__ = ["process_group", "matrix_approximation_rank", "rng"]
34-
35-
def __init__(self, process_group, matrix_approximation_rank=1, random_seed=0):
33+
__slots__ = [
34+
"process_group",
35+
"matrix_approximation_rank",
36+
"use_error_feedback",
37+
"rng",
38+
"error_dict",
39+
]
40+
41+
def __init__(
42+
self,
43+
process_group,
44+
matrix_approximation_rank=1,
45+
use_error_feedback=True,
46+
random_seed=0,
47+
):
3648
self.process_group = process_group
3749
self.matrix_approximation_rank = matrix_approximation_rank
50+
# Error feedback is usually crucial for both for convergence and generalization,
51+
# because PowerSGD is a biased compressor,
52+
# i.e., compressing and decompressing a random gradient does not yield the original in expectation.
53+
# This mechanism requires a temporary copy of the input gradients,
54+
# so it increases the peak memory consumption by the size of gradient tensor.
55+
# However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank),
56+
# sometimes it is possible to converge to the optima without error feedback.
57+
# See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf
58+
self.use_error_feedback = use_error_feedback
3859
# The purpose of this RNG is to generate different random seeds for initializing Q across iterations,
3960
# but in the same order for all the DDP replicas.
4061
# Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps.
4162
# If the same random projection is used,
4263
# there will be differences between the gradients that are never synchronized.
4364
self.rng = np.random.RandomState(random_seed)
65+
# Since there is only a single state instance for all the input buckets,
66+
# need to maintain a dictionary that maps each bucket to the local error.
67+
# TODO(wayi): Currently the key is the (hashcode of) input tensor, which may change across steps,
68+
# since the bucket can be rebuilt in the forward pass (to save peak memory usage).
69+
# Need to add an index field to the input bucket of comm hook.
70+
self.error_dict = {}
4471

4572

4673
def powerSGD_hook(
@@ -98,6 +125,17 @@ def powerSGD_hook(
98125
padded_total_length = square_side_length ** 2
99126
input_tensor.resize_(padded_total_length)
100127
input_tensor[total_length:padded_total_length].fill_(0)
128+
129+
# Incorporate the error from the previous state into the gradients.
130+
if state.use_error_feedback:
131+
if input_tensor in state.error_dict:
132+
input_tensor.add_(state.error_dict[input_tensor])
133+
else:
134+
state.error_dict[input_tensor] = torch.zeros(padded_total_length, device=device)
135+
# Keep a copy of the input tensor,
136+
# so that we can compute the local error caused by compression later,
137+
# by comparing this copy and the input tensor updated after decompression.
138+
input_tensor_cp = torch.clone(input_tensor).detach()
101139
matrix = input_tensor.view(square_side_length, square_side_length)
102140

103141
def create_low_rank_tensor(fill_random_values, rng):
@@ -141,6 +179,9 @@ def decompress(fut):
141179
q = fut.value()[0].div_(world_size)
142180
torch.matmul(p, q.t(), out=matrix)
143181

182+
if state.use_error_feedback:
183+
# Memorize the local errors.
184+
state.error_dict[input_tensor] = input_tensor_cp - input_tensor
144185
ret = input_tensor.resize_(total_length)
145186
return [ret]
146187

0 commit comments

Comments
 (0)