Skip to content

Commit f8e5e64

Browse files
committed
[SandboxVec][DAG] Update DAG when a new instruction is created
The DAG will now receive a callback whenever a new instruction is created and will update itself accordingly.
1 parent 788c88e commit f8e5e64

File tree

5 files changed

+71
-27
lines changed

5 files changed

+71
-27
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

+1
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ class MemDGNode final : public DGNode {
263263
void addMemPred(MemDGNode *PredN) {
264264
[[maybe_unused]] auto Inserted = MemPreds.insert(PredN).second;
265265
assert(Inserted && "PredN already exists!");
266+
assert(PredN != this && "Trying to add a dependency to self!");
266267
if (!Scheduled) {
267268
++PredN->UnscheduledSuccs;
268269
}

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h

+4
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ template <typename T> class Interval {
108108
return (Top == I || Top->comesBefore(I)) &&
109109
(I == Bottom || I->comesBefore(Bottom));
110110
}
111+
/// \Returns true if \p Elm is right before the top or right after the bottom.
112+
bool touches(T *Elm) const {
113+
return Top == Elm->getNextNode() || Bottom == Elm->getPrevNode();
114+
}
111115
T *top() const { return Top; }
112116
T *bottom() const { return Bottom; }
113117

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

+22-2
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,13 @@ MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
368368
}
369369

370370
void DependencyGraph::notifyCreateInstr(Instruction *I) {
371-
auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
372-
// TODO: Update the dependencies for the new node.
371+
// Nothing to do if the node is not in the focus range of the DAG.
372+
if (!(DAGInterval.contains(I) || DAGInterval.touches(I)))
373+
return;
374+
// Include `I` into the interval.
375+
DAGInterval = DAGInterval.getUnionInterval({I, I});
376+
auto *N = getOrCreateNode(I);
377+
auto *MemN = dyn_cast<MemDGNode>(N);
373378

374379
// Update the MemDGNode chain if this is a memory node.
375380
if (MemN != nullptr) {
@@ -381,6 +386,21 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
381386
NextMemN->PrevMemN = MemN;
382387
MemN->NextMemN = NextMemN;
383388
}
389+
390+
// Add Mem dependencies.
391+
// 1. Scan for deps above `I` for deps to `I`: AboveN->MemN.
392+
if (DAGInterval.top()->comesBefore(I)) {
393+
Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode());
394+
auto SrcInterval = MemDGNodeIntervalBuilder::make(AboveIntvl, *this);
395+
scanAndAddDeps(*MemN, SrcInterval);
396+
}
397+
// 2. Scan for deps below `I` for deps from `I`: MemN->BelowN.
398+
if (I->comesBefore(DAGInterval.bottom())) {
399+
Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom());
400+
for (MemDGNode &BelowN :
401+
MemDGNodeIntervalBuilder::make(BelowIntvl, *this))
402+
scanAndAddDeps(BelowN, Interval<MemDGNode>(MemN, MemN));
403+
}
384404
}
385405
}
386406

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

+36-25
Original file line numberDiff line numberDiff line change
@@ -832,9 +832,10 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
832832
}
833833
}
834834

835+
// Check that the DAG gets updated when we create a new instruction.
835836
TEST_F(DependencyGraphTest, CreateInstrCallback) {
836837
parseIR(C, R"IR(
837-
define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
838+
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %new1, i8 %new2) {
838839
store i8 %v1, ptr %ptr
839840
store i8 %v2, ptr %ptr
840841
store i8 %v3, ptr %ptr
@@ -851,42 +852,52 @@ define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
851852
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
852853
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
853854

854-
// Check new instruction callback.
855855
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
856-
DAG.extend({S1, Ret});
857-
auto *Arg = F->getArg(3);
856+
// Create a DAG spanning S1 to S3.
857+
DAG.extend({S1, S3});
858+
auto *ArgNew1 = F->getArg(4);
859+
auto *ArgNew2 = F->getArg(5);
858860
auto *Ptr = S1->getPointerOperand();
861+
862+
auto *S1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
863+
auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
864+
auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
865+
sandboxir::MemDGNode *New1MemN = nullptr;
866+
sandboxir::MemDGNode *New2MemN = nullptr;
859867
{
868+
// Create a new store before S3 (within the span of the DAG).
860869
sandboxir::StoreInst *NewS =
861-
sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
870+
sandboxir::StoreInst::create(ArgNew1, Ptr, Align(8), S3->getIterator(),
862871
/*IsVolatile=*/true, Ctx);
863-
auto *NewSN = DAG.getNode(NewS);
864-
EXPECT_TRUE(NewSN != nullptr);
865-
866872
// Check the MemDGNode chain.
867-
auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
868-
auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
869-
auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
870-
EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
871-
EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
872-
EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
873-
EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
873+
New1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
874+
EXPECT_EQ(S2MemN->getNextNode(), New1MemN);
875+
EXPECT_EQ(New1MemN->getPrevNode(), S2MemN);
876+
EXPECT_EQ(New1MemN->getNextNode(), S3MemN);
877+
EXPECT_EQ(S3MemN->getPrevNode(), New1MemN);
878+
879+
// Check dependencies.
880+
EXPECT_TRUE(memDependency(S1MemN, New1MemN));
881+
EXPECT_TRUE(memDependency(S2MemN, New1MemN));
882+
EXPECT_TRUE(memDependency(New1MemN, S3MemN));
874883
}
875-
876884
{
877-
// Also check if new node is at the end of the BB, after Ret.
885+
// Create a new store before Ret (outside the current DAG).
878886
sandboxir::StoreInst *NewS =
879-
sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
887+
sandboxir::StoreInst::create(ArgNew2, Ptr, Align(8), Ret->getIterator(),
880888
/*IsVolatile=*/true, Ctx);
881889
// Check the MemDGNode chain.
882-
auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
883-
auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
884-
EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
885-
EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
886-
EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
890+
New2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
891+
EXPECT_EQ(S3MemN->getNextNode(), New2MemN);
892+
EXPECT_EQ(New2MemN->getPrevNode(), S3MemN);
893+
EXPECT_EQ(New2MemN->getNextNode(), nullptr);
894+
895+
// Check dependencies.
896+
EXPECT_TRUE(memDependency(S1MemN, New2MemN));
897+
EXPECT_TRUE(memDependency(S2MemN, New2MemN));
898+
EXPECT_TRUE(memDependency(New1MemN, New2MemN));
899+
EXPECT_TRUE(memDependency(S3MemN, New2MemN));
887900
}
888-
889-
// TODO: Check the dependencies to/from NewSN after they land.
890901
}
891902

892903
TEST_F(DependencyGraphTest, EraseInstrCallback) {

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ define void @foo(i8 %v0) {
8787
EXPECT_FALSE(One.contains(I1));
8888
EXPECT_FALSE(One.contains(I2));
8989
EXPECT_FALSE(One.contains(Ret));
90+
// Check touches().
91+
{
92+
sandboxir::Interval<sandboxir::Instruction> Intvl(I2, I2);
93+
EXPECT_TRUE(Intvl.touches(I1));
94+
EXPECT_FALSE(Intvl.touches(I2));
95+
EXPECT_TRUE(Intvl.touches(Ret));
96+
EXPECT_FALSE(Intvl.touches(I0));
97+
}
9098
// Check iterator.
9199
auto BBIt = BB->begin();
92100
for (auto &I : Intvl)

0 commit comments

Comments
 (0)