diff --git a/src/pass/infer_shape_type.cc b/src/pass/infer_shape_type.cc index a6903ba8b2d5..0332532717e6 100644 --- a/src/pass/infer_shape_type.cc +++ b/src/pass/infer_shape_type.cc @@ -108,7 +108,7 @@ Graph InferAttr(Graph &&ret, uint32_t eid = idx.entry_id(nid, igrad[i].index); if (fis_none(rshape[eid])) { rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; - } else { + } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) << "Backward shape inconsistent with the forward shape"; }