Skip to content

Commit

Permalink
[Numpy] allow mix integer dtypes for power/add/multiply (apache#17921)
Browse files Browse the repository at this point in the history
* resolution

* fix sanity error

* remove func 'is_integer'
  • Loading branch information
JiangZhaoh authored and sxjscience committed Jul 1, 2020
1 parent b816d43 commit 34b4708
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
#else
Expand Down
68 changes: 66 additions & 2 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];

if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
if (lhs.type_flag_ == out.type_flag_) {
MixedAllRealBinaryElemwiseCompute<xpu, ROP>(attrs.op->name, ctx, lhs, rhs, out, req[0]);
Expand Down Expand Up @@ -252,7 +251,6 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);

if (!ndim) {
MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
Expand Down Expand Up @@ -290,6 +288,27 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
});
}
});
} else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
TBlob temp_tblob;
if (lhs.type_flag_ == out.type_flag_) {
MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
Tensor<xpu, 1, LType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
} else {
MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
Tensor<xpu, 1, RType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
} else {
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
}
Expand Down Expand Up @@ -320,6 +339,27 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
} else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
TBlob temp_tblob;
if (lhs.type_flag_ == out.type_flag_) {
MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
Tensor<xpu, 1, LType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
} else {
MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
Tensor<xpu, 1, RType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
} else {
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
}
Expand Down Expand Up @@ -384,6 +424,30 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
BinaryBroadcastComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs);
return;
}
if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
Stream<xpu> *s = ctx.get_stream<xpu>();
TBlob temp_tblob;
if (lhs.type_flag_ == out.type_flag_) {
MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
Tensor<xpu, 1, LType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
} else {
MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
Tensor<xpu, 1, RType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
return;
}

#ifndef _WIN32
MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,8 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
use_broadcast=False, equal_nan=True)

if lgrad:
if (ltype in itypes) and (rtype in itypes):
continue
y.backward()
if ltype not in itypes:
assert_almost_equal(mx_test_x1.grad.asnumpy(),
Expand Down Expand Up @@ -2544,6 +2546,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)

if func == 'subtract':
continue
for type1, type2 in itertools.product(itypes, itypes):
if type1 == type2:
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)


@with_seed()
@use_np
Expand Down

0 comments on commit 34b4708

Please sign in to comment.