Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ class DependencyGraph {
/// The DAG spans across all instructions in this interval.
Interval<Instruction> DAGInterval;

Context *Ctx = nullptr;
std::optional<Context::CallbackID> CreateInstrCB;

std::unique_ptr<BatchAAResults> BatchAA;

enum class DependencyType {
Expand Down Expand Up @@ -325,9 +328,24 @@ class DependencyGraph {
/// chain.
void createNewNodes(const Interval<Instruction> &NewInterval);

/// Called by the callbacks when a new instruction \p I has been created.
void notifyCreateInstr(Instruction *I) {
getOrCreateNode(I);
// TODO: Update the dependencies for the new node.
// TODO: Update the MemDGNode chain to include the new node if needed.
}

public:
DependencyGraph(AAResults &AA)
: BatchAA(std::make_unique<BatchAAResults>(AA)) {}
/// This constructor also registers callbacks.
DependencyGraph(AAResults &AA, Context &Ctx)
: Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
CreateInstrCB = Ctx.registerCreateInstrCallback(
[this](Instruction *I) { notifyCreateInstr(I); });
}
~DependencyGraph() {
if (CreateInstrCB)
Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
}

DGNode *getNode(Instruction *I) const {
auto It = InstrToNodeMap.find(I);
Expand All @@ -354,11 +372,6 @@ class DependencyGraph {
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
/// \Returns the range of instructions included in the DAG.
Interval<Instruction> getInterval() const { return DAGInterval; }
/// Called by the scheduler when a new instruction \p I has been created.
void notifyCreateInstr(Instruction *I) {
getOrCreateNode(I);
// TODO: Update the dependencies for the new node.
}
void clear() {
InstrToNodeMap.clear();
DAGInterval = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ class Scheduler {
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
// TODO: This is wasting memory in exchange for fast removal using a raw ptr.
DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
Context &Ctx;
Context::CallbackID CreateInstrCB;

/// \Returns a scheduling bundle containing \p Instrs.
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
Expand Down Expand Up @@ -137,11 +135,8 @@ class Scheduler {
Scheduler &operator=(const Scheduler &) = delete;

public:
Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
CreateInstrCB = Ctx.registerCreateInstrCallback(
[this](Instruction *I) { DAG.notifyCreateInstr(I); });
}
~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }
Scheduler(AAResults &AA, Context &Ctx) : DAG(AA, Ctx) {}
~Scheduler() {}

bool trySchedule(ArrayRef<Instruction *> Instrs);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ define void @foo(i8 %v1, ptr %ptr) {
auto *Call = cast<sandboxir::CallInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Store)));
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Load)));
Expand Down Expand Up @@ -224,7 +224,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
auto Span = DAG.extend({&*BB->begin(), BB->getTerminator()});
// Check extend().
EXPECT_EQ(Span.top(), &*BB->begin());
Expand Down Expand Up @@ -285,7 +285,7 @@ define i8 @foo(i8 %v0, i8 %v1) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});

auto *AddN0 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
Expand Down Expand Up @@ -332,7 +332,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});

auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
Expand Down Expand Up @@ -366,7 +366,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});

auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
Expand Down Expand Up @@ -436,7 +436,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Store0N = cast<sandboxir::MemDGNode>(
Expand All @@ -461,7 +461,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Store0N = cast<sandboxir::MemDGNode>(
Expand All @@ -487,7 +487,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Ld0N = cast<sandboxir::MemDGNode>(
Expand All @@ -512,7 +512,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) {
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Store0N = cast<sandboxir::MemDGNode>(
Expand Down Expand Up @@ -542,7 +542,7 @@ define void @foo(float %v1, float %v2) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});

auto It = BB->begin();
Expand Down Expand Up @@ -574,7 +574,7 @@ define void @foo() {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});

auto It = BB->begin();
Expand Down Expand Up @@ -606,7 +606,7 @@ define void @foo(i8 %v0, i8 %v1, ptr %ptr) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});

auto It = BB->begin();
Expand Down Expand Up @@ -637,7 +637,7 @@ define void @foo(ptr %ptr) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});

auto It = BB->begin();
Expand All @@ -664,7 +664,7 @@ define void @foo(ptr %ptr) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});

auto It = BB->begin();
Expand Down Expand Up @@ -695,7 +695,7 @@ define void @foo() {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});

auto It = BB->begin();
Expand Down Expand Up @@ -728,7 +728,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
auto *S4 = cast<sandboxir::StoreInst>(&*It++);
auto *S5 = cast<sandboxir::StoreInst>(&*It++);
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
{
// Scenario 1: Build new DAG
auto NewIntvl = DAG.extend({S3, S3});
Expand Down Expand Up @@ -788,7 +788,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {

{
// Check UnscheduledSuccs when a node is scheduled
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({S2, S2});
auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
S2N->setScheduled(true);
Expand All @@ -798,3 +798,35 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
EXPECT_EQ(S1N->getNumUnscheduledSuccs(), 0u); // S1 is scheduled
}
}

TEST_F(DependencyGraphTest, CreateInstrCallback) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
store i8 %v1, ptr %ptr
store i8 %v2, ptr %ptr
store i8 %v3, ptr %ptr
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
[[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
auto *S3 = cast<sandboxir::StoreInst>(&*It++);

// Check new instruction callback.
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({S1, S3});
auto *Arg = F->getArg(3);
auto *Ptr = S1->getPointerOperand();
sandboxir::StoreInst *NewS =
sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
/*IsVolatile=*/true, Ctx);
auto *NewSN = DAG.getNode(NewS);
EXPECT_TRUE(NewSN != nullptr);
// TODO: Check the dependencies to/from NewSN after they land.
// TODO: Check the MemDGNode chain.
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto *SN0 = DAG.getNode(S0);
auto *SN1 = DAG.getNode(S1);
Expand Down
Loading