diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 6e3f99d78b932..f4e74fdee84c9 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -45,15 +45,6 @@ class DGNode; class MemDGNode; class DependencyGraph; -/// While OpIt points to a Value that is not an Instruction keep incrementing -/// it. \Returns the first iterator that points to an Instruction, or end. -[[nodiscard]] static User::op_iterator skipNonInstr(User::op_iterator OpIt, - User::op_iterator OpItE) { - while (OpIt != OpItE && !isa((*OpIt).get())) - ++OpIt; - return OpIt; -} - /// Iterate over both def-use and mem dependencies. class PredIterator { User::op_iterator OpIt; @@ -72,6 +63,12 @@ class PredIterator { friend class DGNode; // For constructor friend class MemDGNode; // For constructor + /// Skip iterators that don't point instructions or are outside \p DAG, + /// starting from \p OpIt and ending before \p OpItE.n + static User::op_iterator skipBadIt(User::op_iterator OpIt, + User::op_iterator OpItE, + const DependencyGraph &DAG); + public: using difference_type = std::ptrdiff_t; using value_type = DGNode *; @@ -135,8 +132,9 @@ class DGNode { bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); } using iterator = PredIterator; virtual iterator preds_begin(DependencyGraph &DAG) { - return PredIterator(skipNonInstr(I->op_begin(), I->op_end()), I->op_end(), - this, DAG); + return PredIterator( + PredIterator::skipBadIt(I->op_begin(), I->op_end(), DAG), I->op_end(), + this, DAG); } virtual iterator preds_end(DependencyGraph &DAG) { return PredIterator(I->op_end(), I->op_end(), this, DAG); @@ -249,8 +247,8 @@ class MemDGNode final : public DGNode { } iterator preds_begin(DependencyGraph &DAG) override { auto OpEndIt = I->op_end(); - return PredIterator(skipNonInstr(I->op_begin(), OpEndIt), OpEndIt, - MemPreds.begin(), this, DAG); + return PredIterator(PredIterator::skipBadIt(I->op_begin(), OpEndIt, DAG), + OpEndIt, MemPreds.begin(), this, DAG); } iterator preds_end(DependencyGraph &DAG) override { return PredIterator(I->op_end(), I->op_end(), MemPreds.end(), this, DAG); diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index 7aa8794d26b20..e03cf32be0244 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -14,6 +14,18 @@ namespace llvm::sandboxir { +User::op_iterator PredIterator::skipBadIt(User::op_iterator OpIt, + User::op_iterator OpItE, + const DependencyGraph &DAG) { + auto Skip = [&DAG](auto OpIt) { + auto *I = dyn_cast((*OpIt).get()); + return I == nullptr || DAG.getNode(I) == nullptr; + }; + while (OpIt != OpItE && Skip(OpIt)) + ++OpIt; + return OpIt; +} + PredIterator::value_type PredIterator::operator*() { // If it's a DGNode then we dereference the operand iterator. if (!isa(N)) { @@ -35,16 +47,16 @@ PredIterator &PredIterator::operator++() { if (!isa(N)) { assert(OpIt != OpItE && "Already at end!"); ++OpIt; - // Skip operands that are not instructions. - OpIt = skipNonInstr(OpIt, OpItE); + // Skip operands that are not instructions or are outside the DAG. + OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG); return *this; } // It's a MemDGNode, so if we are not at the end of the use-def iterator we // need to first increment that. if (OpIt != OpItE) { ++OpIt; - // Skip operands that are not instructions. - OpIt = skipNonInstr(OpIt, OpItE); + // Skip operands that are not instructions or are outside the DAG. + OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG); return *this; } // It's a MemDGNode with OpIt == end, so we need to increment MemIt. diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp index 9ec5d830d8b4a..e54e74e1bbecd 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp @@ -79,9 +79,6 @@ void Scheduler::scheduleAndUpdateReadyList(SchedBundle &Bndl) { for (DGNode *N : Bndl) { N->setScheduled(true); for (auto *DepN : N->preds(DAG)) { - // TODO: preds() should not return nullptr. - if (DepN == nullptr) - continue; DepN->decrUnscheduledSuccs(); if (DepN->ready()) ReadyList.insert(DepN); diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index 29fc05a7f256a..263a37ac335d2 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -313,6 +313,39 @@ define i8 @foo(i8 %v0, i8 %v1) { EXPECT_EQ(RetN->getNumUnscheduledSuccs(), 0u); } +// Make sure we don't get null predecessors even if they are outside the DAG. +TEST_F(DependencyGraphTest, NonNullPreds) { + parseIR(C, R"IR( +define void @foo(ptr %ptr, i8 %val) { + %gep = getelementptr i8, ptr %ptr, i32 0 + store i8 %val, ptr %gep + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + [[maybe_unused]] auto *GEP = cast(&*It++); + auto *S0 = cast(&*It++); + auto *Ret = cast(&*It++); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); + // The DAG doesn't include GEP. + DAG.extend({S0, Ret}); + + auto *S0N = DAG.getNode(S0); + // S0 has one operand (the GEP) that is outside the DAG and no memory + // predecessors. So pred_begin() should be == pred_end(). + auto PredIt = S0N->preds_begin(DAG); + auto PredItE = S0N->preds_end(DAG); + EXPECT_EQ(PredIt, PredItE); + // Check preds(). + for (auto *PredN : S0N->preds(DAG)) + EXPECT_NE(PredN, nullptr); +} + TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) { parseIR(C, R"IR( define void @foo(ptr %ptr, i8 %v0, i8 %v1) {