Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Add feature of retain_grad #20500

Merged
merged 36 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8db2ad6
Replace "CloneGradient" with "ElemwiseGradUseNone"
KexinFeng Jun 23, 2021
f7ab33d
fix issue elemwise_add
KexinFeng Jun 23, 2021
0ae6e9b
fix elemwise_add issue with `ElemwiseGradUseNone`
KexinFeng Jun 23, 2021
3f4f93c
reverse_to_CloneGradient
KexinFeng Aug 8, 2021
6117ec0
add_retain_grad
KexinFeng Aug 8, 2021
aeefab9
unit_test
KexinFeng Aug 9, 2021
3b8a02b
tidy_up
KexinFeng Aug 9, 2021
d909216
tidy_up
KexinFeng Aug 9, 2021
6fa5076
sanity
KexinFeng Aug 11, 2021
b00528a
const_reference
KexinFeng Aug 11, 2021
9bd0881
const_ref
KexinFeng Aug 11, 2021
cfea958
Merge branch 'apache:master' into rg_branch
KexinFeng Aug 12, 2021
4593e6f
merge_rg_to_ag
KexinFeng Aug 12, 2021
3970e4b
Merge remote-tracking branch 'upstream/master' into rg_branch
KexinFeng Aug 12, 2021
6ea449f
Merge branch 'rg_branch' of https://github.com/KexinFeng/incubator-mx…
KexinFeng Aug 12, 2021
d77c2f3
sanity
KexinFeng Aug 12, 2021
f7cbc7d
sanity
KexinFeng Aug 12, 2021
e82a49c
add_drop_grad
KexinFeng Aug 12, 2021
4af58f8
sanity_check
KexinFeng Aug 12, 2021
5d8fd98
sanity_check
KexinFeng Aug 12, 2021
b148c38
sanity_check
KexinFeng Aug 12, 2021
f0a7286
Merge remote-tracking branch 'upstream/master' into rg_branch
KexinFeng Aug 13, 2021
2d39523
Merge remote-tracking branch 'upstream/master' into rg_branch
KexinFeng Aug 13, 2021
db048b9
build_err
KexinFeng Aug 13, 2021
2748a33
build_err
KexinFeng Aug 13, 2021
9d9d8cd
skip_remark_variable
KexinFeng Aug 13, 2021
ad46c11
repetitive_mark
KexinFeng Aug 14, 2021
85ccb80
ReInit_in_dropgrad
KexinFeng Aug 14, 2021
e357fbb
ReInit_in_dropgrad
KexinFeng Aug 14, 2021
042a9e7
sanity_check
KexinFeng Aug 14, 2021
6de23de
add drop and tests to gluon
KexinFeng Aug 17, 2021
0c25c94
Merge remote-tracking branch 'upstream/master' into rg_branch
KexinFeng Aug 17, 2021
283b7ac
sanity
KexinFeng Aug 17, 2021
7716929
Merge remote-tracking branch 'upstream/master' into rg_branch
KexinFeng Aug 18, 2021
ec036de
Merge branch 'master' into rg_branch
barry-jin Sep 17, 2021
f316a4f
update exec_pass.h
barry-jin Sep 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,14 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
NDArrayHandle *var_handles,
uint32_t *reqs_array,
NDArrayHandle *grad_handles);
/*!
* \brief unmark nonleaf NDArrays to free the memory
* \param num_var number of variable NDArrays
* \param var_handles variable NDArrays
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXAutogradDropGrads(uint32_t num_var,
NDArrayHandle *var_handles);
/*!
* \brief compute the gradient of outputs w.r.t variabels
* \param num_output number of output NDArray
Expand Down
4 changes: 4 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,16 @@ class Imperative {
void MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<uint32_t>& grad_reqs,
const std::vector<NDArray*>& gradients);
/*! \brief unmark nonleaf variables to free the memory. */
void DropGrads(const std::vector<NDArray*>& variables);
/*! \brief compute the gradient of outputs w.r.t variables. */
std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,
const std::vector<NDArray*>& ograds,
const std::vector<NDArray*>& variables,
bool is_train, bool retain_graph,
bool create_graph);
/*! \brief Return the marked nonleaf nodes. */
std::vector<nnvm::ObjectPtr> ListNonleafVariables(const nnvm::Symbol& sym) const;
/*! \return AutogradRuntime singleton */
static Imperative* Get();
/*! \brief Should op execution bulking be employed during inference. */
Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2885,6 +2885,11 @@ def attach_grad(self, grad_req='write', stype=None):
ctypes.pointer(mx_uint(grad_req)),
ctypes.pointer(grad.handle)))

def drop_grad(self):
KexinFeng marked this conversation as resolved.
Show resolved Hide resolved
"""Free the memory of the marked ndarray."""
check_call(_LIB.MXAutogradDropGrads(
1, ctypes.pointer(self.handle)))

@property
def grad(self):
"""Returns gradient buffer attached to this NDArray."""
Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,11 @@ def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ
ctypes.pointer(mx_uint(grad_req)),
ctypes.pointer(grad.handle)))

def drop_grad(self):
"""Free the memory of the marked ndarray."""
check_call(_LIB.MXAutogradDropGrads(
1, ctypes.pointer(self.handle)))

@property
def grad(self):
"""Returns gradient buffer attached to this ndarray."""
Expand Down
12 changes: 12 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,18 @@ int MXAutogradMarkVariables(uint32_t num_var,
API_END();
}

int MXAutogradDropGrads(uint32_t num_var,
NDArrayHandle *var_handles) {
API_BEGIN();
std::vector<NDArray*> variables;
variables.reserve(num_var);
for (uint32_t i = 0; i < num_var; ++i) {
variables.emplace_back(static_cast<NDArray*>(var_handles[i]));
}
Imperative::Get()->DropGrads(variables);
API_END();
}

int MXAutogradComputeGradient(uint32_t num_output, NDArrayHandle* output_handles) {
return MXAutogradBackward(num_output, output_handles, nullptr, 0);
}
Expand Down
4 changes: 3 additions & 1 deletion src/imperative/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,14 @@ inline Graph MXGradient(
std::vector<const Op*> zero_ops = std::vector<const Op*>(),
std::string copy_op_str = std::string(),
mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(),
DTypeVector in_arg_dtypes = DTypeVector()) {
DTypeVector in_arg_dtypes = DTypeVector(),
std::vector<NodeEntry> us = std::vector<NodeEntry>() ) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
graph.attrs["in_arg_shapes"] = std::make_shared<any>(std::move(in_arg_shapes));
graph.attrs["in_arg_dtypes"] = std::make_shared<any>(std::move(in_arg_dtypes));
graph.attrs["grad_us"] = std::make_shared<any>(std::move(us));

if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
Expand Down
101 changes: 83 additions & 18 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,29 +142,54 @@ void Imperative::MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<uint32_t>& grad_reqs,
const std::vector<NDArray*>& gradients) {
for (uint32_t i = 0; i < variables.size(); ++i) {
std::string str_c(std::to_string(variable_count_++));

variables[i]->autograd_entry_ =
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node);
info.outputs.emplace_back(variables[i]->Detach());
info.out_grads.emplace_back(gradients[i]->Detach());
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
info.ctx = variables[i]->ctx();

gradients[i]->autograd_entry_ =
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node);
grad_info.outputs.emplace_back(gradients[i]->Detach());
grad_info.ctx = gradients[i]->ctx();
// Unmarked leaf nodes have null autograd_entry_, while marked nonleaf nodes don't.
if (!variables[i]->autograd_entry_.node || variables[i]->autograd_entry_.node->is_variable()) {
std::string str_c(std::to_string(variable_count_++));
variables[i]->autograd_entry_ =
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node);
info.outputs.emplace_back(variables[i]->Detach());
info.out_grads.emplace_back(gradients[i]->Detach());
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
info.ctx = variables[i]->ctx();

gradients[i]->autograd_entry_ =
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node);
grad_info.outputs.emplace_back(gradients[i]->Detach());
grad_info.ctx = gradients[i]->ctx();
} else {
AGInfo& info = AGInfo::Get(variables[i]->autograd_entry_.node);
CHECK_EQ(info.out_grads.size(), 0)
<<"The node has already been marked. Cannot mark it again.";
info.out_grads.emplace_back(gradients[i]->Detach());
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
info.ctx = variables[i]->ctx();
}
}
}

// Unmark the variables to free the memory.
void Imperative::DropGrads(const std::vector<NDArray*>& variables) {
for (auto variable : variables) {
if (variable->autograd_entry_.node) {
AGInfo& info = AGInfo::Get(variable->autograd_entry_.node);
CHECK_NE(info.out_grads.size(), 0)
<<"The node has empty out_grads already. Cannot DropGrads again.";
for (auto grad : info.out_grads) {
grad.ReInit();
}
info.out_grads.clear();
info.grad_req = kNullOp;
}
}
}

void Imperative::GetBackwardDependency(const nnvm::ObjectPtr& node,
uint32_t num_inputs,
uint32_t num_outputs,
std::vector<bool>* p_save_inputs,
std::vector<bool>* p_save_outputs) {
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs) {
static auto& fgradient = nnvm::Op::GetAttr<nnvm::FGradient>("FGradient");
std::vector<bool>& save_inputs = *p_save_inputs;
std::vector<bool>& save_outputs = *p_save_outputs;
Expand Down Expand Up @@ -488,6 +513,12 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
}
CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients.";
}
std::vector<ObjectPtr> nleaf_vars = ListNonleafVariables(sym);
std::vector<NodeEntry> us;
us.reserve(nleaf_vars.size());
for (const auto& i : nleaf_vars) {
us.emplace_back(NodeEntry{i, 0, 0});
}

Graph g_graph = pass::MXGradient(graph,
graph.outputs,
Expand All @@ -496,7 +527,10 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
mxnet::AggregateGradient,
nullptr,
zero_ops,
"_copy");
"_copy",
ShapeVector(),
DTypeVector(),
us);
CHECK_EQ(g_graph.outputs.size(), xs.size());
for (const auto& e : g_graph.outputs) {
if (e.node->op() == nullptr) {
Expand Down Expand Up @@ -575,6 +609,20 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
arrays[eid] = x_grads[i - num_forward_outputs];
ref_count[eid] = 1;
}
const std::vector<NodeEntry>& us_grads =
g_graph.GetAttr<std::vector<NodeEntry>>("nleaf_grads");
CHECK_EQ(us_grads.size(), us.size())
<< "Size of queried nleaf_vars and size of their gradients don't match.";
for (size_t i = 0; i < us_grads.size(); i++) {
size_t eid = idx.entry_id(us_grads[i]);
AGInfo& info = AGInfo::Get(us[i].node);
if (arrays[eid]->dtype_ == -1) {
arrays[eid] = &info.out_grads[0];
} else {
info.out_grads[0] = *arrays[eid];
}
ref_count[eid] = 1;
}

// Assign context
auto vctx = PlaceDevice(idx);
Expand Down Expand Up @@ -627,6 +675,11 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
size_t eid = idx.entry_id(idx.outputs()[i]);
array_reqs[eid] = x_reqs[i - num_forward_outputs];
}
for (size_t i = 0; i < us_grads.size(); i++) {
size_t eid = idx.entry_id(us_grads[i]);
AGInfo& info = AGInfo::Get(us[i].node);
array_reqs[eid] = info.grad_req;
}

const auto& shapes = graph.GetAttr<mxnet::ShapeVector>("shape");
const auto& dtypes = graph.GetAttr<DTypeVector>("dtype");
Expand Down Expand Up @@ -766,4 +819,16 @@ void Imperative::DCInfo::Compute(const NDArray& arr) {
info.outputs_.clear();
}

std::vector<nnvm::ObjectPtr> Imperative::ListNonleafVariables(const nnvm::Symbol& sym) const {
using namespace nnvm;
std::vector<ObjectPtr> ret;
DFSVisit(sym.outputs, [&ret](const ObjectPtr& node) {
AGInfo& info = AGInfo::Get(node);
if (info.out_grads.size() > 0 && !node->is_variable()) {
ret.push_back(node);
}
});
return ret;
}

} // namespace mxnet
30 changes: 25 additions & 5 deletions src/nnvm/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Graph BuildGradientGraph(const Graph& src,
const std::vector<ObjectPtr>& topo_order,
std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
std::function<int(const Node&)> mirror_fun,
const std::unordered_map<const Node*, ObjectPtr>& mirror_map);
const std::unordered_map<const Node*, ObjectPtr>& mirror_map,
const std::vector<NodeEntry>& us = std::vector<NodeEntry>());

/*!
* \brief Auxiliary function that maps the forward node of the source graph to
Expand All @@ -88,6 +89,8 @@ Graph Gradient(Graph src) {
const std::vector<NodeEntry>& ys_out_grad =
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
CHECK_EQ(ys.size(), ys_out_grad.size());
const std::vector<NodeEntry>& us =
src.GetAttr<std::vector<NodeEntry> >("grad_us");

// initialize a topological order of the graph nodes and `output_grads`
// that maps every operator node to its gradient entries
Expand Down Expand Up @@ -120,7 +123,7 @@ Graph Gradient(Graph src) {
std::unordered_map<const Node*, ObjectPtr> mirror_map;

// complete the backward graph of the src, but without backward mirroring
nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map);
nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map, us);
if (mirror_fun == nullptr) {
return gsrc; // Gradient pass without mirroring ends here.
}
Expand Down Expand Up @@ -504,12 +507,14 @@ inline bool CheckGradAllZero(const std::vector<NodeEntry>& grads,
return true;
}


Graph BuildGradientGraph(const Graph& src,
const std::vector<NodeEntry>& xs,
const std::vector<ObjectPtr>& topo_order,
std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
std::function<int(const Node&)> mirror_fun,
const std::unordered_map<const Node*, ObjectPtr>& mirror_map) {
const std::unordered_map<const Node*, ObjectPtr>& mirror_map,
const std::vector<NodeEntry>& us) {
static auto& grad_fun_map = Op::GetAttr<nnvm::FGradient>("FGradient");

// gradient aggregation function
Expand Down Expand Up @@ -617,7 +622,7 @@ Graph BuildGradientGraph(const Graph& src,
CHECK(src_fwd_node->inputs.size() <= input_grads.size());
for (auto input_iter = src_fwd_node->inputs.begin(); input_iter != src_fwd_node->inputs.end();
++input_iter, ++input_grad_iter) {
// propagate the input gradients to the output gradients of the input nodes
// propagate the input_grads to the corresponding GradEntries mapped by output_grads
output_grads[input_iter->node.get()][input_iter->index].grads.emplace_back(
std::move(*input_grad_iter));
}
Expand Down Expand Up @@ -661,6 +666,20 @@ Graph BuildGradientGraph(const Graph& src,
ret.outputs[kv.second.second] = kv.first;
}
}

// Take the us' grad NodeEntry and store them in graph.attrs
std::vector<NodeEntry> nleaf_grads;
nleaf_grads.reserve(us.size());
for (const NodeEntry& e : us) {
GradEntry& entry = output_grads[e.node.get()][e.index];
// aggregate sum if it hasn't been
if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads));
}
nleaf_grads.push_back(entry.sum);
}
ret.attrs["nleaf_grads"] = std::make_shared<any>(std::move(nleaf_grads));

return ret;
}

Expand All @@ -673,7 +692,8 @@ NNVM_REGISTER_PASS(MXGradient)
.depend_graph_attr("grad_xs")
.depend_graph_attr("in_arg_shapes")
.depend_graph_attr("in_arg_dtypes")
.depend_graph_attr("grad_ys_out_grad");
.depend_graph_attr("grad_ys_out_grad")
.depend_graph_attr("grad_us");

} // namespace

Expand Down
67 changes: 66 additions & 1 deletion tests/python/unittest/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_detach_updated_grad():
assert x._fresh_grad == False


def test_retain_grad():
def test_retain_graph():
x = mx.nd.ones((2, 2))
dx = mx.nd.zeros((2, 2))
mark_variables([x], [dx], grad_reqs='add')
Expand Down Expand Up @@ -519,3 +519,68 @@ def test_gradient():
dx.backward()
assert abs(x.grad.asscalar() - 2.71828175) < 1e-7

def test_retain_grad_drop_grad():
x = nd.array([1,2,3,4])
x.attach_grad()
y = nd.array([5,6,7,8])
y.attach_grad()

with mx.autograd.record():
u = x * y
z = u * x

u.attach_grad()
z.attach_grad()
out_grad = nd.array([10, 10, 10, 10])
z.backward(out_grad, retain_graph=True)

assert (u.grad == out_grad * x).asnumpy().all()
assert (z.grad == out_grad).asnumpy().all()
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
assert (y.grad == out_grad * x*x).asnumpy().all()

u.drop_grad()
z.drop_grad()
y.drop_grad()
out_grad = nd.array([0.1, 0.1, 0.1, 0.1])
z.backward(out_grad)

assert u.grad is None and z.grad is None and y.grad is None
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()

def test_retain_grad_drop_grad_gluon():
class CompBlock(mx.gluon.HybridBlock):
def __init__(self):
super().__init__()
self.marked_var = None
def forward(self, a, b):
out1 = a*b
out2 = out1 * a
self.marked_var = out1
return out2
x = mx.np.array([1,2,3,4])
y = mx.np.array([5,6,7,8])
x.attach_grad()
y.attach_grad()
block2 = CompBlock()
block2.initialize()
# block2.hybridize()
with mx.autograd.record():
z = block2(x, y)
u = block2.marked_var
u.attach_grad()
z.attach_grad()
z.backward(retain_graph=True)

assert (u.grad == x).all()
assert (z.grad == mx.np.array([1,1,1,1])).all()
assert (x.grad == 2 * x * y).all()
assert (y.grad == x*x).all()

u.drop_grad()
z.drop_grad()
y.drop_grad()
z.backward()

assert u.grad is None and z.grad is None and y.grad is None
assert (x.grad == 2 * x * y).all()