Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
more binary
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyan66 committed Mar 12, 2020
1 parent 9f992e0 commit 1cd1fbf
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
4 changes: 4 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def prepare_workloads():
OpArgMngr.add_workload("bitwise_or", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("bitwise_xor", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("bitwise_and", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("copysign", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("arctan2", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("hypot", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("ldexp", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)


Expand Down
17 changes: 12 additions & 5 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5001,7 +5001,9 @@ def copysign(x1, x2, out=None, **kwargs):
>>> np.copysign(a, np.arange(3)-1)
array([-1., 0., 1.])
"""
return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.copysign(x1, x2, out=out)
return _api_internal.copysign(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5759,8 +5761,9 @@ def arctan2(x1, x2, out=None, **kwargs):
>>> np.arctan2(x, y)
array([ 1.5707964, -1.5707964])
"""
return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2,
_npi.arctan2_scalar, _npi.rarctan2_scalar, out=out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.arctan2(x1, x2, out=out)
return _api_internal.arctan2(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5808,7 +5811,9 @@ def hypot(x1, x2, out=None, **kwargs):
[ 5., 5., 5.],
[ 5., 5., 5.]])
"""
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.hypot(x1, x2, out=out)
return _api_internal.hypot(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5969,7 +5974,9 @@ def ldexp(x1, x2, out=None, **kwargs):
>>> np.ldexp(5, np.arange(4))
array([ 5., 10., 20., 40.])
"""
return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.ldexp(x1, x2, out=out)
return _api_internal.ldexp(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down
32 changes: 32 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,36 @@ MXNET_REGISTER_API("_npi.bitwise_and")
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.copysign")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_copysign");
const nnvm::Op* op_scalar = Op::Get("_npi_copysign_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.arctan2")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_arctan2");
const nnvm::Op* op_scalar = Op::Get("_npi_arctan2_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.hypot")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_hypot");
const nnvm::Op* op_scalar = Op::Get("_npi_hypot_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.ldexp")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_ldexp");
const nnvm::Op* op_scalar = Op::Get("_npi_ldexp_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

} // namespace mxnet

0 comments on commit 1cd1fbf

Please sign in to comment.