Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

[ESIMD] Add more tests for new xmx::dpas() #1291

Merged
merged 1 commit into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions SYCL/ESIMD/dpas/dpas_bf16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//==---------------- dpas_bf16.cpp - DPC++ ESIMD on-device test ----------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: gpu-intel-pvc || gpu-intel-dg2 || esimd_emulator
// UNSUPPORTED: cuda || hip
// RUN: %clangxx -fsycl -fsycl-device-code-split=per_kernel %s -o %t.out
// RUN: %GPU_RUN_PLACEHOLDER %t.out

// This test verifies DPAS support for bfloat16.

#include "dpas_common.hpp"

int main(int argc, const char *argv[]) {
queue Q(esimd_test::ESIMDSelector{}, esimd_test::createExceptionHandler());

bool Print = argc > 1 && std::string(argv[1]) == "-debug";
bool Passed = true;

constexpr bool LetDeduceArgs = true;
Passed &= tests<8, 8, bf16, bf16, LetDeduceArgs>(Q, Print);
Passed &= tests<8, 4, bf16, bf16, LetDeduceArgs>(Q, Print);
Passed &= tests<8, 1, bf16, bf16, LetDeduceArgs>(Q, Print);

// TODO: Enable these cases when esimd::simd(ptr) constructor is fixed.
// Passed &= tests<8, 5, bf16, bf16, LetDeduceArgs>(Q, Print);
// Passed &= tests<8, 3, bf16, bf16, LetDeduceArgs>(Q, Print);

std::cout << (Passed ? "Test Passed\n" : "Test FAILED\n");
return Passed ? 0 : 1;
}
92 changes: 70 additions & 22 deletions SYCL/ESIMD/dpas/dpas_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ std::string toString(dpas_argument_type T) {
return "bf16";
case dpas_argument_type::tf32:
return "tf32";
case dpas_argument_type::S1:
case dpas_argument_type::U1:
case dpas_argument_type::s1:
case dpas_argument_type::u1:
case dpas_argument_type::Invalid:
return "UNSUPPORTED";
}
Expand All @@ -65,7 +65,7 @@ template <dpas_argument_type T> struct DpasPrintType {
static constexpr bool is_uint = T == dpas_argument_type::u2 ||
T == dpas_argument_type::u4 ||
T == dpas_argument_type::u8;
static constexpr bool is_fp = T == dpas_argument_type::FP16 ||
static constexpr bool is_fp = T == dpas_argument_type::fp16 ||
T == dpas_argument_type::bf16 ||
T == dpas_argument_type::tf32;

Expand Down Expand Up @@ -100,7 +100,7 @@ template <dpas_argument_type T> struct DpasNaturalOperandType {
is_uint, unsigned char,
std::conditional_t<
is_fp16, sycl::half,
std::conditional<
std::conditional_t<
is_bf16, sycl::ext::oneapi::experimental::bfloat16, void>>>>;
};

Expand All @@ -123,6 +123,11 @@ template <dpas_argument_type T> constexpr int getBitSize() {

case dpas_argument_type::tf32:
return 32;

case dpas_argument_type::Invalid:
case dpas_argument_type::s1:
case dpas_argument_type::u1:
break;
}
return 0;
}
Expand Down Expand Up @@ -282,7 +287,8 @@ void printMatrix(void *Vec, std::string Msg) {
}

template <int SystolicDepth, int RepeatCount, dpas_argument_type BPrec,
dpas_argument_type APrec, bool UseSrc0>
dpas_argument_type APrec, bool UseSrc0, int ExecSize,
bool LetDeduceArgs>
bool test(queue &Q, bool Print) {
constexpr unsigned Size = 128;
constexpr unsigned VL = 16;
Expand All @@ -300,12 +306,13 @@ bool test(queue &Q, bool Print) {
// where:
constexpr int M = RepeatCount;
constexpr int K = SystolicDepth * OpsPerChannel;
constexpr int N = 16; // Execution size: 16 for PVC.
constexpr int N = ExecSize; // 16 for PVC, 8 for DG2.

auto Dev = Q.get_device();
std::cout << "Running test case " << toString(BPrec, APrec)
<< " with UseSrc0 = " << UseSrc0 << " on "
<< Dev.get_info<info::device::name>() << "\n";
std::cout << "Running on " << Dev.get_info<info::device::name>()
<< " (ExecSize = " << ExecSize << "): " << toString(BPrec, APrec)
<< ", UseSrc0 = " << UseSrc0
<< ", LetDeduceArgs = " << LetDeduceArgs << std::endl;

using ANaturalType = typename DpasNaturalOperandType<APrec>::type;
using BNaturalType = typename DpasNaturalOperandType<BPrec>::type;
Expand All @@ -317,10 +324,10 @@ bool test(queue &Q, bool Print) {
auto BPacked = aligned_alloc_shared<BNaturalType>(128, BPackedSize, Q);
auto Res = aligned_alloc_shared<ResNaturalType>(128, M * N, Q);
// Init APacked;
int Value = 0;
float Value = 1.2;
for (int II = 0; II < M; II++) {
for (int JJ = 0; JJ < K; JJ++) {
Value++;
Value += 1.1;
writeToHorizontallyPackedMatrix<M, K, APrec>(
APacked, II, JJ, static_cast<ANaturalType>(Value));
}
Expand All @@ -345,15 +352,27 @@ bool test(queue &Q, bool Print) {
simd<BNaturalType, BPackedSize> B(BPacked, overaligned_tag<16>{});
simd<ResNaturalType, M * N> C;

if constexpr (UseSrc0) {
// Compute C = C + AxB;
C = 1;
C = dpas<8, RepeatCount, ResNaturalType, ResNaturalType, BNaturalType,
ANaturalType, BPrec, APrec>(C, B, A);
if constexpr (LetDeduceArgs) {
if constexpr (UseSrc0) {
// Compute C = C + AxB;
C = 1;
C = dpas<8, RepeatCount, ResNaturalType>(C, B, A);
} else {
// Compute C = AxB;
C = dpas<8, RepeatCount, ResNaturalType>(B, A);
}

} else {
// Compute C = AxB;
C = dpas<8, RepeatCount, ResNaturalType, BNaturalType, ANaturalType,
BPrec, APrec>(B, A);
if constexpr (UseSrc0) {
// Compute C = C + AxB;
C = 1;
C = dpas<8, RepeatCount, ResNaturalType, ResNaturalType, BNaturalType,
ANaturalType, BPrec, APrec>(C, B, A);
} else {
// Compute C = AxB;
C = dpas<8, RepeatCount, ResNaturalType, BNaturalType, ANaturalType,
BPrec, APrec>(B, A);
}
}

C.copy_to(Res);
Expand Down Expand Up @@ -396,11 +415,40 @@ bool test(queue &Q, bool Print) {
}

template <int SystolicDepth, int RepeatCount, dpas_argument_type T1,
dpas_argument_type T2>
dpas_argument_type T2, bool LetDeduceArgs = false>
bool tests(queue Q, bool Print) {
bool Passed = true;
constexpr bool UseSrc0 = true;
Passed &= test<SystolicDepth, RepeatCount, T1, T2, UseSrc0>(Q, Print);
Passed &= test<SystolicDepth, RepeatCount, T1, T2, !UseSrc0>(Q, Print);
auto Dev = Q.get_device();

// Detect the execution size.
// The device trait is not implemented for esimd_emulator. Use both 8 and 16.
int ExecSize;
bool IsEmulator = false;
try {
ExecSize = Dev.get_info<ext::intel::info::device::gpu_eu_simd_width>();
} catch (sycl::exception e) {
IsEmulator = true;
}
assert((IsEmulator || (ExecSize == 8 || ExecSize == 16)) &&
"Execution size must be 8 or 16");

if (ExecSize == 16 || IsEmulator) {
Passed &=
test<SystolicDepth, RepeatCount, T1, T2, UseSrc0, 16, LetDeduceArgs>(
Q, Print);
Passed &=
test<SystolicDepth, RepeatCount, T1, T2, !UseSrc0, 16, LetDeduceArgs>(
Q, Print);
}
if (ExecSize == 8 || IsEmulator) {
Passed &=
test<SystolicDepth, RepeatCount, T1, T2, UseSrc0, 8, LetDeduceArgs>(
Q, Print);
Passed &=
test<SystolicDepth, RepeatCount, T1, T2, !UseSrc0, 8, LetDeduceArgs>(
Q, Print);
}

return Passed;
}
16 changes: 8 additions & 8 deletions SYCL/ESIMD/dpas/dpas_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: gpu-intel-pvc || esimd_emulator
// REQUIRES: gpu-intel-pvc || gpu-intel-dg2 || esimd_emulator
// UNSUPPORTED: cuda || hip
// RUN: %clangxx -fsycl %s -o %t.out
// RUN: %clangxx -fsycl -fsycl-device-code-split=per_kernel %s -o %t.out
// RUN: %GPU_RUN_PLACEHOLDER %t.out

// This test verifies DPAS support for float16.
Expand All @@ -20,14 +20,14 @@ int main(int argc, const char *argv[]) {
bool Print = argc > 1 && std::string(argv[1]) == "-debug";
bool Passed = true;

// Test unsigned 2-bit integers./
Passed &= tests<8, 8, fp16, fp16>(Q, Print);
Passed &= tests<8, 4, fp16, fp16>(Q, Print);
Passed &= tests<8, 1, fp16, fp16>(Q, Print);
constexpr bool LetDeduceArgs = true;
Passed &= tests<8, 8, fp16, fp16, LetDeduceArgs>(Q, Print);
Passed &= tests<8, 4, fp16, fp16, LetDeduceArgs>(Q, Print);
Passed &= tests<8, 1, fp16, fp16, LetDeduceArgs>(Q, Print);

// TODO: Enable these cases when esimd::simd(ptr) constructor is fixed.
// Passed &= tests<8, 5, fp16, fp16>(Q, Print);
// Passed &= tests<8, 3, fp16, fp16>(Q, Print);
// Passed &= tests<8, 5, fp16, fp16, LetDeduceArgs>(Q, Print);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we talked, these testcases should be re-visited and maybe uncommented in a follow-up patch.

// Passed &= tests<8, 3, fp16, fp16, LetDeduceArgs>(Q, Print);

std::cout << (Passed ? "Test Passed\n" : "Test FAILED\n");
return Passed ? 0 : 1;
Expand Down
10 changes: 6 additions & 4 deletions SYCL/ESIMD/dpas/dpas_int.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: gpu-intel-pvc || esimd_emulator
// REQUIRES: gpu-intel-pvc || gpu-intel-dg2 || esimd_emulator
// UNSUPPORTED: cuda || hip
// RUN: %clangxx -fsycl %s -o %t.out
// RUN: %clangxx -fsycl -fsycl-device-code-split=per_kernel %s -o %t.out
// RUN: %GPU_RUN_PLACEHOLDER %t.out

// This test verifies DPAS support for 2,4,8-bit integers.
Expand All @@ -20,7 +20,9 @@ int main(int argc, const char *argv[]) {
bool Print = argc > 1 && std::string(argv[1]) == "-debug";
bool Passed = true;

// Test unsigned 2-bit integers./
constexpr bool LetDeduceArgs = true;

// Test unsigned 2-bit integers.
Passed &= tests<8, 8, u2, u2>(Q, Print);
Passed &= tests<8, 4, u2, u2>(Q, Print);
// TODO: enable this case when the problem with simd constructor
Expand All @@ -46,7 +48,7 @@ int main(int argc, const char *argv[]) {

// Test couple combinations with 8-bit integers.
Passed &= tests<8, 8, s8, s8>(Q, Print);
Passed &= tests<8, 2, u8, s8>(Q, Print);
Passed &= tests<8, 2, u8, s8, LetDeduceArgs>(Q, Print);

// Test some mixes of 2/4/8-bit integers.
Passed &= tests<8, 8, s2, s4>(Q, Print);
Expand Down