Skip to content

Commit

Permalink
Fixes cases where some lattice points doesn't have expressions for th…
Browse files Browse the repository at this point in the history
…e next level

This fixes #1
  • Loading branch information
fredrikbk committed Jun 6, 2017
1 parent cd7b610 commit ed500b4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
15 changes: 8 additions & 7 deletions src/lower/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ static vector<Stmt> lower(const Target& target,
IndexExpr childExpr = lqExpr;
Target childTarget = target;
if (indexVarCase == LAST_FREE || indexVarCase == BELOW_LAST_FREE) {
// Extract the expression to compute at the next level. If there's no
// computation on the next level for this lattice case then skip it
childExpr = getSubExpr(lqExpr, ctx.schedule.getDescendants(child));
if (!childExpr.defined()) continue;

// Reduce child expression into temporary
TensorBase t("t" + child.getName(), ComponentType::Double);
Expr tensorVar = Var::make(t.getName(), Type(Type::Float,64));
Expand All @@ -320,14 +325,10 @@ static vector<Stmt> lower(const Target& target,
caseBody.push_back(VarAssign::make(tensorVar, 0.0, true));
}

// Extract the expression to compute at the next level
childExpr = getSubExpr(lqExpr, ctx.schedule.getDescendants(child));

// Rewrite lqExpr to substitute the expression computed at the next
// level with the temporary
lqExpr = replace(lqExpr, {{childExpr,taco::Access(t)}});
}
taco_iassert(childExpr.defined());
auto childCode = lower::lower(childTarget, childExpr, child, ctx);
util::append(caseBody, childCode);
}
Expand All @@ -353,17 +354,17 @@ static vector<Stmt> lower(const Target& target,
Expr rpos = resultIterator.getPtrVar();
Stmt posInc = VarAssign::make(rpos, Add::make(rpos, 1));

// Conditionally resize idx and pos
// Conditionally resize result `idx` and `pos` arrays
if (emitAssemble) {
Expr resize =
And::make(Eq::make(0, BitAnd::make(Add::make(rpos, 1), rpos)),
Lte::make(ctx.allocSize, Add::make(rpos, 1)));
Expr newSize = ir::Mul::make(2, ir::Add::make(rpos, 1));

// Resize `idx`
// Resize result `idx` array
Stmt resizeIndices = resultIterator.resizeIdxStorage(newSize);

// Resize `pos`
// Resize result `pos` array
if (indexVarCase == ABOVE_LAST_FREE) {
auto nextStep = resultPath.getStep(resultStep.getStep() + 1);
Stmt resizePos = ctx.iterators[nextStep].resizePtrStorage(newSize);
Expand Down
2 changes: 0 additions & 2 deletions test/expr_storage-tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,6 @@ INSTANTIATE_TEST_CASE_P(axpy_3x3, expr,
},
{6, 0, 16}
),
/*
TestData(Tensor<double>("a",{3},Format({Dense})),
{i},
d33a("B",Format({Sparse,Sparse}))(i,k) *
Expand All @@ -687,7 +686,6 @@ INSTANTIATE_TEST_CASE_P(axpy_3x3, expr,
},
{6, 0, 16}
),
*/
TestData(Tensor<double>("a",{3},Format({Dense})),
{i},
da("alpha",Format())() *
Expand Down

0 comments on commit ed500b4

Please sign in to comment.