Skip to content

Commit

Permalink
Some lowering refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikbk committed Jun 5, 2017
1 parent 049bc54 commit cd7b610
Showing 1 changed file with 27 additions and 31 deletions.
58 changes: 27 additions & 31 deletions src/lower/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,46 +347,42 @@ static vector<Stmt> lower(const Target& target,
}
}

// Emit code to increment the results iterator variable
// Emit code to increment the result `pos` variable and to allocate
// additional storage for result `idx` and `pos` arrays
if (resultIterator.defined() && resultIterator.isSequentialAccess()) {
Expr resultPos = resultIterator.getPtrVar();
Stmt posInc = VarAssign::make(resultPos, Add::make(resultPos, 1));

Expr doResize = ir::And::make(
Eq::make(0, BitAnd::make(Add::make(resultPos, 1), resultPos)),
Lte::make(ctx.allocSize, Add::make(resultPos, 1)));
Expr newSize = ir::Mul::make(2, ir::Add::make(resultPos, 1));
Stmt resizeIndices = resultIterator.resizeIdxStorage(newSize);

if (resultStep != resultPath.getLastStep()) {
// Emit code to resize idx and pos
if (emitAssemble) {
auto nextStep = resultPath.getStep(resultStep.getStep()+1);
Iterator iterNext = ctx.iterators[nextStep];
Stmt resizePos = iterNext.resizePtrStorage(newSize);
if (resizePos.defined()) {
resizeIndices = Block::make({resizeIndices, resizePos});
}
resizeIndices = IfThenElse::make(doResize, resizeIndices);
posInc = Block::make({posInc, resizeIndices});
Expr rpos = resultIterator.getPtrVar();
Stmt posInc = VarAssign::make(rpos, Add::make(rpos, 1));

// Conditionally resize idx and pos
if (emitAssemble) {
Expr resize =
And::make(Eq::make(0, BitAnd::make(Add::make(rpos, 1), rpos)),
Lte::make(ctx.allocSize, Add::make(rpos, 1)));
Expr newSize = ir::Mul::make(2, ir::Add::make(rpos, 1));

// Resize `idx`
Stmt resizeIndices = resultIterator.resizeIdxStorage(newSize);

// Resize `pos`
if (indexVarCase == ABOVE_LAST_FREE) {
auto nextStep = resultPath.getStep(resultStep.getStep() + 1);
Stmt resizePos = ctx.iterators[nextStep].resizePtrStorage(newSize);
resizeIndices = Block::make({resizeIndices, resizePos});
}
posInc = Block::make({posInc,IfThenElse::make(resize,resizeIndices)});
}

// Only increment `pos` if values were produced at the next level
if (indexVarCase == ABOVE_LAST_FREE) {
int step = resultStep.getStep() + 1;
string resultTensorName = resultIterator.getTensor().as<Var>()->name;
string posArrName = resultTensorName + util::toString(step) +
"_pos_arr";
string posArrName = resultTensorName + to_string(step) + "_pos_arr";
Expr posArr = GetProperty::make(resultIterator.getTensor(),
TensorProperty::Indices,
step, 0, posArrName);

Expr producedVals =
Gt::make(Load::make(posArr, Add::make(resultPos,1)),
Load::make(posArr, resultPos));
Expr producedVals = Gt::make(Load::make(posArr, Add::make(rpos,1)),
Load::make(posArr, rpos));
posInc = IfThenElse::make(producedVals, posInc);
} else if (emitAssemble) {
// Emit code to resize idx (at result store loop nest)
resizeIndices = IfThenElse::make(doResize, resizeIndices);
posInc = Block::make({posInc, resizeIndices});
}
util::append(caseBody, {posInc});
}
Expand Down

0 comments on commit cd7b610

Please sign in to comment.