Skip to content

Fabric.all_reduce modifies the input value inplace on GPU #18228

@function2-llx

Description

@function2-llx

Bug description

When calling Fabric.all_reduce on tensors on GPU, it will modify the input value to the sum of all values across processes..

What version are you seeing the problem on?

v2.0

How to reproduce the bug

Run the following code with multiple (e.g., 4) GPUs:

from lightning import Fabric
import torch

def main():
    torch.set_default_device('cuda')
    fabric = Fabric()
    fabric.launch()
    x = torch.tensor(fabric.global_rank).float()
    rx = fabric.all_reduce(x)
    for i in range(fabric.world_size):
        if i == fabric.global_rank:
            print('rank:', i)
            print('input:', x)
            print('reduce:', rx)
        fabric.barrier()

if __name__ == '__main__':
    main()

Error messages and logs

output:

rank: 0
input: tensor(6., device='cuda:0')
reduce: tensor(1.5000, device='cuda:0')
rank: 1
input: tensor(6., device='cuda:1')
reduce: tensor(1.5000, device='cuda:1')
rank: 2
input: tensor(6., device='cuda:2')
reduce: tensor(1.5000, device='cuda:2')
rank: 3
input: tensor(6., device='cuda:3')
reduce: tensor(1.5000, device='cuda:3')

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA A100 80GB PCIe
      • NVIDIA A100 80GB PCIe
      • NVIDIA A100 80GB PCIe
      • NVIDIA A100 80GB PCIe
    • available: True
    • version: 11.8
  • Lightning:
    • lightning: 2.0.6
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.9.0
    • pytorch-lightning: 2.0.5
    • torch: 2.0.1
    • torchmetrics: 0.11.4
    • torchvision: 0.15.2
  • Packages:
    • addict: 2.4.0
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.7.1
    • appdirs: 1.4.4
    • arrow: 1.2.3
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • backoff: 2.2.1
    • beautifulsoup4: 4.12.2
    • blessed: 1.20.0
    • brotli: 1.0.9
    • certifi: 2023.7.22
    • charset-normalizer: 3.2.0
    • click: 8.1.4
    • colorama: 0.4.6
    • contourpy: 1.1.0
    • croniter: 1.4.1
    • cycler: 0.11.0
    • cytoolz: 0.12.1
    • dataclasses: 0.8
    • datasets: 2.13.1
    • dateutils: 0.6.12
    • deepdiff: 6.3.1
    • dill: 0.3.6
    • docker-pycreds: 0.4.0
    • docstring-parser: 0.15
    • einops: 0.6.1
    • et-xmlfile: 1.1.0
    • fastapi: 0.100.0
    • filelock: 3.12.2
    • fonttools: 4.41.0
    • frozenlist: 1.3.3
    • fsspec: 2023.6.0
    • gitdb: 4.0.10
    • gitpython: 3.1.32
    • gmpy2: 2.1.2
    • h11: 0.14.0
    • huggingface-hub: 0.16.2
    • idna: 3.4
    • importlib-metadata: 6.8.0
    • importlib-resources: 6.0.0
    • inquirer: 3.1.3
    • itk-core: 5.3.0
    • itk-filtering: 5.3.0
    • itk-io: 5.3.0
    • itk-numerics: 5.3.0
    • itk-registration: 5.3.0
    • itk-segmentation: 5.3.0
    • itsdangerous: 2.1.2
    • jinja2: 3.1.2
    • joblib: 1.3.0
    • jsonargparse: 4.23.0
    • kiwisolver: 1.4.4
    • lightning: 2.0.6
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.9.0
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib: 3.7.2
    • mdurl: 0.1.2
    • mmcv-full: 1.7.1
    • monai: 1.2.0+78.g71aaa2259
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • multiprocess: 0.70.14
    • munkres: 1.1.4
    • mypy-extensions: 1.0.0
    • networkx: 3.1
    • nibabel: 5.1.0
    • nptyping: 2.5.0
    • numpy: 1.25.1
    • omegaconf: 2.3.0
    • opencv-python: 4.8.0.74
    • openpyxl: 3.1.2
    • ordered-set: 4.1.0
    • packaging: 23.1
    • pandas: 2.0.3
    • pathtools: 0.1.2
    • pillow: 9.4.0
    • pip: 23.1.2
    • platformdirs: 3.8.1
    • ply: 3.11
    • pooch: 1.7.0
    • protobuf: 4.23.3
    • psutil: 5.9.5
    • pyarrow: 12.0.1
    • pydantic: 1.10.11
    • pydicom: 2.4.1
    • pygments: 2.15.1
    • pyjwt: 2.7.0
    • pynrrd: 1.0.0
    • pyparsing: 3.0.9
    • pyqt5: 5.15.7
    • pyqt5-sip: 12.11.0
    • pyre-extensions: 0.0.29
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.5
    • pytz: 2023.3
    • pyyaml: 6.0
    • readchar: 4.0.5
    • regex: 2023.6.3
    • requests: 2.31.0
    • responses: 0.18.0
    • rich: 13.4.2
    • sacremoses: 0.0.53
    • safetensors: 0.3.1
    • scipy: 1.11.1
    • sentry-sdk: 1.28.0
    • setproctitle: 1.3.2
    • setuptools: 68.0.0
    • sip: 6.7.9
    • six: 1.16.0
    • smmap: 3.0.5
    • sniffio: 1.3.0
    • soupsieve: 2.4.1
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • sympy: 1.12
    • timm: 0.9.2
    • tokenizers: 0.13.3
    • toml: 0.10.2
    • tomli: 2.0.1
    • toolz: 0.12.0
    • torch: 2.0.1
    • torchmetrics: 0.11.4
    • torchvision: 0.15.2
    • tornado: 6.3.2
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • transformers: 4.30.2
    • triton: 2.0.0
    • typeshed-client: 2.3.0
    • typing-extensions: 4.7.1
    • typing-inspect: 0.9.0
    • tzdata: 2023.3
    • urllib3: 2.0.3
    • uvicorn: 0.22.0
    • wandb: 0.15.7
    • wcwidth: 0.2.6
    • websocket-client: 1.6.1
    • websockets: 11.0.3
    • wheel: 0.40.0
    • xformers: 0.0.20
    • xxhash: 0.0.0
    • yapf: 0.40.1
    • yarl: 1.9.2
    • zipp: 3.16.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.11.4
    • release: 5.15.0-73-generic
    • version: updated support for 1.2.0 #80-Ubuntu SMP Mon May 15 15:18:26 UTC 2023

More info

The bug does not occur for tensors on the CPU.

cc @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingfabriclightning.fabric.Fabricstrategy: ddpDistributedDataParallelver: 2.0.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions