Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions cuda_bindings/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@
from Cython.Build import cythonize
from pyclibrary import CParser
from setuptools import find_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext
from setuptools.extension import Extension

# ----------------------------------------------------------------------
# Fetch configuration options

CUDA_HOME = os.environ.get("CUDA_HOME")
if not CUDA_HOME:
CUDA_HOME = os.environ.get("CUDA_PATH")
CUDA_HOME = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", None))
if not CUDA_HOME:
raise RuntimeError("Environment variable CUDA_HOME or CUDA_PATH is not set")

Expand Down Expand Up @@ -283,24 +282,49 @@ def do_cythonize(extensions):
extensions += prep_extensions(sources)

# ---------------------------------------------------------------------
# Custom build_ext command
# Files are build in two steps:
# 1) Cythonized (in the do_cythonize() command)
# 2) Compiled to .o files as part of build_ext
# This class is solely for passing the value of nthreads to build_ext
# Custom cmdclass extensions

building_wheel = False


class WheelsBuildExtensions(bdist_wheel):
def run(self):
global building_wheel
building_wheel = True
super().run()


class ParallelBuildExtensions(build_ext):
def initialize_options(self):
build_ext.initialize_options(self)
super().initialize_options()
if nthreads > 0:
self.parallel = nthreads

def finalize_options(self):
build_ext.finalize_options(self)
def build_extension(self, ext):
if building_wheel:
# Strip binaries to remove debug symbols
extra_linker_flags = ["-Wl,--strip-all"]

# Allow extensions to discover libraries at runtime
# relative their wheels installation.
ldflag = "-Wl,--disable-new-dtags"
if ext.name == "cuda.bindings._bindings.cynvrtc":
ldflag += f",-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
elif ext.name == "cuda.bindings._internal.nvjitlink":
ldflag += f",-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"

extra_linker_flags.append(ldflag)
else:
extra_linker_flags = []

ext.extra_link_args += extra_linker_flags
super().build_extension(ext)


cmdclass = {"build_ext": ParallelBuildExtensions}
cmdclass = {
"bdist_wheel": WheelsBuildExtensions,
"build_ext": ParallelBuildExtensions,
}

# ----------------------------------------------------------------------
# Setup
Expand Down
Loading