Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ffi_bitwise binary
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyan66 committed Mar 19, 2020
1 parent 4fa4e65 commit 8190f73
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 8 deletions.
7 changes: 7 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def prepare_workloads():
OpArgMngr.add_workload("full_like", pool['2x2'], 2)
OpArgMngr.add_workload("zeros_like", pool['2x2'])
OpArgMngr.add_workload("ones_like", pool['2x2'])
OpArgMngr.add_workload("bitwise_or", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("bitwise_xor", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("bitwise_and", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("copysign", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("arctan2", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("hypot", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("ldexp", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1'])
OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1])
Expand Down
29 changes: 21 additions & 8 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5002,7 +5002,9 @@ def copysign(x1, x2, out=None, **kwargs):
>>> np.copysign(a, np.arange(3)-1)
array([-1., 0., 1.])
"""
return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.copysign(x1, x2, out=out)
return _api_internal.copysign(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5760,8 +5762,9 @@ def arctan2(x1, x2, out=None, **kwargs):
>>> np.arctan2(x, y)
array([ 1.5707964, -1.5707964])
"""
return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2,
_npi.arctan2_scalar, _npi.rarctan2_scalar, out=out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.arctan2(x1, x2, out=out)
return _api_internal.arctan2(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5809,7 +5812,9 @@ def hypot(x1, x2, out=None, **kwargs):
[ 5., 5., 5.],
[ 5., 5., 5.]])
"""
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.hypot(x1, x2, out=out)
return _api_internal.hypot(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5849,7 +5854,9 @@ def bitwise_and(x1, x2, out=None, **kwargs):
>>> np.bitwise_and(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([False, True])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_and, _np.bitwise_and, _npi.bitwise_and_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.bitwise_and(x1, x2, out=out)
return _api_internal.bitwise_and(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5887,7 +5894,9 @@ def bitwise_xor(x1, x2, out=None, **kwargs):
>>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, False])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.bitwise_xor(x1, x2, out=out)
return _api_internal.bitwise_xor(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5925,7 +5934,9 @@ def bitwise_or(x1, x2, out=None, **kwargs):
>>> np.bitwise_or(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, True])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_or, _np.bitwise_or, _npi.bitwise_or_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.bitwise_or(x1, x2, out=out)
return _api_internal.bitwise_or(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5964,7 +5975,9 @@ def ldexp(x1, x2, out=None, **kwargs):
>>> np.ldexp(5, np.arange(4))
array([ 5., 10., 20., 40.])
"""
return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.ldexp(x1, x2, out=out)
return _api_internal.ldexp(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down
59 changes: 59 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,63 @@ MXNET_REGISTER_API("_npi.lcm")
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.bitwise_or")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_or");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_or_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.bitwise_xor")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_xor");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_xor_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.bitwise_and")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_and");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_and_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.copysign")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_copysign");
const nnvm::Op* op_scalar = Op::Get("_npi_copysign_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rcopysign_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.arctan2")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_arctan2");
const nnvm::Op* op_scalar = Op::Get("_npi_arctan2_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rarctan2_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.hypot")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_hypot");
const nnvm::Op* op_scalar = Op::Get("_npi_hypot_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.ldexp")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_ldexp");
const nnvm::Op* op_scalar = Op::Get("_npi_ldexp_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rldexp_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

} // namespace mxnet

0 comments on commit 8190f73

Please sign in to comment.