6
6
# this software and related documentation outside the terms of the EULA
7
7
# is strictly prohibited.
8
8
{{if 'Windows' == platform.system()}}
9
- import win32api
9
+ import os
10
+ import site
10
11
import struct
12
+ import win32api
11
13
from pywintypes import error
12
14
{{else}}
13
15
cimport cuda.bindings._lib.dlfcn as dlfcn
@@ -40,8 +42,8 @@ cdef int cuPythonInit() except -1 nogil:
40
42
41
43
# Load library
42
44
{{if 'Windows' == platform.system()}}
43
- LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
44
45
with gil:
46
+ # First check if the DLL has been loaded by 3rd parties
45
47
try:
46
48
handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
47
49
except:
@@ -51,7 +53,53 @@ cdef int cuPythonInit() except -1 nogil:
51
53
try:
52
54
handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
53
55
except:
54
- raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll')
56
+ handle = None
57
+
58
+ # Else try default search
59
+ if not handle:
60
+ LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
61
+ try:
62
+ handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
63
+ except:
64
+ pass
65
+
66
+ # Final check if DLLs can be found within pip installations
67
+ if not handle:
68
+ site_packages = [site.getusersitepackages()] + site.getsitepackages()
69
+ for sp in site_packages:
70
+ mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
71
+ if not os.path.isdir(mod_path):
72
+ continue
73
+ os.add_dll_directory(mod_path)
74
+ LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
75
+ LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
76
+ try:
77
+ handle = win32api.LoadLibraryEx(
78
+ # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
79
+ os.path.join(mod_path, "nvrtc64_112_0.dll"),
80
+ 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
81
+
82
+ # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
83
+ # located in the same mod_path.
84
+ # Update PATH environ so that the two dlls can find each other
85
+ os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
86
+ except:
87
+ try:
88
+ handle = win32api.LoadLibraryEx(
89
+ os.path.join(mod_path, "nvrtc64_111_0.dll"),
90
+ 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
91
+ os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
92
+ except:
93
+ try:
94
+ handle = win32api.LoadLibraryEx(
95
+ os.path.join(mod_path, "nvrtc64_110_0.dll"),
96
+ 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
97
+ os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
98
+ except:
99
+ pass
100
+
101
+ if not handle:
102
+ raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll')
55
103
{{else}}
56
104
handle = NULL
57
105
if handle == NULL:
0 commit comments