Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix_uninitialized_issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed Mar 18, 2020
1 parent 2158ff9 commit f333b7b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 18 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6296,9 +6296,9 @@ def interp(x, xp, fp, left=None, right=None, period=None): # pylint: disable=to
"""
if isinstance(x, numeric_types):
return _npi.interp(xp.astype(float), fp.astype(float), left=left,
right=right, period=period, x_scalar=x)
return _npi.interp(xp.astype(float), fp.astype(float), x.astype(float),
left=left, right=right, period=period, x_scalar=None)
right=right, period=period, x_scalar=x, x_is_scalar=True)
return _npi.interp(xp.astype(float), fp.astype(float), x.astype(float), left=left,
right=right, period=period, x_scalar=0.0, x_is_scalar=False)


@set_module('mxnet.symbol.numpy')
Expand Down
4 changes: 3 additions & 1 deletion src/api/operator/numpy/np_interp_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ MXNET_REGISTER_API("_npi.interp")
}
if (args[2].type_code() == kDLInt || args[2].type_code() == kDLFloat) {
param.x_scalar = args[2].operator double();
param.x_is_scalar = true;
attrs.op = op;
attrs.parsed = std::move(param);
SetAttrDict<op::NumpyInterpParam>(&attrs);
Expand All @@ -60,7 +61,8 @@ MXNET_REGISTER_API("_npi.interp")
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
} else {
param.x_scalar = dmlc::nullopt;
param.x_scalar = 0.0;
param.x_is_scalar = false;
attrs.op = op;
attrs.parsed = std::move(param);
SetAttrDict<op::NumpyInterpParam>(&attrs);
Expand Down
26 changes: 15 additions & 11 deletions src/operator/numpy/np_interp_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ struct NumpyInterpParam : public dmlc::Parameter<NumpyInterpParam> {
dmlc::optional<double> left;
dmlc::optional<double> right;
dmlc::optional<double> period;
dmlc::optional<double> x_scalar;
double x_scalar;
bool x_is_scalar;
DMLC_DECLARE_PARAMETER(NumpyInterpParam) {
DMLC_DECLARE_FIELD(left)
.set_default(dmlc::optional<double>())
Expand All @@ -56,20 +57,23 @@ struct NumpyInterpParam : public dmlc::Parameter<NumpyInterpParam> {
.describe("A period for the x-coordinates. This parameter allows"
"the proper interpolation of angular x-coordinates. Parameters"
"left and right are ignored if period is specified.");
DMLC_DECLARE_FIELD(x_scalar)
.set_default(dmlc::optional<double>())
.describe("Input x is a scalar");
DMLC_DECLARE_FIELD(x_scalar).set_default(0.0)
.describe("x is a scalar input");
DMLC_DECLARE_FIELD(x_is_scalar).set_default(false)
.describe("Flag that determines whether input is a scalar");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream left_s, right_s, period_s, x_scalar_s;
std::ostringstream left_s, right_s, period_s, x_scalar_s, x_is_scalar_s;
left_s << left;
right_s << right;
period_s << period;
x_scalar_s << x_scalar;
x_is_scalar_s << x_is_scalar;
(*dict)["left"] = left_s.str();
(*dict)["right"] = right_s.str();
(*dict)["period"] = period_s.str();
(*dict)["x_scalar"] = x_scalar_s.str();
(*dict)["x_is_scalar"] = x_is_scalar_s.str();
}
};

Expand Down Expand Up @@ -193,7 +197,7 @@ void NumpyInterpForward(const nnvm::NodeAttrs& attrs,
dmlc::optional<double> left = param.left;
dmlc::optional<double> right = param.right;
dmlc::optional<double> period = param.period;
dmlc::optional<double> x_scalar = param.x_scalar;
bool x_is_scalar = param.x_is_scalar;

TBlob xp = inputs[0];
const TBlob &fp = inputs[1];
Expand All @@ -213,8 +217,8 @@ void NumpyInterpForward(const nnvm::NodeAttrs& attrs,

size_t topk_temp_size; // Used by Sort
size_t topk_workspace_size = TopKWorkspaceSize<xpu, double>(xp, topk_param, &topk_temp_size);
size_t size_x = x_scalar.has_value() ? 8 : 0;
size_t size_norm_x = x_scalar.has_value() ? 8 : inputs[2].Size() * sizeof(double);
size_t size_x = x_is_scalar ? 8 : 0;
size_t size_norm_x = x_is_scalar ? 8 : inputs[2].Size() * sizeof(double);
size_t size_norm_xp = xp.Size() * sizeof(double);
size_t size_norm = period.has_value()? size_norm_x + size_norm_xp : 0;
size_t size_idx = period.has_value()? xp.Size() * sizeof(index_t) : 0;
Expand All @@ -227,9 +231,9 @@ void NumpyInterpForward(const nnvm::NodeAttrs& attrs,
char* workspace_curr_ptr = temp_mem.dptr_;

TBlob x, idx;
if (x_scalar.has_value()) {
double x_value = x_scalar.value();
Tensor<cpu, 1, double> host_x(&x_value, Shape1(1), ctx.get_stream<cpu>());
if (x_is_scalar) {
double x_scalar = param.x_scalar;
Tensor<cpu, 1, double> host_x(&x_scalar, Shape1(1), ctx.get_stream<cpu>());
Tensor<xpu, 1, double> device_x(reinterpret_cast<double*>(workspace_curr_ptr),
Shape1(1), ctx.get_stream<xpu>());
Copy(device_x, host_x, ctx.get_stream<xpu>());
Expand Down
6 changes: 3 additions & 3 deletions src/operator/numpy/np_interp_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ inline bool NumpyInterpShape(const nnvm::NodeAttrs& attrs,
<< "ValueError: Data points must be 1-D array";
CHECK_EQ(in_attrs->at(0)[0], in_attrs->at(1)[0])
<< "ValueError: fp and xp are not of the same length";
oshape = param.x_scalar.has_value() ? TShape(0, 1) : in_attrs->at(2);
oshape = param.x_is_scalar ? TShape(0, 1) : in_attrs->at(2);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return shape_is_known(out_attrs->at(0));
}
Expand All @@ -62,7 +62,7 @@ NNVM_REGISTER_OP(_npi_interp)
.set_num_inputs([](const NodeAttrs& attrs) {
const NumpyInterpParam& param =
nnvm::get<NumpyInterpParam>(attrs.parsed);
return param.x_scalar.has_value()? 2 : 3;
return param.x_is_scalar ? 2 : 3;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyInterpParam>)
Expand All @@ -72,7 +72,7 @@ NNVM_REGISTER_OP(_npi_interp)
[](const NodeAttrs& attrs) {
const NumpyInterpParam& param =
nnvm::get<NumpyInterpParam>(attrs.parsed);
return param.x_scalar.has_value() ?
return param.x_is_scalar ?
std::vector<std::string>{"xp", "fp"} :
std::vector<std::string>{"xp", "fp", "x"};
})
Expand Down

0 comments on commit f333b7b

Please sign in to comment.