Skip to content

Commit 401491a

Browse files
committed
avoid unnecessary sync
1 parent cb3424d commit 401491a

File tree

1 file changed

+70
-46
lines changed

1 file changed

+70
-46
lines changed

cuda_core/examples/simple_multi_gpu_example.py

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,25 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
import sys
6+
57
import cupy as cp
68

7-
from cuda.core.experimental import Device, LaunchConfig, Program, launch
9+
from cuda.core.experimental import Device, LaunchConfig, Program, launch, system
10+
11+
if system.num_devices < 2:
12+
print("this example requires at least 2 GPUs", file=sys.stderr)
13+
sys.exit(0)
814

915
dtype = cp.float32
1016
size = 50000
1117

12-
# Set GPU0
18+
# Set GPU 0
1319
dev0 = Device(0)
1420
dev0.set_current()
1521
stream0 = dev0.create_stream()
1622

17-
# Allocate memory to GPU0
18-
a = cp.random.random(size, dtype=dtype)
19-
b = cp.random.random(size, dtype=dtype)
20-
c = cp.empty_like(a)
21-
22-
# Set GPU1
23-
dev1 = Device(1)
24-
dev1.set_current()
25-
stream1 = dev1.create_stream()
26-
27-
# Allocate memory to GPU1
28-
x = cp.random.random(size, dtype=dtype)
29-
y = cp.random.random(size, dtype=dtype)
30-
z = cp.empty_like(a)
31-
32-
# compute c = a + b
23+
# Compile a kernel targeting GPU 0 to compute c = a + b
3324
code_add = """
3425
extern "C"
3526
__global__ void vector_add(const float* A,
@@ -42,11 +33,26 @@
4233
}
4334
}
4435
"""
36+
arch0 = "".join(f"{i}" for i in dev0.compute_capability)
37+
prog_add = Program(code_add, code_type="c++")
38+
mod_add = prog_add.compile(
39+
"cubin",
40+
options=(
41+
"-std=c++17",
42+
"-arch=sm_" + arch0,
43+
),
44+
)
45+
ker_add = mod_add.get_kernel("vector_add")
4546

46-
# compute c = a - b
47+
# Set GPU 1
48+
dev1 = Device(1)
49+
dev1.set_current()
50+
stream1 = dev1.create_stream()
51+
52+
# Compile a kernel targeting GPU 1 to compute c = a - b
4753
code_sub = """
4854
extern "C"
49-
__global__ void vector_sub(const *float A,
55+
__global__ void vector_sub(const float* A,
5056
const float* B,
5157
float* C,
5258
size_t N) {
@@ -56,20 +62,6 @@
5662
}
5763
}
5864
"""
59-
60-
arch0 = "".join(f"{i}" for i in dev0.compute_capability)
61-
prog_add = Program(code_add, code_type="c++")
62-
mod_add = prog_add.compile(
63-
"cubin",
64-
options=(
65-
"-std=c++17",
66-
"-arch=sm_" + arch0,
67-
),
68-
)
69-
70-
# run in single precision
71-
ker_add = mod_add.get_kernel("vector_add")
72-
7365
arch1 = "".join(f"{i}" for i in dev1.compute_capability)
7466
prog_sub = Program(code_sub, code_type="c++")
7567
mod_sub = prog_sub.compile(
@@ -79,31 +71,63 @@
7971
"-arch=sm_" + arch1,
8072
),
8173
)
82-
83-
# run in single precision
8474
ker_sub = mod_sub.get_kernel("vector_sub")
8575

86-
# Synchronize devices to ensure that memory has been created
87-
dev0.sync()
88-
dev1.sync()
8976

77+
# This adaptor ensures that any foreign stream (ex: from CuPy) that have not
78+
# yet supported the __cuda_stream__ protocol can still be recognized by
79+
# cuda.core.
80+
class StreamAdaptor:
81+
def __init__(self, obj):
82+
self.obj = obj
83+
84+
@property
85+
def __cuda_stream__(self):
86+
# Note: CuPy streams have a .ptr attribute
87+
return (0, self.obj.ptr)
88+
89+
90+
# Create launch configs for each kernel that will be executed on the respective
91+
# CUDA streams.
9092
block = 256
9193
grid = (size + block - 1) // block
92-
9394
config0 = LaunchConfig(grid=grid, block=block, stream=stream0)
9495
config1 = LaunchConfig(grid=grid, block=block, stream=stream1)
9596

96-
# Launch GPU0 and Synchronize the stream
97+
# Allocate memory on GPU 0
98+
# Note: This runs on CuPy's current stream for GPU 0
9799
dev0.set_current()
98-
launch(ker_add, config0, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size))
99-
stream0.sync()
100+
a = cp.random.random(size, dtype=dtype)
101+
b = cp.random.random(size, dtype=dtype)
102+
c = cp.empty_like(a)
103+
cp_stream0 = StreamAdaptor(cp.cuda.get_current_stream())
100104

101-
# Validate result
102-
assert cp.allclose(c, a + b)
105+
# Establish a stream order to ensure that memory has been initialized before
106+
# accessed by the kernel.
107+
stream0.wait(cp_stream0)
103108

104-
# Launch GPU1 and Synchronize the stream
109+
# Launch the add kernel on GPU 0 / stream 0
110+
launch(ker_add, config0, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size))
111+
112+
# Allocate memory on GPU 1
113+
# Note: This runs on CuPy's current stream for GPU 1.
105114
dev1.set_current()
115+
x = cp.random.random(size, dtype=dtype)
116+
y = cp.random.random(size, dtype=dtype)
117+
z = cp.empty_like(a)
118+
cp_stream1 = StreamAdaptor(cp.cuda.get_current_stream())
119+
120+
# Establish a stream order
121+
stream1.wait(cp_stream1)
122+
123+
# Launch the subtract kernel on GPU 1 / stream 1
106124
launch(ker_sub, config1, x.data.ptr, y.data.ptr, z.data.ptr, cp.uint64(size))
125+
126+
# Synchronize both GPUs are validate the results
127+
dev0.set_current()
128+
stream0.sync()
129+
assert cp.allclose(c, a + b)
130+
dev1.set_current()
107131
stream1.sync()
108132
assert cp.allclose(z, x - y)
109133

0 commit comments

Comments
 (0)