Skip to content

[aarch64] Fix for pytorch-2.1.0 aarch64 wheels crash on A1/Raspberry Pie #1562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aarch64_linux/aarch64_wheel_ci_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ def build_ArmComputeLibrary(git_clone_flags: str = "") -> None:
print('Building Arm Compute Library')
os.system("cd / && mkdir /acl")
os.system(f"git clone https://github.com/ARM-software/ComputeLibrary.git -b v23.05.1 {git_clone_flags}")
os.system('sed -i -e \'s/"armv8.2-a"/"armv8-a"/g\' ComputeLibrary/SConscript; '
'sed -i -e \'s/-march=armv8.2-a+fp16/-march=armv8-a/g\' ComputeLibrary/SConstruct; '
'sed -i -e \'s/"-march=armv8.2-a"/"-march=armv8-a"/g\' ComputeLibrary/filedefs.json')
os.system("cd ComputeLibrary; export acl_install_dir=/acl; "
"scons Werror=1 -j8 debug=0 neon=1 opencl=0 os=linux openmp=1 cppthreads=0 arch=armv8.2-a multi_isa=1 build=native build_dir=$acl_install_dir/build; "
"scons Werror=1 -j8 debug=0 neon=1 opencl=0 os=linux openmp=1 cppthreads=0 arch=armv8a multi_isa=1 build=native build_dir=$acl_install_dir/build; "
"cp -r arm_compute $acl_install_dir; "
"cp -r include $acl_install_dir; "
"cp -r utils $acl_install_dir; "
Expand Down Expand Up @@ -108,6 +105,9 @@ def parse_arguments():
else:
print("build pytorch without mkldnn backend")

# work around to fix Raspberry pie crash
print("Applying mkl-dnn patch to fix Raspberry pie crash")
os.system(f"cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/aarch64-fix-default-build-flags-to-armv8-a.patch")
os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel")
pytorch_wheel_name = complete_wheel("pytorch")
print(f"Build Compelete. Created {pytorch_wheel_name}..")
4 changes: 3 additions & 1 deletion aarch64_linux/build_aarch64_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,9 @@ def start_build(host: RemoteHost, *,
build_ArmComputeLibrary(host, git_clone_flags)
print("build pytorch with mkldnn+acl backend")
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
host.run_cmd(f"cd pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}")
host.run_cmd(f"cd $HOME && git clone https://github.com/pytorch/builder.git")
host.run_cmd(f"cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/aarch64-fix-default-build-flags-to-armv8-a.patch")
host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}")
print('Repair the wheel')
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
host.run_cmd(f"export LD_LIBRARY_PATH=$HOME/acl/build:$HOME/pytorch/build/lib && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}")
Expand Down
29 changes: 29 additions & 0 deletions mkldnn_fix/aarch64-fix-default-build-flags-to-armv8-a.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
cmake/platform.cmake | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/cmake/platform.cmake b/cmake/platform.cmake
index 8630460ce..602eafe8e 100644
--- a/cmake/platform.cmake
+++ b/cmake/platform.cmake
@@ -198,7 +198,7 @@ elseif(UNIX OR MINGW)
endif()
# For native compilation tune for the host processor
if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR)
- append(DEF_ARCH_OPT_FLAGS "-mcpu=native")
+ append(DEF_ARCH_OPT_FLAGS "-march=armv8-a")
endif()
elseif(DNNL_TARGET_ARCH STREQUAL "PPC64")
if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
@@ -295,7 +295,7 @@ elseif(UNIX OR MINGW)
endif()
# For native compilation tune for the host processor
if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR)
- append(DEF_ARCH_OPT_FLAGS "-mcpu=native")
+ append(DEF_ARCH_OPT_FLAGS "-march=armv8-a")
endif()
elseif(DNNL_TARGET_ARCH STREQUAL "PPC64")
if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
--
2.34.1