Skip to content

Commit 93c1e54

Browse files
committed
update tests, couple TODOs left
1 parent 4942a19 commit 93c1e54

File tree

8 files changed

+371
-43
lines changed

8 files changed

+371
-43
lines changed

cuda_core/tests/test_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from cuda.core._context import Context
1+
from cuda.core.experimental._context import Context
22

33
def test_context_initialization():
44
try:

cuda_core/tests/test_device.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,53 @@
1-
from cuda.core._device import Device
1+
from cuda import cuda, cudart
2+
from cuda.core.experimental._device import Device
3+
from cuda.core.experimental._utils import handle_return, ComputeCapability, CUDAError, \
4+
precondition
25

36
def test_device_initialization():
47
device = Device()
5-
assert device is not None
8+
assert device is not None
9+
10+
def test_device_repr():
11+
device = Device()
12+
assert str(device).startswith('<Device 0')
13+
14+
def test_device_set_current():
15+
device = Device()
16+
device.set_current()
17+
18+
def test_device_create_stream():
19+
device = Device()
20+
stream = device.create_stream()
21+
assert stream is not None
22+
23+
24+
def test_pci_bus_id():
25+
device = Device(0)
26+
bus_id = handle_return(cudart.cudaDeviceGetPCIBusId(13, device.device_id))
27+
assert device.pci_bus_id == bus_id[:12].decode()
28+
29+
def test_uuid():
30+
device = Device(0)
31+
driver_ver = handle_return(cuda.cuDriverGetVersion())
32+
if driver_ver >= 11040:
33+
uuid = handle_return(cuda.cuDeviceGetUuid_v2(device.device_id))
34+
else:
35+
uuid = handle_return(cuda.cuDeviceGetUuid(device.device_id))
36+
uuid = uuid.bytes.hex()
37+
expected_uuid = f"{uuid[:8]}-{uuid[8:12]}-{uuid[12:16]}-{uuid[16:20]}-{uuid[20:]}"
38+
assert device.uuid == expected_uuid
39+
40+
def test_name():
41+
device = Device(0)
42+
name = handle_return(cuda.cuDeviceGetName(128, device.device_id))
43+
name = name.split(b'\0')[0]
44+
assert device.name == name.decode()
45+
46+
def test_compute_capability():
47+
device = Device(0)
48+
major = handle_return(cudart.cudaDeviceGetAttribute(
49+
cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, device.device_id))
50+
minor = handle_return(cudart.cudaDeviceGetAttribute(
51+
cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device.device_id))
52+
expected_cc = ComputeCapability(major, minor)
53+
assert device.compute_capability == expected_cc

cuda_core/tests/test_event.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,38 @@
1-
from cuda.core._event import Event
1+
from cuda import cuda
2+
from cuda.core.experimental._event import EventOptions, Event
3+
from cuda.core.experimental._utils import handle_return
24

3-
def test_event_initialization():
5+
def test_is_timing_disabled():
6+
options = EventOptions(enable_timing=False)
7+
event = Event._init(options)
8+
assert event.is_timing_disabled == True
9+
10+
def test_is_sync_busy_waited():
11+
options = EventOptions(busy_waited_sync=True)
12+
event = Event._init(options)
13+
assert event.is_sync_busy_waited == True
14+
15+
def test_is_ipc_supported():
16+
options = EventOptions(support_ipc=True)
417
try:
5-
event = Event()
6-
except NotImplementedError as e:
18+
event = Event._init(options)
19+
except NotImplementedError:
720
assert True
821
else:
9-
assert False, "Expected NotImplementedError was not raised"
22+
assert False
23+
24+
def test_sync():
25+
options = EventOptions()
26+
event = Event._init(options)
27+
event.sync()
28+
assert event.is_done == True
29+
30+
def test_is_done():
31+
options = EventOptions()
32+
event = Event._init(options)
33+
assert event.is_done == True
34+
35+
def test_handle():
36+
options = EventOptions()
37+
event = Event._init(options)
38+
assert isinstance(event.handle, int)

cuda_core/tests/test_launcher.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,77 @@
1-
from cuda.core._launcher import LaunchConfig
1+
from cuda.core.experimental._launcher import LaunchConfig
2+
from cuda.core.experimental._stream import Stream
3+
from cuda.core.experimental._device import Device
4+
from cuda.core.experimental._utils import handle_return
5+
from cuda import cuda
26

3-
def test_launch_initialization():
7+
def test_launch_config_init():
48
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=0)
5-
69
assert config.grid == (1, 1, 1)
710
assert config.block == (1, 1, 1)
11+
assert config.stream is None
12+
assert config.shmem_size == 0
13+
14+
config = LaunchConfig(grid=(2, 2, 2), block=(2, 2, 2), stream=Device().create_stream(), shmem_size=1024)
15+
assert config.grid == (2, 2, 2)
16+
assert config.block == (2, 2, 2)
17+
assert isinstance(config.stream, Stream)
18+
assert config.shmem_size == 1024
19+
20+
def test_launch_config_cast_to_3_tuple():
21+
config = LaunchConfig(grid=1, block=1)
22+
assert config._cast_to_3_tuple(1) == (1, 1, 1)
23+
assert config._cast_to_3_tuple((1, 2)) == (1, 2, 1)
24+
assert config._cast_to_3_tuple((1, 2, 3)) == (1, 2, 3)
25+
26+
# Edge cases
27+
assert config._cast_to_3_tuple(999) == (999, 1, 1)
28+
assert config._cast_to_3_tuple((999, 888)) == (999, 888, 1)
29+
assert config._cast_to_3_tuple((999, 888, 777)) == (999, 888, 777)
30+
31+
def test_launch_config_invalid_values():
32+
try:
33+
LaunchConfig(grid=0, block=1)
34+
except ValueError:
35+
assert True
36+
else:
37+
assert False
38+
39+
try:
40+
LaunchConfig(grid=(0, 1), block=1)
41+
except ValueError:
42+
assert True
43+
else:
44+
assert False
45+
46+
try:
47+
LaunchConfig(grid=(1, 1, 1), block=0)
48+
except ValueError:
49+
assert True
50+
else:
51+
assert False
52+
53+
try:
54+
LaunchConfig(grid=(1, 1, 1), block=(0, 1))
55+
except ValueError:
56+
assert True
57+
else:
58+
assert False
59+
60+
def test_launch_config_stream():
61+
stream = Device().create_stream()
62+
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=stream, shmem_size=0)
63+
assert config.stream == stream
64+
65+
try:
66+
LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream="invalid_stream", shmem_size=0)
67+
except ValueError:
68+
assert True
69+
else:
70+
assert False
71+
72+
def test_launch_config_shmem_size():
73+
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=2048)
74+
assert config.shmem_size == 2048
75+
76+
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None)
877
assert config.shmem_size == 0

cuda_core/tests/test_memory.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1-
from cuda.core._memory import Buffer, MemoryResource
1+
# FILE: test_memory.py
2+
3+
from cuda.core.experimental._memory import Buffer, MemoryResource
4+
from cuda.core.experimental._device import Device
5+
from cuda import cuda
6+
from cuda.core.experimental._utils import handle_return
27

38
class DummyMemoryResource(MemoryResource):
49
def __init__(self):
510
pass
611

712
def allocate(self, size, stream=None) -> Buffer:
8-
pass
13+
ptr = handle_return(cuda.cuMemAlloc(size))
14+
return Buffer(ptr=ptr, size=size, mr=self)
915

1016
def deallocate(self, ptr, size, stream=None):
11-
pass
17+
handle_return(cuda.cuMemFree(ptr))
1218

1319
@property
1420
def is_device_accessible(self) -> bool:
@@ -24,7 +30,42 @@ def device_id(self) -> int:
2430

2531
def test_buffer_initialization():
2632
dummy_mr = DummyMemoryResource()
27-
buffer = Buffer(ptr=1234, size=1024, mr=dummy_mr)
28-
assert buffer._ptr == 1234
29-
assert buffer._size == 1024
30-
assert buffer._mr == dummy_mr
33+
buffer = dummy_mr.allocate(size=1024)
34+
assert buffer.handle != 0
35+
assert buffer.size == 1024
36+
assert buffer.memory_resource == dummy_mr
37+
assert buffer.is_device_accessible == True
38+
assert buffer.is_host_accessible == True
39+
assert buffer.device_id == 0
40+
dummy_mr.deallocate(buffer.handle, buffer.size)
41+
42+
def test_buffer_copy_to():
43+
dummy_mr = DummyMemoryResource()
44+
src_buffer = dummy_mr.allocate(size=1024)
45+
dst_buffer = dummy_mr.allocate(size=1024)
46+
device = Device()
47+
device.set_current()
48+
stream = device.create_stream()
49+
src_buffer.copy_to(dst_buffer, stream=stream)
50+
# Assuming cuMemcpyAsync is correctly called, we can't directly check the result here
51+
dummy_mr.deallocate(src_buffer.handle, src_buffer.size)
52+
dummy_mr.deallocate(dst_buffer.handle, dst_buffer.size)
53+
54+
def test_buffer_copy_from():
55+
dummy_mr = DummyMemoryResource()
56+
src_buffer = dummy_mr.allocate(size=1024)
57+
dst_buffer = dummy_mr.allocate(size=1024)
58+
device = Device()
59+
device.set_current()
60+
stream = device.create_stream()
61+
dst_buffer.copy_from(src_buffer, stream=stream)
62+
# Assuming cuMemcpyAsync is correctly called, we can't directly check the result here
63+
dummy_mr.deallocate(src_buffer.handle, src_buffer.size)
64+
dummy_mr.deallocate(dst_buffer.handle, dst_buffer.size)
65+
66+
def test_buffer_close():
67+
dummy_mr = DummyMemoryResource()
68+
buffer = dummy_mr.allocate(size=1024)
69+
buffer.close()
70+
assert buffer.handle == 0
71+
assert buffer.memory_resource == None

cuda_core/tests/test_module.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,35 @@
1-
from cuda.core._module import ObjectCode
2-
3-
def test_module_initialization():
4-
module_code = b"dummy_code"
5-
code_type = "ptx"
6-
module = ObjectCode(module=module_code, code_type=code_type)
7-
assert module._handle is not None
8-
assert module._code_type == code_type
9-
assert module._module == module_code
10-
assert module._loader is not None
11-
assert module._sym_map == {}
1+
import pytest
2+
from cuda import cuda
3+
from cuda.core.experimental._device import Device
4+
from cuda.core.experimental._module import Kernel, ObjectCode
5+
from cuda.core.experimental._utils import handle_return
6+
7+
@pytest.fixture(scope='module')
8+
def init_cuda():
9+
Device().set_current()
10+
11+
def test_object_code_initialization():
12+
# Test with supported code types
13+
for code_type in ["cubin", "ptx", "fatbin"]:
14+
module_data = b"dummy_data"
15+
obj_code = ObjectCode(module_data, code_type)
16+
assert obj_code._code_type == code_type
17+
assert obj_code._module == module_data
18+
assert obj_code._handle is not None
19+
20+
# Test with unsupported code type
21+
with pytest.raises(ValueError):
22+
ObjectCode(b"dummy_data", "unsupported_code_type")
23+
24+
#TODO add ObjectCode tests which provide the appropriate data for cuLibraryLoadFromFile
25+
def test_object_code_initialization_with_str():
26+
assert True
27+
28+
def test_object_code_initialization_with_jit_options():
29+
assert True
30+
31+
def test_object_code_get_kernel():
32+
assert True
33+
34+
def test_kernel_from_obj():
35+
assert True

cuda_core/tests/test_program.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,50 @@
1-
from cuda.core._program import Program
1+
import pytest
2+
from cuda import nvrtc
3+
from cuda.core.experimental._program import Program
4+
from cuda.core.experimental._module import ObjectCode, Kernel
25

3-
def test_program_initialization():
4-
code = "__device__ int test_func() { return 0; }"
6+
def test_program_init_valid_code_type():
7+
code = "extern \"C\" __global__ void my_kernel() {}"
58
program = Program(code, "c++")
6-
assert program._handle is not None
7-
assert program._backend == "nvrtc"
9+
assert program.backend == "nvrtc"
10+
assert program.handle is not None
11+
12+
def test_program_init_invalid_code_type():
13+
code = "extern \"C\" __global__ void my_kernel() {}"
14+
with pytest.raises(NotImplementedError):
15+
Program(code, "python")
16+
17+
def test_program_init_invalid_code_format():
18+
code = 12345
19+
with pytest.raises(TypeError):
20+
Program(code, "c++")
21+
22+
def test_program_compile_valid_target_type():
23+
code = "extern \"C\" __global__ void my_kernel() {}"
24+
program = Program(code, "c++")
25+
object_code = program.compile("ptx")
26+
kernel = object_code.get_kernel("my_kernel")
27+
assert isinstance(object_code, ObjectCode)
28+
assert isinstance(kernel, Kernel)
29+
30+
def test_program_compile_invalid_target_type():
31+
code = "extern \"C\" __global__ void my_kernel() {}"
32+
program = Program(code, "c++")
33+
with pytest.raises(NotImplementedError):
34+
program.compile("invalid_target")
35+
36+
def test_program_backend_property():
37+
code = "extern \"C\" __global__ void my_kernel() {}"
38+
program = Program(code, "c++")
39+
assert program.backend == "nvrtc"
40+
41+
def test_program_handle_property():
42+
code = "extern \"C\" __global__ void my_kernel() {}"
43+
program = Program(code, "c++")
44+
assert program.handle is not None
45+
46+
def test_program_close():
47+
code = "extern \"C\" __global__ void my_kernel() {}"
48+
program = Program(code, "c++")
49+
program.close()
50+
assert program.handle is None

0 commit comments

Comments
 (0)