1+ import logging
2+
3+ import nvshmem .core as nvshmem # type: ignore[import]
14import pytest
25import torch
6+ import torch .distributed as dist
7+ from cuda .core .experimental import Device # type: ignore[import]
8+ from nvshmem .core import Teams # type: ignore[import]
9+
10+ from pplx_kernels import nvshmem_init
311
412from .distributed_utils import (
513 ProcessGroupInfo ,
816 require_multi_node ,
917)
1018
11- from cuda .core .experimental import Device
12- import nvshmem .core as nvshmem
13- import torch .distributed as dist
14- from nvshmem .core import Teams
15- from pplx_kernels import nvshmem_init , PyTorchStreamWrapper
19+ logger = logging .getLogger (__name__ )
1620
1721def test_nvshmem_1_gpu () -> None :
1822
1923 local_rank = 0
20- world_size = 1
24+ rank_id = 0 # Define rank_id for single GPU test
2125
2226 torch .cuda .set_device (local_rank )
23- device = torch .device ("cuda" , local_rank )
2427 dev = Device (local_rank )
2528 dev .set_current ()
2629
@@ -39,17 +42,15 @@ def test_nvshmem_1_gpu() -> None:
3942 assert nvshmem .n_pes () == 1
4043
4144 nvshmem .finalize ()
42-
4345
4446
4547def _worker_test_nvshmem_4_gpu (pgi : ProcessGroupInfo ) -> None :
4648 local_rank = dist .get_rank ()
47- world_size = dist .get_world_size ()
4849
4950 dev = Device (local_rank )
5051 dev .set_current ()
5152
52- nvshmem_init (global_rank = pgi .rank , local_rank = local_rank , world_size = world_size , device = dev )
53+ nvshmem_init (global_rank = pgi .rank , local_rank = local_rank , world_size = pgi . world_size , device = dev )
5354
5455 # Check host initialization status
5556 test_script_init_status = nvshmem .direct .init_status ()
@@ -72,12 +73,10 @@ def test_nvshmem_4_gpu() -> None:
7273
7374def _worker_test_all_to_all (pgi : ProcessGroupInfo ) -> None :
7475 local_rank = dist .get_rank ()
75- world_size = dist .get_world_size ()
7676
7777 dev = Device (local_rank )
7878 dev .set_current ()
79- stream = PyTorchStreamWrapper (torch .cuda .current_stream ())
80-
79+
8180 num_ranks = dist .get_world_size ()
8281 rank_id = dist .get_rank ()
8382
@@ -98,9 +97,9 @@ def _worker_test_all_to_all(pgi: ProcessGroupInfo) -> None:
9897 t_out = nvshmem .tensor ( (pgi .world_size ,), dtype = torch .int32 )
9998
10099 team = Teams .TEAM_WORLD
101- nvshmem .collective .alltoall (team , t_out , t_in , stream = stream )
100+ nvshmem .collective .alltoall (team , t_out , t_in )
102101
103- nvshmem .collective .barrier (team , stream = stream )
102+ nvshmem .collective .barrier (team )
104103 torch .cuda .synchronize ()
105104
106105 assert t_out .tolist () == list (range (pgi .world_size ))
0 commit comments