From 6a449f2b72552147ae7e0f0dbd70ad61d3843d5c Mon Sep 17 00:00:00 2001 From: Joshua Cranmer Date: Fri, 5 Aug 2022 16:30:21 -0400 Subject: [PATCH 1/5] [SYCL][SPIR-V] Change the LLVM type name of SPIR-V matrix types. This eliminates the need for the SPIR-V translator to query the pointer element type of the members in the struct to figure out what matrix type it really is. --- clang/lib/CodeGen/CodeGenTypes.cpp | 49 +++++++++++++++++++++++++++ clang/test/CodeGenSYCL/matrix.cpp | 31 +++++++++++++++++ sycl/test/matrix/matrix-int8-test.cpp | 6 ++-- 3 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 clang/test/CodeGenSYCL/matrix.cpp diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index fffaecaae3ad2..8fdce47e3529b 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -51,6 +51,55 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, StringRef suffix) { SmallString<256> TypeName; llvm::raw_svector_ostream OS(TypeName); + // If RD is spirv_JointMatrixINTEL type, mangle differently. + if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { + if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { + if (auto TemplateDecl = dyn_cast(RD)) { + auto TemplateArgs = TemplateDecl->getTemplateArgs().asArray(); + OS << "spirv.JointMatrixINTEL."; + for (auto &TemplateArg : TemplateArgs) { + OS << "_"; + if (TemplateArg.getKind() == TemplateArgument::Type) { + llvm::Type *TTy = ConvertType(TemplateArg.getAsType()); + if (TTy->isIntegerTy()) { + switch (TTy->getIntegerBitWidth()) { + case 8: + OS << "char"; + break; + case 16: + OS << "short"; + break; + case 32: + OS << "int"; + break; + case 64: + OS << "long"; + break; + default: + OS << "i" << TTy->getIntegerBitWidth(); + break; + } + } else if (TTy->isBFloatTy()) { + OS << "bfloat16"; + } else if (TTy->isStructTy()) { + StringRef LlvmTyName = TTy->getStructName(); + // Emit half/bfloat16 for cl::sycl[::*]::{half,bfloat16} + if (LlvmTyName.startswith("class.cl::sycl::") || + LlvmTyName.startswith("class.__sycl_internal::")) + LlvmTyName = LlvmTyName.rsplit("::").second; + OS << LlvmTyName; + } else { + TTy->print(OS, false, true); + } + } else if (TemplateArg.getKind() == TemplateArgument::Integral) { + OS << TemplateArg.getAsIntegral(); + } + } + Ty->setName(OS.str()); + return; + } + } + } OS << RD->getKindName() << '.'; // FIXME: We probably want to make more tweaks to the printing policy. For diff --git a/clang/test/CodeGenSYCL/matrix.cpp b/clang/test/CodeGenSYCL/matrix.cpp new file mode 100644 index 0000000000000..1cba24ee81fe3 --- /dev/null +++ b/clang/test/CodeGenSYCL/matrix.cpp @@ -0,0 +1,31 @@ +// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - -no-opaque-pointers | FileCheck %s +#include +#include + +namespace __spv { + template + struct __spirv_JointMatrixINTEL; +} + +// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1 +void f1(__spv::__spirv_JointMatrixINTEL *matrix) {} + +// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0 +void f2(__spv::__spirv_JointMatrixINTEL *matrix) {} + +// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0 +void f3(__spv::__spirv_JointMatrixINTEL *matrix) {} + +namespace cl { + namespace sycl { + class half {}; + class bfloat16 {}; + } +} +typedef cl::sycl::half my_half; + +// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0 +void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} + +// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 +void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index aa38ece1b8713..891d37aec696a 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -1,8 +1,8 @@ // RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s -// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } -// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } -// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } #include #if (SYCL_EXT_ONEAPI_MATRIX == 2) From da02eb35e5b21ba8f364d3adddac7e922801655f Mon Sep 17 00:00:00 2001 From: Joshua Cranmer Date: Mon, 8 Aug 2022 15:58:16 -0400 Subject: [PATCH 2/5] Change code formatting. --- clang/lib/CodeGen/CodeGenTypes.cpp | 13 ++++++------- clang/test/CodeGenSYCL/matrix.cpp | 5 +++++ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index 8fdce47e3529b..7ba61a69956f0 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -55,7 +55,8 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { if (auto TemplateDecl = dyn_cast(RD)) { - auto TemplateArgs = TemplateDecl->getTemplateArgs().asArray(); + ArrayRef TemplateArgs = + TemplateDecl->getTemplateArgs().asArray(); OS << "spirv.JointMatrixINTEL."; for (auto &TemplateArg : TemplateArgs) { OS << "_"; @@ -79,21 +80,19 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, OS << "i" << TTy->getIntegerBitWidth(); break; } - } else if (TTy->isBFloatTy()) { + } else if (TTy->isBFloatTy()) OS << "bfloat16"; - } else if (TTy->isStructTy()) { + else if (TTy->isStructTy()) { StringRef LlvmTyName = TTy->getStructName(); // Emit half/bfloat16 for cl::sycl[::*]::{half,bfloat16} if (LlvmTyName.startswith("class.cl::sycl::") || LlvmTyName.startswith("class.__sycl_internal::")) LlvmTyName = LlvmTyName.rsplit("::").second; OS << LlvmTyName; - } else { + } else TTy->print(OS, false, true); - } - } else if (TemplateArg.getKind() == TemplateArgument::Integral) { + } else if (TemplateArg.getKind() == TemplateArgument::Integral) OS << TemplateArg.getAsIntegral(); - } } Ty->setName(OS.str()); return; diff --git a/clang/test/CodeGenSYCL/matrix.cpp b/clang/test/CodeGenSYCL/matrix.cpp index 1cba24ee81fe3..f2589c4e07790 100644 --- a/clang/test/CodeGenSYCL/matrix.cpp +++ b/clang/test/CodeGenSYCL/matrix.cpp @@ -1,4 +1,6 @@ // RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - -no-opaque-pointers | FileCheck %s +// Test that SPIR-V codegen generates the expected LLVM struct name for the +// JointMatrixINTEL type. #include #include @@ -29,3 +31,6 @@ void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} + +// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 +void f6(__spv::__spirv_JointMatrixINTEL<__int128, 10, 2, 0, 0> *matrix) {} From 53a12bceffe7d21676c73fab689f1a25e2982831 Mon Sep 17 00:00:00 2001 From: Joshua Cranmer Date: Mon, 8 Aug 2022 16:20:11 -0400 Subject: [PATCH 3/5] Delete the __int128 test, since it's not valid for the SPIR-V target. --- clang/test/CodeGenSYCL/matrix.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/clang/test/CodeGenSYCL/matrix.cpp b/clang/test/CodeGenSYCL/matrix.cpp index f2589c4e07790..8edfc9f2ab6e9 100644 --- a/clang/test/CodeGenSYCL/matrix.cpp +++ b/clang/test/CodeGenSYCL/matrix.cpp @@ -31,6 +31,3 @@ void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} - -// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 -void f6(__spv::__spirv_JointMatrixINTEL<__int128, 10, 2, 0, 0> *matrix) {} From 290b8d27fe73889740c138b67aaece51b88e7d5b Mon Sep 17 00:00:00 2001 From: Joshua Cranmer Date: Tue, 9 Aug 2022 14:50:14 -0400 Subject: [PATCH 4/5] Add i128 check via _BitInt extension. --- clang/test/CodeGenSYCL/matrix.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/clang/test/CodeGenSYCL/matrix.cpp b/clang/test/CodeGenSYCL/matrix.cpp index 8edfc9f2ab6e9..22b65a0411345 100644 --- a/clang/test/CodeGenSYCL/matrix.cpp +++ b/clang/test/CodeGenSYCL/matrix.cpp @@ -31,3 +31,6 @@ void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} + +// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 +void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {} From cbf7da49d62326892c099e40be6bb7f064939c99 Mon Sep 17 00:00:00 2001 From: Joshua Cranmer Date: Wed, 17 Aug 2022 15:09:27 -0400 Subject: [PATCH 5/5] Update patch after cl::sycl -> sycl:: rename landed. --- clang/lib/CodeGen/CodeGenTypes.cpp | 4 ++-- clang/test/CodeGenSYCL/matrix.cpp | 12 +++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index 7ba61a69956f0..7bf7ba0b84d20 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -84,8 +84,8 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, OS << "bfloat16"; else if (TTy->isStructTy()) { StringRef LlvmTyName = TTy->getStructName(); - // Emit half/bfloat16 for cl::sycl[::*]::{half,bfloat16} - if (LlvmTyName.startswith("class.cl::sycl::") || + // Emit half/bfloat16 for sycl[::*]::{half,bfloat16} + if (LlvmTyName.startswith("class.sycl::") || LlvmTyName.startswith("class.__sycl_internal::")) LlvmTyName = LlvmTyName.rsplit("::").second; OS << LlvmTyName; diff --git a/clang/test/CodeGenSYCL/matrix.cpp b/clang/test/CodeGenSYCL/matrix.cpp index 22b65a0411345..a361518590519 100644 --- a/clang/test/CodeGenSYCL/matrix.cpp +++ b/clang/test/CodeGenSYCL/matrix.cpp @@ -18,19 +18,17 @@ void f2(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0 void f3(__spv::__spirv_JointMatrixINTEL *matrix) {} -namespace cl { - namespace sycl { - class half {}; - class bfloat16 {}; - } +namespace sycl { + class half {}; + class bfloat16 {}; } -typedef cl::sycl::half my_half; +typedef sycl::half my_half; // CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0 void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 -void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} +void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}