@@ -46,39 +46,45 @@ cdef int cuPythonInit() except -1 nogil:
4646
4747 # Load library
4848 {{if 'Windows' == platform.system()}}
49- handle = NULL
5049 with gil:
5150 # First check if the DLL has been loaded by 3rd parties
5251 try:
5352 handle = win32api.GetModuleHandle("nvrtc64_120_0.dll")
5453 except:
55- pass
54+ handle = None
5655
5756 # Try default search
58- if handle == NULL :
57+ if not handle :
5958 LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
6059 try:
6160 handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
6261 except:
6362 pass
6463
6564 # Check if DLLs are found within pip installations
66- if handle == NULL :
65+ if not handle :
6766 site_packages = [site.getusersitepackages()] + site.getsitepackages()
6867 for sp in site_packages:
6968 mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
7069 if not os.path.isdir(mod_path):
7170 continue
7271 os.add_dll_directory(mod_path)
72+ LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
73+ LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
7374 try:
7475 handle = win32api.LoadLibraryEx(
7576 # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
7677 os.path.join(mod_path, "nvrtc64_120_0.dll"),
7778 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
79+
80+ # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
81+ # located in the same mod_path.
82+ # Update PATH environ so that the two dlls can find each other
83+ os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
7884 except:
7985 pass
8086
81- if handle == NULL :
87+ if not handle :
8288 raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
8389 {{else}}
8490 handle = NULL
0 commit comments