Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] FFI binary bitwise ops (#17812)
Browse files Browse the repository at this point in the history
* ffi_bitwise  binary

* retrigger ci
  • Loading branch information
Yiyan66 authored Apr 8, 2020
1 parent 178b98e commit d1616c9
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 8 deletions.
7 changes: 7 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def prepare_workloads():
OpArgMngr.add_workload("full_like", pool['2x2'], 2)
OpArgMngr.add_workload("zeros_like", pool['2x2'])
OpArgMngr.add_workload("ones_like", pool['2x2'])
OpArgMngr.add_workload("bitwise_and", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("bitwise_xor", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("bitwise_or", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("copysign", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("arctan2", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("hypot", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("ldexp", pool['2x2'].astype(int), pool['2x2'].astype(int))
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1'])
OpArgMngr.add_workload("fmax", pool['2x2'], pool['2x2'])
Expand Down
29 changes: 21 additions & 8 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5052,7 +5052,9 @@ def copysign(x1, x2, out=None, **kwargs):
>>> np.copysign(a, np.arange(3)-1)
array([-1., 0., 1.])
"""
return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.copysign(x1, x2, out=out)
return _api_internal.copysign(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5808,8 +5810,9 @@ def arctan2(x1, x2, out=None, **kwargs):
>>> np.arctan2(x, y)
array([ 1.5707964, -1.5707964])
"""
return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2,
_npi.arctan2_scalar, _npi.rarctan2_scalar, out=out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.arctan2(x1, x2, out=out)
return _api_internal.arctan2(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5857,7 +5860,9 @@ def hypot(x1, x2, out=None, **kwargs):
[ 5., 5., 5.],
[ 5., 5., 5.]])
"""
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.hypot(x1, x2, out=out)
return _api_internal.hypot(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5897,7 +5902,9 @@ def bitwise_and(x1, x2, out=None, **kwargs):
>>> np.bitwise_and(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([False, True])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_and, _np.bitwise_and, _npi.bitwise_and_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.bitwise_and(x1, x2, out=out)
return _api_internal.bitwise_and(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5935,7 +5942,9 @@ def bitwise_xor(x1, x2, out=None, **kwargs):
>>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, False])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.bitwise_xor(x1, x2, out=out)
return _api_internal.bitwise_xor(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5973,7 +5982,9 @@ def bitwise_or(x1, x2, out=None, **kwargs):
>>> np.bitwise_or(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, True])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_or, _np.bitwise_or, _npi.bitwise_or_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.bitwise_or(x1, x2, out=out)
return _api_internal.bitwise_or(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6012,7 +6023,9 @@ def ldexp(x1, x2, out=None, **kwargs):
>>> np.ldexp(5, np.arange(4))
array([ 5., 10., 20., 40.])
"""
return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.ldexp(x1, x2, out=out)
return _api_internal.ldexp(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down
59 changes: 59 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,63 @@ MXNET_REGISTER_API("_npi.lcm")
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.bitwise_or")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_or");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_or_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.bitwise_xor")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_xor");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_xor_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.bitwise_and")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_and");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_and_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.copysign")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_copysign");
const nnvm::Op* op_scalar = Op::Get("_npi_copysign_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rcopysign_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.arctan2")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_arctan2");
const nnvm::Op* op_scalar = Op::Get("_npi_arctan2_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rarctan2_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.hypot")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_hypot");
const nnvm::Op* op_scalar = Op::Get("_npi_hypot_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.ldexp")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_ldexp");
const nnvm::Op* op_scalar = Op::Get("_npi_ldexp_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rldexp_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

} // namespace mxnet

0 comments on commit d1616c9

Please sign in to comment.