Skip to content

Commit c813450

Browse files
committed
Plug in ctypes.windll.kernel32.GetModuleFileNameW()
1 parent 016a103 commit c813450

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

cuda_bindings/cuda/bindings/_path_finder/load_nvidia_dynamic_library.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ctypes
66
import functools
77
import sys
8+
from typing import Optional, Tuple
89

910
if sys.platform == "win32":
1011
import ctypes.wintypes
@@ -60,24 +61,29 @@ def _windows_cuDriverGetVersion() -> int:
6061

6162

6263
@functools.cache
63-
def _windows_load_with_dll_basename(name: str) -> int:
64+
def _windows_load_with_dll_basename(name: str) -> Tuple[Optional[int], Optional[str]]:
6465
driver_ver = _windows_cuDriverGetVersion()
6566
del driver_ver # Keeping this here because it will probably be needed in the future.
6667

6768
dll_names = SUPPORTED_WINDOWS_DLLS.get(name)
6869
if dll_names is None:
6970
return None
7071

72+
kernel32 = ctypes.windll.kernel32
73+
7174
for dll_name in dll_names:
72-
try:
73-
return win32api.LoadLibrary(dll_name)
74-
except pywintypes.error:
75-
pass
75+
handle = kernel32.LoadLibraryW(ctypes.c_wchar_p(dll_name))
76+
if handle:
77+
buf = ctypes.create_unicode_buffer(260)
78+
n_chars = kernel32.GetModuleFileNameW(ctypes.wintypes.HMODULE(handle), buf, len(buf))
79+
if n_chars == 0:
80+
raise OSError("GetModuleFileNameW failed")
81+
return handle, buf.value
7682

77-
return None
83+
return None, None
7884

7985

80-
def _load_and_report_path_linux(libname, soname: str) -> (int, str):
86+
def _load_and_report_path_linux(libname, soname: str) -> Tuple[int, str]:
8187
handle = ctypes.CDLL(soname, _LINUX_CDLL_MODE)
8288
for symbol_name in EXPECTED_LIB_SYMBOLS[libname]:
8389
symbol = getattr(handle, symbol_name, None)
@@ -100,9 +106,10 @@ def load_nvidia_dynamic_library(libname: str) -> int:
100106
found = _find_nvidia_dynamic_library(libname)
101107
if found.abs_path is None:
102108
if sys.platform == "win32":
103-
handle = _windows_load_with_dll_basename(libname)
109+
handle, abs_path = _windows_load_with_dll_basename(libname)
104110
if handle:
105111
# Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
112+
print(f"SYSTEM ABS_PATH for {libname=!r}: {abs_path}", flush=True)
106113
return handle
107114
else:
108115
try:
@@ -122,6 +129,7 @@ def load_nvidia_dynamic_library(libname: str) -> int:
122129
except pywintypes.error as e:
123130
raise RuntimeError(f"Failed to load DLL at {found.abs_path}: {e}") from e
124131
# Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
132+
print(f"FOUND ABS_PATH for {libname=!r}: {found.abs_path}", flush=True)
125133
return handle # C signed int, matches win32api.GetProcAddress
126134
else:
127135
try:

0 commit comments

Comments
 (0)