Skip to content

Commit 1f4d738

Browse files
committed
Support discovery of nvrtc libraries at run time
CTK installations distribute their libraries using personal packages: - nvidia-cuda-nvrtc-cuXX The relative path of their libraries to cuda-bindings is consistent, and allows us to use relative paths to discover them when loading at run time.
1 parent ca9e641 commit 1f4d738

File tree

3 files changed

+98
-14
lines changed

3 files changed

+98
-14
lines changed

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

+51-3
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
# this software and related documentation outside the terms of the EULA
77
# is strictly prohibited.
88
{{if 'Windows' == platform.system()}}
9-
import win32api
9+
import os
10+
import site
1011
import struct
12+
import win32api
1113
from pywintypes import error
1214
{{else}}
1315
cimport cuda.bindings._lib.dlfcn as dlfcn
@@ -40,8 +42,8 @@ cdef int cuPythonInit() except -1 nogil:
4042

4143
# Load library
4244
{{if 'Windows' == platform.system()}}
43-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
4445
with gil:
46+
# First check if the DLL has been loaded by 3rd parties
4547
try:
4648
handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
4749
except:
@@ -51,7 +53,53 @@ cdef int cuPythonInit() except -1 nogil:
5153
try:
5254
handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
5355
except:
54-
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll')
56+
handle = None
57+
58+
# Else try default search
59+
if not handle:
60+
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
61+
try:
62+
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
63+
except:
64+
pass
65+
66+
# Final check if DLLs can be found within pip installations
67+
if not handle:
68+
site_packages = [site.getusersitepackages()] + site.getsitepackages()
69+
for sp in site_packages:
70+
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
71+
if not os.path.isdir(mod_path):
72+
continue
73+
os.add_dll_directory(mod_path)
74+
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
75+
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
76+
try:
77+
handle = win32api.LoadLibraryEx(
78+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
79+
os.path.join(mod_path, "nvrtc64_112_0.dll"),
80+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
81+
82+
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
83+
# located in the same mod_path.
84+
# Update PATH environ so that the two dlls can find each other
85+
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
86+
except:
87+
try:
88+
handle = win32api.LoadLibraryEx(
89+
os.path.join(mod_path, "nvrtc64_111_0.dll"),
90+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
91+
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
92+
except:
93+
try:
94+
handle = win32api.LoadLibraryEx(
95+
os.path.join(mod_path, "nvrtc64_110_0.dll"),
96+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
97+
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
98+
except:
99+
pass
100+
101+
if not handle:
102+
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll')
55103
{{else}}
56104
handle = NULL
57105
if handle == NULL:

cuda_bindings/pyproject.toml

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ dependencies = [
3232
"pywin32; sys_platform == 'win32'",
3333
]
3434

35+
[project.optional-dependencies]
36+
all = [
37+
"nvidia-cuda-nvrtc-cu11"
38+
]
39+
3540
[project.urls]
3641
Repository = "https://github.com/NVIDIA/cuda-python"
3742
Documentation = "https://nvidia.github.io/cuda-python/"

cuda_bindings/setup.py

+42-11
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717
from pyclibrary import CParser
1818
from setuptools import find_packages, setup
1919
from setuptools.extension import Extension
20+
from setuptools.command.bdist_wheel import bdist_wheel
2021
from setuptools.command.build_ext import build_ext
2122
import versioneer
2223

2324

2425
# ----------------------------------------------------------------------
2526
# Fetch configuration options
2627

27-
CUDA_HOME = os.environ.get("CUDA_HOME")
28-
if not CUDA_HOME:
29-
CUDA_HOME = os.environ.get("CUDA_PATH")
28+
CUDA_HOME = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", None))
3029
if not CUDA_HOME:
3130
raise RuntimeError('Environment variable CUDA_HOME or CUDA_PATH is not set')
3231

@@ -236,20 +235,52 @@ def do_cythonize(extensions):
236235
extensions += prep_extensions(sources)
237236

238237
# ---------------------------------------------------------------------
239-
# Custom build_ext command
240-
# Files are build in two steps:
241-
# 1) Cythonized (in the do_cythonize() command)
242-
# 2) Compiled to .o files as part of build_ext
243-
# This class is solely for passing the value of nthreads to build_ext
238+
# Custom cmdclass extensions
239+
240+
building_wheel = False
241+
242+
243+
class WheelsBuildExtensions(bdist_wheel):
244+
def run(self):
245+
global building_wheel
246+
building_wheel = True
247+
super().run()
248+
244249

245250
class ParallelBuildExtensions(build_ext):
246251
def initialize_options(self):
247-
build_ext.initialize_options(self)
252+
super().initialize_options()
248253
if nthreads > 0:
249254
self.parallel = nthreads
250255

251-
def finalize_options(self):
252-
build_ext.finalize_options(self)
256+
def build_extension(self, ext):
257+
if building_wheel and sys.platform == "linux":
258+
# Strip binaries to remove debug symbols
259+
extra_linker_flags = ["-Wl,--strip-all"]
260+
261+
# Allow extensions to discover libraries at runtime
262+
# relative their wheels installation.
263+
if ext.name == "cuda.bindings._bindings.cynvrtc":
264+
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
265+
elif ext.name == "cuda.bindings._internal.nvjitlink":
266+
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"
267+
else:
268+
ldflag = None
269+
270+
if ldflag:
271+
extra_linker_flags.append(ldflag)
272+
else:
273+
extra_linker_flags = []
274+
275+
ext.extra_link_args += extra_linker_flags
276+
super().build_extension(ext)
277+
278+
279+
cmdclass = {
280+
"bdist_wheel": WheelsBuildExtensions,
281+
"build_ext": ParallelBuildExtensions,
282+
}
283+
253284

254285
cmdclass = {"build_ext": ParallelBuildExtensions}
255286
cmdclass = versioneer.get_cmdclass(cmdclass)

0 commit comments

Comments
 (0)