Skip to content

Commit 4df060b

Browse files
ArvinZhuangBorda
authored andcommitted
Match the number of outputs of backward with forward for AllGatherGrad (#6625)
1 parent 836d02a commit 4df060b

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

pytorch_lightning/utilities/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def backward(ctx, *grad_output):
187187

188188
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
189189

190-
return grad_output[torch.distributed.get_rank()]
190+
return grad_output[torch.distributed.get_rank()], None
191191

192192

193193
def all_gather_ddp_if_available(

tests/utilities/test_all_gather_grad.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,26 @@ def training_epoch_end(self, outputs) -> None:
9696

9797
trainer.fit(model)
9898
assert model.training_epoch_end_called
99+
100+
101+
@RunIf(min_gpus=2, skip_windows=True, special=True)
102+
def test_all_gather_sync_grads(tmpdir):
103+
104+
class TestModel(BoringModel):
105+
106+
training_step_called = False
107+
108+
def training_step(self, batch, batch_idx):
109+
self.training_step_called = True
110+
tensor = torch.rand(2, 2, requires_grad=True, device=self.device)
111+
gathered_tensor = self.all_gather(tensor, sync_grads=True)
112+
assert gathered_tensor.shape == torch.Size([2, 2, 2])
113+
114+
loss = gathered_tensor.sum()
115+
116+
return loss
117+
118+
model = TestModel()
119+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2)
120+
trainer.fit(model)
121+
assert model.training_step_called

0 commit comments

Comments
 (0)