Skip to content

Commit c084122

Browse files
authored
aarch64: cherrypick onednn pr#1768 to improve torch.compile perf (#1716)
this improves the bert base torch.compile perf by 5.8x on AWS c7g instance.
1 parent fa3dbb0 commit c084122

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
From b8595f6bcbc4f8e38ea9d1c42a65f82d72a07598 Mon Sep 17 00:00:00 2001
2+
From: Sunita Nadampalli <[email protected]>
3+
Date: Wed, 6 Mar 2024 00:37:30 +0000
4+
Subject: [PATCH] onednn: pr1768: aarch64: add acl sbgemm inner product
5+
primitive
6+
7+
---
8+
src/cpu/aarch64/acl_inner_product.hpp | 33 +++++++++++++++++++++++----
9+
src/cpu/cpu_inner_product_list.cpp | 9 ++++++++
10+
2 files changed, 37 insertions(+), 5 deletions(-)
11+
12+
diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp
13+
index a2be164f09..762d5d4896 100644
14+
--- a/src/cpu/aarch64/acl_inner_product.hpp
15+
+++ b/src/cpu/aarch64/acl_inner_product.hpp
16+
@@ -93,20 +93,33 @@ struct acl_inner_product_fwd_t : public primitive_t {
17+
18+
status_t init(engine_t *engine) {
19+
using namespace data_type;
20+
+ const format_kind_t weights_format_kind_received
21+
+ = weights_md_.format_kind;
22+
+
23+
const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef)
24+
&& attr()->has_default_values(
25+
primitive_attr_t::skip_mask_t::post_ops, f16);
26+
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
27+
&& attr()->has_default_values(
28+
primitive_attr_t::skip_mask_t::post_ops, f32);
29+
+
30+
+ const bool is_fp32_bf16_ok
31+
+ = expect_data_types(f32, bf16, f32, f32, undef)
32+
+ && attr()->has_default_values(
33+
+ primitive_attr_t::skip_mask_t::post_ops, f32);
34+
+
35+
+ const bool is_weights_md_format_ok
36+
+ = utils::one_of(weights_format_kind_received,
37+
+ format_kind::any, format_kind::blocked);
38+
+
39+
const bool ok = is_fwd() && !has_zero_dim_memory()
40+
- && utils::one_of(true, is_fp16_ok, is_fp32_ok)
41+
- && weights_md_.format_kind == format_kind::any
42+
- && set_default_params() == status::success;
43+
+ && utils::one_of(true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok)
44+
+ && is_weights_md_format_ok
45+
+ && set_default_params(true) == status::success;
46+
47+
if (!ok) return status::unimplemented;
48+
49+
- CHECK(init_conf_ip(engine));
50+
+ CHECK(init_conf_ip(engine, weights_format_kind_received));
51+
52+
return status::success;
53+
}
54+
@@ -115,7 +128,8 @@ struct acl_inner_product_fwd_t : public primitive_t {
55+
56+
acl_post_ops_t post_ops;
57+
58+
- status_t init_conf_ip(engine_t *engine) {
59+
+ status_t init_conf_ip(
60+
+ engine_t *engine, format_kind_t weights_format_kind_received) {
61+
62+
ACL_CHECK_SUPPORT(src_md()->ndims != weights_md()->ndims,
63+
"source and weights dimensions must match");
64+
@@ -257,10 +271,19 @@ struct acl_inner_product_fwd_t : public primitive_t {
65+
return status::unimplemented;
66+
}
67+
68+
+ const memory_desc_t weights_md_received = weights_md_;
69+
acl_utils::reorder_to_weight_format(aip.wei_tensor_info,
70+
weights_md_, expected_weight_format, inner_dim, o_dim,
71+
remaining_dims, {});
72+
73+
+ ACL_CHECK_SUPPORT(
74+
+ (weights_format_kind_received == format_kind::blocked)
75+
+ && !(dnnl_memory_desc_equal(
76+
+ &weights_md_received, &weights_md_)),
77+
+ "specific blocked format not supported by ACL, use "
78+
+ "format_kind_t::any to find a supported blocked format for "
79+
+ "your platform");
80+
+
81+
// clang-format off
82+
83+
// Validate fully connected layer manually to check for return status
84+
diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp
85+
index fdd7b17769..1f59547304 100644
86+
--- a/src/cpu/cpu_inner_product_list.cpp
87+
+++ b/src/cpu/cpu_inner_product_list.cpp
88+
@@ -83,6 +83,15 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
89+
CPU_INSTANCE(ref_inner_product_fwd_t)
90+
nullptr,
91+
}},
92+
+ /* With graph compilation, we are able to reorder and pre-pack the weights during the model load
93+
+ * and compilation phase itself so that redundant and on-the-fly reorders can be avoided.
94+
+ * This primitive definition is to support gemm fastmath mode for the compile scenario where src is
95+
+ * in fp32 and weights are in bf16
96+
+ */
97+
+ {{forward, f32, bf16, f32}, {
98+
+ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
99+
+ nullptr,
100+
+ }},
101+
{{backward_data, f32, f32, f32}, REG_BWD_PK({
102+
CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32
103+
CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>)
104+
--
105+
2.34.1
106+

0 commit comments

Comments
 (0)