-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SandboxVec][DAG] Update DAG when a new instruction is created #126124
Conversation
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-vectorizers Author: vporpo (vporpo) ChangesThe DAG will now receive a callback whenever a new instruction is created and will update itself accordingly. Full diff: https://github.com/llvm/llvm-project/pull/126124.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index f4e74fdee84c919..fab456d925526c2 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -263,6 +263,7 @@ class MemDGNode final : public DGNode {
void addMemPred(MemDGNode *PredN) {
[[maybe_unused]] auto Inserted = MemPreds.insert(PredN).second;
assert(Inserted && "PredN already exists!");
+ assert(PredN != this && "Trying to add a dependency to self!");
if (!Scheduled) {
++PredN->UnscheduledSuccs;
}
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
index 18cd29e9e14ee40..f6c5a204673372f 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
@@ -108,6 +108,10 @@ template <typename T> class Interval {
return (Top == I || Top->comesBefore(I)) &&
(I == Bottom || I->comesBefore(Bottom));
}
+ /// \Returns true if \p Elm is right before the top or right after the bottom.
+ bool touches(T *Elm) const {
+ return Top == Elm->getNextNode() || Bottom == Elm->getPrevNode();
+ }
T *top() const { return Top; }
T *bottom() const { return Bottom; }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index e03cf32be024406..2680667afc4de29 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -368,8 +368,13 @@ MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
}
void DependencyGraph::notifyCreateInstr(Instruction *I) {
- auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
- // TODO: Update the dependencies for the new node.
+ // Nothing to do if the node is not in the focus range of the DAG.
+ if (!(DAGInterval.contains(I) || DAGInterval.touches(I)))
+ return;
+ // Include `I` into the interval.
+ DAGInterval = DAGInterval.getUnionInterval({I, I});
+ auto *N = getOrCreateNode(I);
+ auto *MemN = dyn_cast<MemDGNode>(N);
// Update the MemDGNode chain if this is a memory node.
if (MemN != nullptr) {
@@ -381,6 +386,21 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
NextMemN->PrevMemN = MemN;
MemN->NextMemN = NextMemN;
}
+
+ // Add Mem dependencies.
+ // 1. Scan for deps above `I` for deps to `I`: AboveN->MemN.
+ if (DAGInterval.top()->comesBefore(I)) {
+ Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode());
+ auto SrcInterval = MemDGNodeIntervalBuilder::make(AboveIntvl, *this);
+ scanAndAddDeps(*MemN, SrcInterval);
+ }
+ // 2. Scan for deps below `I` for deps from `I`: MemN->BelowN.
+ if (I->comesBefore(DAGInterval.bottom())) {
+ Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom());
+ for (MemDGNode &BelowN :
+ MemDGNodeIntervalBuilder::make(BelowIntvl, *this))
+ scanAndAddDeps(BelowN, Interval<MemDGNode>(MemN, MemN));
+ }
}
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 263a37ac335d2ae..f1e9afefb45311b 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -832,9 +832,10 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
}
}
+// Check that the DAG gets updated when we create a new instruction.
TEST_F(DependencyGraphTest, CreateInstrCallback) {
parseIR(C, R"IR(
-define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %new1, i8 %new2) {
store i8 %v1, ptr %ptr
store i8 %v2, ptr %ptr
store i8 %v3, ptr %ptr
@@ -851,42 +852,52 @@ define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- // Check new instruction callback.
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
- DAG.extend({S1, Ret});
- auto *Arg = F->getArg(3);
+ // Create a DAG spanning S1 to S3.
+ DAG.extend({S1, S3});
+ auto *ArgNew1 = F->getArg(4);
+ auto *ArgNew2 = F->getArg(5);
auto *Ptr = S1->getPointerOperand();
+
+ auto *S1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+ sandboxir::MemDGNode *New1MemN = nullptr;
+ sandboxir::MemDGNode *New2MemN = nullptr;
{
+ // Create a new store before S3 (within the span of the DAG).
sandboxir::StoreInst *NewS =
- sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+ sandboxir::StoreInst::create(ArgNew1, Ptr, Align(8), S3->getIterator(),
/*IsVolatile=*/true, Ctx);
- auto *NewSN = DAG.getNode(NewS);
- EXPECT_TRUE(NewSN != nullptr);
-
// Check the MemDGNode chain.
- auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
- auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
- auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
- EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
- EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
- EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
- EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
+ New1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+ EXPECT_EQ(S2MemN->getNextNode(), New1MemN);
+ EXPECT_EQ(New1MemN->getPrevNode(), S2MemN);
+ EXPECT_EQ(New1MemN->getNextNode(), S3MemN);
+ EXPECT_EQ(S3MemN->getPrevNode(), New1MemN);
+
+ // Check dependencies.
+ EXPECT_TRUE(memDependency(S1MemN, New1MemN));
+ EXPECT_TRUE(memDependency(S2MemN, New1MemN));
+ EXPECT_TRUE(memDependency(New1MemN, S3MemN));
}
-
{
- // Also check if new node is at the end of the BB, after Ret.
+ // Create a new store before Ret (outside the current DAG).
sandboxir::StoreInst *NewS =
- sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
+ sandboxir::StoreInst::create(ArgNew2, Ptr, Align(8), Ret->getIterator(),
/*IsVolatile=*/true, Ctx);
// Check the MemDGNode chain.
- auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
- auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
- EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
- EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
- EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
+ New2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+ EXPECT_EQ(S3MemN->getNextNode(), New2MemN);
+ EXPECT_EQ(New2MemN->getPrevNode(), S3MemN);
+ EXPECT_EQ(New2MemN->getNextNode(), nullptr);
+
+ // Check dependencies.
+ EXPECT_TRUE(memDependency(S1MemN, New2MemN));
+ EXPECT_TRUE(memDependency(S2MemN, New2MemN));
+ EXPECT_TRUE(memDependency(New1MemN, New2MemN));
+ EXPECT_TRUE(memDependency(S3MemN, New2MemN));
}
-
- // TODO: Check the dependencies to/from NewSN after they land.
}
TEST_F(DependencyGraphTest, EraseInstrCallback) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
index 32521ed79a314be..50b5a592c9daa75 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
@@ -87,6 +87,14 @@ define void @foo(i8 %v0) {
EXPECT_FALSE(One.contains(I1));
EXPECT_FALSE(One.contains(I2));
EXPECT_FALSE(One.contains(Ret));
+ // Check touches().
+ {
+ sandboxir::Interval<sandboxir::Instruction> Intvl(I2, I2);
+ EXPECT_TRUE(Intvl.touches(I1));
+ EXPECT_FALSE(Intvl.touches(I2));
+ EXPECT_TRUE(Intvl.touches(Ret));
+ EXPECT_FALSE(Intvl.touches(I0));
+ }
// Check iterator.
auto BBIt = BB->begin();
for (auto &I : Intvl)
|
EXPECT_FALSE(Intvl.touches(I2)); | ||
EXPECT_TRUE(Intvl.touches(Ret)); | ||
EXPECT_FALSE(Intvl.touches(I0)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a EXPECT_TRUE(Intvl.contains(I2)); here for clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
The DAG will now receive a callback whenever a new instruction is created and will update itself accordingly.
…126124) The DAG will now receive a callback whenever a new instruction is created and will update itself accordingly.
The DAG will now receive a callback whenever a new instruction is created and will update itself accordingly.