Skip to content

[ESIMD] Implement the new non-experimental low-level API for DPAS #6834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sycl/include/sycl/ext/intel/esimd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
#include <sycl/ext/intel/esimd/detail/half_type_traits.hpp>
#include <sycl/ext/intel/esimd/simd.hpp>
#include <sycl/ext/intel/esimd/simd_view.hpp>
#include <sycl/ext/intel/esimd/xmx/dpas.hpp>
#include <sycl/ext/intel/experimental/esimd/kernel_properties.hpp>
#include <sycl/ext/intel/experimental/esimd/math.hpp>
#include <sycl/ext/intel/experimental/esimd/memory.hpp>
Expand Down
47 changes: 47 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/xmx/common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//==-------------- xmx/common.hpp - DPC++ Explicit SIMD API ----------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Explicit SIMD API types used in ESIMD Intel Xe Matrix eXtension.
//===----------------------------------------------------------------------===//

#pragma once

#include <sycl/detail/defines_elementary.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext::intel::esimd::xmx {

enum class dpas_argument_type {
Invalid = 0,
u1 = 1, // unsigned 1 bit
U1 __SYCL_DEPRECATED("use u1") = u1,
s1 = 2, // signed 1 bit
S1 __SYCL_DEPRECATED("use s1") = s1,
u2 = 3, // unsigned 2 bits
U2 __SYCL_DEPRECATED("use u2") = u2,
s2 = 4, // signed 2 bits
S2 __SYCL_DEPRECATED("use s2") = s2,
u4 = 5, // unsigned 4 bits
U4 __SYCL_DEPRECATED("use u4") = u4,
s4 = 6, // signed 4 bits
S4 __SYCL_DEPRECATED("use s4") = s4,
u8 = 7, // unsigned 8 bits
U8 __SYCL_DEPRECATED("use u8") = u8,
s8 = 8, // signed 8 bits
S8 __SYCL_DEPRECATED("use s8") = s8,
bf16 = 9, // bfloat 16
BF16 __SYCL_DEPRECATED("use bf16") = bf16,
fp16 = 10, // half float
FP16 __SYCL_DEPRECATED("use fp16") = fp16,
tf32 = 12, // tensorfloat 32
TF32 __SYCL_DEPRECATED("use tf32") = tf32
};

} // namespace ext::intel::esimd::xmx
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
340 changes: 340 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
//==----------------- xmx/dpas.hpp - DPC++ Explicit SIMD API ---------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Explicit SIMD API for DPAS Intel Xe Matrix eXtension.
//===----------------------------------------------------------------------===//

#pragma once

#include <sycl/detail/defines_elementary.hpp>
#include <sycl/ext/intel/esimd/detail/types.hpp>
#include <sycl/ext/intel/esimd/xmx/common.hpp>
#include <sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp>
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {

namespace ext::intel::esimd::xmx {

namespace detail {

template <typename T> constexpr dpas_argument_type dpas_precision_from_type() {
// TODO: add support for tfloat32 here.
if constexpr (std::is_same_v<T, sycl::half>)
return dpas_argument_type::FP16;
else if constexpr (std::is_same_v<T,
sycl::ext::oneapi::experimental::bfloat16>)
return dpas_argument_type::BF16;
else if constexpr (std::is_same_v<T, unsigned char>)
return dpas_argument_type::U8;
else if constexpr (__ESIMD_DNS::is_type<T, char, signed char>())
return dpas_argument_type::S8;
else
return dpas_argument_type::Invalid;
}

template <dpas_argument_type T> constexpr int dpas_bitsize_from_precision() {
if constexpr (T == dpas_argument_type::U2 || T == dpas_argument_type::S2)
return 2;
else if constexpr (T == dpas_argument_type::U4 || T == dpas_argument_type::S4)
return 4;
else if constexpr (T == dpas_argument_type::U8 || T == dpas_argument_type::S8)
return 8;
else if constexpr (T == dpas_argument_type::BF16 ||
T == dpas_argument_type::FP16)
return 16;
else if constexpr (T == dpas_argument_type::TF32)
return 32;
else
return -1;
}

template <int RepeatCount, int AElemBitSize, int BElemBitSize, bool IsDPASW>
constexpr void verify_repeat_count() {
static_assert(RepeatCount >= 1 && RepeatCount <= 8,
"Repeat count must be within 1 to 8 range");

if constexpr (IsDPASW && RepeatCount != 8) {
static_assert(!(AElemBitSize == 2 && BElemBitSize > 4),
"Unsupported repeat count for DPASW operation");

static_assert(
RepeatCount == 4 ||
(AElemBitSize != 2 && (AElemBitSize != 4 || BElemBitSize <= 4)),
"Unsupported repeat count for DPASW operation");
}
}

template <int SystolicDepth, int RepeatCount, typename T, typename CT,
typename BT, typename AT, dpas_argument_type BPrecision,
dpas_argument_type APrecision, int BN, int AN, bool IsDPASW = false>
constexpr int verify_parameters_and_deduce_exec_size() {

static_assert(SystolicDepth == 8, "Systolic depth must be equal to 8");
static_assert(
APrecision != dpas_argument_type::Invalid &&
BPrecision != dpas_argument_type::Invalid,
"The types of dpas arguments are either incorrect or cannot be deduced."
"Fix the types and/or explicitly specify them.");

constexpr int AElemBitSize = dpas_bitsize_from_precision<APrecision>();
constexpr int BElemBitSize = dpas_bitsize_from_precision<BPrecision>();
static_assert(AElemBitSize != -1 && BElemBitSize != -1,
"Cannot deduce element size of input arguments");
verify_repeat_count<RepeatCount, AElemBitSize, BElemBitSize, IsDPASW>();

constexpr int OpsPerChannel =
std::min(32 / std::max(AElemBitSize, BElemBitSize), 8);

// A(_Mx_K) * B(_Kx_N) + C(_Mx_N)
// where:
// _M = RepeatCount;
// _K = SystolicDepth * OpsPerChannel;
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
constexpr int _M = RepeatCount;
constexpr int _K = SystolicDepth * OpsPerChannel;

// Compute _N (aka ExecutionSize) from the matrix B.
// It has _K*_N elements of BPrecision type, and BN elements of BT type
// hold those _K*_N*BPrecision bits, which let's us compute _N.
constexpr int BMatrixBitSize = sizeof(BT) * BN * 8;
constexpr int BNumElems = BMatrixBitSize / BElemBitSize;
constexpr int _N = BNumElems / _K;
static_assert(_K * _N == BNumElems, "Cannot deduce the execution size.");

// Now verify that AN elements of AT type hold exactly _M*_K elements
// of APrecision type/size. Similarly for B: BN elements of BT type must
// hold _K*_N elements of BPrecision type/size.
// DPASW accepts 2x less expected AN elements than regular DPAS.
constexpr int AFactorForDPASW = IsDPASW ? 2 : 1;
static_assert(_M * _K * AElemBitSize == AN * sizeof(AT) * 8 * AFactorForDPASW,
"The first matrix multiplier has wrong size.");
static_assert(_K * _N * BElemBitSize == BN * sizeof(BT) * 8,
"The second matrix multiplier has wrong size.");

// Execution size may be 8 or 16 depending on the target device.
// User must check if used execution size is supported before calling DPAS.
constexpr int ExecutionSize = _N;

static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16),
"Execution size must be 8 or 16 for DPAS and 8 for DPASW.");

if constexpr (APrecision == dpas_argument_type::FP16 ||
BPrecision == dpas_argument_type::FP16) {
if constexpr (ExecutionSize == 8) {
static_assert(APrecision == BPrecision &&
__ESIMD_DNS::is_type<T, float>() &&
__ESIMD_DNS::is_type<CT, float>(),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f | f | hf | hf \n");
} else {
static_assert(APrecision == BPrecision &&
__ESIMD_DNS::is_type<T, float, sycl::half>() &&
__ESIMD_DNS::is_type<CT, float, sycl::half>(),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f, hf | f, hf | hf | hf \n");
}
} else if constexpr (APrecision == dpas_argument_type::BF16 ||
BPrecision == dpas_argument_type::BF16) {
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
if constexpr (ExecutionSize == 8) {
static_assert(APrecision == BPrecision &&
__ESIMD_DNS::is_type<T, float, bfloat16>() &&
__ESIMD_DNS::is_type<CT, float, bfloat16>(),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f | f | bf | bf \n");
} else {
static_assert(APrecision == BPrecision &&
__ESIMD_DNS::is_type<T, float, bfloat16>() &&
__ESIMD_DNS::is_type<CT, float, bfloat16>(),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f, bf | f, bf | bf | bf \n");
}
} else if constexpr (APrecision == dpas_argument_type::TF32 ||
BPrecision == dpas_argument_type::TF32) {
static_assert(ExecutionSize == 16,
"tf32 type can be used only with ExecutionSize=16");
static_assert(APrecision == BPrecision && std::is_same_v<T, float> &&
std::is_same_v<CT, float>,
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f | f | tf32 | tf32 \n");
} else {
static_assert((APrecision == dpas_argument_type::U2 ||
APrecision == dpas_argument_type::S2 ||
APrecision == dpas_argument_type::U4 ||
APrecision == dpas_argument_type::S4 ||
APrecision == dpas_argument_type::U8 ||
APrecision == dpas_argument_type::S8) &&
(BPrecision == dpas_argument_type::U2 ||
BPrecision == dpas_argument_type::S2 ||
BPrecision == dpas_argument_type::U4 ||
BPrecision == dpas_argument_type::S4 ||
BPrecision == dpas_argument_type::U8 ||
BPrecision == dpas_argument_type::S8),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n");
}
return ExecutionSize;
}

} // namespace detail

/// @defgroup sycl_esimd_xmx_systolic_array_api Systolic Array APIs.
/// APIs below are used to implement dot product accumulate systolic functions
/// @ingroup sycl_esimd

/// @addtogroup sycl_esimd_xmx_systolic_array_api
/// @{
/// DPAS (Dot Product Accumulate Systolic)
/// Computes the result of matrix operations: Result = C + A x B;
/// @param C represents DPAS accumulator operand.
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded
/// layout.
/// @param A represents the 1st matrix multiplier.
/// @return the vector value of DPAS computation result.
template <
int SystolicDepth, int RepeatCount, typename T, typename CT, typename BT,
typename AT,
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
int N, int BN, int AN>
__ESIMD_NS::simd<T, N> dpas(__ESIMD_NS::simd<CT, N> C,
__ESIMD_NS::simd<BT, BN> B,
__ESIMD_NS::simd<AT, AN> A) {
(void)detail::verify_parameters_and_deduce_exec_size<
SystolicDepth, RepeatCount, T, CT, BT, AT, BPrecision, APrecision, BN,
AN>();

constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT));
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT));
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();
using CRawT = typename __ESIMD_NS::simd<CT, N>::raw_element_type;
return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, T,
CRawT, int, int, N, BNCasted, ANCasted>(
C.data(), BCasted.data(), ACasted.data());
}

/// DPAS (Dot Product Accumulate Systolic)
/// Computes the result of matrix operations: Result = A x B;
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded
/// layout.
/// @param A represents the 1st matrix multiplier.
/// @return the vector value of DPAS computation result.
template <
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
int BN, int AN>
auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {

constexpr int ExecutionSize =
detail::verify_parameters_and_deduce_exec_size<SystolicDepth, RepeatCount,
T, T, BT, AT, BPrecision,
APrecision, BN, AN>();
// Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
// where:
// _M = RepeatCount;
// _K = SystolicDepth * OpsPerChannel;
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
constexpr int ResultN = RepeatCount * ExecutionSize;

constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT));
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT));
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();

constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
((int)APrecision << 8) + (int)BPrecision;
__ESIMD_NS::simd<T, ResultN> Result =
__esimd_dpas_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
BCasted.data(), ACasted.data());
return Result;
}

/// DPAS (Dot Product Accumulate Systolic)
/// Computes the result of matrix operations: Result = C + A x B;
/// @param C represents DPAS accumulator operand.
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded
/// layout.
/// @param A represents the 1st matrix multiplier.
/// @return the vector value of DPAS computation result.
template <
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
int N, int BN, int AN>
__ESIMD_NS::simd<T, N> dpasw(__ESIMD_NS::simd<T, N> C,
__ESIMD_NS::simd<BT, BN> B,
__ESIMD_NS::simd<AT, AN> A) {

constexpr bool IsDPASW = true;
(void)detail::verify_parameters_and_deduce_exec_size<
SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
IsDPASW>();

constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT));
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT));
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();

constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
((int)APrecision << 8) + (int)BPrecision;
return __esimd_dpasw<Info, T, int, int, N, BNCasted, ANCasted>(
C.data(), BCasted.data(), ACasted.data());
}

/// DPAS (Dot Product Accumulate Systolic)
/// Computes the result of matrix operations: Result = A x B;
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded
/// layout.
/// @param A represents the 1st matrix multiplier.
/// @return the vector value of DPAS computation result.
template <
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
int BN, int AN>
auto dpasw(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {

constexpr bool IsDPASW = true;
constexpr int ExecutionSize = detail::verify_parameters_and_deduce_exec_size<
SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
IsDPASW>();

// Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
// where:
// _M = RepeatCount;
// _K = SystolicDepth * OpsPerChannel;
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
constexpr int ResultN = RepeatCount * ExecutionSize;

constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT));
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT));
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();

constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
((int)APrecision << 8) + (int)BPrecision;
__ESIMD_NS::simd<T, ResultN> Result =
__esimd_dpasw_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
BCasted.data(), ACasted.data());
return Result;
}

/// @} sycl_esimd_xmx_systolic_array_api

} // namespace ext::intel::esimd::xmx
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
Loading