|
2 | 2 | import sys
|
3 | 3 |
|
4 | 4 | if sys.platform == "win32":
|
| 5 | + import ctypes.wintypes |
| 6 | + |
5 | 7 | import pywintypes
|
6 | 8 | import win32api
|
| 9 | + |
| 10 | + # Mirrors WinBase.h (unfortunately not defined already elsewhere) |
| 11 | + _WINBASE_LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800 |
| 12 | + |
7 | 13 | else:
|
8 | 14 | import ctypes
|
9 | 15 | import os
|
10 | 16 |
|
| 17 | + _LINUX_CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL |
| 18 | + |
11 | 19 | from .find_nvidia_dynamic_library import find_nvidia_dynamic_library
|
12 | 20 |
|
13 |
| -_LINUX_CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL |
| 21 | + |
| 22 | +@functools.cache |
| 23 | +def _windows_cuDriverGetVersion() -> int: |
| 24 | + handle = win32api.LoadLibrary("nvcuda.dll") |
| 25 | + |
| 26 | + kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) |
| 27 | + GetProcAddress = kernel32.GetProcAddress |
| 28 | + GetProcAddress.argtypes = [ctypes.wintypes.HMODULE, ctypes.wintypes.LPCSTR] |
| 29 | + GetProcAddress.restype = ctypes.c_void_p |
| 30 | + cuDriverGetVersion = GetProcAddress(handle, b"cuDriverGetVersion") |
| 31 | + assert cuDriverGetVersion |
| 32 | + |
| 33 | + FUNC_TYPE = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.POINTER(ctypes.c_int)) |
| 34 | + cuDriverGetVersion_fn = FUNC_TYPE(cuDriverGetVersion) |
| 35 | + driver_ver = ctypes.c_int() |
| 36 | + err = cuDriverGetVersion_fn(ctypes.byref(driver_ver)) |
| 37 | + assert err == 0 |
| 38 | + return driver_ver.value |
| 39 | + |
| 40 | + |
| 41 | +@functools.cache |
| 42 | +def _windows_load_with_dll_basename(name: str) -> int: |
| 43 | + driver_ver = _windows_cuDriverGetVersion() |
| 44 | + del driver_ver # Keeping this here because it will probably be needed in the future. |
| 45 | + |
| 46 | + if name == "nvJitLink": |
| 47 | + dll_name = "nvJitLink_120_0.dll" |
| 48 | + elif name == "nvvm": |
| 49 | + dll_name = "nvvm64_40_0.dll" |
| 50 | + |
| 51 | + try: |
| 52 | + return win32api.LoadLibrary(dll_name) |
| 53 | + except pywintypes.error: |
| 54 | + pass |
| 55 | + |
| 56 | + return None |
14 | 57 |
|
15 | 58 |
|
16 | 59 | @functools.cache
|
17 | 60 | def load_nvidia_dynamic_library(name: str) -> int:
|
18 | 61 | # First try using the platform-specific dynamic loader search mechanisms
|
19 | 62 | if sys.platform == "win32":
|
20 |
| - pass # TODO |
| 63 | + handle = _windows_load_with_dll_basename(name) |
| 64 | + if handle: |
| 65 | + return handle |
21 | 66 | else:
|
22 | 67 | dl_path = f"lib{name}.so" # Version intentionally no specified.
|
23 | 68 | try:
|
|
0 commit comments