Skip to content

Lazy load code modules #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 10, 2024
40 changes: 27 additions & 13 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import importlib.metadata

from cuda import cuda
from cuda.core.experimental._utils import handle_return
from cuda.core.experimental._utils import handle_return, precondition

_backend = {
"old": {
Expand Down Expand Up @@ -106,30 +106,43 @@ class ObjectCode:

"""

__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map")
__slots__ = ("_handle", "_backend_version", "_jit_options", "_code_type", "_module", "_loader", "_sym_map")
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin")

def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
if code_type not in self._supported_code_type:
raise ValueError
_lazy_init()

# handle is assigned during _lazy_load
self._handle = None
self._jit_options = jit_options

self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
self._loader = _backend[self._backend_version]

self._code_type = code_type
self._module = module
self._sym_map = {} if symbol_mapping is None else symbol_mapping

backend = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
self._loader = _backend[backend]
# TODO: do we want to unload in a finalizer? Probably not..

def _lazy_load_module(self, *args, **kwargs):
if self._handle is not None:
return
jit_options = self._jit_options
module = self._module
if isinstance(module, str):
# TODO: this option is only taken by the new library APIs, but we have
# a bug that we can't easily support it just yet (NVIDIA/cuda-python#73).
if jit_options is not None:
raise ValueError
module = module.encode()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't thank me, thank the good folks at ruff

self._handle = handle_return(self._loader["file"](module))
else:
assert isinstance(module, bytes)
if jit_options is None:
jit_options = {}
if backend == "new":
if self._backend_version == "new":
args = (
module,
list(jit_options.keys()),
Expand All @@ -141,15 +154,15 @@ def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
0,
)
else: # "old" backend
args = (module, len(jit_options), list(jit_options.keys()), list(jit_options.values()))
args = (
module,
len(jit_options),
list(jit_options.keys()),
list(jit_options.values()),
)
self._handle = handle_return(self._loader["data"](*args))

self._code_type = code_type
self._module = module
self._sym_map = {} if symbol_mapping is None else symbol_mapping

# TODO: do we want to unload in a finalizer? Probably not..

@precondition(_lazy_load_module)
def get_kernel(self, name):
"""Return the :obj:`Kernel` of a specified name from this object code.

Expand All @@ -168,6 +181,7 @@ def get_kernel(self, name):
name = self._sym_map[name]
except KeyError:
name = name.encode()

data = handle_return(self._loader["kernel"](self._handle, name))
return Kernel._from_obj(data, self)

Expand Down
10 changes: 8 additions & 2 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import sys

try:
from cuda.bindings import driver
from cuda.bindings import driver, nvrtc
except ImportError:
from cuda import cuda as driver

from cuda import nvrtc
import pytest

from cuda.core.experimental import Device, _device
Expand Down Expand Up @@ -65,3 +65,9 @@ def clean_up_cffi_files():
os.remove(f)
except FileNotFoundError:
pass # noqa: SIM105


def can_load_generated_ptx():
_, driver_ver = driver.cuDriverGetVersion()
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver
55 changes: 17 additions & 38 deletions cuda_core/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,22 @@
# this software and related documentation outside the terms of the EULA
# is strictly prohibited.

import importlib

import pytest

from cuda.core.experimental._module import ObjectCode


@pytest.mark.skipif(
int(importlib.metadata.version("cuda-python").split(".")[0]) < 12,
reason="Module loading for older drivers validate require valid module code.",
)
def test_object_code_initialization():
# Test with supported code types
for code_type in ["cubin", "ptx", "fatbin"]:
module_data = b"dummy_data"
obj_code = ObjectCode(module_data, code_type)
assert obj_code._code_type == code_type
assert obj_code._module == module_data
assert obj_code._handle is not None

# Test with unsupported code type
with pytest.raises(ValueError):
ObjectCode(b"dummy_data", "unsupported_code_type")


# TODO add ObjectCode tests which provide the appropriate data for cuLibraryLoadFromFile
def test_object_code_initialization_with_str():
assert True


def test_object_code_initialization_with_jit_options():
assert True


def test_object_code_get_kernel():
assert True


def test_kernel_from_obj():
assert True
from conftest import can_load_generated_ptx

from cuda.core.experimental import Program


@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
def test_get_kernel():
kernel = """
extern __device__ int B();
extern __device__ int C(int a, int b);
__global__ void A() { int result = C(B(), 1);}
"""
object_code = Program(kernel, "c++").compile("ptx", options=("-rdc=true",))
assert object_code._handle is None
kernel = object_code.get_kernel("A")
assert object_code._handle is not None
assert kernel._handle is not None
10 changes: 1 addition & 9 deletions cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,12 @@
# is strictly prohibited.

import pytest
from conftest import can_load_generated_ptx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that we don't need explicit imports for things from conftest?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this only aplies to fixtures. helper functions still need to be imported. This is what my research and tests have shown.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could add a new file called utils, or helpers and import that instead. There is mixed opinions on best practice online

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what my research and tests have shown.

I was only relying on chatgpt before. It told me the imports are not needed. If that's not true: keeping it simple seems best to me, unless you anticipate that we're accumulating many helper functions. I.e. just keep what you have?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it told me that too, then changed its mind. Keeping it simple was my idea as well, with only 1 helper function, it seems premature to split it into a new file, but perhaps down the line it will make sense

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can move it to tests/utils next time we touch this file


from cuda import cuda, nvrtc
from cuda.core.experimental import Device, Program
from cuda.core.experimental._module import Kernel, ObjectCode


def can_load_generated_ptx():
_, driver_ver = cuda.cuDriverGetVersion()
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
if nvrtc_major * 1000 + nvrtc_minor * 10 > driver_ver:
return False
return True


def test_program_init_valid_code_type():
code = 'extern "C" __global__ void my_kernel() {}'
program = Program(code, "c++")
Expand Down
Loading