Skip to content

Commit 063dd59

Browse files
committed
[SandboxVec][DependencyGraph] Fix dependency node iterators
This patch fixes a bug in the dependency node iterators that would incorrectly not skip nodes that are not in the current DAG. This resulted in iterators returning nullptr when dereferenced. The fix is to update the existing "skip" function to not only skip non-instruction values but also to skip instructions not in the DAG.
1 parent c5f99e1 commit 063dd59

File tree

4 files changed

+62
-20
lines changed

4 files changed

+62
-20
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

+11-13
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,6 @@ class DGNode;
4545
class MemDGNode;
4646
class DependencyGraph;
4747

48-
/// While OpIt points to a Value that is not an Instruction keep incrementing
49-
/// it. \Returns the first iterator that points to an Instruction, or end.
50-
[[nodiscard]] static User::op_iterator skipNonInstr(User::op_iterator OpIt,
51-
User::op_iterator OpItE) {
52-
while (OpIt != OpItE && !isa<Instruction>((*OpIt).get()))
53-
++OpIt;
54-
return OpIt;
55-
}
56-
5748
/// Iterate over both def-use and mem dependencies.
5849
class PredIterator {
5950
User::op_iterator OpIt;
@@ -72,6 +63,12 @@ class PredIterator {
7263
friend class DGNode; // For constructor
7364
friend class MemDGNode; // For constructor
7465

66+
/// Skip iterators that don't point instructions or are outside \p DAG,
67+
/// starting from \p OpIt and ending before \p OpItE.n
68+
static User::op_iterator skipBadIt(User::op_iterator OpIt,
69+
User::op_iterator OpItE,
70+
const DependencyGraph &DAG);
71+
7572
public:
7673
using difference_type = std::ptrdiff_t;
7774
using value_type = DGNode *;
@@ -135,8 +132,9 @@ class DGNode {
135132
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
136133
using iterator = PredIterator;
137134
virtual iterator preds_begin(DependencyGraph &DAG) {
138-
return PredIterator(skipNonInstr(I->op_begin(), I->op_end()), I->op_end(),
139-
this, DAG);
135+
return PredIterator(
136+
PredIterator::skipBadIt(I->op_begin(), I->op_end(), DAG), I->op_end(),
137+
this, DAG);
140138
}
141139
virtual iterator preds_end(DependencyGraph &DAG) {
142140
return PredIterator(I->op_end(), I->op_end(), this, DAG);
@@ -249,8 +247,8 @@ class MemDGNode final : public DGNode {
249247
}
250248
iterator preds_begin(DependencyGraph &DAG) override {
251249
auto OpEndIt = I->op_end();
252-
return PredIterator(skipNonInstr(I->op_begin(), OpEndIt), OpEndIt,
253-
MemPreds.begin(), this, DAG);
250+
return PredIterator(PredIterator::skipBadIt(I->op_begin(), OpEndIt, DAG),
251+
OpEndIt, MemPreds.begin(), this, DAG);
254252
}
255253
iterator preds_end(DependencyGraph &DAG) override {
256254
return PredIterator(I->op_end(), I->op_end(), MemPreds.end(), this, DAG);

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,20 @@
1414

1515
namespace llvm::sandboxir {
1616

17+
User::op_iterator PredIterator::skipBadIt(User::op_iterator OpIt,
18+
User::op_iterator OpItE,
19+
const DependencyGraph &DAG) {
20+
auto Skip = [&DAG](auto OpIt) {
21+
auto *I = dyn_cast<Instruction>((*OpIt).get());
22+
if (I == nullptr)
23+
return true;
24+
return DAG.getNode(I) == nullptr;
25+
};
26+
while (OpIt != OpItE && Skip(OpIt))
27+
++OpIt;
28+
return OpIt;
29+
}
30+
1731
PredIterator::value_type PredIterator::operator*() {
1832
// If it's a DGNode then we dereference the operand iterator.
1933
if (!isa<MemDGNode>(N)) {
@@ -35,16 +49,16 @@ PredIterator &PredIterator::operator++() {
3549
if (!isa<MemDGNode>(N)) {
3650
assert(OpIt != OpItE && "Already at end!");
3751
++OpIt;
38-
// Skip operands that are not instructions.
39-
OpIt = skipNonInstr(OpIt, OpItE);
52+
// Skip operands that are not instructions or are outside the DAG.
53+
OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG);
4054
return *this;
4155
}
4256
// It's a MemDGNode, so if we are not at the end of the use-def iterator we
4357
// need to first increment that.
4458
if (OpIt != OpItE) {
4559
++OpIt;
46-
// Skip operands that are not instructions.
47-
OpIt = skipNonInstr(OpIt, OpItE);
60+
// Skip operands that are not instructions or are outside the DAG.
61+
OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG);
4862
return *this;
4963
}
5064
// It's a MemDGNode with OpIt == end, so we need to increment MemIt.

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ void Scheduler::scheduleAndUpdateReadyList(SchedBundle &Bndl) {
7979
for (DGNode *N : Bndl) {
8080
N->setScheduled(true);
8181
for (auto *DepN : N->preds(DAG)) {
82-
// TODO: preds() should not return nullptr.
83-
if (DepN == nullptr)
84-
continue;
8582
DepN->decrUnscheduledSuccs();
8683
if (DepN->ready())
8784
ReadyList.insert(DepN);

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,39 @@ define i8 @foo(i8 %v0, i8 %v1) {
313313
EXPECT_EQ(RetN->getNumUnscheduledSuccs(), 0u);
314314
}
315315

316+
// Make sure we don't get null predecessors even if they are outside the DAG.
317+
TEST_F(DependencyGraphTest, NonNullPreds) {
318+
parseIR(C, R"IR(
319+
define void @foo(ptr %ptr, i8 %val) {
320+
%gep = getelementptr i8, ptr %ptr, i32 0
321+
store i8 %val, ptr %gep
322+
ret void
323+
}
324+
)IR");
325+
llvm::Function *LLVMF = &*M->getFunction("foo");
326+
sandboxir::Context Ctx(C);
327+
auto *F = Ctx.createFunction(LLVMF);
328+
auto *BB = &*F->begin();
329+
auto It = BB->begin();
330+
[[maybe_unused]] auto *GEP = cast<sandboxir::GetElementPtrInst>(&*It++);
331+
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
332+
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
333+
334+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
335+
// The DAG doesn't include GEP.
336+
DAG.extend({S0, Ret});
337+
338+
auto *S0N = DAG.getNode(S0);
339+
// S0 has one operand (the GEP) that is outside the DAG and no memory
340+
// predecessors. So pred_begin() should be == pred_end().
341+
auto PredIt = S0N->preds_begin(DAG);
342+
auto PredItE = S0N->preds_end(DAG);
343+
EXPECT_EQ(PredIt, PredItE);
344+
// Check preds().
345+
for (auto *PredN : S0N->preds(DAG))
346+
EXPECT_NE(PredN, nullptr);
347+
}
348+
316349
TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
317350
parseIR(C, R"IR(
318351
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {

0 commit comments

Comments
 (0)