Skip to content

Commit 664d87d

Browse files
committed
Support complex recursive types that use an array of pointers
The use of pointer arrays inside the recursive type leads to the SPIR-V IR, where the TypePointer is declared after TypeArray. In turn, TypeArray can't handle the unknown element type, because the info from TypeForwardPointer is not used. Now the translator stores forward declared IDs to use them in backward translation. It actually means that the translator "knows" about the type but it would be handled later. Original commit: KhronosGroup/SPIRV-LLVM-Translator@8a0235a
1 parent 6c44f4c commit 664d87d

File tree

7 files changed

+62
-13
lines changed

7 files changed

+62
-13
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,12 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool IsClassMember) {
475475
getSPIRVTypeName(kSPIRVTypeName::JointMatrixINTEL, SS.str());
476476
return mapType(T, getOrCreateOpaquePtrType(M, Name));
477477
}
478+
case OpTypeForwardPointer: {
479+
SPIRVTypeForwardPointer *FP =
480+
static_cast<SPIRVTypeForwardPointer *>(static_cast<SPIRVEntry *>(T));
481+
return mapType(T, transType(static_cast<SPIRVType *>(
482+
BM->getEntry(FP->getPointerId()))));
483+
}
478484

479485
default: {
480486
auto OC = T->getOpCode();

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ class SPIRVModuleImpl : public SPIRVModule {
509509
SPIRVForwardPointerVec ForwardPointerVec;
510510
SPIRVTypeVec TypeVec;
511511
SPIRVIdToEntryMap IdEntryMap;
512+
SPIRVIdToEntryMap IdTypeForwardMap; // Forward declared IDs
512513
SPIRVFunctionVector FuncVec;
513514
SPIRVConstantVector ConstVec;
514515
SPIRVVariableVec VariableVec;
@@ -706,9 +707,15 @@ SPIRVEntry *SPIRVModuleImpl::addEntry(SPIRVEntry *Entry) {
706707
} else
707708
IdEntryMap[Id] = Entry;
708709
} else {
710+
// Collect entries with no ID to de-allocate them at the end.
709711
// Entry of OpLine will be deleted by std::shared_ptr automatically.
710712
if (Entry->getOpCode() != OpLine)
711713
EntryNoId.insert(Entry);
714+
715+
// Store the known ID of pointer type that would be declared later.
716+
if (Entry->getOpCode() == OpTypeForwardPointer)
717+
IdTypeForwardMap[static_cast<SPIRVTypeForwardPointer *>(Entry)
718+
->getPointerId()] = Entry;
712719
}
713720

714721
Entry->setModule(this);
@@ -762,8 +769,15 @@ SPIRVId SPIRVModuleImpl::getId(SPIRVId Id, unsigned Increment) {
762769
SPIRVEntry *SPIRVModuleImpl::getEntry(SPIRVId Id) const {
763770
assert(Id != SPIRVID_INVALID && "Invalid Id");
764771
SPIRVIdToEntryMap::const_iterator Loc = IdEntryMap.find(Id);
765-
assert(Loc != IdEntryMap.end() && "Id is not in map");
766-
return Loc->second;
772+
if (Loc != IdEntryMap.end()) {
773+
return Loc->second;
774+
}
775+
SPIRVIdToEntryMap::const_iterator LocFwd = IdTypeForwardMap.find(Id);
776+
if (LocFwd != IdTypeForwardMap.end()) {
777+
return LocFwd->second;
778+
}
779+
assert(false && "Id is not in map");
780+
return nullptr;
767781
}
768782

769783
SPIRVExtInstSetKind SPIRVModuleImpl::getBuiltinSet(SPIRVId SetId) const {
@@ -1732,6 +1746,11 @@ class TopologicalSort {
17321746
return true;
17331747
State = Discovered;
17341748
for (SPIRVEntry *Op : E->getNonLiteralOperands()) {
1749+
if (Op->getOpCode() == OpTypeForwardPointer) {
1750+
SPIRVEntry *FP = E->getModule()->getEntry(
1751+
static_cast<SPIRVTypeForwardPointer *>(Op)->getPointerId());
1752+
Op = FP;
1753+
}
17351754
if (EntryStateMap[Op] == Visited)
17361755
continue;
17371756
if (visit(Op)) {
@@ -1745,7 +1764,7 @@ class TopologicalSort {
17451764
SPIRVTypePointer *Ptr = static_cast<SPIRVTypePointer *>(E);
17461765
SPIRVModule *BM = E->getModule();
17471766
ForwardPointerSet.insert(BM->add(new SPIRVTypeForwardPointer(
1748-
BM, Ptr, Ptr->getPointerStorageClass())));
1767+
BM, Ptr->getId(), Ptr->getPointerStorageClass())));
17491768
return false;
17501769
}
17511770
return true;
@@ -1776,11 +1795,11 @@ class TopologicalSort {
17761795
: ForwardPointerSet(
17771796
16, // bucket count
17781797
[](const SPIRVTypeForwardPointer *Ptr) {
1779-
return std::hash<SPIRVId>()(Ptr->getPointer()->getId());
1798+
return std::hash<SPIRVId>()(Ptr->getPointerId());
17801799
},
17811800
[](const SPIRVTypeForwardPointer *Ptr1,
17821801
const SPIRVTypeForwardPointer *Ptr2) {
1783-
return Ptr1->getPointer()->getId() == Ptr2->getPointer()->getId();
1802+
return Ptr1->getPointerId() == Ptr2->getPointerId();
17841803
}),
17851804
EntryStateMap([](SPIRVEntry *A, SPIRVEntry *B) -> bool {
17861805
return A->getId() < B->getId();

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVStream.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ template <class T> const SPIRVEncoder &encode(const SPIRVEncoder &O, T V) {
121121
return O << static_cast<SPIRVWord>(V);
122122
}
123123

124+
template <>
125+
const SPIRVEncoder &operator<<(const SPIRVEncoder &O, SPIRVType *P) {
126+
if (!P->hasId() && P->getOpCode() == OpTypeForwardPointer)
127+
return O << static_cast<SPIRVTypeForwardPointer *>(
128+
static_cast<SPIRVEntry *>(P))
129+
->getPointerId();
130+
return O << P->getId();
131+
}
132+
124133
#define SPIRV_DEF_ENCDEC(Type) \
125134
const SPIRVDecoder &operator>>(const SPIRVDecoder &I, Type &V) { \
126135
return decode(I, V); \

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVStream.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ template <typename T>
193193
const SPIRVEncoder &operator<<(const SPIRVEncoder &O, T *P) {
194194
return O << P->getId();
195195
}
196+
template <> const SPIRVEncoder &operator<<(const SPIRVEncoder &O, SPIRVType *P);
196197

197198
template <typename T>
198199
const SPIRVEncoder &operator<<(const SPIRVEncoder &O, const std::vector<T> &V) {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,11 @@ SPIRVConstant *SPIRVTypeArray::getLength() const {
266266
_SPIRV_IMP_ENCDEC3(SPIRVTypeArray, Id, ElemType, Length)
267267

268268
void SPIRVTypeForwardPointer::encode(spv_ostream &O) const {
269-
getEncoder(O) << Pointer << SC;
269+
getEncoder(O) << PointerId << SC;
270270
}
271271

272272
void SPIRVTypeForwardPointer::decode(std::istream &I) {
273273
auto Decoder = getDecoder(I);
274-
SPIRVId PointerId;
275274
Decoder >> PointerId >> SC;
276275
}
277276

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,17 +274,17 @@ class SPIRVTypePointer : public SPIRVType {
274274

275275
class SPIRVTypeForwardPointer : public SPIRVEntryNoId<OpTypeForwardPointer> {
276276
public:
277-
SPIRVTypeForwardPointer(SPIRVModule *M, SPIRVTypePointer *Pointer,
277+
SPIRVTypeForwardPointer(SPIRVModule *M, SPIRVId PointerId,
278278
SPIRVStorageClassKind SC)
279-
: SPIRVEntryNoId(M, 3), Pointer(Pointer), SC(SC) {}
279+
: SPIRVEntryNoId(M, 3), PointerId(PointerId), SC(SC) {}
280280

281281
SPIRVTypeForwardPointer()
282-
: Pointer(nullptr), SC(StorageClassUniformConstant) {}
282+
: PointerId(SPIRVID_INVALID), SC(StorageClassUniformConstant) {}
283283

284-
SPIRVTypePointer *getPointer() const { return Pointer; }
284+
SPIRVId getPointerId() const { return PointerId; }
285285
_SPIRV_DCL_ENCDEC
286286
private:
287-
SPIRVTypePointer *Pointer;
287+
SPIRVId PointerId;
288288
SPIRVStorageClassKind SC;
289289
};
290290

llvm-spirv/test/transcoding/RecursiveType.ll

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,49 @@
33
; RUN: FileCheck < %t.txt %s --check-prefix=CHECK-SPIRV
44
; RUN: llvm-spirv %t.bc -o %t.spv
55
; RUN: spirv-val %t.spv
6+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
7+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
8+
69
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
710
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
811

12+
; RUN: llvm-spirv -r %t.spt -spirv-text -o %t.rev.bc
13+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
14+
915
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
1016
target triple = "spir-unknown-unknown"
1117

1218
%struct.A = type { i32, %struct.C }
1319
%struct.C = type { i32, %struct.B }
1420
%struct.B = type { i32, %struct.A addrspace(4)* }
1521
%struct.Node = type { %struct.Node addrspace(1)*, i32 }
22+
%struct.Flag = type { [3 x %struct.Flag addrspace(3)*] }
1623

24+
; CHECK-SPIRV-DAG: 3 TypeForwardPointer [[FlagFwdPtr:[0-9]+]] 4
1725
; CHECK-SPIRV-DAG: 3 TypeForwardPointer [[NodeFwdPtr:[0-9]+]] 5
1826
; CHECK-SPIRV-DAG: 3 TypeForwardPointer [[AFwdPtr:[0-9]+]] 8
1927
; CHECK-SPIRV: 4 TypeInt [[IntID:[0-9]+]] 32 0
2028
; CHECK-SPIRV: 4 TypeStruct [[BID:[0-9]+]] {{[0-9]+}} [[AFwdPtr]]
2129
; CHECK-SPIRV: 4 TypeStruct [[CID:[0-9]+]] {{[0-9]+}} [[BID]]
2230
; CHECK-SPIRV: 4 TypeStruct [[AID:[0-9]+]] {{[0-9]+}} [[CID]]
31+
2332
; CHECK-SPIRV: 4 TypePointer [[AFwdPtr]] 8 [[AID:[0-9]+]]
2433
; CHECK-SPIRV: 4 TypeStruct [[NodeID:[0-9]+]] [[NodeFwdPtr]]
34+
2535
; CHECK-SPIRV: 4 TypePointer [[NodeFwdPtr]] 5 [[NodeID]]
36+
; CHECK-SPIRV: 4 TypeArray [[FlagID:[0-9]+]] [[FlagFwdPtr]]
37+
; CHECK-SPIRV: 3 TypeStruct [[FlagStructID:[0-9]+]] [[FlagID]]
38+
39+
; CHECK-SPIRV: 4 TypePointer [[FlagFwdPtr]] 4 [[FlagStructID]]
2640

2741
; CHECK-LLVM: %struct.A = type { i32, %struct.C }
2842
; CHECK-LLVM: %struct.C = type { i32, %struct.B }
2943
; CHECK-LLVM: %struct.B = type { i32, %struct.A addrspace(4)* }
3044
; CHECK-LLVM: %struct.Node = type { %struct.Node addrspace(1)*, i32 }
45+
; CHECK-LLVM: %struct.Flag = type { [3 x %struct.Flag addrspace(3)*] }
3146

3247
; Function Attrs: nounwind
33-
define spir_kernel void @test(%struct.A addrspace(1)* %result, %struct.Node addrspace(1)* %node) #0 !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
48+
define spir_kernel void @test(%struct.A addrspace(1)* %result, %struct.Node addrspace(1)* %node, %struct.Flag addrspace(1)* %flag) #0 !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
3449
ret void
3550
}
3651

0 commit comments

Comments
 (0)