Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SandboxVec][DependencyGraph] Fix dependency node iterators #125616

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,6 @@ class DGNode;
class MemDGNode;
class DependencyGraph;

/// While OpIt points to a Value that is not an Instruction keep incrementing
/// it. \Returns the first iterator that points to an Instruction, or end.
[[nodiscard]] static User::op_iterator skipNonInstr(User::op_iterator OpIt,
User::op_iterator OpItE) {
while (OpIt != OpItE && !isa<Instruction>((*OpIt).get()))
++OpIt;
return OpIt;
}

/// Iterate over both def-use and mem dependencies.
class PredIterator {
User::op_iterator OpIt;
Expand All @@ -72,6 +63,12 @@ class PredIterator {
friend class DGNode; // For constructor
friend class MemDGNode; // For constructor

/// Skip iterators that don't point instructions or are outside \p DAG,
/// starting from \p OpIt and ending before \p OpItE.n
static User::op_iterator skipBadIt(User::op_iterator OpIt,
User::op_iterator OpItE,
const DependencyGraph &DAG);

public:
using difference_type = std::ptrdiff_t;
using value_type = DGNode *;
Expand Down Expand Up @@ -135,8 +132,9 @@ class DGNode {
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
using iterator = PredIterator;
virtual iterator preds_begin(DependencyGraph &DAG) {
return PredIterator(skipNonInstr(I->op_begin(), I->op_end()), I->op_end(),
this, DAG);
return PredIterator(
PredIterator::skipBadIt(I->op_begin(), I->op_end(), DAG), I->op_end(),
this, DAG);
}
virtual iterator preds_end(DependencyGraph &DAG) {
return PredIterator(I->op_end(), I->op_end(), this, DAG);
Expand Down Expand Up @@ -249,8 +247,8 @@ class MemDGNode final : public DGNode {
}
iterator preds_begin(DependencyGraph &DAG) override {
auto OpEndIt = I->op_end();
return PredIterator(skipNonInstr(I->op_begin(), OpEndIt), OpEndIt,
MemPreds.begin(), this, DAG);
return PredIterator(PredIterator::skipBadIt(I->op_begin(), OpEndIt, DAG),
OpEndIt, MemPreds.begin(), this, DAG);
}
iterator preds_end(DependencyGraph &DAG) override {
return PredIterator(I->op_end(), I->op_end(), MemPreds.end(), this, DAG);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@

namespace llvm::sandboxir {

User::op_iterator PredIterator::skipBadIt(User::op_iterator OpIt,
User::op_iterator OpItE,
const DependencyGraph &DAG) {
auto Skip = [&DAG](auto OpIt) {
auto *I = dyn_cast<Instruction>((*OpIt).get());
return I == nullptr || DAG.getNode(I) == nullptr;
};
while (OpIt != OpItE && Skip(OpIt))
++OpIt;
return OpIt;
}

PredIterator::value_type PredIterator::operator*() {
// If it's a DGNode then we dereference the operand iterator.
if (!isa<MemDGNode>(N)) {
Expand All @@ -35,16 +47,16 @@ PredIterator &PredIterator::operator++() {
if (!isa<MemDGNode>(N)) {
assert(OpIt != OpItE && "Already at end!");
++OpIt;
// Skip operands that are not instructions.
OpIt = skipNonInstr(OpIt, OpItE);
// Skip operands that are not instructions or are outside the DAG.
OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG);
return *this;
}
// It's a MemDGNode, so if we are not at the end of the use-def iterator we
// need to first increment that.
if (OpIt != OpItE) {
++OpIt;
// Skip operands that are not instructions.
OpIt = skipNonInstr(OpIt, OpItE);
// Skip operands that are not instructions or are outside the DAG.
OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG);
return *this;
}
// It's a MemDGNode with OpIt == end, so we need to increment MemIt.
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ void Scheduler::scheduleAndUpdateReadyList(SchedBundle &Bndl) {
for (DGNode *N : Bndl) {
N->setScheduled(true);
for (auto *DepN : N->preds(DAG)) {
// TODO: preds() should not return nullptr.
if (DepN == nullptr)
continue;
DepN->decrUnscheduledSuccs();
if (DepN->ready())
ReadyList.insert(DepN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,39 @@ define i8 @foo(i8 %v0, i8 %v1) {
EXPECT_EQ(RetN->getNumUnscheduledSuccs(), 0u);
}

// Make sure we don't get null predecessors even if they are outside the DAG.
TEST_F(DependencyGraphTest, NonNullPreds) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %val) {
%gep = getelementptr i8, ptr %ptr, i32 0
store i8 %val, ptr %gep
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
[[maybe_unused]] auto *GEP = cast<sandboxir::GetElementPtrInst>(&*It++);
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
// The DAG doesn't include GEP.
DAG.extend({S0, Ret});

auto *S0N = DAG.getNode(S0);
// S0 has one operand (the GEP) that is outside the DAG and no memory
// predecessors. So pred_begin() should be == pred_end().
auto PredIt = S0N->preds_begin(DAG);
auto PredItE = S0N->preds_end(DAG);
EXPECT_EQ(PredIt, PredItE);
// Check preds().
for (auto *PredN : S0N->preds(DAG))
EXPECT_NE(PredN, nullptr);
}

TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
Expand Down