Skip to content

Commit b8ef52b

Browse files
authored
Match the number of outputs of backward with forward for AllGatherGrad (#6625)
1 parent 2dd6f9e commit b8ef52b

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

pytorch_lightning/utilities/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def backward(ctx, *grad_output):
172172

173173
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
174174

175-
return grad_output[torch.distributed.get_rank()]
175+
return grad_output[torch.distributed.get_rank()], None
176176

177177

178178
def all_gather_ddp_if_available(

tests/utilities/test_all_gather_grad.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,26 @@ def training_epoch_end(self, outputs) -> None:
9595

9696
trainer.fit(model)
9797
assert model.training_epoch_end_called
98+
99+
100+
@RunIf(min_gpus=2, skip_windows=True, special=True)
101+
def test_all_gather_sync_grads(tmpdir):
102+
103+
class TestModel(BoringModel):
104+
105+
training_step_called = False
106+
107+
def training_step(self, batch, batch_idx):
108+
self.training_step_called = True
109+
tensor = torch.rand(2, 2, requires_grad=True, device=self.device)
110+
gathered_tensor = self.all_gather(tensor, sync_grads=True)
111+
assert gathered_tensor.shape == torch.Size([2, 2, 2])
112+
113+
loss = gathered_tensor.sum()
114+
115+
return loss
116+
117+
model = TestModel()
118+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2)
119+
trainer.fit(model)
120+
assert model.training_step_called

0 commit comments

Comments
 (0)