Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Higher Order Gradient Support arcsin, arccos. #15515

Merged
53 changes: 51 additions & 2 deletions src/operator/tensor/elemwise_unary_op_trig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,31 @@ The storage type of ``arcsin`` output depends upon the input storage type:
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arcsin" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arcsin,
unary_bwd<mshadow_op::arcsin_grad>);
unary_bwd<mshadow_op::arcsin_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: head_grad_grads (dL/dxgrad)
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseGradUseIn)
// f(x) = arcsin(x)
// n: f'(x) = 1/(1-x^2)^1/2
// f''(x) = f'(x) * x/(1-x^2)
// Note: x/(1-x^2) = x * f'(x)^2
auto dydx = n->inputs[0];
auto x = n->inputs[1];
auto dydx_mul_grad_x = nnvm::NodeEntry{n};
auto op = mxnet::util::NodeOpGen{n};

auto x_grad = op.div(dydx_mul_grad_x, dydx);
auto x_grad_square = op.square(x_grad);
auto x_grad_square_mul_x = op.mul(x_grad_square, x);
auto x_grad_grad = op.mul(dydx_mul_grad_x, x_grad_square_mul_x);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(op.mul(ograds[0], x_grad));
ret.emplace_back(op.mul(ograds[0], x_grad_grad));
return ret;
});

// arccos
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(arccos, cpu, mshadow_op::arccos)
Expand All @@ -207,7 +231,32 @@ The storage type of ``arccos`` output is always dense
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arccos" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arccos,
unary_bwd<mshadow_op::arccos_grad>);
unary_bwd<mshadow_op::arccos_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: head_grad_grads (dL/dxgrad)
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseGradUseIn)
// f(x) = arccos(x)
// n: f'(x) = -1/(1-x^2)^1/2
// f''(x) = f'(x) * x/(1-x^2)
// Note: x/(1-x^2) = x * f'(x)^2
auto dydx = n->inputs[0];
auto x = n->inputs[1];
auto dydx_mul_grad_x = nnvm::NodeEntry{n};
auto op = mxnet::util::NodeOpGen{n};

auto x_grad = op.div(dydx_mul_grad_x, dydx);
auto x_grad_square = op.square(x_grad);
auto x_grad_square_mul_x = op.mul(x_grad_square, x);
auto x_grad_grad = op.mul(dydx_mul_grad_x, x_grad_square_mul_x);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(op.mul(ograds[0], x_grad));
ret.emplace_back(op.mul(ograds[0], x_grad_grad));
return ret;
});


// arctan
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arctan, cpu, mshadow_op::arctan)
Expand Down
38 changes: 38 additions & 0 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,44 @@ def grad_grad_op(x):
array, tanh, grad_grad_op, rtol=1e-6, atol=1e-6)


@with_seed()
def test_arcsin():
def arcsin(x):
return nd.arcsin(x)

def grad_grad_op(x):
return x / nd.sqrt((1-x**2)**3)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
# Hack: Decrease std_dev to make
# sure all elements
# are in range -1 to 1
# i.e. Domain of arcsin
array *= 0.2
check_second_order_unary(array, arcsin, grad_grad_op)


@with_seed()
def test_arccos():
def arccos(x):
return nd.arccos(x)

def grad_grad_op(x):
return -x / nd.sqrt((1-x**2)**3)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
# Hack: Decrease std_dev to make
# sure all elements
# are in range -1 to 1
# i.e. Domain of arccos
array *= 0.2
check_second_order_unary(array, arccos, grad_grad_op)


@with_seed()
def test_arctan():
def arctan(x):
Expand Down