Skip to content

Commit 0a97d53

Browse files
committed
Fix setup.py develop workflow
1 parent 6cfa477 commit 0a97d53

File tree

4 files changed

+41
-7
lines changed

4 files changed

+41
-7
lines changed

setup.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,17 +317,39 @@ def build_cmake(self, ext):
317317
if not os.path.exists(self.build_temp):
318318
os.makedirs(self.build_temp)
319319

320+
# Get the expected extension file name that Python will look for
321+
ext_filename = os.path.basename(self.get_ext_filename(ext.name))
322+
ext_basename = os.path.splitext(ext_filename)[0]
323+
324+
# Add TORCHAO_CMAKE_EXT_SO_NAME is used in CMake to name the library
325+
# to something the python extension expects
326+
print("EXTENSION NAME", ext_basename)
320327
subprocess.check_call(
321328
[
322329
"cmake",
323330
ext.cmake_lists_dir,
324331
]
325332
+ ext.cmake_args
326-
+ ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir],
333+
+ [
334+
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
335+
f"-DTORCHAO_CMAKE_EXT_SO_NAME={ext_basename}",
336+
],
327337
cwd=self.build_temp,
328338
)
329339
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)
330340

341+
# # Handle the case where the file is created with .dylib extension instead of .so
342+
# # This is needed because Python expects .so files on macOS, but CMake might create .dylib
343+
# if ext.name == "torchao.experimental":
344+
# # Get the expected extension file name
345+
# ext_path = os.path.join(extdir, ext_filename)
346+
# dylib_path = os.path.join(extdir, f"{ext_basename}.dylib")
347+
348+
# # If the .dylib exists but the .so doesn't, rename it
349+
# if os.path.exists(dylib_path) and not os.path.exists(ext_path):
350+
# print(f"Renaming {dylib_path} to {ext_path}")
351+
# os.rename(dylib_path, ext_path)
352+
331353

332354
class CMakeExtension(Extension):
333355
def __init__(
@@ -702,7 +724,7 @@ def bool_to_on_off(value):
702724

703725
ext_modules.append(
704726
CMakeExtension(
705-
"torchao.experimental",
727+
"torchao._experimental_aten_ops",
706728
cmake_lists_dir="torchao/experimental",
707729
cmake_args=(
708730
[

torchao/experimental/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,18 @@ if(TORCHAO_BUILD_ATEN_OPS)
136136
ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp
137137
)
138138
list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")
139+
140+
# Use the Python extension name if provided
139141
add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten})
142+
if(DEFINED TORCHAO_CMAKE_EXT_SO_NAME)
143+
message(STATUS "Setting output name to: ${TORCHAO_CMAKE_EXT_SO_NAME}.so")
144+
set_target_properties(torchao_ops_aten PROPERTIES
145+
OUTPUT_NAME ${TORCHAO_CMAKE_EXT_SO_NAME}
146+
PREFIX "" # Remove "lib" prefix for Python extensions
147+
SUFFIX ".so" # Add ".so" suffix for Python extensions
148+
)
149+
endif()
150+
140151
target_link_torchao_parallel_backend(torchao_ops_aten "${TORCHAO_PARALLEL_BACKEND}")
141152
if (TORCHAO_BUILD_CPU_AARCH64)
142153
target_link_libraries(torchao_ops_aten PRIVATE torchao_kernels_aarch64)

torchao/experimental/op_lib.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@
2222

2323

2424
def find_and_load_libtorchao_ops(potential_paths):
25+
# import torchao._experimental_aten_ops
26+
2527
for lib_path in potential_paths:
26-
libs = list(lib_path.glob("libtorchao_ops_aten.*"))
28+
libs = list(lib_path.glob("_experimental_aten_ops.*"))
2729

2830
if not libs:
2931
continue
3032

3133
assert len(libs) == 1, (
32-
f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}"
34+
f"Expected to find one _experimental_aten_ops.* library at {lib_path}, but found {len(libs)}"
3335
)
3436

3537
target_lib = libs[0]

torchao/experimental/tests/test_embedding_xbit_quantizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,9 @@ def test_shared_embedding(self):
183183
self.assertTrue(torch.allclose(result, exported_result))
184184

185185
# Check the shared_embedding and linear ops use the same lifted weight
186-
weight = "b_getattr_l__fn_____0___unembedding_packed_weights"
187186
expected_lines = [
188-
f"torch.ops.torchao._shared_embedding_4bit.default({weight}, 4096, 131, 4096, reshape)",
189-
f"torch.ops.torchao._linear_8bit_act_4bit_weight.default(linear, {weight}, 4096, 131, 4096)",
187+
"torch.ops.torchao._shared_embedding_4bit.default",
188+
"torch.ops.torchao._linear_8bit_act_4bit_weight.default",
190189
]
191190
for line in expected_lines:
192191
FileCheck().check_count(line, 1, exactly=True).run(

0 commit comments

Comments
 (0)