|
41 | 41 | #include "SPIRVInternal.h"
|
42 | 42 | #include "libSPIRV/SPIRVDebug.h"
|
43 | 43 |
|
| 44 | +#include "llvm/ADT/StringExtras.h" // llvm::isDigit |
| 45 | +#include "llvm/Demangle/Demangle.h" |
44 | 46 | #include "llvm/IR/InstVisitor.h"
|
45 | 47 | #include "llvm/IR/Instructions.h"
|
46 | 48 | #include "llvm/IR/Operator.h"
|
@@ -104,7 +106,7 @@ class SPIRVRegularizeLLVMBase {
|
104 | 106 | void buildUMulWithOverflowFunc(Function *UMulFunc);
|
105 | 107 |
|
106 | 108 | static std::string lowerLLVMIntrinsicName(IntrinsicInst *II);
|
107 |
| - |
| 109 | + void adaptStructTypes(StructType *ST); |
108 | 110 | static char ID;
|
109 | 111 |
|
110 | 112 | private:
|
@@ -291,6 +293,58 @@ void SPIRVRegularizeLLVMBase::lowerUMulWithOverflow(
|
291 | 293 | UMulIntrinsic->setCalledFunction(UMulFunc);
|
292 | 294 | }
|
293 | 295 |
|
| 296 | +void SPIRVRegularizeLLVMBase::adaptStructTypes(StructType *ST) { |
| 297 | + if (!ST->hasName()) |
| 298 | + return; |
| 299 | + StringRef STName = ST->getName(); |
| 300 | + STName.consume_front("struct."); |
| 301 | + StringRef MangledName = STName.substr(0, STName.find('.')); |
| 302 | + |
| 303 | + // Demangle the name of a template struct and parse the template |
| 304 | + // parameters which look like: |
| 305 | + // <signed char, 2ul, 2ul, (spv::MatrixLayout)0, (spv::Scope)3> |
| 306 | + // The result should look like SPIR-V friendly LLVM IR: |
| 307 | + // %spirv.JointMatrixINTEL._char_2_2_0_3 |
| 308 | + if (MangledName.startswith("_ZTSN5__spv24__spirv_JointMatrixINTEL")) { |
| 309 | + std::string DemangledName = llvm::demangle(MangledName.str()); |
| 310 | + StringRef Name(DemangledName); |
| 311 | + Name = Name.slice(Name.find('<') + 1, Name.rfind('>')); |
| 312 | + std::stringstream SPVName; |
| 313 | + // Name = signed char, 2ul, 2ul, (spv::MatrixLayout)0, (spv::Scope)3 |
| 314 | + auto P = Name.split(", "); |
| 315 | + // P.first = "signed char |
| 316 | + // P.second = "2ul, 2ul, (spv::MatrixLayout)0, (spv::Scope)3" |
| 317 | + StringRef ElemType = P.first; |
| 318 | + // remove possile qualifiers, like "const" or "signed" |
| 319 | + ElemType.consume_back(" const"); |
| 320 | + size_t Space = ElemType.rfind(' '); |
| 321 | + if (Space != StringRef::npos) |
| 322 | + ElemType = ElemType.substr(Space + 1); |
| 323 | + P = P.second.split(", "); |
| 324 | + // P.first = "2ul" |
| 325 | + // P.second = "2ul, (spv::MatrixLayout)0, (spv::Scope)3" |
| 326 | + StringRef Rows = P.first.take_while(llvm::isDigit); |
| 327 | + P = P.second.split(", "); |
| 328 | + // P.first = "2ul" |
| 329 | + // P.second = "(spv::MatrixLayout)0, (spv::Scope)3" |
| 330 | + StringRef Cols = P.first.take_while(llvm::isDigit); |
| 331 | + P = P.second.split(", "); |
| 332 | + // P.first = "(spv::MatrixLayout)0" |
| 333 | + // P.second = "(spv::Scope)3" |
| 334 | + StringRef Layout = P.first.substr(P.first.rfind(')') + 1); |
| 335 | + StringRef Scope = P.second.substr(P.second.rfind(')') + 1); |
| 336 | + |
| 337 | + SPVName << kSPIRVTypeName::PrefixAndDelim |
| 338 | + << kSPIRVTypeName::JointMatrixINTEL << kSPIRVTypeName::Delimiter |
| 339 | + << kSPIRVTypeName::PostfixDelim << ElemType.str() |
| 340 | + << kSPIRVTypeName::PostfixDelim << Rows.str() |
| 341 | + << kSPIRVTypeName::PostfixDelim << Cols.str() |
| 342 | + << kSPIRVTypeName::PostfixDelim << Layout.str() |
| 343 | + << kSPIRVTypeName::PostfixDelim << Scope.str(); |
| 344 | + ST->setName(SPVName.str()); |
| 345 | + } |
| 346 | +} |
| 347 | + |
294 | 348 | bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) {
|
295 | 349 | M = &Module;
|
296 | 350 | Ctx = &M->getContext();
|
@@ -430,6 +484,9 @@ bool SPIRVRegularizeLLVMBase::regularize() {
|
430 | 484 | }
|
431 | 485 | }
|
432 | 486 |
|
| 487 | + for (StructType *ST : M->getIdentifiedStructTypes()) |
| 488 | + adaptStructTypes(ST); |
| 489 | + |
433 | 490 | if (SPIRVDbgSaveRegularizedModule)
|
434 | 491 | saveLLVMModule(M, RegularizedModuleTmpFile);
|
435 | 492 | return true;
|
|
0 commit comments