diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index f4e74fdee84c9..fab456d925526 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -263,6 +263,7 @@ class MemDGNode final : public DGNode { void addMemPred(MemDGNode *PredN) { [[maybe_unused]] auto Inserted = MemPreds.insert(PredN).second; assert(Inserted && "PredN already exists!"); + assert(PredN != this && "Trying to add a dependency to self!"); if (!Scheduled) { ++PredN->UnscheduledSuccs; } diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h index 18cd29e9e14ee..f6c5a20467337 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h @@ -108,6 +108,10 @@ template class Interval { return (Top == I || Top->comesBefore(I)) && (I == Bottom || I->comesBefore(Bottom)); } + /// \Returns true if \p Elm is right before the top or right after the bottom. + bool touches(T *Elm) const { + return Top == Elm->getNextNode() || Bottom == Elm->getPrevNode(); + } T *top() const { return Top; } T *bottom() const { return Bottom; } diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index e03cf32be0244..2680667afc4de 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -368,8 +368,13 @@ MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN, } void DependencyGraph::notifyCreateInstr(Instruction *I) { - auto *MemN = dyn_cast(getOrCreateNode(I)); - // TODO: Update the dependencies for the new node. + // Nothing to do if the node is not in the focus range of the DAG. + if (!(DAGInterval.contains(I) || DAGInterval.touches(I))) + return; + // Include `I` into the interval. + DAGInterval = DAGInterval.getUnionInterval({I, I}); + auto *N = getOrCreateNode(I); + auto *MemN = dyn_cast(N); // Update the MemDGNode chain if this is a memory node. if (MemN != nullptr) { @@ -381,6 +386,21 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) { NextMemN->PrevMemN = MemN; MemN->NextMemN = NextMemN; } + + // Add Mem dependencies. + // 1. Scan for deps above `I` for deps to `I`: AboveN->MemN. + if (DAGInterval.top()->comesBefore(I)) { + Interval AboveIntvl(DAGInterval.top(), I->getPrevNode()); + auto SrcInterval = MemDGNodeIntervalBuilder::make(AboveIntvl, *this); + scanAndAddDeps(*MemN, SrcInterval); + } + // 2. Scan for deps below `I` for deps from `I`: MemN->BelowN. + if (I->comesBefore(DAGInterval.bottom())) { + Interval BelowIntvl(I->getNextNode(), DAGInterval.bottom()); + for (MemDGNode &BelowN : + MemDGNodeIntervalBuilder::make(BelowIntvl, *this)) + scanAndAddDeps(BelowN, Interval(MemN, MemN)); + } } } diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index 263a37ac335d2..f1e9afefb4531 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -832,9 +832,10 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) { } } +// Check that the DAG gets updated when we create a new instruction. TEST_F(DependencyGraphTest, CreateInstrCallback) { parseIR(C, R"IR( -define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { +define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %new1, i8 %new2) { store i8 %v1, ptr %ptr store i8 %v2, ptr %ptr store i8 %v3, ptr %ptr @@ -851,42 +852,52 @@ define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { auto *S3 = cast(&*It++); auto *Ret = cast(&*It++); - // Check new instruction callback. sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); - DAG.extend({S1, Ret}); - auto *Arg = F->getArg(3); + // Create a DAG spanning S1 to S3. + DAG.extend({S1, S3}); + auto *ArgNew1 = F->getArg(4); + auto *ArgNew2 = F->getArg(5); auto *Ptr = S1->getPointerOperand(); + + auto *S1MemN = cast(DAG.getNode(S1)); + auto *S2MemN = cast(DAG.getNode(S2)); + auto *S3MemN = cast(DAG.getNode(S3)); + sandboxir::MemDGNode *New1MemN = nullptr; + sandboxir::MemDGNode *New2MemN = nullptr; { + // Create a new store before S3 (within the span of the DAG). sandboxir::StoreInst *NewS = - sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(), + sandboxir::StoreInst::create(ArgNew1, Ptr, Align(8), S3->getIterator(), /*IsVolatile=*/true, Ctx); - auto *NewSN = DAG.getNode(NewS); - EXPECT_TRUE(NewSN != nullptr); - // Check the MemDGNode chain. - auto *S2MemN = cast(DAG.getNode(S2)); - auto *NewMemSN = cast(NewSN); - auto *S3MemN = cast(DAG.getNode(S3)); - EXPECT_EQ(S2MemN->getNextNode(), NewMemSN); - EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN); - EXPECT_EQ(NewMemSN->getNextNode(), S3MemN); - EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN); + New1MemN = cast(DAG.getNode(NewS)); + EXPECT_EQ(S2MemN->getNextNode(), New1MemN); + EXPECT_EQ(New1MemN->getPrevNode(), S2MemN); + EXPECT_EQ(New1MemN->getNextNode(), S3MemN); + EXPECT_EQ(S3MemN->getPrevNode(), New1MemN); + + // Check dependencies. + EXPECT_TRUE(memDependency(S1MemN, New1MemN)); + EXPECT_TRUE(memDependency(S2MemN, New1MemN)); + EXPECT_TRUE(memDependency(New1MemN, S3MemN)); } - { - // Also check if new node is at the end of the BB, after Ret. + // Create a new store before Ret (outside the current DAG). sandboxir::StoreInst *NewS = - sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(), + sandboxir::StoreInst::create(ArgNew2, Ptr, Align(8), Ret->getIterator(), /*IsVolatile=*/true, Ctx); // Check the MemDGNode chain. - auto *S3MemN = cast(DAG.getNode(S3)); - auto *NewMemSN = cast(DAG.getNode(NewS)); - EXPECT_EQ(S3MemN->getNextNode(), NewMemSN); - EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN); - EXPECT_EQ(NewMemSN->getNextNode(), nullptr); + New2MemN = cast(DAG.getNode(NewS)); + EXPECT_EQ(S3MemN->getNextNode(), New2MemN); + EXPECT_EQ(New2MemN->getPrevNode(), S3MemN); + EXPECT_EQ(New2MemN->getNextNode(), nullptr); + + // Check dependencies. + EXPECT_TRUE(memDependency(S1MemN, New2MemN)); + EXPECT_TRUE(memDependency(S2MemN, New2MemN)); + EXPECT_TRUE(memDependency(New1MemN, New2MemN)); + EXPECT_TRUE(memDependency(S3MemN, New2MemN)); } - - // TODO: Check the dependencies to/from NewSN after they land. } TEST_F(DependencyGraphTest, EraseInstrCallback) { diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp index 32521ed79a314..59498371b4d73 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp @@ -87,6 +87,15 @@ define void @foo(i8 %v0) { EXPECT_FALSE(One.contains(I1)); EXPECT_FALSE(One.contains(I2)); EXPECT_FALSE(One.contains(Ret)); + // Check touches(). + { + sandboxir::Interval Intvl(I2, I2); + EXPECT_TRUE(Intvl.touches(I1)); + EXPECT_TRUE(Intvl.contains(I2)); + EXPECT_FALSE(Intvl.touches(I2)); + EXPECT_TRUE(Intvl.touches(Ret)); + EXPECT_FALSE(Intvl.touches(I0)); + } // Check iterator. auto BBIt = BB->begin(); for (auto &I : Intvl)