Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] allow mix integer dtypes for power/add/multiply #17921

Merged
merged 3 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
#else
Expand Down
68 changes: 66 additions & 2 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];

if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
if (lhs.type_flag_ == out.type_flag_) {
MixedAllRealBinaryElemwiseCompute<xpu, ROP>(attrs.op->name, ctx, lhs, rhs, out, req[0]);
Expand Down Expand Up @@ -252,7 +251,6 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);

if (!ndim) {
MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
Expand Down Expand Up @@ -290,6 +288,27 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
});
}
});
} else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
TBlob temp_tblob;
if (lhs.type_flag_ == out.type_flag_) {
MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
Tensor<xpu, 1, LType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
} else {
MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
Tensor<xpu, 1, RType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
} else {
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
}
Expand Down Expand Up @@ -320,6 +339,27 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
} else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
TBlob temp_tblob;
if (lhs.type_flag_ == out.type_flag_) {
MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
Tensor<xpu, 1, LType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
} else {
MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
Tensor<xpu, 1, RType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
} else {
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
}
Expand Down Expand Up @@ -384,6 +424,30 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
BinaryBroadcastComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs);
return;
}
if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
Stream<xpu> *s = ctx.get_stream<xpu>();
TBlob temp_tblob;
if (lhs.type_flag_ == out.type_flag_) {
MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
Tensor<xpu, 1, LType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
} else {
MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
Tensor<xpu, 1, RType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
return;
}

#ifndef _WIN32
MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,8 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
use_broadcast=False, equal_nan=True)

if lgrad:
if (ltype in itypes) and (rtype in itypes):
continue
y.backward()
if ltype not in itypes:
assert_almost_equal(mx_test_x1.grad.asnumpy(),
Expand Down Expand Up @@ -2478,6 +2480,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)

if func == 'subtract':
continue
for type1, type2 in itertools.product(itypes, itypes):
if type1 == type2:
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)


@with_seed()
@use_np
Expand Down