|
| 1 | +From b8595f6bcbc4f8e38ea9d1c42a65f82d72a07598 Mon Sep 17 00:00:00 2001 |
| 2 | +From: Sunita Nadampalli < [email protected]> |
| 3 | +Date: Wed, 6 Mar 2024 00:37:30 +0000 |
| 4 | +Subject: [PATCH] onednn: pr1768: aarch64: add acl sbgemm inner product |
| 5 | + primitive |
| 6 | + |
| 7 | +--- |
| 8 | + src/cpu/aarch64/acl_inner_product.hpp | 33 +++++++++++++++++++++++---- |
| 9 | + src/cpu/cpu_inner_product_list.cpp | 9 ++++++++ |
| 10 | + 2 files changed, 37 insertions(+), 5 deletions(-) |
| 11 | + |
| 12 | +diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp |
| 13 | +index a2be164f09..762d5d4896 100644 |
| 14 | +--- a/src/cpu/aarch64/acl_inner_product.hpp |
| 15 | ++++ b/src/cpu/aarch64/acl_inner_product.hpp |
| 16 | +@@ -93,20 +93,33 @@ struct acl_inner_product_fwd_t : public primitive_t { |
| 17 | + |
| 18 | + status_t init(engine_t *engine) { |
| 19 | + using namespace data_type; |
| 20 | ++ const format_kind_t weights_format_kind_received |
| 21 | ++ = weights_md_.format_kind; |
| 22 | ++ |
| 23 | + const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) |
| 24 | + && attr()->has_default_values( |
| 25 | + primitive_attr_t::skip_mask_t::post_ops, f16); |
| 26 | + const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) |
| 27 | + && attr()->has_default_values( |
| 28 | + primitive_attr_t::skip_mask_t::post_ops, f32); |
| 29 | ++ |
| 30 | ++ const bool is_fp32_bf16_ok |
| 31 | ++ = expect_data_types(f32, bf16, f32, f32, undef) |
| 32 | ++ && attr()->has_default_values( |
| 33 | ++ primitive_attr_t::skip_mask_t::post_ops, f32); |
| 34 | ++ |
| 35 | ++ const bool is_weights_md_format_ok |
| 36 | ++ = utils::one_of(weights_format_kind_received, |
| 37 | ++ format_kind::any, format_kind::blocked); |
| 38 | ++ |
| 39 | + const bool ok = is_fwd() && !has_zero_dim_memory() |
| 40 | +- && utils::one_of(true, is_fp16_ok, is_fp32_ok) |
| 41 | +- && weights_md_.format_kind == format_kind::any |
| 42 | +- && set_default_params() == status::success; |
| 43 | ++ && utils::one_of(true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok) |
| 44 | ++ && is_weights_md_format_ok |
| 45 | ++ && set_default_params(true) == status::success; |
| 46 | + |
| 47 | + if (!ok) return status::unimplemented; |
| 48 | + |
| 49 | +- CHECK(init_conf_ip(engine)); |
| 50 | ++ CHECK(init_conf_ip(engine, weights_format_kind_received)); |
| 51 | + |
| 52 | + return status::success; |
| 53 | + } |
| 54 | +@@ -115,7 +128,8 @@ struct acl_inner_product_fwd_t : public primitive_t { |
| 55 | + |
| 56 | + acl_post_ops_t post_ops; |
| 57 | + |
| 58 | +- status_t init_conf_ip(engine_t *engine) { |
| 59 | ++ status_t init_conf_ip( |
| 60 | ++ engine_t *engine, format_kind_t weights_format_kind_received) { |
| 61 | + |
| 62 | + ACL_CHECK_SUPPORT(src_md()->ndims != weights_md()->ndims, |
| 63 | + "source and weights dimensions must match"); |
| 64 | +@@ -257,10 +271,19 @@ struct acl_inner_product_fwd_t : public primitive_t { |
| 65 | + return status::unimplemented; |
| 66 | + } |
| 67 | + |
| 68 | ++ const memory_desc_t weights_md_received = weights_md_; |
| 69 | + acl_utils::reorder_to_weight_format(aip.wei_tensor_info, |
| 70 | + weights_md_, expected_weight_format, inner_dim, o_dim, |
| 71 | + remaining_dims, {}); |
| 72 | + |
| 73 | ++ ACL_CHECK_SUPPORT( |
| 74 | ++ (weights_format_kind_received == format_kind::blocked) |
| 75 | ++ && !(dnnl_memory_desc_equal( |
| 76 | ++ &weights_md_received, &weights_md_)), |
| 77 | ++ "specific blocked format not supported by ACL, use " |
| 78 | ++ "format_kind_t::any to find a supported blocked format for " |
| 79 | ++ "your platform"); |
| 80 | ++ |
| 81 | + // clang-format off |
| 82 | + |
| 83 | + // Validate fully connected layer manually to check for return status |
| 84 | +diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp |
| 85 | +index fdd7b17769..1f59547304 100644 |
| 86 | +--- a/src/cpu/cpu_inner_product_list.cpp |
| 87 | ++++ b/src/cpu/cpu_inner_product_list.cpp |
| 88 | +@@ -83,6 +83,15 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() |
| 89 | + CPU_INSTANCE(ref_inner_product_fwd_t) |
| 90 | + nullptr, |
| 91 | + }}, |
| 92 | ++ /* With graph compilation, we are able to reorder and pre-pack the weights during the model load |
| 93 | ++ * and compilation phase itself so that redundant and on-the-fly reorders can be avoided. |
| 94 | ++ * This primitive definition is to support gemm fastmath mode for the compile scenario where src is |
| 95 | ++ * in fp32 and weights are in bf16 |
| 96 | ++ */ |
| 97 | ++ {{forward, f32, bf16, f32}, { |
| 98 | ++ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) |
| 99 | ++ nullptr, |
| 100 | ++ }}, |
| 101 | + {{backward_data, f32, f32, f32}, REG_BWD_PK({ |
| 102 | + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32 |
| 103 | + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>) |
| 104 | +-- |
| 105 | +2.34.1 |
| 106 | + |
0 commit comments