-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
Copy pathgraph.cc
121 lines (113 loc) · 4.53 KB
/
graph.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
/*!
* Copyright (c) 2016 by Contributors
* \file graph_attr_types.cc
* \brief Graph node data structure.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <limits>
namespace nnvm {
const IndexedGraph& Graph::indexed_graph() const {
if (indexed_graph_ == nullptr) {
indexed_graph_.reset(new IndexedGraph(*this));
}
return *indexed_graph_;
}
// a subgraph should not refer to any nodes with higher level
// where "level" refers to the nested depth of the subgraph
// e.g. the main graph is level 0
// subgraphs of the main graph is level 1
// subgraphs of the subgraphs of the main graph is level 2
static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subgraphs) {
std::vector<const std::vector<nnvm::NodeEntry>*> curr_level;
std::vector<const std::vector<nnvm::NodeEntry>*> next_level;
std::unordered_map<nnvm::Node*, uint32_t> node2level;
for (auto &subgraph : subgraphs)
next_level.push_back(&subgraph->outputs);
for (uint32_t level = 0; !next_level.empty(); ++level) {
curr_level.swap(next_level);
next_level.clear();
for (const std::vector<NodeEntry> *graph_ptr : curr_level) {
const std::vector<NodeEntry> &graph = *graph_ptr;
DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) {
nnvm::Node *node = n.get();
// if the node is visited, but on a different level, then check failed
// if check failed here or before, we stop doing anything, but raise an error
CHECK(!node2level.count(node) || node2level[node] == level)
<< "A subgraph should not depend on the outputs of nodes on higher levels";
// otherwise, this node belongs to the current level
node2level[node] = level;
// subgraphs of current node belongs to next level
for (const auto& subgraph : n->attrs.subgraphs) {
next_level.push_back(&subgraph->outputs);
}
});
}
}
}
// implement constructor from graph
IndexedGraph::IndexedGraph(const Graph &g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
std::vector<std::shared_ptr<Symbol>> subgraphs;
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
(const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
for (const auto &subgraph : n->attrs.subgraphs)
subgraphs.push_back(subgraph);
// nodes_
IndexedGraph::Node new_node;
new_node.source = n.get();
new_node.weak_ref = n;
nodes_.emplace_back(std::move(new_node));
// arg_nodes_
if (n->is_variable()) {
input_nodes_.push_back(nid);
}
// node2index_
node2index_[n.get()] = nid;
// entry rptr
entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs());
// input entries
for (const auto& e : n->inputs) {
auto it = node2index_.find(e.node.get());
CHECK(it != node2index_.end() && it->first == e.node.get());
input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version});
}
inputs_rptr.push_back(input_entries_.size());
// control deps
for (const auto& nptr : n->control_deps) {
auto it = node2index_.find(nptr.get());
CHECK(it != node2index_.end() && it->first == nptr.get());
control_deps_.push_back(it->second);
}
control_rptr.push_back(control_deps_.size());
});
if (!subgraphs.empty())
SubgraphSanityCheck(subgraphs);
for (const auto& e : g.outputs) {
outputs_.emplace_back(NodeEntry{
node2index_.at(e.node.get()), e.index, e.version});
}
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
// setup array view
// input_entries_ and control_rptr must not change after this step.
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs = array_view<NodeEntry>(
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op() != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op())) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
}
}
}
const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].control_deps = array_view<uint32_t>(
cptr + control_rptr[nid], cptr + control_rptr[nid + 1]);
}
}
} // namespace nnvm