From df777524a41b8b5cb5206dd4536d70724025f08e Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Thu, 7 Nov 2024 09:25:34 -0800 Subject: [PATCH] [SandboxVec][DAG] Cleanup: Move callback registration from Scheduler to DAG This is a refactoring patch that moves the callback registration for getting notified about new instructions from the scheduler to the DAG. This makes sense from a design and testing point of view: - the DAG should not rely on the scheduler for getting notified - the notifiers don't need to be public - it's easier to test the notifiers directly from within the DAG unit tests --- .../SandboxVectorizer/DependencyGraph.h | 27 ++++++-- .../Vectorize/SandboxVectorizer/Scheduler.h | 9 +-- .../SandboxVectorizer/DependencyGraphTest.cpp | 66 ++++++++++++++----- .../SandboxVectorizer/SchedulerTest.cpp | 2 +- 4 files changed, 72 insertions(+), 32 deletions(-) diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 5211c7922ea2f..765b65c4971be 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -290,6 +290,9 @@ class DependencyGraph { /// The DAG spans across all instructions in this interval. Interval DAGInterval; + Context *Ctx = nullptr; + std::optional CreateInstrCB; + std::unique_ptr BatchAA; enum class DependencyType { @@ -325,9 +328,24 @@ class DependencyGraph { /// chain. void createNewNodes(const Interval &NewInterval); + /// Called by the callbacks when a new instruction \p I has been created. + void notifyCreateInstr(Instruction *I) { + getOrCreateNode(I); + // TODO: Update the dependencies for the new node. + // TODO: Update the MemDGNode chain to include the new node if needed. + } + public: - DependencyGraph(AAResults &AA) - : BatchAA(std::make_unique(AA)) {} + /// This constructor also registers callbacks. + DependencyGraph(AAResults &AA, Context &Ctx) + : Ctx(&Ctx), BatchAA(std::make_unique(AA)) { + CreateInstrCB = Ctx.registerCreateInstrCallback( + [this](Instruction *I) { notifyCreateInstr(I); }); + } + ~DependencyGraph() { + if (CreateInstrCB) + Ctx->unregisterCreateInstrCallback(*CreateInstrCB); + } DGNode *getNode(Instruction *I) const { auto It = InstrToNodeMap.find(I); @@ -354,11 +372,6 @@ class DependencyGraph { Interval extend(ArrayRef Instrs); /// \Returns the range of instructions included in the DAG. Interval getInterval() const { return DAGInterval; } - /// Called by the scheduler when a new instruction \p I has been created. - void notifyCreateInstr(Instruction *I) { - getOrCreateNode(I); - // TODO: Update the dependencies for the new node. - } void clear() { InstrToNodeMap.clear(); DAGInterval = {}; diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h index 9c11b5dbc1643..022fd71df67dc 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h @@ -106,8 +106,6 @@ class Scheduler { std::optional ScheduleTopItOpt; // TODO: This is wasting memory in exchange for fast removal using a raw ptr. DenseMap> Bndls; - Context &Ctx; - Context::CallbackID CreateInstrCB; /// \Returns a scheduling bundle containing \p Instrs. SchedBundle *createBundle(ArrayRef Instrs); @@ -137,11 +135,8 @@ class Scheduler { Scheduler &operator=(const Scheduler &) = delete; public: - Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) { - CreateInstrCB = Ctx.registerCreateInstrCallback( - [this](Instruction *I) { DAG.notifyCreateInstr(I); }); - } - ~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); } + Scheduler(AAResults &AA, Context &Ctx) : DAG(AA, Ctx) {} + ~Scheduler() {} bool trySchedule(ArrayRef Instrs); diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index 061d57c31ce23..206f6c5b4c135 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -194,7 +194,7 @@ define void @foo(i8 %v1, ptr %ptr) { auto *Call = cast(&*It++); auto *Ret = cast(&*It++); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()}); EXPECT_TRUE(isa(DAG.getNode(Store))); EXPECT_TRUE(isa(DAG.getNode(Load))); @@ -224,7 +224,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { auto *S0 = cast(&*It++); auto *S1 = cast(&*It++); auto *Ret = cast(&*It++); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); auto Span = DAG.extend({&*BB->begin(), BB->getTerminator()}); // Check extend(). EXPECT_EQ(Span.top(), &*BB->begin()); @@ -285,7 +285,7 @@ define i8 @foo(i8 %v0, i8 %v1) { auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); auto It = BB->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto *AddN0 = DAG.getNode(cast(&*It++)); @@ -332,7 +332,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { auto *S1 = cast(&*It++); [[maybe_unused]] auto *Ret = cast(&*It++); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto *S0N = cast(DAG.getNode(S0)); @@ -366,7 +366,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { auto *S1 = cast(&*It++); auto *Ret = cast(&*It++); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto *S0N = cast(DAG.getNode(S0)); @@ -436,7 +436,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { sandboxir::Context Ctx(C); auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto It = BB->begin(); auto *Store0N = cast( @@ -461,7 +461,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) { sandboxir::Context Ctx(C); auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto It = BB->begin(); auto *Store0N = cast( @@ -487,7 +487,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) { sandboxir::Context Ctx(C); auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto It = BB->begin(); auto *Ld0N = cast( @@ -512,7 +512,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) { sandboxir::Context Ctx(C); auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto It = BB->begin(); auto *Store0N = cast( @@ -542,7 +542,7 @@ define void @foo(float %v1, float %v2) { auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); auto It = BB->begin(); @@ -574,7 +574,7 @@ define void @foo() { auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); auto It = BB->begin(); @@ -606,7 +606,7 @@ define void @foo(i8 %v0, i8 %v1, ptr %ptr) { auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); auto It = BB->begin(); @@ -637,7 +637,7 @@ define void @foo(ptr %ptr) { auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); auto It = BB->begin(); @@ -664,7 +664,7 @@ define void @foo(ptr %ptr) { auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); auto It = BB->begin(); @@ -695,7 +695,7 @@ define void @foo() { auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); auto It = BB->begin(); @@ -728,7 +728,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) { auto *S3 = cast(&*It++); auto *S4 = cast(&*It++); auto *S5 = cast(&*It++); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); { // Scenario 1: Build new DAG auto NewIntvl = DAG.extend({S3, S3}); @@ -788,7 +788,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) { { // Check UnscheduledSuccs when a node is scheduled - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({S2, S2}); auto *S2N = cast(DAG.getNode(S2)); S2N->setScheduled(true); @@ -798,3 +798,35 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) { EXPECT_EQ(S1N->getNumUnscheduledSuccs(), 0u); // S1 is scheduled } } + +TEST_F(DependencyGraphTest, CreateInstrCallback) { + parseIR(C, R"IR( +define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { + store i8 %v1, ptr %ptr + store i8 %v2, ptr %ptr + store i8 %v3, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *S1 = cast(&*It++); + [[maybe_unused]] auto *S2 = cast(&*It++); + auto *S3 = cast(&*It++); + + // Check new instruction callback. + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); + DAG.extend({S1, S3}); + auto *Arg = F->getArg(3); + auto *Ptr = S1->getPointerOperand(); + sandboxir::StoreInst *NewS = + sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(), + /*IsVolatile=*/true, Ctx); + auto *NewSN = DAG.getNode(NewS); + EXPECT_TRUE(NewSN != nullptr); + // TODO: Check the dependencies to/from NewSN after they land. + // TODO: Check the MemDGNode chain. +} diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp index 94a5791442974..c5e44a97976a7 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp @@ -70,7 +70,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { auto *S1 = cast(&*It++); auto *Ret = cast(&*It++); - sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto *SN0 = DAG.getNode(S0); auto *SN1 = DAG.getNode(S1);