Skip to content

Commit

Permalink
Add Tensor::FlatTo1D and work around weird clang error (apache#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored and tqchen committed Sep 12, 2016
1 parent 478e5fd commit ca4d4f6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
19 changes: 19 additions & 0 deletions mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ struct Shape {
MSHADOW_XINLINE bool operator!=(const Shape<kDimension> &s) const {
return !(*this == s);
}
/*!
* flatten the tensor, return a 1D shape
* \return the flat 1d shape
*/
MSHADOW_XINLINE Shape<1> FlatTo1D(void) const {
Shape<1> s;
s[0] = this->Size();
return s;
}
/*!
* flatten the higher dimension to second dimension, return a 2D shape
* \return the flat 2d shape
Expand Down Expand Up @@ -360,6 +369,13 @@ struct Tensor: public TRValue<Tensor<Device, dimension, DType>,
MSHADOW_XINLINE index_t size(index_t idx) const {
return shape_[idx];
}
/*!
* \brief flatten the tensor to 1 dimension
* \return tensor after flatten
*/
MSHADOW_XINLINE Tensor<Device, 1, DType> FlatTo1D(void) const {
return Tensor<Device, 1, DType>(dptr_, shape_.FlatTo1D(), stride_, stream_);
}
/*!
* \brief flatten the tensor to 2 dimension, collapse the higher dimensions together
* \return tensor after flatten
Expand Down Expand Up @@ -434,6 +450,9 @@ struct Tensor<Device, 1, DType>:
inline void set_stream(Stream<Device> *stream) {
this->stream_ = stream;
}
MSHADOW_XINLINE Tensor<Device, 1, DType> FlatTo1D(void) const {
return *this;
}
MSHADOW_XINLINE Tensor<Device, 2, DType> FlatTo2D(void) const {
return Tensor<Device, 2, DType>(dptr_, shape_.FlatTo2D(), stride_, stream_);
}
Expand Down
6 changes: 3 additions & 3 deletions mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ inline void MapPlan(TRValue<R, cpu, dim, DType> *dst,
for (index_t y = 0; y < shape[0]; ++y) {
for (index_t x = 0; x < shape[1]; ++x) {
// trust your compiler! -_- they will optimize it
Saver::Save(dplan.REval(y, x), plan.Eval(y, x));
Saver::template Save<DType>(dplan.REval(y, x), plan.Eval(y, x));
}
}
}
Expand Down Expand Up @@ -217,7 +217,7 @@ inline void MapReduceKeepLowest(TRValue<R, cpu, 1, DType> *dst,
for (index_t y = 1; y < eshape[0]; ++y) {
Reducer::Reduce(res, splan.Eval(y, x));
}
Saver::Save(dplan.REval(0, x), res * scale);
Saver::template Save<DType>(dplan.REval(0, x), res * scale);
}
}

Expand Down Expand Up @@ -254,7 +254,7 @@ inline void MapReduceKeepHighDim(TRValue<R, cpu, 1, DType> *dst,
}
Reducer::Reduce(res, tres);
}
Saver::Save(dplan.REval(0, c), DType(res * scale));
Saver::template Save<DType>(dplan.REval(0, c), DType(res * scale));
}
}

Expand Down

0 comments on commit ca4d4f6

Please sign in to comment.