diff --git a/CMakeLists.txt b/CMakeLists.txt index 255f9275762..b34ed07a10e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -764,10 +764,6 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module) endif() -if(EXECUTORCH_BUILD_EXTENSION_TRAINING) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training) -endif() - if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util) endif() @@ -872,34 +868,13 @@ if(EXECUTORCH_BUILD_PYBIND) if(EXECUTORCH_BUILD_EXTENSION_TRAINING) - set(_pybind_training_dep_libs - ${TORCH_PYTHON_LIBRARY} - etdump - executorch - util - torch - extension_training - ) - - if(EXECUTORCH_BUILD_XNNPACK) - # need to explicitly specify XNNPACK and microkernels-prod - # here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu - list(APPEND _pybind_training_dep_libs xnnpack_backend XNNPACK microkernels-prod) - endif() - - # pybind training - pybind11_add_module(_training_lib SHARED extension/training/pybindings/_training_lib.cpp) - - target_include_directories(_training_lib PRIVATE ${TORCH_INCLUDE_DIRS}) - target_compile_options(_training_lib PUBLIC ${_pybind_compile_options}) - target_link_libraries(_training_lib PRIVATE ${_pybind_training_dep_libs}) - - install(TARGETS _training_lib - LIBRARY DESTINATION executorch/extension/training/pybindings - ) endif() endif() +if(EXECUTORCH_BUILD_EXTENSION_TRAINING) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training) +endif() + if(EXECUTORCH_BUILD_KERNELS_CUSTOM) # TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/custom_ops) diff --git a/extension/training/CMakeLists.txt b/extension/training/CMakeLists.txt index 97e75955837..ee496a7e577 100644 --- a/extension/training/CMakeLists.txt +++ b/extension/training/CMakeLists.txt @@ -26,7 +26,7 @@ target_include_directories( target_include_directories(extension_training PUBLIC ${EXECUTORCH_ROOT}/..) target_compile_options(extension_training PUBLIC ${_common_compile_options}) target_link_libraries(extension_training executorch_core - extension_data_loader extension_module extension_tensor extension_flat_tensor) + extension_data_loader extension_module_static extension_tensor extension_flat_tensor) list(TRANSFORM _train_xor__srcs PREPEND "${EXECUTORCH_ROOT}/") @@ -40,6 +40,33 @@ train_xor gflags executorch_core portable_ops_lib extension_tensor ) target_compile_options(train_xor PUBLIC ${_common_compile_options}) +# Pybind library. +set(_pybind_training_dep_libs + ${TORCH_PYTHON_LIBRARY} + etdump + executorch + util + torch + extension_training +) + +if(EXECUTORCH_BUILD_XNNPACK) +# need to explicitly specify XNNPACK and microkernels-prod +# here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu +list(APPEND _pybind_training_dep_libs xnnpack_backend XNNPACK microkernels-prod) +endif() + +# pybind training +pybind11_add_module(_training_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/pybindings/_training_lib.cpp) + +target_include_directories(_training_lib PRIVATE ${TORCH_INCLUDE_DIRS}) +target_compile_options(_training_lib PUBLIC -Wno-deprecated-declarations -fPIC -frtti -fexceptions) +target_link_libraries(_training_lib PRIVATE ${_pybind_training_dep_libs}) + +install(TARGETS _training_lib + LIBRARY DESTINATION executorch/extension/training/pybindings +) + # Install libraries install( TARGETS extension_training diff --git a/pytest.ini b/pytest.ini index da96469d1e5..cd647c43a1c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -51,6 +51,7 @@ addopts = extension/llm/modules/test extension/llm/export extension/pybindings/test + extension/training/pybindings/test # Runtime runtime # test TODO: fix these tests diff --git a/setup.py b/setup.py index eac28e8e26c..f603488b16f 100644 --- a/setup.py +++ b/setup.py @@ -869,7 +869,7 @@ def get_ext_modules() -> List[Extension]: ext_modules.append( # Install the prebuilt pybindings extension wrapper for training BuiltExtension( - "_training_lib.*", + "extension/training/_training_lib.*", "executorch.extension.training.pybindings._training_lib", ) )