diff --git a/aarch64_linux/aarch64_wheel_ci_build.py b/aarch64_linux/aarch64_wheel_ci_build.py index 4ad620ba2..bdc6717ef 100755 --- a/aarch64_linux/aarch64_wheel_ci_build.py +++ b/aarch64_linux/aarch64_wheel_ci_build.py @@ -21,11 +21,8 @@ def build_ArmComputeLibrary(git_clone_flags: str = "") -> None: print('Building Arm Compute Library') os.system("cd / && mkdir /acl") os.system(f"git clone https://github.com/ARM-software/ComputeLibrary.git -b v23.05.1 {git_clone_flags}") - os.system('sed -i -e \'s/"armv8.2-a"/"armv8-a"/g\' ComputeLibrary/SConscript; ' - 'sed -i -e \'s/-march=armv8.2-a+fp16/-march=armv8-a/g\' ComputeLibrary/SConstruct; ' - 'sed -i -e \'s/"-march=armv8.2-a"/"-march=armv8-a"/g\' ComputeLibrary/filedefs.json') os.system("cd ComputeLibrary; export acl_install_dir=/acl; " - "scons Werror=1 -j8 debug=0 neon=1 opencl=0 os=linux openmp=1 cppthreads=0 arch=armv8.2-a multi_isa=1 build=native build_dir=$acl_install_dir/build; " + "scons Werror=1 -j8 debug=0 neon=1 opencl=0 os=linux openmp=1 cppthreads=0 arch=armv8a multi_isa=1 build=native build_dir=$acl_install_dir/build; " "cp -r arm_compute $acl_install_dir; " "cp -r include $acl_install_dir; " "cp -r utils $acl_install_dir; " @@ -108,6 +105,9 @@ def parse_arguments(): else: print("build pytorch without mkldnn backend") + # work around to fix Raspberry pie crash + print("Applying mkl-dnn patch to fix Raspberry pie crash") + os.system(f"cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/aarch64-fix-default-build-flags-to-armv8-a.patch") os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel") pytorch_wheel_name = complete_wheel("pytorch") print(f"Build Compelete. Created {pytorch_wheel_name}..") diff --git a/aarch64_linux/build_aarch64_wheel.py b/aarch64_linux/build_aarch64_wheel.py index 0bab3126a..f75d4270e 100755 --- a/aarch64_linux/build_aarch64_wheel.py +++ b/aarch64_linux/build_aarch64_wheel.py @@ -557,7 +557,9 @@ def start_build(host: RemoteHost, *, build_ArmComputeLibrary(host, git_clone_flags) print("build pytorch with mkldnn+acl backend") build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON" - host.run_cmd(f"cd pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}") + host.run_cmd(f"cd $HOME && git clone https://github.com/pytorch/builder.git") + host.run_cmd(f"cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/aarch64-fix-default-build-flags-to-armv8-a.patch") + host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}") print('Repair the wheel') pytorch_wheel_name = host.list_dir("pytorch/dist")[0] host.run_cmd(f"export LD_LIBRARY_PATH=$HOME/acl/build:$HOME/pytorch/build/lib && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}") diff --git a/mkldnn_fix/aarch64-fix-default-build-flags-to-armv8-a.patch b/mkldnn_fix/aarch64-fix-default-build-flags-to-armv8-a.patch new file mode 100644 index 000000000..f6e91010a --- /dev/null +++ b/mkldnn_fix/aarch64-fix-default-build-flags-to-armv8-a.patch @@ -0,0 +1,29 @@ +--- + cmake/platform.cmake | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/cmake/platform.cmake b/cmake/platform.cmake +index 8630460ce..602eafe8e 100644 +--- a/cmake/platform.cmake ++++ b/cmake/platform.cmake +@@ -198,7 +198,7 @@ elseif(UNIX OR MINGW) + endif() + # For native compilation tune for the host processor + if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR) +- append(DEF_ARCH_OPT_FLAGS "-mcpu=native") ++ append(DEF_ARCH_OPT_FLAGS "-march=armv8-a") + endif() + elseif(DNNL_TARGET_ARCH STREQUAL "PPC64") + if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug") +@@ -295,7 +295,7 @@ elseif(UNIX OR MINGW) + endif() + # For native compilation tune for the host processor + if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR) +- append(DEF_ARCH_OPT_FLAGS "-mcpu=native") ++ append(DEF_ARCH_OPT_FLAGS "-march=armv8-a") + endif() + elseif(DNNL_TARGET_ARCH STREQUAL "PPC64") + if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug") +-- +2.34.1 +