|
1 |
| -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. |
| 1 | +# Copyright 2021-2025 NVIDIA Corporation. All rights reserved. |
2 | 2 | #
|
3 | 3 | # Please refer to the NVIDIA end user license agreement (EULA) associated
|
4 | 4 | # with this source code for terms and conditions that govern your use of
|
|
10 | 10 | import contextlib
|
11 | 11 | import glob
|
12 | 12 | import os
|
| 13 | +import pathlib |
13 | 14 | import platform
|
14 | 15 | import shutil
|
15 | 16 | import sys
|
|
23 | 24 | from setuptools import find_packages, setup
|
24 | 25 | from setuptools.command.bdist_wheel import bdist_wheel
|
25 | 26 | from setuptools.command.build_ext import build_ext
|
| 27 | +from setuptools.command.build_py import build_py |
| 28 | +from setuptools.command.editable_wheel import _TopLevelFinder, editable_wheel |
26 | 29 | from setuptools.extension import Extension
|
27 | 30 |
|
28 | 31 | # ----------------------------------------------------------------------
|
@@ -402,9 +405,79 @@ def build_extension(self, ext):
|
402 | 405 | super().build_extension(ext)
|
403 | 406 |
|
404 | 407 |
|
| 408 | +################################################################################ |
| 409 | +# Adapted from NVIDIA/numba-cuda |
| 410 | +# TODO: Remove this block once we get rid of cuda.__version__ and the .pth files |
| 411 | + |
| 412 | +REDIRECTOR_PTH = "_cuda_bindings_redirector.pth" |
| 413 | +REDIRECTOR_PY = "_cuda_bindings_redirector.py" |
| 414 | +SITE_PACKAGES = pathlib.Path("site-packages") |
| 415 | + |
| 416 | + |
| 417 | +class build_py_with_redirector(build_py): # noqa: N801 |
| 418 | + """Include the redirector files in the generated wheel.""" |
| 419 | + |
| 420 | + def copy_redirector_file(self, source, destination="."): |
| 421 | + destination = pathlib.Path(self.build_lib) / destination |
| 422 | + self.copy_file(str(source), str(destination), preserve_mode=0) |
| 423 | + |
| 424 | + def run(self): |
| 425 | + super().run() |
| 426 | + self.copy_redirector_file(SITE_PACKAGES / REDIRECTOR_PTH) |
| 427 | + self.copy_redirector_file(SITE_PACKAGES / REDIRECTOR_PY) |
| 428 | + |
| 429 | + def get_source_files(self): |
| 430 | + src = super().get_source_files() |
| 431 | + src.extend( |
| 432 | + [ |
| 433 | + str(SITE_PACKAGES / REDIRECTOR_PTH), |
| 434 | + str(SITE_PACKAGES / REDIRECTOR_PY), |
| 435 | + ] |
| 436 | + ) |
| 437 | + return src |
| 438 | + |
| 439 | + def get_output_mapping(self): |
| 440 | + mapping = super().get_output_mapping() |
| 441 | + build_lib = pathlib.Path(self.build_lib) |
| 442 | + mapping[str(build_lib / REDIRECTOR_PTH)] = REDIRECTOR_PTH |
| 443 | + mapping[str(build_lib / REDIRECTOR_PY)] = REDIRECTOR_PY |
| 444 | + return mapping |
| 445 | + |
| 446 | + |
| 447 | +class TopLevelFinderWithRedirector(_TopLevelFinder): |
| 448 | + """Include the redirector files in the editable wheel.""" |
| 449 | + |
| 450 | + def get_implementation(self): |
| 451 | + for item in super().get_implementation(): # noqa: UP028 |
| 452 | + yield item |
| 453 | + |
| 454 | + with open(SITE_PACKAGES / REDIRECTOR_PTH) as f: |
| 455 | + yield (REDIRECTOR_PTH, f.read()) |
| 456 | + |
| 457 | + with open(SITE_PACKAGES / REDIRECTOR_PY) as f: |
| 458 | + yield (REDIRECTOR_PY, f.read()) |
| 459 | + |
| 460 | + |
| 461 | +class editable_wheel_with_redirector(editable_wheel): |
| 462 | + def _select_strategy(self, name, tag, build_lib): |
| 463 | + # The default mode is "lenient" - others are "strict" and "compat". |
| 464 | + # "compat" is deprecated. "strict" creates a tree of links to files in |
| 465 | + # the repo. It could be implemented, but we only handle the default |
| 466 | + # case for now. |
| 467 | + if self.mode is not None and self.mode != "lenient": |
| 468 | + raise RuntimeError(f"Only lenient mode is supported for editable install. Current mode is {self.mode}") |
| 469 | + |
| 470 | + return TopLevelFinderWithRedirector(self.distribution, name) |
| 471 | + |
| 472 | + |
| 473 | +################################################################################ |
| 474 | + |
| 475 | + |
405 | 476 | cmdclass = {
|
406 | 477 | "bdist_wheel": WheelsBuildExtensions,
|
407 | 478 | "build_ext": ParallelBuildExtensions,
|
| 479 | + "build_py": build_py_with_redirector, |
| 480 | + "editable_wheel": editable_wheel_with_redirector, |
408 | 481 | }
|
409 | 482 |
|
410 | 483 | # ----------------------------------------------------------------------
|
|
0 commit comments