Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

DFS switched to backorder(by linmin) #70

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/mxnet/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ struct Resource {
};
/*! \brief pointer to the resource */
void *ptr;
/*! \brief engine variable */
void *var;
};

/*!
Expand Down
35 changes: 35 additions & 0 deletions src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) {
}
}

for (const Resource& r : op_node.op_ctx.requested) {
exec.mutate_vars.push_back(static_cast<DAGEngine::Variable>(r.var));
}
// start setup exec function.
Operator* op = op_node.op.get();
OpContext* op_ctx_ptr = &op_node.op_ctx;
Expand Down Expand Up @@ -398,6 +401,20 @@ void GraphExecutor::InitDataEntryMemory() {
out->type = kInternalAllocated;
}
}
// resource
const std::vector<ResourceRequest>& reqs = GetResource(nid);
op_nodes_[nid].resources.resize(reqs.size());
op_nodes_[nid].op_ctx.requested.resize(reqs.size());
for (uint32_t i = 0; i < reqs.size(); ++i) {
op_nodes_[nid].resources[i].req = reqs[i];
}
// allocate resource
for (ResourceEntry& entry : op_nodes_[nid].resources) {
if (entry.req.type == Resource::kTempSpace) {
entry.storage_id =
allocator.Request(op_nodes_[nid].ctx, mshadow::Shape1(entry.req.space_size), nid);
}
}
// then free inputs
for (DataEntryInfo *in : in_data) {
// temp_ref_count == 0 means it is taken by inplace op
Expand All @@ -417,6 +434,12 @@ void GraphExecutor::InitDataEntryMemory() {
allocator.Release(out->storage_id, nid);
}
}
// release the resource, as soon as the forward is finished we can release it.
for (ResourceEntry& res : op_nodes_[nid].resources) {
if (res.req.type == Resource::kTempSpace) {
allocator.Release(res.storage_id, nid);
}
}
}
// one pass complete, allocate real memory
allocator.InitStorages();
Expand All @@ -430,6 +453,18 @@ void GraphExecutor::InitDataEntryMemory() {
out.data = allocator.Get(out.storage_id, out.shape);
}
}
// get the pointer to the tempspace
std::vector<Resource>& resources = op_nodes_[nid].op_ctx.requested;
for (uint32_t i = 0; i < resources.size(); ++i) {
ResourceEntry& entry = op_nodes_[nid].resources[i];
if (entry.req.type == Resource::kTempSpace) {
entry.data = allocator.Get(entry.storage_id,
mshadow::Shape1(entry.req.space_size));
}
entry.tblob = entry.data.data();
resources[i].ptr = &entry.tblob;
resources[i].var = static_cast<void*>(entry.data.var());
}
}
for (StaticGraph::DataEntry e : graph_.heads) {
DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index];
Expand Down
13 changes: 13 additions & 0 deletions src/symbol/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ class GraphExecutor : public Executor {
storage_id(GraphStorageAllocator::kBadStorageID),
temp_ref_count(0), ref_count(0) {}
};
// information of the resource
struct ResourceEntry {
/*! \brief actual data for the entry if it is a temp space */
NArray data;
/*! \brief tblob (this is quite ugly) */
TBlob tblob;
/*! \brief the resource request */
ResourceRequest req;
// storage id from allocator if it is a temp space
GraphStorageAllocator::StorageID storage_id;
};
// all the information needed to push the op to engine
struct OpExecEntry {
// execution function for
Expand All @@ -111,6 +122,8 @@ class GraphExecutor : public Executor {
std::vector<DataEntryInfo> outputs;
// auxiliary data information of op
std::vector<DataEntryInfo> aux_states;
// resource entry
std::vector<ResourceEntry> resources;
// The following parts are constructed in InitOpNodes
// the real operator
std::shared_ptr<Operator> op;
Expand Down
44 changes: 27 additions & 17 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

namespace mxnet {
std::vector<uint32_t> StaticGraph::TopoSort() const {
std::vector<std::pair<uint32_t, uint32_t> > stack;
std::unordered_set<uint32_t> visited;
std::vector<uint32_t> ret(nodes.size());
std::vector<uint32_t> head_node;
// out degree
std::vector<int> out_degree(nodes.size(), 0);
for (const Node& n : nodes) {
for (const DataEntry& e : n.inputs) {
Expand All @@ -21,28 +26,33 @@ std::vector<uint32_t> StaticGraph::TopoSort() const {
++out_degree[n.backward_source_id];
}
}
std::vector<uint32_t> ret(nodes.size());
auto result = ret.rbegin();
std::queue<uint32_t> queue;
for (size_t i = 0; i < nodes.size(); ++i) {
if (out_degree[i] == 0) {
queue.push(static_cast<uint32_t>(i));
stack.push_back(std::make_pair(static_cast<uint32_t>(i), 0));
}
}
while (!queue.empty()) {
uint32_t node_id = queue.front();
queue.pop();
*result = node_id;
++result;
const Node& n = nodes[node_id];
for (const DataEntry& e : n.inputs) {
if (--out_degree[e.source_id] == 0) {
queue.push(e.source_id);
// heads
for (auto &head : head_node) {
stack.push_back(std::make_pair(head, 0));
}
int count = 0;
while (!stack.empty()) {
std::pair<uint32_t, uint32_t>& back = stack.back();
const Node& n = nodes[back.first];
if (back.second == n.inputs.size() + (n.is_backward() ? 1 : 0)) {
ret[count++] = back.first;
visited.insert(back.first);
stack.pop_back();
} else {
uint32_t input;
if (back.second == n.inputs.size() && n.is_backward()) {
input = n.backward_source_id;
back.second++;
} else {
input = n.inputs[back.second++].source_id;
}
}
if (n.is_backward()) {
if (--out_degree[n.backward_source_id] == 0) {
queue.push(n.backward_source_id);
if (visited.count(input) == 0) {
stack.push_back(std::make_pair(input, 0));
}
}
}
Expand Down
23 changes: 12 additions & 11 deletions src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,26 @@ inline bool Symbol::is_atomic() const {
// implementation of template functions
template<typename FVisit>
inline void Symbol::DFSVisit(FVisit fvisit) const {
std::vector<const std::shared_ptr<Node>*> stack;
std::vector<std::pair<const std::shared_ptr<Node>*, uint32_t> > stack;
std::unordered_set<Node*> visited;
// put the head into the graph
for (auto &head : heads_) {
Node *ptr = head.source.get();
if (visited.count(ptr) == 0) {
stack.push_back(&head.source);
visited.insert(ptr);
stack.push_back(std::make_pair(&head.source, 0));
}
}
while (!stack.empty()) {
const std::shared_ptr<Node> *back = stack.back();
stack.pop_back();
fvisit(*back);
for (auto it = back->get()->inputs.rbegin(); it != back->get()->inputs.rend(); ++it) {
Node *ptr = it->source.get();
if (visited.count(ptr) == 0) {
stack.push_back(&it->source);
visited.insert(ptr);
std::pair<const std::shared_ptr<Node> *, uint32_t>& back = stack.back();
if (back.second == back.first->get()->inputs.size()) {
fvisit(*(back.first));
visited.insert(back.first->get());
stack.pop_back();
} else {
std::vector<Symbol::DataEntry>& inputs = back.first->get()->inputs;
Symbol::DataEntry& input = inputs.at(back.second++);
if (visited.count(input.source.get()) == 0) {
stack.push_back(std::make_pair(&input.source, 0));
}
}
}
Expand Down