Skip to content

Commit c546c5c

Browse files
authored
[1/2] Wean off of PYBIND in favor of torch.ops.load_library (#1276)
Wean ao off of PYBIND [part 1]
1 parent 01dc7da commit c546c5c

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

packaging/post_build_script.sh

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77

88
set -eux
99

10-
WHEEL_NAME=$(ls dist/)
10+
# Prepare manywheel, only for CUDA.
11+
# The wheel is a pure python wheel for other platforms.
12+
if [[ "$CU_VERSION" == cu* ]]; then
13+
WHEEL_NAME=$(ls dist/)
1114

12-
pushd dist
13-
# Prepare manywheel
14-
manylinux_plat=manylinux2014_x86_64
15-
if [[ "$CU_VERSION" == "xpu" ]]; then
16-
manylinux_plat=manylinux_2_28_x86_64
17-
fi
18-
auditwheel repair --plat "$manylinux_plat" -w . \
15+
pushd dist
16+
manylinux_plat=manylinux2014_x86_64
17+
auditwheel repair --plat "$manylinux_plat" -w . \
1918
--exclude libtorch.so \
2019
--exclude libtorch_python.so \
2120
--exclude libtorch_cuda.so \
@@ -26,10 +25,11 @@ auditwheel repair --plat "$manylinux_plat" -w . \
2625
--exclude libcudart.so.11.0 \
2726
"${WHEEL_NAME}"
2827

29-
ls -lah .
30-
# Clean up the linux_x86_64 wheel
31-
rm "${WHEEL_NAME}"
32-
popd
28+
ls -lah .
29+
# Clean up the linux_x86_64 wheel
30+
rm "${WHEEL_NAME}"
31+
popd
32+
fi
3333

3434
MANYWHEEL_NAME=$(ls dist/)
3535
# Try to install the new wheel

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def get_extensions():
110110
if use_cuda:
111111
sources += cuda_sources
112112

113+
if len(sources) == 0:
114+
return None
115+
113116
ext_modules = [
114117
extension(
115118
"torchao._C",

torchao/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
)
2323
if not _IS_FBCODE:
2424
try:
25-
from . import _C
25+
from pathlib import Path
26+
so_files = list(Path(__file__).parent.glob("_C*.so"))
27+
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
28+
torch.ops.load_library(so_files[0])
2629
from . import ops
2730
except:
28-
_C = None
2931
logging.info("Skipping import of cpp extensions")
3032

3133
from torchao.quantization import (

torchao/csrc/init.cpp

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)