Skip to content

Commit ea2da57

Browse files
TylerNowickitnowicki
and
tnowicki
authored
[Coroutines] Move the SuspendCrossingInfo analysis helper into its own header/source (#106306)
* Move the SuspendCrossingInfo analysis helper into its own header/source See RFC for more info: https://discourse.llvm.org/t/rfc-abi-objects-for-coroutines/81057 Co-authored-by: tnowicki <[email protected]>
1 parent 1651014 commit ea2da57

File tree

4 files changed

+394
-310
lines changed

4 files changed

+394
-310
lines changed

llvm/lib/Transforms/Coroutines/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_llvm_component_library(LLVMCoroutines
77
CoroElide.cpp
88
CoroFrame.cpp
99
CoroSplit.cpp
10+
SuspendCrossingInfo.cpp
1011

1112
ADDITIONAL_HEADER_DIRS
1213
${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/Coroutines

llvm/lib/Transforms/Coroutines/CoroFrame.cpp

Lines changed: 12 additions & 310 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
//===----------------------------------------------------------------------===//
1717

1818
#include "CoroInternal.h"
19+
#include "SuspendCrossingInfo.h"
1920
#include "llvm/ADT/BitVector.h"
2021
#include "llvm/ADT/PostOrderIterator.h"
2122
#include "llvm/ADT/ScopeExit.h"
@@ -51,315 +52,6 @@ extern cl::opt<bool> UseNewDbgInfoFormat;
5152
// "coro-frame", which results in leaner debug spew.
5253
#define DEBUG_TYPE "coro-suspend-crossing"
5354

54-
enum { SmallVectorThreshold = 32 };
55-
56-
// Provides two way mapping between the blocks and numbers.
57-
namespace {
58-
class BlockToIndexMapping {
59-
SmallVector<BasicBlock *, SmallVectorThreshold> V;
60-
61-
public:
62-
size_t size() const { return V.size(); }
63-
64-
BlockToIndexMapping(Function &F) {
65-
for (BasicBlock &BB : F)
66-
V.push_back(&BB);
67-
llvm::sort(V);
68-
}
69-
70-
size_t blockToIndex(BasicBlock const *BB) const {
71-
auto *I = llvm::lower_bound(V, BB);
72-
assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block");
73-
return I - V.begin();
74-
}
75-
76-
BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; }
77-
};
78-
} // end anonymous namespace
79-
80-
// The SuspendCrossingInfo maintains data that allows to answer a question
81-
// whether given two BasicBlocks A and B there is a path from A to B that
82-
// passes through a suspend point.
83-
//
84-
// For every basic block 'i' it maintains a BlockData that consists of:
85-
// Consumes: a bit vector which contains a set of indices of blocks that can
86-
// reach block 'i'. A block can trivially reach itself.
87-
// Kills: a bit vector which contains a set of indices of blocks that can
88-
// reach block 'i' but there is a path crossing a suspend point
89-
// not repeating 'i' (path to 'i' without cycles containing 'i').
90-
// Suspend: a boolean indicating whether block 'i' contains a suspend point.
91-
// End: a boolean indicating whether block 'i' contains a coro.end intrinsic.
92-
// KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that
93-
// crosses a suspend point.
94-
//
95-
namespace {
96-
class SuspendCrossingInfo {
97-
BlockToIndexMapping Mapping;
98-
99-
struct BlockData {
100-
BitVector Consumes;
101-
BitVector Kills;
102-
bool Suspend = false;
103-
bool End = false;
104-
bool KillLoop = false;
105-
bool Changed = false;
106-
};
107-
SmallVector<BlockData, SmallVectorThreshold> Block;
108-
109-
iterator_range<pred_iterator> predecessors(BlockData const &BD) const {
110-
BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
111-
return llvm::predecessors(BB);
112-
}
113-
114-
BlockData &getBlockData(BasicBlock *BB) {
115-
return Block[Mapping.blockToIndex(BB)];
116-
}
117-
118-
/// Compute the BlockData for the current function in one iteration.
119-
/// Initialize - Whether this is the first iteration, we can optimize
120-
/// the initial case a little bit by manual loop switch.
121-
/// Returns whether the BlockData changes in this iteration.
122-
template <bool Initialize = false>
123-
bool computeBlockData(const ReversePostOrderTraversal<Function *> &RPOT);
124-
125-
public:
126-
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
127-
void dump() const;
128-
void dump(StringRef Label, BitVector const &BV,
129-
const ReversePostOrderTraversal<Function *> &RPOT) const;
130-
#endif
131-
132-
SuspendCrossingInfo(Function &F, coro::Shape &Shape);
133-
134-
/// Returns true if there is a path from \p From to \p To crossing a suspend
135-
/// point without crossing \p From a 2nd time.
136-
bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const {
137-
size_t const FromIndex = Mapping.blockToIndex(From);
138-
size_t const ToIndex = Mapping.blockToIndex(To);
139-
bool const Result = Block[ToIndex].Kills[FromIndex];
140-
LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
141-
<< " answer is " << Result << "\n");
142-
return Result;
143-
}
144-
145-
/// Returns true if there is a path from \p From to \p To crossing a suspend
146-
/// point without crossing \p From a 2nd time. If \p From is the same as \p To
147-
/// this will also check if there is a looping path crossing a suspend point.
148-
bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From,
149-
BasicBlock *To) const {
150-
size_t const FromIndex = Mapping.blockToIndex(From);
151-
size_t const ToIndex = Mapping.blockToIndex(To);
152-
bool Result = Block[ToIndex].Kills[FromIndex] ||
153-
(From == To && Block[ToIndex].KillLoop);
154-
LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
155-
<< " answer is " << Result << " (path or loop)\n");
156-
return Result;
157-
}
158-
159-
bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const {
160-
auto *I = cast<Instruction>(U);
161-
162-
// We rewrote PHINodes, so that only the ones with exactly one incoming
163-
// value need to be analyzed.
164-
if (auto *PN = dyn_cast<PHINode>(I))
165-
if (PN->getNumIncomingValues() > 1)
166-
return false;
167-
168-
BasicBlock *UseBB = I->getParent();
169-
170-
// As a special case, treat uses by an llvm.coro.suspend.retcon or an
171-
// llvm.coro.suspend.async as if they were uses in the suspend's single
172-
// predecessor: the uses conceptually occur before the suspend.
173-
if (isa<CoroSuspendRetconInst>(I) || isa<CoroSuspendAsyncInst>(I)) {
174-
UseBB = UseBB->getSinglePredecessor();
175-
assert(UseBB && "should have split coro.suspend into its own block");
176-
}
177-
178-
return hasPathCrossingSuspendPoint(DefBB, UseBB);
179-
}
180-
181-
bool isDefinitionAcrossSuspend(Argument &A, User *U) const {
182-
return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U);
183-
}
184-
185-
bool isDefinitionAcrossSuspend(Instruction &I, User *U) const {
186-
auto *DefBB = I.getParent();
187-
188-
// As a special case, treat values produced by an llvm.coro.suspend.*
189-
// as if they were defined in the single successor: the uses
190-
// conceptually occur after the suspend.
191-
if (isa<AnyCoroSuspendInst>(I)) {
192-
DefBB = DefBB->getSingleSuccessor();
193-
assert(DefBB && "should have split coro.suspend into its own block");
194-
}
195-
196-
return isDefinitionAcrossSuspend(DefBB, U);
197-
}
198-
199-
bool isDefinitionAcrossSuspend(Value &V, User *U) const {
200-
if (auto *Arg = dyn_cast<Argument>(&V))
201-
return isDefinitionAcrossSuspend(*Arg, U);
202-
if (auto *Inst = dyn_cast<Instruction>(&V))
203-
return isDefinitionAcrossSuspend(*Inst, U);
204-
205-
llvm_unreachable(
206-
"Coroutine could only collect Argument and Instruction now.");
207-
}
208-
};
209-
} // end anonymous namespace
210-
211-
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
212-
static std::string getBasicBlockLabel(const BasicBlock *BB) {
213-
if (BB->hasName())
214-
return BB->getName().str();
215-
216-
std::string S;
217-
raw_string_ostream OS(S);
218-
BB->printAsOperand(OS, false);
219-
return OS.str().substr(1);
220-
}
221-
222-
LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(
223-
StringRef Label, BitVector const &BV,
224-
const ReversePostOrderTraversal<Function *> &RPOT) const {
225-
dbgs() << Label << ":";
226-
for (const BasicBlock *BB : RPOT) {
227-
auto BBNo = Mapping.blockToIndex(BB);
228-
if (BV[BBNo])
229-
dbgs() << " " << getBasicBlockLabel(BB);
230-
}
231-
dbgs() << "\n";
232-
}
233-
234-
LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
235-
if (Block.empty())
236-
return;
237-
238-
BasicBlock *const B = Mapping.indexToBlock(0);
239-
Function *F = B->getParent();
240-
241-
ReversePostOrderTraversal<Function *> RPOT(F);
242-
for (const BasicBlock *BB : RPOT) {
243-
auto BBNo = Mapping.blockToIndex(BB);
244-
dbgs() << getBasicBlockLabel(BB) << ":\n";
245-
dump(" Consumes", Block[BBNo].Consumes, RPOT);
246-
dump(" Kills", Block[BBNo].Kills, RPOT);
247-
}
248-
dbgs() << "\n";
249-
}
250-
#endif
251-
252-
template <bool Initialize>
253-
bool SuspendCrossingInfo::computeBlockData(
254-
const ReversePostOrderTraversal<Function *> &RPOT) {
255-
bool Changed = false;
256-
257-
for (const BasicBlock *BB : RPOT) {
258-
auto BBNo = Mapping.blockToIndex(BB);
259-
auto &B = Block[BBNo];
260-
261-
// We don't need to count the predecessors when initialization.
262-
if constexpr (!Initialize)
263-
// If all the predecessors of the current Block don't change,
264-
// the BlockData for the current block must not change too.
265-
if (all_of(predecessors(B), [this](BasicBlock *BB) {
266-
return !Block[Mapping.blockToIndex(BB)].Changed;
267-
})) {
268-
B.Changed = false;
269-
continue;
270-
}
271-
272-
// Saved Consumes and Kills bitsets so that it is easy to see
273-
// if anything changed after propagation.
274-
auto SavedConsumes = B.Consumes;
275-
auto SavedKills = B.Kills;
276-
277-
for (BasicBlock *PI : predecessors(B)) {
278-
auto PrevNo = Mapping.blockToIndex(PI);
279-
auto &P = Block[PrevNo];
280-
281-
// Propagate Kills and Consumes from predecessors into B.
282-
B.Consumes |= P.Consumes;
283-
B.Kills |= P.Kills;
284-
285-
// If block P is a suspend block, it should propagate kills into block
286-
// B for every block P consumes.
287-
if (P.Suspend)
288-
B.Kills |= P.Consumes;
289-
}
290-
291-
if (B.Suspend) {
292-
// If block B is a suspend block, it should kill all of the blocks it
293-
// consumes.
294-
B.Kills |= B.Consumes;
295-
} else if (B.End) {
296-
// If block B is an end block, it should not propagate kills as the
297-
// blocks following coro.end() are reached during initial invocation
298-
// of the coroutine while all the data are still available on the
299-
// stack or in the registers.
300-
B.Kills.reset();
301-
} else {
302-
// This is reached when B block it not Suspend nor coro.end and it
303-
// need to make sure that it is not in the kill set.
304-
B.KillLoop |= B.Kills[BBNo];
305-
B.Kills.reset(BBNo);
306-
}
307-
308-
if constexpr (!Initialize) {
309-
B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
310-
Changed |= B.Changed;
311-
}
312-
}
313-
314-
return Changed;
315-
}
316-
317-
SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
318-
: Mapping(F) {
319-
const size_t N = Mapping.size();
320-
Block.resize(N);
321-
322-
// Initialize every block so that it consumes itself
323-
for (size_t I = 0; I < N; ++I) {
324-
auto &B = Block[I];
325-
B.Consumes.resize(N);
326-
B.Kills.resize(N);
327-
B.Consumes.set(I);
328-
B.Changed = true;
329-
}
330-
331-
// Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
332-
// the code beyond coro.end is reachable during initial invocation of the
333-
// coroutine.
334-
for (auto *CE : Shape.CoroEnds)
335-
getBlockData(CE->getParent()).End = true;
336-
337-
// Mark all suspend blocks and indicate that they kill everything they
338-
// consume. Note, that crossing coro.save also requires a spill, as any code
339-
// between coro.save and coro.suspend may resume the coroutine and all of the
340-
// state needs to be saved by that time.
341-
auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
342-
BasicBlock *SuspendBlock = BarrierInst->getParent();
343-
auto &B = getBlockData(SuspendBlock);
344-
B.Suspend = true;
345-
B.Kills |= B.Consumes;
346-
};
347-
for (auto *CSI : Shape.CoroSuspends) {
348-
markSuspendBlock(CSI);
349-
if (auto *Save = CSI->getCoroSave())
350-
markSuspendBlock(Save);
351-
}
352-
353-
// It is considered to be faster to use RPO traversal for forward-edges
354-
// dataflow analysis.
355-
ReversePostOrderTraversal<Function *> RPOT(&F);
356-
computeBlockData</*Initialize=*/true>(RPOT);
357-
while (computeBlockData</*Initialize*/ false>(RPOT))
358-
;
359-
360-
LLVM_DEBUG(dump());
361-
}
362-
36355
namespace {
36456

36557
// RematGraph is used to construct a DAG for rematerializable instructions
@@ -438,6 +130,16 @@ struct RematGraph {
438130
}
439131

440132
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
133+
static std::string getBasicBlockLabel(const BasicBlock *BB) {
134+
if (BB->hasName())
135+
return BB->getName().str();
136+
137+
std::string S;
138+
raw_string_ostream OS(S);
139+
BB->printAsOperand(OS, false);
140+
return OS.str().substr(1);
141+
}
142+
441143
void dump() const {
442144
dbgs() << "Entry (";
443145
dbgs() << getBasicBlockLabel(EntryNode->Node->getParent());
@@ -3159,7 +2861,7 @@ void coro::buildCoroutineFrame(
31592861
rewritePHIs(F);
31602862

31612863
// Build suspend crossing info.
3162-
SuspendCrossingInfo Checker(F, Shape);
2864+
SuspendCrossingInfo Checker(F, Shape.CoroSuspends, Shape.CoroEnds);
31632865

31642866
doRematerializations(F, Checker, MaterializableCallback);
31652867

0 commit comments

Comments
 (0)