Skip to content
Merged
63 changes: 9 additions & 54 deletions cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
# This code was automatically generated with version 12.8.0. Do not modify it directly.
{{if 'Windows' == platform.system()}}
import os
import site
import struct
import win32api
from pywintypes import error
{{else}}
cimport cuda.bindings._lib.dlfcn as dlfcn
from libc.stdint cimport uintptr_t
{{endif}}
from cuda.bindings import path_finder

cdef bint __cuPythonInit = False
{{if 'nvrtcGetErrorString' in found_functions}}cdef void *__nvrtcGetErrorString = NULL{{endif}}
Expand Down Expand Up @@ -46,64 +45,18 @@ cdef bint __cuPythonInit = False
{{if 'nvrtcSetFlowCallback' in found_functions}}cdef void *__nvrtcSetFlowCallback = NULL{{endif}}

cdef int cuPythonInit() except -1 nogil:
{{if 'Windows' != platform.system()}}
cdef void* handle = NULL
{{endif}}

global __cuPythonInit
if __cuPythonInit:
return 0
__cuPythonInit = True

# Load library
{{if 'Windows' == platform.system()}}
with gil:
# First check if the DLL has been loaded by 3rd parties
try:
handle = win32api.GetModuleHandle("nvrtc64_120_0.dll")
except:
handle = None

# Else try default search
if not handle:
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
try:
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
except:
pass

# Final check if DLLs can be found within pip installations
if not handle:
site_packages = [site.getusersitepackages()] + site.getsitepackages()
for sp in site_packages:
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
if not os.path.isdir(mod_path):
continue
os.add_dll_directory(mod_path)
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
try:
handle = win32api.LoadLibraryEx(
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
os.path.join(mod_path, "nvrtc64_120_0.dll"),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)

# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
# located in the same mod_path.
# Update PATH environ so that the two dlls can find each other
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
except:
pass

if not handle:
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
{{else}}
handle = dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW)
if handle == NULL:
with gil:
raise RuntimeError('Failed to dlopen libnvrtc.so.12')
{{endif}}


# Load function
{{if 'Windows' == platform.system()}}
with gil:
handle = path_finder.load_nvidia_dynamic_library("nvrtc")
{{if 'nvrtcGetErrorString' in found_functions}}
try:
global __nvrtcGetErrorString
Expand Down Expand Up @@ -288,6 +241,8 @@ cdef int cuPythonInit() except -1 nogil:
{{endif}}

{{else}}
with gil:
handle = <void*><uintptr_t>path_finder.load_nvidia_dynamic_library("nvrtc")
{{if 'nvrtcGetErrorString' in found_functions}}
global __nvrtcGetErrorString
__nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,50 @@ def _find_so_using_nvidia_lib_dirs(libname, so_basename, error_messages, attachm
return so_name
# Look for a versioned library
# Using sort here mainly to make the result deterministic.
for node in sorted(glob.glob(os.path.join(lib_dir, file_wild))):
so_name = os.path.join(lib_dir, node)
for so_name in sorted(glob.glob(os.path.join(lib_dir, file_wild))):
if os.path.isfile(so_name):
return so_name
_no_such_file_in_sub_dirs(nvidia_sub_dirs, file_wild, error_messages, attachments)
return None


def _append_to_os_environ_path(dirpath):
curr_path = os.environ.get("PATH")
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))


def _find_dll_using_nvidia_bin_dirs(libname, error_messages, attachments):
if libname == "nvvm": # noqa: SIM108
nvidia_sub_dirs = ("nvidia", "*", "nvvm", "bin")
else:
nvidia_sub_dirs = ("nvidia", "*", "bin")
file_wild = libname + "*.dll"
for bin_dir in sys_path_find_sub_dirs(nvidia_sub_dirs):
for node in sorted(glob.glob(os.path.join(bin_dir, file_wild))):
dll_name = os.path.join(bin_dir, node)
if os.path.isfile(dll_name):
return dll_name
dll_name = None
have_builtins = False
for path in sorted(glob.glob(os.path.join(bin_dir, file_wild))):
# nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl:
# nvidia\cuda_nvrtc\bin\
# nvrtc-builtins64_128.dll
# nvrtc64_120_0.alt.dll
# nvrtc64_120_0.dll
node = os.path.basename(path)
if node.endswith(".alt.dll"):
continue
if "-builtins" in node:
have_builtins = True
continue
if dll_name is not None:
continue
if os.path.isfile(path):
dll_name = path
if dll_name is not None:
if have_builtins:
# Add the DLL directory to the search path
os.add_dll_directory(bin_dir)
# Update PATH as a fallback for dependent DLL resolution
_append_to_os_environ_path(bin_dir)
return dll_name
_no_such_file_in_sub_dirs(nvidia_sub_dirs, file_wild, error_messages, attachments)
return None

Expand Down Expand Up @@ -78,7 +103,6 @@ def _find_so_using_cudalib_dir(so_basename, error_messages, attachments):
candidate_so_dirs.append(alt_dir)
libs.reverse()
candidate_so_names = [so_dirname + so_basename for so_dirname in candidate_so_dirs]
error_messages = []
for so_name in candidate_so_names:
if os.path.isfile(so_name):
return so_name
Expand All @@ -98,8 +122,7 @@ def _find_dll_using_cudalib_dir(libname, error_messages, attachments):
if cudalib_dir is None:
return None
file_wild = libname + "*.dll"
for node in sorted(glob.glob(os.path.join(cudalib_dir, file_wild))):
dll_name = os.path.join(cudalib_dir, node)
for dll_name in sorted(glob.glob(os.path.join(cudalib_dir, file_wild))):
if os.path.isfile(dll_name):
return dll_name
error_messages.append(f"No such file: {file_wild}")
Expand All @@ -123,7 +146,7 @@ def find_nvidia_dynamic_library(name: str) -> str:
dll_name = _find_dll_using_cudalib_dir(name, error_messages, attachments)
if dll_name is None:
attachments = "\n".join(attachments)
raise RuntimeError(f"Failure finding {name}*.dll: {', '.join(error_messages)}\n{attachments}")
raise RuntimeError(f'Failure finding "{name}*.dll": {", ".join(error_messages)}\n{attachments}')
return dll_name

so_basename = f"lib{name}.so"
Expand All @@ -135,5 +158,5 @@ def find_nvidia_dynamic_library(name: str) -> str:
so_name = _find_so_using_cudalib_dir(so_basename, error_messages, attachments)
if so_name is None:
attachments = "\n".join(attachments)
raise RuntimeError(f"Failure finding {so_basename}: {', '.join(error_messages)}\n{attachments}")
raise RuntimeError(f'Failure finding "{so_basename}": {", ".join(error_messages)}\n{attachments}')
return so_name
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import win32api

# Mirrors WinBase.h (unfortunately not defined already elsewhere)
_WINBASE_LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
_WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
_WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000

else:
import ctypes
Expand Down Expand Up @@ -77,8 +78,9 @@ def load_nvidia_dynamic_library(name: str) -> int:

dl_path = find_nvidia_dynamic_library(name)
if sys.platform == "win32":
flags = _WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | _WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
try:
handle = win32api.LoadLibrary(dl_path)
handle = win32api.LoadLibraryEx(dl_path, 0, flags)
except pywintypes.error as e:
raise RuntimeError(f"Failed to load DLL at {dl_path}: {e}") from e
# Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
Expand Down
13 changes: 11 additions & 2 deletions cuda_bindings/tests/path_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

for k, v in paths.items():
print(f"{k}: {v}", flush=True)
print()

print(path_finder.find_nvidia_dynamic_library("nvvm"))
print(path_finder.find_nvidia_dynamic_library("nvJitLink"))
libnames = ("nvJitLink", "nvrtc", "nvvm")

for libname in libnames:
print(path_finder.find_nvidia_dynamic_library(libname))
print()

for libname in libnames:
print(libname)
print(path_finder.load_nvidia_dynamic_library(libname))
print()
Loading