-
Notifications
You must be signed in to change notification settings - Fork 19
Add device_id parameter to all torch distributed initialization calls #203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Updated 20 example files to add device_id=torch.device(f"cuda:{local_rank}") - Updated 3 benchmark files to add device_id=torch.device(f"cuda:{rank}") - Updated 5 reference benchmark files to add device_id with LOCAL_RANK from environment - All changes follow the pattern: device_id=torch.device(f"cuda:{device_id}") Co-authored-by: mawad-amd <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR systematically adds the device_id
parameter to all dist.init_process_group()
calls across the Iris codebase to ensure proper PyTorch distributed initialization with device specification. The change addresses issue #202 by making distributed initialization consistent and explicit about which CUDA device each process should use.
Key changes:
- Added
device_id=torch.device(f"cuda:{device_id}")
parameter to all distributed process group initializations - For functions with explicit rank parameters, the rank is used as the device_id
- For scripts using environment variables, LOCAL_RANK is retrieved and used as device_id
- Added necessary
import os
statements where LOCAL_RANK environment variable access was needed
Reviewed Changes
Copilot reviewed 25 out of 25 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
examples/00_load/load_bench.py | Added device_id parameter using local_rank for distributed initialization |
examples/01_store/store_bench.py | Added device_id parameter using local_rank for distributed initialization |
examples/02_all_load/all_load_bench.py | Added device_id parameter using local_rank for distributed initialization |
examples/03_all_store/all_store_bench.py | Added device_id parameter using local_rank for distributed initialization |
examples/04_atomic_add/atomic_add_bench.py | Added device_id parameter using local_rank for distributed initialization |
examples/05_atomic_xchg/atomic_xchg_bench.py | Added device_id parameter using local_rank for distributed initialization |
examples/06_message_passing/message_passing_load_store.py | Added device_id parameter using local_rank for distributed initialization |
examples/06_message_passing/message_passing_put.py | Added device_id parameter using local_rank for distributed initialization |
examples/07_gemm_all_scatter/benchmark.py | Added device_id parameter using local_rank for distributed initialization |
examples/08_gemm_atomics_all_reduce/benchmark.py | Added device_id parameter using local_rank for distributed initialization |
examples/09_gemm_one_shot_all_reduce/benchmark.py | Added device_id parameter using local_rank for distributed initialization |
examples/10_gemm_all_scatter_wg_specialization/benchmark.py | Added device_id parameter using local_rank for distributed initialization |
examples/11_gemm_all_scatter_producer_consumer/benchmark.py | Added device_id parameter using local_rank for distributed initialization |
examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py | Added device_id parameter using local_rank for distributed initialization |
examples/13_flash_decode/example_run.py | Added device_id parameter using rank for distributed initialization |
examples/14_all_gather_gemm/example_run_pull.py | Added device_id parameter using rank for distributed initialization |
examples/14_all_gather_gemm/example_run_push.py | Added device_id parameter using rank for distributed initialization |
benchmark/examples/benchmark_all_gather_gemm_pull.py | Added device_id parameter using rank for distributed initialization |
benchmark/examples/benchmark_all_gather_gemm_push.py | Added device_id parameter using rank for distributed initialization |
benchmark/examples/benchmark_flash_decode.py | Added device_id parameter using rank for distributed initialization |
examples/benchmark/reference/all_gather.py | Added os import and device_id parameter using LOCAL_RANK environment variable |
examples/benchmark/reference/all_reduce.py | Added os import and device_id parameter using LOCAL_RANK environment variable |
examples/benchmark/reference/reduce_scatter.py | Added os import and device_id parameter using LOCAL_RANK environment variable |
examples/benchmark/reference/all_gather_gemm/benchmark_rccl_torch.py | Added device_id parameter using rank for distributed initialization |
examples/benchmark/reference/flash_decode_rccl/benchmark_flash_decode_rccl.py | Added LOCAL_RANK retrieval and device_id parameter for distributed initialization |
Problem
All PyTorch Distributed initialization calls should explicitly specify the
device_id
parameter to ensure proper GPU device assignment. Without this parameter, PyTorch may not correctly map processes to their intended GPU devices, which can lead to runtime issues in multi-GPU distributed training scenarios.Solution
Updated all
dist.init_process_group()
calls across the codebase to include thedevice_id=torch.device(f"cuda:{device_id}")
parameter. The changes follow two patterns based on how each file obtains its rank:Pattern 1: Files with explicit rank parameter
For files that pass rank explicitly to the worker function:
Pattern 2: Files using environment variables
For files that rely on environment variables for process group initialization:
Files Modified
examples/00_load
throughexamples/14_all_gather_gemm
benchmark/examples/benchmark_*.py
examples/benchmark/reference/
Testing
All files pass Python syntax validation. The changes maintain backward compatibility and will be validated by the CI pipeline on AMD GPU hardware.
Fixes issue requiring device_id parameter for all torch distributed initialization.
Original prompt
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.