Skip to content

Commit 5e6c751

Browse files
authored
Merge branch '11.8.x' into backport-307-to-11.8.x
2 parents 51c2d6c + a20f0f4 commit 5e6c751

File tree

3 files changed

+105
-17
lines changed

3 files changed

+105
-17
lines changed

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

+60-6
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
# this software and related documentation outside the terms of the EULA
77
# is strictly prohibited.
88
{{if 'Windows' == platform.system()}}
9-
import win32api
9+
import os
10+
import site
1011
import struct
12+
import win32api
1113
from pywintypes import error
1214
{{else}}
1315
cimport cuda.bindings._lib.dlfcn as dlfcn
@@ -40,18 +42,70 @@ cdef int cuPythonInit() except -1 nogil:
4042

4143
# Load library
4244
{{if 'Windows' == platform.system()}}
43-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
4445
with gil:
46+
# First check if the DLL has been loaded by 3rd parties
4547
try:
46-
handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
48+
handle = win32api.GetModuleHandle("nvrtc64_112_0.dll")
4749
except:
4850
try:
49-
handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
51+
handle = win32api.GetModuleHandle("nvrtc64_111_0.dll")
52+
except:
53+
try:
54+
handle = win32api.GetModuleHandle("nvrtc64_110_0.dll")
55+
except:
56+
handle = None
57+
58+
# Else try default search
59+
if not handle:
60+
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
61+
try:
62+
handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
63+
except:
64+
try:
65+
handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
66+
except:
67+
try:
68+
handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
69+
except:
70+
pass
71+
72+
# Final check if DLLs can be found within pip installations
73+
if not handle:
74+
site_packages = [site.getusersitepackages()] + site.getsitepackages()
75+
for sp in site_packages:
76+
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
77+
if not os.path.isdir(mod_path):
78+
continue
79+
os.add_dll_directory(mod_path)
80+
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
81+
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
82+
try:
83+
handle = win32api.LoadLibraryEx(
84+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
85+
os.path.join(mod_path, "nvrtc64_112_0.dll"),
86+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
87+
88+
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
89+
# located in the same mod_path.
90+
# Update PATH environ so that the two dlls can find each other
91+
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
5092
except:
5193
try:
52-
handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
94+
handle = win32api.LoadLibraryEx(
95+
os.path.join(mod_path, "nvrtc64_111_0.dll"),
96+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
97+
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
5398
except:
54-
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll')
99+
try:
100+
handle = win32api.LoadLibraryEx(
101+
os.path.join(mod_path, "nvrtc64_110_0.dll"),
102+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
103+
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
104+
except:
105+
pass
106+
107+
if not handle:
108+
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll')
55109
{{else}}
56110
handle = NULL
57111
if handle == NULL:

cuda_bindings/pyproject.toml

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ dependencies = [
3232
"pywin32; sys_platform == 'win32'",
3333
]
3434

35+
[project.optional-dependencies]
36+
all = [
37+
"nvidia-cuda-nvrtc-cu11"
38+
]
39+
3540
[project.urls]
3641
Repository = "https://github.com/NVIDIA/cuda-python"
3742
Documentation = "https://nvidia.github.io/cuda-python/"

cuda_bindings/setup.py

+40-11
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717
from pyclibrary import CParser
1818
from setuptools import find_packages, setup
1919
from setuptools.extension import Extension
20+
from setuptools.command.bdist_wheel import bdist_wheel
2021
from setuptools.command.build_ext import build_ext
2122

2223

2324
# ----------------------------------------------------------------------
2425
# Fetch configuration options
2526

26-
CUDA_HOME = os.environ.get("CUDA_HOME")
27-
if not CUDA_HOME:
28-
CUDA_HOME = os.environ.get("CUDA_PATH")
27+
CUDA_HOME = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", None))
2928
if not CUDA_HOME:
3029
raise RuntimeError('Environment variable CUDA_HOME or CUDA_PATH is not set')
3130

@@ -235,20 +234,50 @@ def do_cythonize(extensions):
235234
extensions += prep_extensions(sources)
236235

237236
# ---------------------------------------------------------------------
238-
# Custom build_ext command
239-
# Files are build in two steps:
240-
# 1) Cythonized (in the do_cythonize() command)
241-
# 2) Compiled to .o files as part of build_ext
242-
# This class is solely for passing the value of nthreads to build_ext
237+
# Custom cmdclass extensions
238+
239+
building_wheel = False
240+
241+
242+
class WheelsBuildExtensions(bdist_wheel):
243+
def run(self):
244+
global building_wheel
245+
building_wheel = True
246+
super().run()
247+
243248

244249
class ParallelBuildExtensions(build_ext):
245250
def initialize_options(self):
246-
build_ext.initialize_options(self)
251+
super().initialize_options()
247252
if nthreads > 0:
248253
self.parallel = nthreads
249254

250-
def finalize_options(self):
251-
build_ext.finalize_options(self)
255+
def build_extension(self, ext):
256+
if building_wheel and sys.platform == "linux":
257+
# Strip binaries to remove debug symbols
258+
extra_linker_flags = ["-Wl,--strip-all"]
259+
260+
# Allow extensions to discover libraries at runtime
261+
# relative their wheels installation.
262+
if ext.name == "cuda.bindings._bindings.cynvrtc":
263+
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
264+
else:
265+
ldflag = None
266+
267+
if ldflag:
268+
extra_linker_flags.append(ldflag)
269+
else:
270+
extra_linker_flags = []
271+
272+
ext.extra_link_args += extra_linker_flags
273+
super().build_extension(ext)
274+
275+
276+
cmdclass = {
277+
"bdist_wheel": WheelsBuildExtensions,
278+
"build_ext": ParallelBuildExtensions,
279+
}
280+
252281

253282
cmdclass = {"build_ext": ParallelBuildExtensions}
254283

0 commit comments

Comments
 (0)