From ed500b474a9757bbe04c816e88bf3c2b37b22ecf Mon Sep 17 00:00:00 2001 From: Fredrik Kjolstad Date: Tue, 6 Jun 2017 10:00:37 -0400 Subject: [PATCH] Fixes cases where some lattice points doesn't have expressions for the next level This fixes #1 --- src/lower/lower.cpp | 15 ++++++++------- test/expr_storage-tests.cpp | 2 -- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/lower/lower.cpp b/src/lower/lower.cpp index f2c791ee9..5a05b0dde 100644 --- a/src/lower/lower.cpp +++ b/src/lower/lower.cpp @@ -310,6 +310,11 @@ static vector lower(const Target& target, IndexExpr childExpr = lqExpr; Target childTarget = target; if (indexVarCase == LAST_FREE || indexVarCase == BELOW_LAST_FREE) { + // Extract the expression to compute at the next level. If there's no + // computation on the next level for this lattice case then skip it + childExpr = getSubExpr(lqExpr, ctx.schedule.getDescendants(child)); + if (!childExpr.defined()) continue; + // Reduce child expression into temporary TensorBase t("t" + child.getName(), ComponentType::Double); Expr tensorVar = Var::make(t.getName(), Type(Type::Float,64)); @@ -320,14 +325,10 @@ static vector lower(const Target& target, caseBody.push_back(VarAssign::make(tensorVar, 0.0, true)); } - // Extract the expression to compute at the next level - childExpr = getSubExpr(lqExpr, ctx.schedule.getDescendants(child)); - // Rewrite lqExpr to substitute the expression computed at the next // level with the temporary lqExpr = replace(lqExpr, {{childExpr,taco::Access(t)}}); } - taco_iassert(childExpr.defined()); auto childCode = lower::lower(childTarget, childExpr, child, ctx); util::append(caseBody, childCode); } @@ -353,17 +354,17 @@ static vector lower(const Target& target, Expr rpos = resultIterator.getPtrVar(); Stmt posInc = VarAssign::make(rpos, Add::make(rpos, 1)); - // Conditionally resize idx and pos + // Conditionally resize result `idx` and `pos` arrays if (emitAssemble) { Expr resize = And::make(Eq::make(0, BitAnd::make(Add::make(rpos, 1), rpos)), Lte::make(ctx.allocSize, Add::make(rpos, 1))); Expr newSize = ir::Mul::make(2, ir::Add::make(rpos, 1)); - // Resize `idx` + // Resize result `idx` array Stmt resizeIndices = resultIterator.resizeIdxStorage(newSize); - // Resize `pos` + // Resize result `pos` array if (indexVarCase == ABOVE_LAST_FREE) { auto nextStep = resultPath.getStep(resultStep.getStep() + 1); Stmt resizePos = ctx.iterators[nextStep].resizePtrStorage(newSize); diff --git a/test/expr_storage-tests.cpp b/test/expr_storage-tests.cpp index 09ed03db7..191f20509 100644 --- a/test/expr_storage-tests.cpp +++ b/test/expr_storage-tests.cpp @@ -673,7 +673,6 @@ INSTANTIATE_TEST_CASE_P(axpy_3x3, expr, }, {6, 0, 16} ), -/* TestData(Tensor("a",{3},Format({Dense})), {i}, d33a("B",Format({Sparse,Sparse}))(i,k) * @@ -687,7 +686,6 @@ INSTANTIATE_TEST_CASE_P(axpy_3x3, expr, }, {6, 0, 16} ), -*/ TestData(Tensor("a",{3},Format({Dense})), {i}, da("alpha",Format())() *