77# is strictly prohibited.
88
99import os
10+ import pathlib
1011import time
1112
13+ import numpy as np
1214import pytest
1315
1416import cuda .core .experimental
1517from cuda .core .experimental import Device , EventOptions , LaunchConfig , Program , ProgramOptions , launch
18+ from cuda .core .experimental ._memory import _DefaultPinnedMemorySource
1619
1720
1821def test_event_init_disabled ():
@@ -113,27 +116,44 @@ def test_error_timing_recorded():
113116 event3 - event2
114117
115118
119+ # TODO: improve this once path finder can find headers
120+ @pytest .mark .skipif (os .environ .get ("CUDA_PATH" ) is None , reason = "need libcu++ header" )
116121def test_error_timing_incomplete ():
117122 device = Device ()
118123 device .set_current ()
119124
120- # This kernel is designed to not complete
125+ # This kernel is designed to busy loop until a signal is received
121126 code = """
127+ #include <cuda/atomic>
128+
122129extern "C"
123- __global__ void wait() {
124- while (1 > 0) {
130+ __global__ void wait(int* val) {
131+ cuda::atomic_ref<int, cuda::thread_scope_system> signal{*val};
132+ while (true) {
133+ if (signal.load(cuda::memory_order_relaxed)) {
134+ break;
135+ }
125136 }
126137}
127138"""
128139
129140 arch = "" .join (f"{ i } " for i in device .compute_capability )
130- program_options = ProgramOptions (std = "c++11" , arch = f"sm_{ arch } " )
141+ program_options = ProgramOptions (
142+ std = "c++17" ,
143+ arch = f"sm_{ arch } " ,
144+ include_path = str (pathlib .Path (os .environ ["CUDA_PATH" ]) / pathlib .Path ("include" )),
145+ )
131146 prog = Program (code , code_type = "c++" , options = program_options )
132147 mod = prog .compile (target_type = "cubin" )
133148 ker = mod .get_kernel ("wait" )
134149
150+ mr = _DefaultPinnedMemorySource ()
151+ b = mr .allocate (4 )
152+ arr = np .from_dlpack (b ).view (np .int32 )
153+ arr [0 ] = 0
154+
135155 config = LaunchConfig (grid = 1 , block = 1 )
136- ker_args = ()
156+ ker_args = (arr . ctypes . data , )
137157
138158 enabled = EventOptions (enable_timing = True )
139159 stream = device .create_stream ()
@@ -145,3 +165,7 @@ def test_error_timing_incomplete():
145165 # event3 will never complete because the stream is waiting on wait() to complete
146166 with pytest .raises (RuntimeError , match = "^One or both events have not completed." ):
147167 event3 - event1
168+
169+ arr [0 ] = 1
170+ event3 .sync ()
171+ event3 - event1 # this should work
0 commit comments