Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,21 @@ def build_cmake(self, ext):
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

# Get the expected extension file name that Python will look for
# We force CMake to use this library name
ext_filename = os.path.basename(self.get_ext_filename(ext.name))
ext_basename = os.path.splitext(ext_filename)[0]

subprocess.check_call(
[
"cmake",
ext.cmake_lists_dir,
]
+ ext.cmake_args
+ ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir],
+ [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename,
],
cwd=self.build_temp,
)
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)
Expand Down Expand Up @@ -708,7 +716,7 @@ def bool_to_on_off(value):

ext_modules.append(
CMakeExtension(
"torchao.experimental",
"torchao._experimental_aten_ops",
cmake_lists_dir="torchao/experimental",
cmake_args=(
[
Expand Down
11 changes: 11 additions & 0 deletions torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,18 @@ if(TORCHAO_BUILD_ATEN_OPS)
ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp
)
list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")

# Use the Python extension name if provided
add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten})
if(DEFINED TORCHAO_CMAKE_EXT_SO_NAME)
message(STATUS "Setting output name to: ${TORCHAO_CMAKE_EXT_SO_NAME}.so")
set_target_properties(torchao_ops_aten PROPERTIES
OUTPUT_NAME ${TORCHAO_CMAKE_EXT_SO_NAME}
PREFIX "" # Remove "lib" prefix for Python extensions
SUFFIX ".so" # Add ".so" suffix for Python extensions
)
endif()

target_link_torchao_parallel_backend(torchao_ops_aten "${TORCHAO_PARALLEL_BACKEND}")
if (TORCHAO_BUILD_CPU_AARCH64)
target_link_libraries(torchao_ops_aten PRIVATE torchao_kernels_aarch64)
Expand Down
8 changes: 6 additions & 2 deletions torchao/experimental/op_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@


def find_and_load_libtorchao_ops(potential_paths):
"""
Finds and loads torchao._experimental_aten_ops from one of the provided paths
"""

for lib_path in potential_paths:
libs = list(lib_path.glob("libtorchao_ops_aten.*"))
libs = list(lib_path.glob("_experimental_aten_ops.*"))

if not libs:
continue

assert len(libs) == 1, (
f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}"
f"Expected to find one _experimental_aten_ops.* library at {lib_path}, but found {len(libs)}"
)

target_lib = libs[0]
Expand Down
5 changes: 2 additions & 3 deletions torchao/experimental/tests/test_embedding_xbit_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,9 @@ def test_shared_embedding(self):
self.assertTrue(torch.allclose(result, exported_result))

# Check the shared_embedding and linear ops use the same lifted weight
weight = "b_getattr_l__fn_____0___unembedding_packed_weights"
expected_lines = [
f"torch.ops.torchao._shared_embedding_4bit.default({weight}, 4096, 131, 4096, reshape)",
f"torch.ops.torchao._linear_8bit_act_4bit_weight.default(linear, {weight}, 4096, 131, 4096)",
"torch.ops.torchao._shared_embedding_4bit.default",
"torch.ops.torchao._linear_8bit_act_4bit_weight.default",
]
for line in expected_lines:
FileCheck().check_count(line, 1, exactly=True).run(
Expand Down
Loading