Skip to content

Commit 9fba2b7

Browse files
committed
use precondition and update test
1 parent 016055e commit 9fba2b7

File tree

2 files changed

+43
-61
lines changed

2 files changed

+43
-61
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import importlib.metadata
66

77
from cuda import cuda
8-
from cuda.core.experimental._utils import handle_return
8+
from cuda.core.experimental._utils import handle_return, precondition
99

1010
_backend = {
1111
"old": {
@@ -127,31 +127,10 @@ def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
127127
self._sym_map = {} if symbol_mapping is None else symbol_mapping
128128

129129
# TODO: do we want to unload in a finalizer? Probably not..
130-
131-
def get_kernel(self, name):
132-
"""Return the :obj:`Kernel` of a specified name from this object code.
133-
134-
Parameters
135-
----------
136-
name : Any
137-
Name of the kernel to retrieve.
138-
139-
Returns
140-
-------
141-
:obj:`Kernel`
142-
Newly created kernel object.
143-
144-
"""
145-
try:
146-
name = self._sym_map[name]
147-
except KeyError:
148-
name = name.encode()
149-
150-
self._lazy_load_module()
151-
data = handle_return(self._loader["kernel"](self._handle, name))
152-
return Kernel._from_obj(data, self)
153-
154-
def _lazy_load_module(self):
130+
131+
def _lazy_load_module(self, *args, **kwargs):
132+
if self._handle is not None:
133+
return
155134
if isinstance(self._module, str):
156135
# TODO: this option is only taken by the new library APIs, but we have
157136
# a bug that we can't easily support it just yet (NVIDIA/cuda-python#73).
@@ -178,4 +157,28 @@ def _lazy_load_module(self):
178157
args = (self._module, len(self._jit_options), list(self._jit_options.keys()), list(self._jit_options.values()))
179158
self._handle = handle_return(self._loader["data"](*args))
180159

160+
@precondition(_lazy_load_module)
161+
def get_kernel(self, name):
162+
"""Return the :obj:`Kernel` of a specified name from this object code.
163+
164+
Parameters
165+
----------
166+
name : Any
167+
Name of the kernel to retrieve.
168+
169+
Returns
170+
-------
171+
:obj:`Kernel`
172+
Newly created kernel object.
173+
174+
"""
175+
try:
176+
name = self._sym_map[name]
177+
except KeyError:
178+
name = name.encode()
179+
180+
data = handle_return(self._loader["kernel"](self._handle, name))
181+
return Kernel._from_obj(data, self)
182+
183+
181184
# TODO: implement from_handle()

cuda_core/tests/test_module.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,38 +10,17 @@
1010

1111
import pytest
1212

13-
from cuda.core.experimental._module import ObjectCode
14-
15-
16-
@pytest.mark.skipif(
17-
int(importlib.metadata.version("cuda-python").split(".")[0]) < 12,
18-
reason="Module loading for older drivers validate require valid module code.",
19-
)
20-
def test_object_code_initialization():
21-
# Test with supported code types
22-
for code_type in ["cubin", "ptx", "fatbin"]:
23-
module_data = b"dummy_data"
24-
obj_code = ObjectCode(module_data, code_type)
25-
assert obj_code._code_type == code_type
26-
assert obj_code._module == module_data
27-
28-
# Test with unsupported code type
29-
with pytest.raises(ValueError):
30-
ObjectCode(b"dummy_data", "unsupported_code_type")
31-
32-
33-
# TODO add ObjectCode tests which provide the appropriate data for cuLibraryLoadFromFile
34-
def test_object_code_initialization_with_str():
35-
assert True
36-
37-
38-
def test_object_code_initialization_with_jit_options():
39-
assert True
40-
41-
42-
def test_object_code_get_kernel():
43-
assert True
44-
45-
46-
def test_kernel_from_obj():
47-
assert True
13+
from cuda.core.experimental import Program
14+
15+
16+
def test_get_kernel():
17+
kernel = """
18+
extern __device__ int B();
19+
extern __device__ int C(int a, int b);
20+
__global__ void A() { int result = C(B(), 1);}
21+
"""
22+
object_code = Program(kernel, "c++").compile("ptx", options=("-rdc=true",))
23+
assert object_code._handle is None
24+
kernel = object_code.get_kernel("A")
25+
assert object_code._handle is not None
26+
assert kernel._handle is not None

0 commit comments

Comments
 (0)