Skip to content

Commit 0dfae43

Browse files
leofangrwgk
andauthored
Fix support for NVVM from conda on Windows + other fixes (NVIDIA#563)
* also search for conda nvvm on windows * fix comment for conda nvvm on linux * fix path loop & dll name * Ensure mod_path is always defined when used. Make DLL search order consistent between all three cases. * Fix bug in previous commit: need to break out of loop over suffix if site-package DLL was found. Using the most obvious approach to solve this problem: return immediately on success. * Move LOAD_LIBRARY_SEARCH_* constants outside loop. * Fix oversight (forgot to replace one assignment with return) --------- Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]>
1 parent d425a88 commit 0dfae43

File tree

4 files changed

+65
-76
lines changed

4 files changed

+65
-76
lines changed

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,36 +60,37 @@ cdef int cuPythonInit() except -1 nogil:
6060
except:
6161
handle = None
6262

63-
# Else try default search
64-
if not handle:
65-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
66-
try:
67-
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
68-
except:
69-
pass
70-
71-
# Final check if DLLs can be found within pip installations
63+
# Check if DLLs can be found within pip installations
7264
if not handle:
65+
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
66+
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
7367
site_packages = [site.getusersitepackages()] + site.getsitepackages()
7468
for sp in site_packages:
7569
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
76-
if not os.path.isdir(mod_path):
77-
continue
78-
os.add_dll_directory(mod_path)
79-
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
80-
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
81-
try:
82-
handle = win32api.LoadLibraryEx(
83-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
84-
os.path.join(mod_path, "nvrtc64_120_0.dll"),
85-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
86-
87-
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
88-
# located in the same mod_path.
89-
# Update PATH environ so that the two dlls can find each other
90-
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
91-
except:
92-
pass
70+
if os.path.isdir(mod_path):
71+
os.add_dll_directory(mod_path)
72+
try:
73+
handle = win32api.LoadLibraryEx(
74+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
75+
os.path.join(mod_path, "nvrtc64_120_0.dll"),
76+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
77+
78+
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
79+
# located in the same mod_path.
80+
# Update PATH environ so that the two dlls can find each other
81+
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
82+
except:
83+
pass
84+
else:
85+
break
86+
else:
87+
# Else try default search
88+
# Only reached if DLL wasn't found in any site-package path
89+
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
90+
try:
91+
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
92+
except:
93+
pass
9394

9495
if not handle:
9596
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,40 +56,30 @@ cdef load_library(const int driver_ver):
5656

5757
# First check if the DLL has been loaded by 3rd parties
5858
try:
59-
handle = win32api.GetModuleHandle(dll_name)
59+
return win32api.GetModuleHandle(dll_name)
6060
except:
6161
pass
62-
else:
63-
break
6462

6563
# Next, check if DLLs are installed via pip
6664
for sp in get_site_packages():
6765
mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin")
68-
if not os.path.isdir(mod_path):
69-
continue
70-
os.add_dll_directory(mod_path)
71-
try:
72-
handle = win32api.LoadLibraryEx(
73-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
74-
os.path.join(mod_path, dll_name),
75-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
76-
except:
77-
pass
78-
else:
79-
break
80-
66+
if os.path.isdir(mod_path):
67+
os.add_dll_directory(mod_path)
68+
try:
69+
return win32api.LoadLibraryEx(
70+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
71+
os.path.join(mod_path, dll_name),
72+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
73+
except:
74+
pass
8175
# Finally, try default search
76+
# Only reached if DLL wasn't found in any site-package path
8277
try:
83-
handle = win32api.LoadLibrary(dll_name)
78+
return win32api.LoadLibrary(dll_name)
8479
except:
8580
pass
86-
else:
87-
break
88-
else:
89-
raise RuntimeError('Failed to load nvJitLink')
9081

91-
assert handle != 0
92-
return handle
82+
raise RuntimeError('Failed to load nvJitLink')
9383

9484

9585
cdef int _check_or_init_nvjitlink() except -1 nogil:

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ cdef void* __nvvmGetProgramLog = NULL
4141

4242

4343
cdef inline list get_site_packages():
44-
return [site.getusersitepackages()] + site.getsitepackages()
44+
return [site.getusersitepackages()] + site.getsitepackages() + ["conda"]
4545

4646

4747
cdef load_library(const int driver_ver):
@@ -50,44 +50,42 @@ cdef load_library(const int driver_ver):
5050
for suffix in get_nvvm_dso_version_suffix(driver_ver):
5151
if len(suffix) == 0:
5252
continue
53-
dll_name = "nvvm64_40_0"
53+
dll_name = "nvvm64_40_0.dll"
5454

5555
# First check if the DLL has been loaded by 3rd parties
5656
try:
57-
handle = win32api.GetModuleHandle(dll_name)
57+
return win32api.GetModuleHandle(dll_name)
5858
except:
5959
pass
60-
else:
61-
break
6260

63-
# Next, check if DLLs are installed via pip
61+
# Next, check if DLLs are installed via pip or conda
6462
for sp in get_site_packages():
65-
mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin")
66-
if not os.path.isdir(mod_path):
67-
continue
68-
os.add_dll_directory(mod_path)
69-
try:
70-
handle = win32api.LoadLibraryEx(
71-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
72-
os.path.join(mod_path, dll_name),
73-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
74-
except:
75-
pass
76-
else:
77-
break
63+
if sp == "conda":
64+
# nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path
65+
conda_prefix = os.environ.get("CONDA_PREFIX")
66+
if conda_prefix is None:
67+
continue
68+
mod_path = os.path.join(conda_prefix, "Library", "nvvm", "bin")
69+
else:
70+
mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin")
71+
if os.path.isdir(mod_path):
72+
os.add_dll_directory(mod_path)
73+
try:
74+
return win32api.LoadLibraryEx(
75+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
76+
os.path.join(mod_path, dll_name),
77+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
78+
except:
79+
pass
7880

7981
# Finally, try default search
82+
# Only reached if DLL wasn't found in any site-package path
8083
try:
81-
handle = win32api.LoadLibrary(dll_name)
84+
return win32api.LoadLibrary(dll_name)
8285
except:
8386
pass
84-
else:
85-
break
86-
else:
87-
raise RuntimeError('Failed to load nvvm')
8887

89-
assert handle != 0
90-
return handle
88+
raise RuntimeError('Failed to load nvvm')
9189

9290

9391
cdef int _check_or_init_nvvm() except -1 nogil:

cuda_bindings/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def build_extension(self, ext):
390390
# to <loc>/site-packages/nvidia/cuda_nvcc/nvvm/lib64/
391391
rel1 = "$ORIGIN/../../../nvidia/cuda_nvcc/nvvm/lib64"
392392
# from <loc>/lib/python3.*/site-packages/cuda/bindings/_internal/
393-
# to <loc>/lib/nvvm/lib64/
393+
# to <loc>/nvvm/lib64/
394394
rel2 = "$ORIGIN/../../../../../../nvvm/lib64"
395395
ldflag = f"-Wl,--disable-new-dtags,-rpath,{rel1},-rpath,{rel2}"
396396
else:

0 commit comments

Comments
 (0)