Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[API] Add logaddexp #20673

Merged
merged 9 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/python_docs/python/api/np/routines.math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Exponents and logarithms
log10
log2
log1p
logaddexp


Other special functions
Expand Down Expand Up @@ -133,6 +134,7 @@ Rational routines
:toctree: generated/

lcm
gcd


Arithmetic operations
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@
'_npi_expm1',
'_npi_ldexp',
'_npi_ldexp_scalar',
'_npi_logaddexp',
'_npi_logaddexp_scalar',
'_npi_log',
'_npi_log10',
'_npi_log1p',
Expand Down
44 changes: 43 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
'where', 'bincount', 'rollaxis', 'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'diag', 'diagonal',
'positive']
'positive', 'logaddexp']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6914,6 +6914,48 @@ def ldexp(x1, x2, out=None, **kwargs):
return _api_internal.ldexp(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def logaddexp(x1, x2, out=None, **kwargs):
"""
Logarithm of the sum of exponentiations of the inputs.

Calculates log(exp(x1) + exp(x2)). This function is useful in statistics where
the calculated probabilities of events may be so small as to exceed the range of
normal floating point numbers. In such cases the logarithm of the calculate
probability is stored. This function allows adding probabilities stored
in such a fashion.

Parameters
----------
x1 : ndarray or scalar
Array of multipliers.
x2 : ndarray or scalar, int
Array of twos exponents.
out : ndarray, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.

Returns
-------
y : ndarray or scalar
Logarithm of exp(x1) + exp(x2). This is a scalar if both x1 and x2 are scalars.

Examples
--------
>>> prob1 = np.log(1e-50)
>>> prob2 = np.log(2.5e-50)
>>> prob12 = np.logaddexp(prob1, prob2)
>>> prob12
-113.87649168120691
>>> np.exp(prob12)
3.5000000000000057e-50
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.logaddexp(x1, x2, out=out)
return _api_internal.logaddexp(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
def vdot(a, b):
r"""
Expand Down
43 changes: 42 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal', 'positive']
'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal',
'positive', 'logaddexp']

__all__ += fallback.__all__

Expand Down Expand Up @@ -9504,6 +9505,46 @@ def ldexp(x1, x2, out=None, **kwargs):
return _mx_nd_np.ldexp(x1, x2, out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def logaddexp(x1, x2, out=None, **kwargs):
"""
Logarithm of the sum of exponentiations of the inputs.

Calculates log(exp(x1) + exp(x2)). This function is useful in statistics where
the calculated probabilities of events may be so small as to exceed the range of
normal floating point numbers. In such cases the logarithm of the calculate
probability is stored. This function allows adding probabilities stored
in such a fashion.

Parameters
----------
x1 : ndarray or scalar
Array of multipliers.
x2 : ndarray or scalar, int
Array of twos exponents.
out : ndarray, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.

Returns
-------
y : ndarray or scalar
Logarithm of exp(x1) + exp(x2). This is a scalar if both x1 and x2 are scalars.

Examples
--------
>>> prob1 = np.log(1e-50)
>>> prob2 = np.log(2.5e-50)
>>> prob12 = np.logaddexp(prob1, prob2)
>>> prob12
-113.87649168120691
>>> np.exp(prob12)
3.5000000000000057e-50
"""
return _mx_nd_np.logaddexp(x1, x2, out)


@set_module('mxnet.numpy')
def vdot(a, b):
r"""
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def _register_array_function():
'lcm',
'gcd',
# 'ldexp',
'logaddexp',
'subtract',
'multiply',
'true_divide',
Expand Down
8 changes: 8 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ MXNET_REGISTER_API("_npi.bitwise_and")
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.logaddexp")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_logaddexp");
const nnvm::Op* op_scalar = Op::Get("_npi_logaddexp_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.copysign")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
Expand Down
14 changes: 14 additions & 0 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,20 @@ rldexp_grad(const DType val,
return val2 * op::power(static_cast<type>(2), val) * op::log(static_cast<type>(2));
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
logaddexp_grad(const DType val,
const DType2 val2) {
return op::exp(val) / (op::exp(val) + op::exp(val2));
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
logaddexp_rgrad(const DType val,
const DType2 val2) {
return op::exp(val2) / (op::exp(val) + op::exp(val2));
}

template <typename DType, typename DType2>
__device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) {
auto bsq = scalar * scalar;
Expand Down
10 changes: 10 additions & 0 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,16 @@ rldexp(const DType a, const DType2 b) {
return ldexp(b, a);
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
logaddexp(const DType a, const DType2 b) {
if (type_util::has_double_or_integral<DType, DType2>::value) {
return ::log(::exp(static_cast<double>(a)) + ::exp(static_cast<double>(b)));
} else {
return ::log(::expf(static_cast<float>(a)) + ::expf(static_cast<float>(b)));
}
}

#undef DEFINE_BINARY_MATH_FUNC

template <typename DType, typename DType2>
Expand Down
7 changes: 7 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,13 @@ MXNET_BINARY_MATH_OP(rldexp, math::id(b) * math::pow(2.0f, a)); // swap a and b

MXNET_BINARY_MATH_OP(rldexp_grad, math::id(b) * math::pow(2.0f, a) * math::log(2.0f));

/*! \brief used for generate element of logaddexp */
MXNET_BINARY_MATH_OP(logaddexp, math::log(math::exp(a) + math::exp(b)));

MXNET_BINARY_MATH_OP(logaddexp_grad, math::exp(a) / (math::exp(a) + math::exp(b)));

MXNET_BINARY_MATH_OP(logaddexp_rgrad, math::exp(b) / (math::exp(a) + math::exp(b)));

/*! \brief used for generate element of round */
MXNET_SIMPLE_UNARY_MATH_OP(round);

Expand Down
60 changes: 60 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_lae.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_elemwise_broadcast_op_lae.cc
* \brief CPU Implementation of basic functions for elementwise numpy binary logaddexp.
*/

#include "./np_elemwise_broadcast_op.h"

namespace mxnet {
namespace op {

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_logaddexp)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logaddexp>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_logaddexp"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_logaddexp_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logaddexp>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_logaddexp_scalar"});

NNVM_REGISTER_OP(_backward_npi_logaddexp)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>(
"FCompute<cpu>",
BinaryBroadcastBackwardUseIn<cpu, mshadow_op::logaddexp_grad, mshadow_op::logaddexp_rgrad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_logaddexp_scalar)
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::logaddexp_grad>);

} // namespace op
} // namespace mxnet
44 changes: 44 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_lae.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_elemwise_broadcast_op_lae.cu
* \brief GPU Implementation of basic functions for elementwise binary broadcast logaddexp operator.
*/

#include "./np_elemwise_broadcast_op.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_logaddexp)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastRTCCompute{"logaddexp"});

NNVM_REGISTER_OP(_npi_logaddexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCCompute{"logaddexp"});

NNVM_REGISTER_OP(_backward_npi_logaddexp)
.set_attr<FCompute>("FCompute<gpu>",
BinaryBroadcastRTCBackwardUseIn{"logaddexp_grad", "logaddexp_rgrad"});

NNVM_REGISTER_OP(_backward_npi_logaddexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCBackward{"logaddexp_grad"});

} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp);
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logaddexp); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logaddexp_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logaddexp_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::posone); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::negone); // NOLINT()
/*!
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,13 @@ def _add_workload_ldexp():
OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(-9223372036854775808, np.int64))


def _add_workload_logaddexp(array_pool):
OpArgMngr.add_workload('logaddexp', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('logaddexp', array_pool['4x1'], 2)
OpArgMngr.add_workload('logaddexp', 2, array_pool['4x1'])
OpArgMngr.add_workload('logaddexp', array_pool['4x1'], array_pool['1x1x0'])


def _add_workload_subtract(array_pool):
OpArgMngr.add_workload('subtract', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('subtract', array_pool['4x1'], 2)
Expand Down Expand Up @@ -3082,6 +3089,7 @@ def _prepare_workloads():
_add_workload_bitwise_xor()
_add_workload_bitwise_or()
_add_workload_ldexp()
_add_workload_logaddexp(array_pool)
_add_workload_subtract(array_pool)
_add_workload_multiply(array_pool)
_add_workload_power(array_pool)
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3113,6 +3113,8 @@ def forward(self, a, b, *args, **kwargs):
'hypot': (-1, 1, [lambda y, x1, x2: x1 / y],
[lambda y, x1, x2: x2 / y]),
'ldexp': (-3, 3, [None], None, [[onp.int32]]),
'logaddexp': (-10, 10, [lambda y, x1, x2: onp.exp(x1) / (onp.exp(x1) + onp.exp(x2))],
[lambda y, x1, x2: onp.exp(x2) / (onp.exp(x1) + onp.exp(x2))])
}
if is_op_runnable():
funcs['logical_and'] = (-100, 100, [None], None, [[onp.float32, onp.float64]])
Expand Down