From 008aef31c3b78c9c5cd112696af2e3c345eee582 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 21 Nov 2016 20:30:00 -0800 Subject: [PATCH] enable shape inference with hint func (#84) --- src/pass/infer_shape_type.cc | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/pass/infer_shape_type.cc b/src/pass/infer_shape_type.cc index fd79190ea042..644adb066603 100644 --- a/src/pass/infer_shape_type.cc +++ b/src/pass/infer_shape_type.cc @@ -58,7 +58,7 @@ Graph InferAttr(Graph &&ret, std::vector ishape, oshape; // inference step function for nid - auto infer_step = [&](uint32_t nid) { + auto infer_step = [&](uint32_t nid, bool last_iter) { const auto& inode = idx[nid]; const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_outputs = inode.source->num_outputs(); @@ -113,16 +113,20 @@ Graph InferAttr(Graph &&ret, oshape[i] = rshape[idx.entry_id(nid, i)]; if (fis_none(oshape[i])) forward_known = false; } + auto finfer = finfer_shape.get(inode.source->op(), fdefault); if (!forward_known) { - auto finfer = finfer_shape.get(inode.source->op(), fdefault); - CHECK(finfer != nullptr) - << "Attribute " << infer_name - << " is not registed by op " << inode.source->op()->name; - // Call inference function of the operator. - try { - forward_known = finfer(inode.source->attrs, &ishape, &oshape); - } catch (const std::exception& e) { - throw dmlc::Error(e.what() + std::string(" with ") + inode.source->attrs.name); + if (finfer != nullptr) { + // Call inference function of the operator. + try { + forward_known = finfer(inode.source->attrs, &ishape, &oshape); + } catch (const std::exception& e) { + throw dmlc::Error(e.what() + std::string(" with ") + inode.source->attrs.name); + } + } else { + CHECK(!last_iter) + << "Attribute " << infer_name + << " is not registed by op " << inode.source->op()->name + << " we are not able to complete the inference because of this"; } } // Save to the result map. @@ -140,12 +144,12 @@ Graph InferAttr(Graph &&ret, for (int i = 0; i < kMaxStep; ++i) { if (i % 2 == 0) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { - infer_step(nid); + infer_step(nid, i + 1 == kMaxStep); } } else { // backward inference for (uint32_t i = idx.num_nodes(); i != 0; --i) { - infer_step(i - 1); + infer_step(i - 1, i + 1 == kMaxStep); } } num_unknown = 0;