Skip to content

Commit 70e31b6

Browse files
awaelchlicarmocca
andauthored
Make all_reduce consistent for both NCCL and GLOO (#18235)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 27d9125 commit 70e31b6

File tree

6 files changed

+55
-17
lines changed

6 files changed

+55
-17
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
192192
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))
193193

194194

195+
- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235))
196+
195197

196198
## [2.0.5] - 2023-07-07
197199

src/lightning/fabric/fabric.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,11 +557,13 @@ def all_reduce(
557557
) -> Union[Tensor, Dict, List, Tuple]:
558558
"""Reduce tensors or collections of tensors from multiple processes.
559559
560+
The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor.
560561
This method needs to be called on all processes and the tensors need to have the same shape across all
561562
processes, otherwise your program will stall forever.
562563
563564
Args:
564-
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
565+
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. Tensor will be
566+
modified in-place.
565567
group: the process group to reduce results across. Defaults to all processes (world).
566568
reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp.
567569
Some strategies may limit the choices here.

src/lightning/fabric/utilities/distributed.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def _sync_ddp_if_available(
107107

108108

109109
def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
110-
"""Function to reduce the tensors from several DDP processes to one main process.
110+
"""Reduces a tensor across several distributed processes.
111+
112+
This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes.
111113
112114
Args:
113115
result: The value to sync and reduce (typically tensor or number)
@@ -116,25 +118,26 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
116118
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
117119
118120
Return:
119-
reduced value
121+
The reduced value.
120122
121123
"""
122124
divide_by_world_size = False
123-
124-
if group is None:
125-
group = torch.distributed.group.WORLD
125+
group = torch.distributed.group.WORLD if group is None else group
126126

127127
op: Optional[ReduceOp]
128128
if isinstance(reduce_op, str):
129-
if reduce_op.lower() in ("avg", "mean"):
129+
reduce_op = "avg" if reduce_op == "mean" else reduce_op
130+
if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo":
131+
# The GLOO backend does not support the `ReduceOp.AVG` operation
130132
op = ReduceOp.SUM # type: ignore[assignment]
131133
divide_by_world_size = True
132134
else:
133135
op = getattr(ReduceOp, reduce_op.upper())
134136
else:
135137
op = reduce_op
136138

137-
# WA for HPU. HPU doesn't support Long types, forcefully set it to float
139+
# HPU doesn't support Long types, forcefully set it to float
140+
# TODO: move this to the `lightning_habana` package
138141
if (
139142
package_available("habana_frameworks")
140143
and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
@@ -150,11 +153,15 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
150153
# Sync all processes before reduction
151154
torch.distributed.barrier(group=group)
152155
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
156+
world_size = torch.distributed.get_world_size(group)
153157

154-
if divide_by_world_size:
155-
result = result / torch.distributed.get_world_size(group)
156-
157-
return result
158+
if not divide_by_world_size:
159+
return result
160+
# `torch.distributed.all_reduce` is in-place, so we should do the division in-place to leave the modified tensors
161+
# with the expected value
162+
if not torch.is_floating_point(result):
163+
return result.copy_(result / world_size)
164+
return result.div_(world_size)
158165

159166

160167
class _AllGather(torch.autograd.Function):

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def update(self, value: _VALUE, batch_size: int) -> None:
237237

238238
def compute(self) -> Tensor:
239239
if self.is_tensor:
240-
value = self.meta.sync(self.value)
240+
value = self.meta.sync(self.value.clone()) # `clone` because `sync` is in-place
241241
if self.meta.is_mean_reduction:
242242
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
243243
return value / cumulated_batch_size

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from lightning.fabric.plugins.environments import LightningEnvironment
88
from lightning.fabric.strategies import DDPStrategy
99
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
10-
from lightning.fabric.utilities.distributed import _gather_all_tensors
10+
from lightning.fabric.utilities.distributed import _gather_all_tensors, _sync_ddp
1111
from tests_fabric.helpers.runif import RunIf
1212

1313

@@ -62,20 +62,47 @@ def _test_all_gather_uneven_tensors_multidim(strategy):
6262
assert (val == torch.ones_like(val)).all()
6363

6464

65+
def _test_all_reduce(strategy):
66+
rank = strategy.local_rank
67+
device = strategy.root_device
68+
world_size = strategy.num_processes
69+
70+
for dtype in (torch.long, torch.int, torch.float, torch.half):
71+
# max
72+
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
73+
expected = torch.tensor(2, device=device, dtype=dtype)
74+
result = _sync_ddp(tensor, reduce_op="max")
75+
assert torch.equal(result, expected)
76+
assert result is tensor # inplace
77+
# sum
78+
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
79+
expected = torch.tensor(sum(range(1, world_size + 1)), device=device, dtype=dtype)
80+
result = _sync_ddp(tensor, reduce_op="sum")
81+
assert torch.equal(result, expected)
82+
assert result is tensor # inplace
83+
# average
84+
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
85+
expected = torch.tensor(sum(range(1, world_size + 1)) / 2, device=device, dtype=dtype)
86+
result = _sync_ddp(tensor, reduce_op="avg")
87+
assert torch.equal(result, expected)
88+
assert result is tensor # inplace
89+
90+
6591
@RunIf(skip_windows=True)
6692
@pytest.mark.parametrize(
6793
"process",
6894
[
6995
_test_all_gather_uneven_tensors_multidim,
7096
_test_all_gather_uneven_tensors,
97+
_test_all_reduce,
7198
],
7299
)
73100
@pytest.mark.parametrize(
74101
"devices",
75102
[
76103
pytest.param([torch.device("cuda:0"), torch.device("cuda:1")], marks=RunIf(min_cuda_gpus=2)),
77-
[torch.device("cpu")] * 2,
104+
[torch.device("cpu"), torch.device("cpu")],
78105
],
79106
)
80-
def test_gather_all_tensors(devices, process):
107+
def test_collective_operations(devices, process):
81108
spawn_launch(process, devices)

tests/tests_pytorch/core/test_metric_result_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def on_train_epoch_end(self) -> None:
356356
assert metrics["callback"]["tracking"] == expected
357357
assert computed_value == 2
358358

359-
assert self.results["training_step.tracking_2"].value == total * devices
359+
assert self.results["training_step.tracking_2"].value == total
360360
assert metrics["callback"]["tracking_2"] == expected
361361
assert computed_value == 2
362362
self.has_validated_sum = True

0 commit comments

Comments
 (0)