Skip to content

Commit 3d2fd88

Browse files
authored
Merge pull request #574 from leofang/backport-563-to-11.8.x
[Backport] Fix support for NVVM from conda on Windows + other fixes
2 parents a302a42 + e027b05 commit 3d2fd88

File tree

3 files changed

+62
-58
lines changed

3 files changed

+62
-58
lines changed

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -55,54 +55,60 @@ cdef int cuPythonInit() except -1 nogil:
5555
except:
5656
handle = None
5757

58-
# Else try default search
59-
if not handle:
60-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
61-
try:
62-
handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
63-
except:
64-
try:
65-
handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
66-
except:
67-
try:
68-
handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
69-
except:
70-
pass
71-
72-
# Final check if DLLs can be found within pip installations
58+
# Next check if DLLs can be found within pip installations
7359
if not handle:
60+
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
61+
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
7462
site_packages = [site.getusersitepackages()] + site.getsitepackages()
7563
for sp in site_packages:
7664
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
7765
if not os.path.isdir(mod_path):
7866
continue
7967
os.add_dll_directory(mod_path)
80-
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
81-
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
82-
try:
83-
handle = win32api.LoadLibraryEx(
84-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
85-
os.path.join(mod_path, "nvrtc64_112_0.dll"),
86-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
87-
88-
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
89-
# located in the same mod_path.
90-
# Update PATH environ so that the two dlls can find each other
91-
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
92-
except:
9368
try:
9469
handle = win32api.LoadLibraryEx(
95-
os.path.join(mod_path, "nvrtc64_111_0.dll"),
70+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
71+
os.path.join(mod_path, "nvrtc64_112_0.dll"),
9672
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
73+
74+
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
75+
# located in the same mod_path.
76+
# Update PATH environ so that the two dlls can find each other
9777
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
9878
except:
9979
try:
10080
handle = win32api.LoadLibraryEx(
101-
os.path.join(mod_path, "nvrtc64_110_0.dll"),
81+
os.path.join(mod_path, "nvrtc64_111_0.dll"),
10282
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
10383
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
10484
except:
105-
pass
85+
try:
86+
handle = win32api.LoadLibraryEx(
87+
os.path.join(mod_path, "nvrtc64_110_0.dll"),
88+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
89+
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
90+
except:
91+
pass
92+
else:
93+
break
94+
else:
95+
break
96+
else:
97+
break
98+
else:
99+
# Else try default search
100+
# Only reached if DLL wasn't found in any site-package path
101+
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
102+
try:
103+
handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
104+
except:
105+
try:
106+
handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
107+
except:
108+
try:
109+
handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
110+
except:
111+
pass
106112

107113
if not handle:
108114
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll')

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ cdef void* __nvvmGetProgramLog = NULL
4141

4242

4343
cdef inline list get_site_packages():
44-
return [site.getusersitepackages()] + site.getsitepackages()
44+
return [site.getusersitepackages()] + site.getsitepackages() + ["conda"]
4545

4646

4747
cdef load_library(const int driver_ver):
@@ -50,44 +50,42 @@ cdef load_library(const int driver_ver):
5050
for suffix in get_nvvm_dso_version_suffix(driver_ver):
5151
if len(suffix) == 0:
5252
continue
53-
dll_name = "nvvm64_40_0"
53+
dll_name = "nvvm64_40_0.dll"
5454

5555
# First check if the DLL has been loaded by 3rd parties
5656
try:
57-
handle = win32api.GetModuleHandle(dll_name)
57+
return win32api.GetModuleHandle(dll_name)
5858
except:
5959
pass
60-
else:
61-
break
6260

63-
# Next, check if DLLs are installed via pip
61+
# Next, check if DLLs are installed via pip or conda
6462
for sp in get_site_packages():
65-
mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin")
66-
if not os.path.isdir(mod_path):
67-
continue
68-
os.add_dll_directory(mod_path)
69-
try:
70-
handle = win32api.LoadLibraryEx(
71-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
72-
os.path.join(mod_path, dll_name),
73-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
74-
except:
75-
pass
76-
else:
77-
break
63+
if sp == "conda":
64+
# nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path
65+
conda_prefix = os.environ.get("CONDA_PREFIX")
66+
if conda_prefix is None:
67+
continue
68+
mod_path = os.path.join(conda_prefix, "Library", "nvvm", "bin")
69+
else:
70+
mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin")
71+
if os.path.isdir(mod_path):
72+
os.add_dll_directory(mod_path)
73+
try:
74+
return win32api.LoadLibraryEx(
75+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
76+
os.path.join(mod_path, dll_name),
77+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
78+
except:
79+
pass
7880

7981
# Finally, try default search
82+
# Only reached if DLL wasn't found in any site-package path
8083
try:
81-
handle = win32api.LoadLibrary(dll_name)
84+
return win32api.LoadLibrary(dll_name)
8285
except:
8386
pass
84-
else:
85-
break
86-
else:
87-
raise RuntimeError('Failed to load nvvm')
8887

89-
assert handle != 0
90-
return handle
88+
raise RuntimeError('Failed to load nvvm')
9189

9290

9391
cdef int _check_or_init_nvvm() except -1 nogil:

cuda_bindings/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def build_extension(self, ext):
307307
# to <loc>/site-packages/nvidia/cuda_nvcc/nvvm/lib64/
308308
rel1 = "$ORIGIN/../../../nvidia/cuda_nvcc/nvvm/lib64"
309309
# from <loc>/lib/python3.*/site-packages/cuda/bindings/_internal/
310-
# to <loc>/lib/nvvm/lib64/
310+
# to <loc>/nvvm/lib64/
311311
rel2 = "$ORIGIN/../../../../../../nvvm/lib64"
312312
ldflag = f"-Wl,--disable-new-dtags,-rpath,{rel1},-rpath,{rel2}"
313313
else:

0 commit comments

Comments
 (0)