diff --git a/setup.py b/setup.py index 1e7d76a704..d9d7e80506 100644 --- a/setup.py +++ b/setup.py @@ -317,13 +317,21 @@ def build_cmake(self, ext): if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) + # Get the expected extension file name that Python will look for + # We force CMake to use this library name + ext_filename = os.path.basename(self.get_ext_filename(ext.name)) + ext_basename = os.path.splitext(ext_filename)[0] + subprocess.check_call( [ "cmake", ext.cmake_lists_dir, ] + ext.cmake_args - + ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir], + + [ + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, + "-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename, + ], cwd=self.build_temp, ) subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp) @@ -708,7 +716,7 @@ def bool_to_on_off(value): ext_modules.append( CMakeExtension( - "torchao.experimental", + "torchao._experimental_aten_ops", cmake_lists_dir="torchao/experimental", cmake_args=( [ diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index 656101d9e7..317b35643b 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -136,7 +136,18 @@ if(TORCHAO_BUILD_ATEN_OPS) ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp ) list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") + + # Use the Python extension name if provided add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten}) + if(DEFINED TORCHAO_CMAKE_EXT_SO_NAME) + message(STATUS "Setting output name to: ${TORCHAO_CMAKE_EXT_SO_NAME}.so") + set_target_properties(torchao_ops_aten PROPERTIES + OUTPUT_NAME ${TORCHAO_CMAKE_EXT_SO_NAME} + PREFIX "" # Remove "lib" prefix for Python extensions + SUFFIX ".so" # Add ".so" suffix for Python extensions + ) + endif() + target_link_torchao_parallel_backend(torchao_ops_aten "${TORCHAO_PARALLEL_BACKEND}") if (TORCHAO_BUILD_CPU_AARCH64) target_link_libraries(torchao_ops_aten PRIVATE torchao_kernels_aarch64) diff --git a/torchao/experimental/op_lib.py b/torchao/experimental/op_lib.py index 182d1c3312..e895858d55 100644 --- a/torchao/experimental/op_lib.py +++ b/torchao/experimental/op_lib.py @@ -22,14 +22,18 @@ def find_and_load_libtorchao_ops(potential_paths): + """ + Finds and loads torchao._experimental_aten_ops from one of the provided paths + """ + for lib_path in potential_paths: - libs = list(lib_path.glob("libtorchao_ops_aten.*")) + libs = list(lib_path.glob("_experimental_aten_ops.*")) if not libs: continue assert len(libs) == 1, ( - f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}" + f"Expected to find one _experimental_aten_ops.* library at {lib_path}, but found {len(libs)}" ) target_lib = libs[0] diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index 1a87245ad4..459c1c5e97 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -183,10 +183,9 @@ def test_shared_embedding(self): self.assertTrue(torch.allclose(result, exported_result)) # Check the shared_embedding and linear ops use the same lifted weight - weight = "b_getattr_l__fn_____0___unembedding_packed_weights" expected_lines = [ - f"torch.ops.torchao._shared_embedding_4bit.default({weight}, 4096, 131, 4096, reshape)", - f"torch.ops.torchao._linear_8bit_act_4bit_weight.default(linear, {weight}, 4096, 131, 4096)", + "torch.ops.torchao._shared_embedding_4bit.default", + "torch.ops.torchao._linear_8bit_act_4bit_weight.default", ] for line in expected_lines: FileCheck().check_count(line, 1, exactly=True).run(