5
5
import ctypes
6
6
import functools
7
7
import sys
8
+ from typing import Optional , Tuple
8
9
9
10
if sys .platform == "win32" :
10
11
import ctypes .wintypes
@@ -60,24 +61,29 @@ def _windows_cuDriverGetVersion() -> int:
60
61
61
62
62
63
@functools .cache
63
- def _windows_load_with_dll_basename (name : str ) -> int :
64
+ def _windows_load_with_dll_basename (name : str ) -> Tuple [ Optional [ int ], Optional [ str ]] :
64
65
driver_ver = _windows_cuDriverGetVersion ()
65
66
del driver_ver # Keeping this here because it will probably be needed in the future.
66
67
67
68
dll_names = SUPPORTED_WINDOWS_DLLS .get (name )
68
69
if dll_names is None :
69
70
return None
70
71
72
+ kernel32 = ctypes .windll .kernel32
73
+
71
74
for dll_name in dll_names :
72
- try :
73
- return win32api .LoadLibrary (dll_name )
74
- except pywintypes .error :
75
- pass
75
+ handle = kernel32 .LoadLibraryW (ctypes .c_wchar_p (dll_name ))
76
+ if handle :
77
+ buf = ctypes .create_unicode_buffer (260 )
78
+ n_chars = kernel32 .GetModuleFileNameW (ctypes .wintypes .HMODULE (handle ), buf , len (buf ))
79
+ if n_chars == 0 :
80
+ raise OSError ("GetModuleFileNameW failed" )
81
+ return handle , buf .value
76
82
77
- return None
83
+ return None , None
78
84
79
85
80
- def _load_and_report_path_linux (libname , soname : str ) -> ( int , str ) :
86
+ def _load_and_report_path_linux (libname , soname : str ) -> Tuple [ int , str ] :
81
87
handle = ctypes .CDLL (soname , _LINUX_CDLL_MODE )
82
88
for symbol_name in EXPECTED_LIB_SYMBOLS [libname ]:
83
89
symbol = getattr (handle , symbol_name , None )
@@ -100,9 +106,10 @@ def load_nvidia_dynamic_library(libname: str) -> int:
100
106
found = _find_nvidia_dynamic_library (libname )
101
107
if found .abs_path is None :
102
108
if sys .platform == "win32" :
103
- handle = _windows_load_with_dll_basename (libname )
109
+ handle , abs_path = _windows_load_with_dll_basename (libname )
104
110
if handle :
105
111
# Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
112
+ print (f"SYSTEM ABS_PATH for { libname = !r} : { abs_path } " , flush = True )
106
113
return handle
107
114
else :
108
115
try :
@@ -122,6 +129,7 @@ def load_nvidia_dynamic_library(libname: str) -> int:
122
129
except pywintypes .error as e :
123
130
raise RuntimeError (f"Failed to load DLL at { found .abs_path } : { e } " ) from e
124
131
# Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
132
+ print (f"FOUND ABS_PATH for { libname = !r} : { found .abs_path } " , flush = True )
125
133
return handle # C signed int, matches win32api.GetProcAddress
126
134
else :
127
135
try :
0 commit comments