Skip to content

Commit

Permalink
[SandboxVec][DependencyGraph] Fix dependency node iterators (#125616)
Browse files Browse the repository at this point in the history
This patch fixes a bug in the dependency node iterators that would
incorrectly not skip nodes that are not in the current DAG. This
resulted in iterators returning nullptr when dereferenced.

The fix is to update the existing "skip" function to not only skip
non-instruction values but also to skip instructions not in the DAG.
  • Loading branch information
vporpo authored Feb 6, 2025
1 parent 6dc41a6 commit 788c88e
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,6 @@ class DGNode;
class MemDGNode;
class DependencyGraph;

/// While OpIt points to a Value that is not an Instruction keep incrementing
/// it. \Returns the first iterator that points to an Instruction, or end.
[[nodiscard]] static User::op_iterator skipNonInstr(User::op_iterator OpIt,
User::op_iterator OpItE) {
while (OpIt != OpItE && !isa<Instruction>((*OpIt).get()))
++OpIt;
return OpIt;
}

/// Iterate over both def-use and mem dependencies.
class PredIterator {
User::op_iterator OpIt;
Expand All @@ -72,6 +63,12 @@ class PredIterator {
friend class DGNode; // For constructor
friend class MemDGNode; // For constructor

/// Skip iterators that don't point instructions or are outside \p DAG,
/// starting from \p OpIt and ending before \p OpItE.n
static User::op_iterator skipBadIt(User::op_iterator OpIt,
User::op_iterator OpItE,
const DependencyGraph &DAG);

public:
using difference_type = std::ptrdiff_t;
using value_type = DGNode *;
Expand Down Expand Up @@ -135,8 +132,9 @@ class DGNode {
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
using iterator = PredIterator;
virtual iterator preds_begin(DependencyGraph &DAG) {
return PredIterator(skipNonInstr(I->op_begin(), I->op_end()), I->op_end(),
this, DAG);
return PredIterator(
PredIterator::skipBadIt(I->op_begin(), I->op_end(), DAG), I->op_end(),
this, DAG);
}
virtual iterator preds_end(DependencyGraph &DAG) {
return PredIterator(I->op_end(), I->op_end(), this, DAG);
Expand Down Expand Up @@ -249,8 +247,8 @@ class MemDGNode final : public DGNode {
}
iterator preds_begin(DependencyGraph &DAG) override {
auto OpEndIt = I->op_end();
return PredIterator(skipNonInstr(I->op_begin(), OpEndIt), OpEndIt,
MemPreds.begin(), this, DAG);
return PredIterator(PredIterator::skipBadIt(I->op_begin(), OpEndIt, DAG),
OpEndIt, MemPreds.begin(), this, DAG);
}
iterator preds_end(DependencyGraph &DAG) override {
return PredIterator(I->op_end(), I->op_end(), MemPreds.end(), this, DAG);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@

namespace llvm::sandboxir {

User::op_iterator PredIterator::skipBadIt(User::op_iterator OpIt,
User::op_iterator OpItE,
const DependencyGraph &DAG) {
auto Skip = [&DAG](auto OpIt) {
auto *I = dyn_cast<Instruction>((*OpIt).get());
return I == nullptr || DAG.getNode(I) == nullptr;
};
while (OpIt != OpItE && Skip(OpIt))
++OpIt;
return OpIt;
}

PredIterator::value_type PredIterator::operator*() {
// If it's a DGNode then we dereference the operand iterator.
if (!isa<MemDGNode>(N)) {
Expand All @@ -35,16 +47,16 @@ PredIterator &PredIterator::operator++() {
if (!isa<MemDGNode>(N)) {
assert(OpIt != OpItE && "Already at end!");
++OpIt;
// Skip operands that are not instructions.
OpIt = skipNonInstr(OpIt, OpItE);
// Skip operands that are not instructions or are outside the DAG.
OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG);
return *this;
}
// It's a MemDGNode, so if we are not at the end of the use-def iterator we
// need to first increment that.
if (OpIt != OpItE) {
++OpIt;
// Skip operands that are not instructions.
OpIt = skipNonInstr(OpIt, OpItE);
// Skip operands that are not instructions or are outside the DAG.
OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG);
return *this;
}
// It's a MemDGNode with OpIt == end, so we need to increment MemIt.
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ void Scheduler::scheduleAndUpdateReadyList(SchedBundle &Bndl) {
for (DGNode *N : Bndl) {
N->setScheduled(true);
for (auto *DepN : N->preds(DAG)) {
// TODO: preds() should not return nullptr.
if (DepN == nullptr)
continue;
DepN->decrUnscheduledSuccs();
if (DepN->ready())
ReadyList.insert(DepN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,39 @@ define i8 @foo(i8 %v0, i8 %v1) {
EXPECT_EQ(RetN->getNumUnscheduledSuccs(), 0u);
}

// Make sure we don't get null predecessors even if they are outside the DAG.
TEST_F(DependencyGraphTest, NonNullPreds) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %val) {
%gep = getelementptr i8, ptr %ptr, i32 0
store i8 %val, ptr %gep
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
[[maybe_unused]] auto *GEP = cast<sandboxir::GetElementPtrInst>(&*It++);
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
// The DAG doesn't include GEP.
DAG.extend({S0, Ret});

auto *S0N = DAG.getNode(S0);
// S0 has one operand (the GEP) that is outside the DAG and no memory
// predecessors. So pred_begin() should be == pred_end().
auto PredIt = S0N->preds_begin(DAG);
auto PredItE = S0N->preds_end(DAG);
EXPECT_EQ(PredIt, PredItE);
// Check preds().
for (auto *PredN : S0N->preds(DAG))
EXPECT_NE(PredN, nullptr);
}

TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
Expand Down

0 comments on commit 788c88e

Please sign in to comment.