Skip to content

Commit

Permalink
enable shape inference with hint func (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Nov 22, 2016
1 parent f11f3ef commit 008aef3
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Graph InferAttr(Graph &&ret,
std::vector<AttrType> ishape, oshape;

// inference step function for nid
auto infer_step = [&](uint32_t nid) {
auto infer_step = [&](uint32_t nid, bool last_iter) {
const auto& inode = idx[nid];
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
Expand Down Expand Up @@ -113,16 +113,20 @@ Graph InferAttr(Graph &&ret,
oshape[i] = rshape[idx.entry_id(nid, i)];
if (fis_none(oshape[i])) forward_known = false;
}
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
if (!forward_known) {
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
CHECK(finfer != nullptr)
<< "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name;
// Call inference function of the operator.
try {
forward_known = finfer(inode.source->attrs, &ishape, &oshape);
} catch (const std::exception& e) {
throw dmlc::Error(e.what() + std::string(" with ") + inode.source->attrs.name);
if (finfer != nullptr) {
// Call inference function of the operator.
try {
forward_known = finfer(inode.source->attrs, &ishape, &oshape);
} catch (const std::exception& e) {
throw dmlc::Error(e.what() + std::string(" with ") + inode.source->attrs.name);
}
} else {
CHECK(!last_iter)
<< "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name
<< " we are not able to complete the inference because of this";
}
}
// Save to the result map.
Expand All @@ -140,12 +144,12 @@ Graph InferAttr(Graph &&ret,
for (int i = 0; i < kMaxStep; ++i) {
if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid);
infer_step(nid, i + 1 == kMaxStep);
}
} else {
// backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1);
infer_step(i - 1, i + 1 == kMaxStep);
}
}
num_unknown = 0;
Expand Down

0 comments on commit 008aef3

Please sign in to comment.