Skip to content

Commit ab5fc90

Browse files
authored
aarch64: apply the cherrypicked onednn PR-1768 (#1717)
This is to improve the torch.compile() perf by 5.8x on AWS Graviton3 instances. This patching is required till PyTorch oneDNN is upgraded to v3.4.
1 parent c084122 commit ab5fc90

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

aarch64_linux/aarch64_wheel_ci_build.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ def parse_arguments():
111111
with open("/builder/mkldnn_fix/fix-xbyak-failure.patch") as f:
112112
check_call(["patch", "-p1"], stdin=f, cwd="/pytorch/third_party/ideep/mkl-dnn")
113113

114+
print("Applying mkl-dnn patch to improve torch.compile() perf")
115+
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch") # noqa: E501
116+
114117
os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel")
115118
pytorch_wheel_name = complete_wheel("pytorch")
116119
print(f"Build Compelete. Created {pytorch_wheel_name}..")

aarch64_linux/build_aarch64_wheel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ def start_build(host: RemoteHost, *,
558558
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
559559
host.run_cmd("cd $HOME && git clone https://github.com/pytorch/builder.git")
560560
host.run_cmd("cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/fix-xbyak-failure.patch") # noqa: E501
561+
host.run_cmd("cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch") # noqa: E501
561562
host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}") # noqa: E501
562563
print('Repair the wheel')
563564
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]

0 commit comments

Comments
 (0)