@@ -509,6 +509,7 @@ class SPIRVModuleImpl : public SPIRVModule {
509
509
SPIRVForwardPointerVec ForwardPointerVec;
510
510
SPIRVTypeVec TypeVec;
511
511
SPIRVIdToEntryMap IdEntryMap;
512
+ SPIRVIdToEntryMap IdTypeForwardMap; // Forward declared IDs
512
513
SPIRVFunctionVector FuncVec;
513
514
SPIRVConstantVector ConstVec;
514
515
SPIRVVariableVec VariableVec;
@@ -706,9 +707,15 @@ SPIRVEntry *SPIRVModuleImpl::addEntry(SPIRVEntry *Entry) {
706
707
} else
707
708
IdEntryMap[Id] = Entry;
708
709
} else {
710
+ // Collect entries with no ID to de-allocate them at the end.
709
711
// Entry of OpLine will be deleted by std::shared_ptr automatically.
710
712
if (Entry->getOpCode () != OpLine)
711
713
EntryNoId.insert (Entry);
714
+
715
+ // Store the known ID of pointer type that would be declared later.
716
+ if (Entry->getOpCode () == OpTypeForwardPointer)
717
+ IdTypeForwardMap[static_cast <SPIRVTypeForwardPointer *>(Entry)
718
+ ->getPointerId ()] = Entry;
712
719
}
713
720
714
721
Entry->setModule (this );
@@ -762,8 +769,15 @@ SPIRVId SPIRVModuleImpl::getId(SPIRVId Id, unsigned Increment) {
762
769
SPIRVEntry *SPIRVModuleImpl::getEntry (SPIRVId Id) const {
763
770
assert (Id != SPIRVID_INVALID && " Invalid Id" );
764
771
SPIRVIdToEntryMap::const_iterator Loc = IdEntryMap.find (Id);
765
- assert (Loc != IdEntryMap.end () && " Id is not in map" );
766
- return Loc->second ;
772
+ if (Loc != IdEntryMap.end ()) {
773
+ return Loc->second ;
774
+ }
775
+ SPIRVIdToEntryMap::const_iterator LocFwd = IdTypeForwardMap.find (Id);
776
+ if (LocFwd != IdTypeForwardMap.end ()) {
777
+ return LocFwd->second ;
778
+ }
779
+ assert (false && " Id is not in map" );
780
+ return nullptr ;
767
781
}
768
782
769
783
SPIRVExtInstSetKind SPIRVModuleImpl::getBuiltinSet (SPIRVId SetId) const {
@@ -1732,6 +1746,11 @@ class TopologicalSort {
1732
1746
return true ;
1733
1747
State = Discovered;
1734
1748
for (SPIRVEntry *Op : E->getNonLiteralOperands ()) {
1749
+ if (Op->getOpCode () == OpTypeForwardPointer) {
1750
+ SPIRVEntry *FP = E->getModule ()->getEntry (
1751
+ static_cast <SPIRVTypeForwardPointer *>(Op)->getPointerId ());
1752
+ Op = FP;
1753
+ }
1735
1754
if (EntryStateMap[Op] == Visited)
1736
1755
continue ;
1737
1756
if (visit (Op)) {
@@ -1745,7 +1764,7 @@ class TopologicalSort {
1745
1764
SPIRVTypePointer *Ptr = static_cast <SPIRVTypePointer *>(E);
1746
1765
SPIRVModule *BM = E->getModule ();
1747
1766
ForwardPointerSet.insert (BM->add (new SPIRVTypeForwardPointer (
1748
- BM, Ptr, Ptr->getPointerStorageClass ())));
1767
+ BM, Ptr-> getId () , Ptr->getPointerStorageClass ())));
1749
1768
return false ;
1750
1769
}
1751
1770
return true ;
@@ -1776,11 +1795,11 @@ class TopologicalSort {
1776
1795
: ForwardPointerSet(
1777
1796
16 , // bucket count
1778
1797
[](const SPIRVTypeForwardPointer *Ptr) {
1779
- return std::hash<SPIRVId>()(Ptr->getPointer ()-> getId ());
1798
+ return std::hash<SPIRVId>()(Ptr->getPointerId ());
1780
1799
},
1781
1800
[](const SPIRVTypeForwardPointer *Ptr1,
1782
1801
const SPIRVTypeForwardPointer *Ptr2) {
1783
- return Ptr1->getPointer ()-> getId () == Ptr2->getPointer ()-> getId ();
1802
+ return Ptr1->getPointerId () == Ptr2->getPointerId ();
1784
1803
}),
1785
1804
EntryStateMap ([](SPIRVEntry *A, SPIRVEntry *B) -> bool {
1786
1805
return A->getId () < B->getId ();
0 commit comments