Skip to content

Commit daae8e5

Browse files
committed
Add _windows_load_with_dll_basename()
1 parent b5dcd67 commit daae8e5

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

cuda_bindings/cuda/bindings/_path_finder/load_nvidia_dynamic_library.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,67 @@
22
import sys
33

44
if sys.platform == "win32":
5+
import ctypes.wintypes
6+
57
import pywintypes
68
import win32api
9+
10+
# Mirrors WinBase.h (unfortunately not defined already elsewhere)
11+
_WINBASE_LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
12+
713
else:
814
import ctypes
915
import os
1016

17+
_LINUX_CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL
18+
1119
from .find_nvidia_dynamic_library import find_nvidia_dynamic_library
1220

13-
_LINUX_CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL
21+
22+
@functools.cache
23+
def _windows_cuDriverGetVersion() -> int:
24+
handle = win32api.LoadLibrary("nvcuda.dll")
25+
26+
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
27+
GetProcAddress = kernel32.GetProcAddress
28+
GetProcAddress.argtypes = [ctypes.wintypes.HMODULE, ctypes.wintypes.LPCSTR]
29+
GetProcAddress.restype = ctypes.c_void_p
30+
cuDriverGetVersion = GetProcAddress(handle, b"cuDriverGetVersion")
31+
assert cuDriverGetVersion
32+
33+
FUNC_TYPE = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.POINTER(ctypes.c_int))
34+
cuDriverGetVersion_fn = FUNC_TYPE(cuDriverGetVersion)
35+
driver_ver = ctypes.c_int()
36+
err = cuDriverGetVersion_fn(ctypes.byref(driver_ver))
37+
assert err == 0
38+
return driver_ver.value
39+
40+
41+
@functools.cache
42+
def _windows_load_with_dll_basename(name: str) -> int:
43+
driver_ver = _windows_cuDriverGetVersion()
44+
del driver_ver # Keeping this here because it will probably be needed in the future.
45+
46+
if name == "nvJitLink":
47+
dll_name = "nvJitLink_120_0.dll"
48+
elif name == "nvvm":
49+
dll_name = "nvvm64_40_0.dll"
50+
51+
try:
52+
return win32api.LoadLibrary(dll_name)
53+
except pywintypes.error:
54+
pass
55+
56+
return None
1457

1558

1659
@functools.cache
1760
def load_nvidia_dynamic_library(name: str) -> int:
1861
# First try using the platform-specific dynamic loader search mechanisms
1962
if sys.platform == "win32":
20-
pass # TODO
63+
handle = _windows_load_with_dll_basename(name)
64+
if handle:
65+
return handle
2166
else:
2267
dl_path = f"lib{name}.so" # Version intentionally no specified.
2368
try:

0 commit comments

Comments
 (0)