Skip to content

Commit 8cda903

Browse files
clean up device initialization in test (#507)
1 parent ba8de1d commit 8cda903

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

cuda_core/tests/test_device.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,26 @@ def test_device_set_current(deinit_cuda):
3737
assert handle_return(driver.cuCtxGetCurrent()) is not None
3838

3939

40-
def test_device_repr():
40+
def test_device_repr(deinit_cuda):
4141
device = Device(0)
42+
device.set_current()
4243
assert str(device).startswith("<Device 0")
4344

4445

45-
def test_device_alloc(init_cuda):
46+
def test_device_alloc(deinit_cuda):
4647
device = Device()
48+
device.set_current()
4749
buffer = device.allocate(1024)
4850
device.sync()
4951
assert buffer.handle != 0
5052
assert buffer.size == 1024
51-
assert buffer.device_id == 0
53+
assert buffer.device_id == int(device)
54+
55+
56+
def test_device_id(deinit_cuda):
57+
for device in cuda.core.experimental.system.devices:
58+
device.set_current()
59+
assert device.device_id == handle_return(runtime.cudaGetDevice())
5260

5361

5462
def test_device_create_stream(init_cuda):

0 commit comments

Comments
 (0)