|
1 | 1 | # pyright: reportCallIssue=false
|
2 | 2 |
|
3 |
| -from collections.abc import Sequence |
| 3 | +from typing import Any, Optional |
4 | 4 |
|
5 |
| -import torch |
| 5 | +import nvshmem.core as nvshmem # type: ignore[import] |
| 6 | +import torch.distributed as dist |
6 | 7 |
|
7 |
| -from .ops import _ops |
8 | 8 |
|
9 | 9 | ###### NVSHMEM ######
|
10 |
| - |
11 |
| - |
12 |
| -def nvshmem_get_unique_id() -> torch.Tensor: |
13 |
| - return _ops.nvshmem_get_unique_id() |
14 |
| - |
15 |
| - |
16 |
| -def nvshmem_unique_id_size() -> int: |
17 |
| - return _ops.nvshmem_unique_id_size() |
18 |
| - |
19 |
| - |
20 |
| -def nvshmem_alloc_empty_unique_id() -> torch.Tensor: |
21 |
| - return torch.zeros(nvshmem_unique_id_size(), dtype=torch.uint8, device="cpu") |
22 |
| - |
23 |
| - |
24 |
| -def nvshmem_init(uid: torch.Tensor, rank: int, world_size: int) -> int: |
25 |
| - status = _ops.nvshmem_init(uid, rank, world_size) |
26 |
| - torch.cuda.synchronize() |
27 |
| - return status |
28 |
| - |
29 |
| - |
30 |
| -def nvshmem_alltoall(dest: torch.Tensor, source: torch.Tensor) -> None: |
31 |
| - return _ops.nvshmem_alltoall(dest, source) |
32 |
| - |
33 |
| - |
34 |
| -def nvshmem_finalize() -> None: |
35 |
| - torch.cuda.synchronize() |
36 |
| - _ops.nvshmem_finalize() |
37 |
| - |
38 |
| - |
39 |
| -def nvshmem_my_pe() -> int: |
40 |
| - return _ops.nvshmem_my_pe() |
41 |
| - |
42 |
| - |
43 |
| -def nvshmem_n_pes() -> int: |
44 |
| - return _ops.nvshmem_n_pes() |
45 |
| - |
46 |
| - |
47 |
| -def nvshmem_malloc( |
48 |
| - shape: Sequence[int], |
49 |
| - dtype: torch.dtype, |
50 |
| - device: torch.device, |
51 |
| -) -> torch.Tensor: |
52 |
| - return _ops.nvshmem_malloc(shape, dtype, device) |
53 |
| - |
54 |
| - |
55 |
| -def nvshmem_barrier_all() -> None: |
56 |
| - _ops.nvshmem_barrier_all() |
57 |
| - |
58 |
| - |
59 |
| -def nvshmem_barrier_all_on_current_stream() -> None: |
60 |
| - _ops.nvshmem_barrier_all_on_current_stream() |
| 10 | +def nvshmem_init( |
| 11 | + global_rank: int, |
| 12 | + local_rank: int, |
| 13 | + world_size: int, |
| 14 | + device: Any, |
| 15 | + uid: Optional[Any] = None, |
| 16 | +) -> None: |
| 17 | + uniqueid = nvshmem.get_unique_id(empty=True) |
| 18 | + if local_rank == 0: |
| 19 | + uniqueid = nvshmem.get_unique_id() |
| 20 | + broadcast_objects = [uniqueid] |
| 21 | + else: |
| 22 | + broadcast_objects = [None] |
| 23 | + |
| 24 | + dist.broadcast_object_list(broadcast_objects, src=0) |
| 25 | + dist.barrier() |
| 26 | + |
| 27 | + nvshmem.init( |
| 28 | + device=device, |
| 29 | + uid=broadcast_objects[0], |
| 30 | + rank=global_rank, |
| 31 | + nranks=world_size, |
| 32 | + initializer_method="uid", |
| 33 | + ) |
| 34 | + |
| 35 | + |
| 36 | +# This stream wrapper returns the format required by CUDA Python. This workaround will be removed when nvshmem4py supports Torch stream interoperability. |
| 37 | +# For more information see: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol |
| 38 | +class PyTorchStreamWrapper: |
| 39 | + def __init__(self, pt_stream: Any) -> None: |
| 40 | + self.pt_stream = pt_stream |
| 41 | + self.handle = pt_stream.cuda_stream |
| 42 | + |
| 43 | + def __cuda_stream__(self) -> tuple[int, int]: |
| 44 | + stream_id = self.pt_stream.cuda_stream |
| 45 | + return (0, stream_id) |
0 commit comments