|
1 | 1 | # pyright: reportCallIssue=false
|
2 | 2 |
|
3 |
| -from collections.abc import Sequence |
| 3 | +from typing import Any, Optional |
4 | 4 |
|
5 |
| -import torch |
| 5 | +import nvshmem.core as nvshmem # type: ignore[import] |
| 6 | +import torch.distributed as dist |
6 | 7 |
|
7 |
| -from .ops import _ops |
8 | 8 |
|
9 | 9 | ###### NVSHMEM ######
|
| 10 | +def nvshmem_init(global_rank: int, local_rank: int, world_size: int, device: Any, uid: Optional[Any] = None) -> None: |
| 11 | + uniqueid = nvshmem.get_unique_id(empty=True) |
| 12 | + if local_rank == 0: |
| 13 | + uniqueid = nvshmem.get_unique_id() |
| 14 | + broadcast_objects = [uniqueid] |
| 15 | + else: |
| 16 | + broadcast_objects = [None] |
10 | 17 |
|
| 18 | + dist.broadcast_object_list(broadcast_objects, src=0) |
| 19 | + dist.barrier() |
11 | 20 |
|
12 |
| -def nvshmem_get_unique_id() -> torch.Tensor: |
13 |
| - return _ops.nvshmem_get_unique_id() |
| 21 | + nvshmem.init(device=device, uid=broadcast_objects[0], rank=global_rank, nranks=world_size, initializer_method="uid") |
14 | 22 |
|
15 | 23 |
|
16 |
| -def nvshmem_unique_id_size() -> int: |
17 |
| - return _ops.nvshmem_unique_id_size() |
| 24 | +# This stream wrapper returns the format required by CUDA Python. This workaround will be removed when nvshmem4py supports Torch stream interoperability. |
| 25 | +# For more information see: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol |
| 26 | +class PyTorchStreamWrapper: |
| 27 | + def __init__(self, pt_stream: Any) -> None: |
| 28 | + self.pt_stream = pt_stream |
| 29 | + self.handle = pt_stream.cuda_stream |
18 | 30 |
|
| 31 | + def __cuda_stream__(self) -> tuple[int, int]: |
| 32 | + stream_id = self.pt_stream.cuda_stream |
| 33 | + return (0, stream_id) |
19 | 34 |
|
20 |
| -def nvshmem_alloc_empty_unique_id() -> torch.Tensor: |
21 |
| - return torch.zeros(nvshmem_unique_id_size(), dtype=torch.uint8, device="cpu") |
22 | 35 |
|
23 |
| - |
24 |
| -def nvshmem_init(uid: torch.Tensor, rank: int, world_size: int) -> int: |
25 |
| - status = _ops.nvshmem_init(uid, rank, world_size) |
26 |
| - torch.cuda.synchronize() |
27 |
| - return status |
28 |
| - |
29 |
| - |
30 |
| -def nvshmem_alltoall(dest: torch.Tensor, source: torch.Tensor) -> None: |
31 |
| - return _ops.nvshmem_alltoall(dest, source) |
32 |
| - |
33 |
| - |
34 |
| -def nvshmem_finalize() -> None: |
35 |
| - torch.cuda.synchronize() |
36 |
| - _ops.nvshmem_finalize() |
37 |
| - |
38 |
| - |
39 |
| -def nvshmem_my_pe() -> int: |
40 |
| - return _ops.nvshmem_my_pe() |
41 |
| - |
42 |
| - |
43 |
| -def nvshmem_n_pes() -> int: |
44 |
| - return _ops.nvshmem_n_pes() |
45 |
| - |
46 |
| - |
47 |
| -def nvshmem_malloc( |
48 |
| - shape: Sequence[int], |
49 |
| - dtype: torch.dtype, |
50 |
| - device: torch.device, |
51 |
| -) -> torch.Tensor: |
52 |
| - return _ops.nvshmem_malloc(shape, dtype, device) |
53 |
| - |
54 |
| - |
55 |
| -def nvshmem_barrier_all() -> None: |
56 |
| - _ops.nvshmem_barrier_all() |
57 |
| - |
58 |
| - |
59 |
| -def nvshmem_barrier_all_on_current_stream() -> None: |
60 |
| - _ops.nvshmem_barrier_all_on_current_stream() |
0 commit comments