Skip to content

Commit f88cbf3

Browse files
committed
Fix bug in previous commit: need to break out of loop of suffix if site-package DLL was found. Using the most obvious approach to solve this problem: return immediately on success.
1 parent a0baf71 commit f88cbf3

File tree

2 files changed

+20
-39
lines changed

2 files changed

+20
-39
lines changed

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

+10-20
Original file line numberDiff line numberDiff line change
@@ -56,40 +56,30 @@ cdef load_library(const int driver_ver):
5656

5757
# First check if the DLL has been loaded by 3rd parties
5858
try:
59-
handle = win32api.GetModuleHandle(dll_name)
59+
return win32api.GetModuleHandle(dll_name)
6060
except:
6161
pass
62-
else:
63-
break
6462

6563
# Next, check if DLLs are installed via pip
6664
for sp in get_site_packages():
6765
mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin")
6866
if os.path.isdir(mod_path):
6967
os.add_dll_directory(mod_path)
7068
try:
71-
handle = win32api.LoadLibraryEx(
69+
return win32api.LoadLibraryEx(
7270
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
7371
os.path.join(mod_path, dll_name),
7472
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
7573
except:
7674
pass
77-
else:
78-
break
79-
else:
80-
# Finally, try default search
81-
# Only reached if DLL wasn't found in any site-package path
82-
try:
83-
handle = win32api.LoadLibrary(dll_name)
84-
except:
85-
pass
86-
else:
87-
break
88-
else:
89-
raise RuntimeError('Failed to load nvJitLink')
90-
91-
assert handle != 0
92-
return handle
75+
# Finally, try default search
76+
# Only reached if DLL wasn't found in any site-package path
77+
try:
78+
return win32api.LoadLibrary(dll_name)
79+
except:
80+
pass
81+
82+
raise RuntimeError('Failed to load nvJitLink')
9383

9484

9585
cdef int _check_or_init_nvjitlink() except -1 nogil:

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

+10-19
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,9 @@ cdef load_library(const int driver_ver):
5454

5555
# First check if the DLL has been loaded by 3rd parties
5656
try:
57-
handle = win32api.GetModuleHandle(dll_name)
57+
return win32api.GetModuleHandle(dll_name)
5858
except:
5959
pass
60-
else:
61-
break
6260

6361
# Next, check if DLLs are installed via pip or conda
6462
for sp in get_site_packages():
@@ -73,28 +71,21 @@ cdef load_library(const int driver_ver):
7371
if os.path.isdir(mod_path):
7472
os.add_dll_directory(mod_path)
7573
try:
76-
handle = win32api.LoadLibraryEx(
74+
return win32api.LoadLibraryEx(
7775
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
7876
os.path.join(mod_path, dll_name),
7977
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
8078
except:
8179
pass
82-
else:
83-
break
84-
else:
85-
# Finally, try default search
86-
# Only reached if DLL wasn't found in any site-package path
87-
try:
88-
handle = win32api.LoadLibrary(dll_name)
89-
except:
90-
pass
91-
else:
92-
break
93-
else:
94-
raise RuntimeError('Failed to load nvvm')
9580

96-
assert handle != 0
97-
return handle
81+
# Finally, try default search
82+
# Only reached if DLL wasn't found in any site-package path
83+
try:
84+
handle = win32api.LoadLibrary(dll_name)
85+
except:
86+
pass
87+
88+
raise RuntimeError('Failed to load nvvm')
9889

9990

10091
cdef int _check_or_init_nvvm() except -1 nogil:

0 commit comments

Comments
 (0)