diff --git a/cuda_core/cuda/core/experimental/_program.py b/cuda_core/cuda/core/experimental/_program.py index 1db453fed..cdef7c3be 100644 --- a/cuda_core/cuda/core/experimental/_program.py +++ b/cuda_core/cuda/core/experimental/_program.py @@ -298,6 +298,7 @@ class ProgramOptions: split_compile: int | None = None fdevice_syntax_only: bool | None = None minimal: bool | None = None + numba_debug: bool | None = None # Custom option for Numba debugging def __post_init__(self): self._name = self.name.encode() @@ -418,6 +419,8 @@ def __post_init__(self): self._formatted_options.append("--fdevice-syntax-only") if self.minimal is not None and self.minimal: self._formatted_options.append("--minimal") + if self.numba_debug: + self._formatted_options.append("--numba-debug") def _as_bytes(self): # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index adc778973..3add6b6b4 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -31,12 +31,29 @@ def _is_nvvm_available(): ) try: - from cuda.core.experimental._utils.cuda_utils import driver, handle_return + from cuda.core.experimental._utils.cuda_utils import driver, handle_return, nvrtc _cuda_driver_version = handle_return(driver.cuDriverGetVersion()) except Exception: _cuda_driver_version = 0 + +def _get_nvrtc_version_for_tests(): + """ + Get NVRTC version. + + Returns: + int: Version in format major * 1000 + minor * 100 (e.g., 13200 for CUDA 13.2) + None: If NVRTC is not available + """ + try: + nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) + version = nvrtc_major * 1000 + nvrtc_minor * 100 + return version + except Exception: + return None + + _libnvvm_version = None _libnvvm_version_attempted = False @@ -176,6 +193,13 @@ def ptx_code_object(): [ ProgramOptions(name="abc"), ProgramOptions(device_code_optimize=True, debug=True), + pytest.param( + ProgramOptions(debug=True, numba_debug=True), + marks=pytest.mark.skipif( + (_get_nvrtc_version_for_tests() or 0) < 13200, + reason="numba_debug requires NVRTC >= 13.2", + ), + ), ProgramOptions(relocatable_device_code=True, max_register_count=32), ProgramOptions(ftz=True, prec_sqrt=False, prec_div=False), ProgramOptions(fma=False, use_fast_math=True),