-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-381] Enhancement of take operator #11326
Conversation
@reminisce @piiswrong @anirudh2290 @rahul003 @eric-haibin-lin Please give a review when you have time, thanks! |
28b8007
to
a6ed57c
Compare
@junrushao1994 you might want to keep an eye on this PR. |
@piiswrong @reminisce @anirudh2290 @rahul003 ping for review |
229897a
to
b4b5af3
Compare
.set_default(0) | ||
.describe("The axis of input array to be taken."); | ||
.describe("The axis of input array to be taken." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy take currently has raise as default but it doesnt seem to be supported in mxnet currently. We can also make raise as default but it will be a breaking change. We should add an issue to add it for 2.0 release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree, but something worth noting here is that adding 'raise' mode may impact the performance a bit as you need another kernel to check if all indices are within the legal range.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you dont need another kernel you can just do it inside the same kernel. you already have the bounds check for indices inside the Take kernel you can just maintain state of whether bound check passed or failed.
src/operator/tensor/indexing_op.h
Outdated
oshape[i + idxshape.ndim()] = arrshape[i + 1]; | ||
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); | ||
TShape oshape(idxshape.ndim() + arrshape.ndim() - 1); | ||
for (int i = 0; i < static_cast<int>(idxshape.ndim()); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use index_t here and avoid static_cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, will do.
src/operator/tensor/indexing_op.h
Outdated
for (int i = 0; i < static_cast<int>(idxshape.ndim()); ++i) { | ||
oshape[i + actual_axis] = idxshape[i]; | ||
} | ||
for (int i = 0; i < static_cast<int>(arrshape.ndim()); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use index_t here and avoid static_cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, will do.
src/operator/tensor/indexing_op.h
Outdated
} | ||
for (size_t i = 0; i < arrshape.ndim() - 1; i++) { | ||
oshape[i + idxshape.ndim()] = arrshape[i + 1]; | ||
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index_t here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do
const int in_ndims, const int out_ndims, const int idx_ndims, | ||
const int axis_dim, const int axis) { | ||
// i is the global flattened index in the output | ||
const int out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IType can be used for all indexes here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's possibility of IType to be of a floating number type, so compiler will complain about it. That's also the reason why the legacy Map function above is also using a cast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay makes sense.
@@ -389,7 +389,7 @@ Examples:: | |||
)code" ADD_FILELINE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given an input array with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, the output
This only holds true for axis =0 right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will update that doc.
src/operator/tensor/indexing_op.h
Outdated
MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { | ||
// get size of temporary storage for sort | ||
char* temp_storage_ptr = nullptr; | ||
int* src_indptr_ptr = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use dim_t instead of int here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're using cub DeviceHistogram for doing the histogramming of indices here we need to stick to int32, currently int32 should suffice. Or we can switch our own histogram kernel which supports all types, but that would be slower compared to cub's implementation.
src/operator/tensor/indexing_op.h
Outdated
s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast<int>(arrshape[axis])); | ||
} | ||
Tensor<cpu, 1, int> original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); | ||
Tensor<cpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can move this to the start and use temp_storage.dptr_ to reuse it and remove temp_storage_ptr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that I did not notice this comment earlier, the tensor is purely for the SortByKey function call, so keeping declaration of it closer to the function call makes more sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can also keep it at the same place. i am essentially suggesting that temp_storage_ptr seems not required and can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used as a shorthand for the calculated pointer within the whole workspace pool: https://github.com/apache/incubator-mxnet/pull/11326/files#diff-ed06b8d9798aca630313f2a9dd3fcd68R950
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do the following:
Tensor<cpu, 1, char> temp_storage(workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes, Shape1(temp_storage_bytes), s);
and use temp_storage or temp_storage.dptr_ for the pointer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay
}); | ||
}); | ||
} | ||
|
||
#ifdef __CUDACC__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this gpu specific code be moved to cuh file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to re-use the kernel here, if I move this to cuh and the cpu compiler will not see that kernel.
d49079e
to
5df48b2
Compare
src/operator/tensor/indexing_op.cc
Outdated
- `axis`- Only slicing along axis 0 is supported for now. | ||
- `mode`- Only `clip` mode is supported for now. | ||
- `axis`- Could be from -r to r-1 where r is the rank of input tensor | ||
- `mode`- Could be either `clip` or `wrap`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can move this explanation to the respective arguments and delete the note.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/operator/tensor/indexing_op.cc
Outdated
- `axis`- Only slicing along axis 0 is supported for now. | ||
- `mode`- Only `clip` mode is supported for now. | ||
- `axis`- Could be from -r to r-1 where r is the rank of input tensor | ||
- `mode`- Could be either `clip` or `wrap`. | ||
|
||
Examples:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to lack of extra blank lines, fixed.
27feafd
to
e8e51d7
Compare
@piiswrong @reminisce please give a review when you have time, thanks! |
@@ -272,7 +274,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst, | |||
const mshadow::Tensor<gpu, 1, IndexType>& sorted, | |||
const mshadow::Tensor<gpu, 1, IndexType>& index, | |||
const mshadow::Tensor<gpu, 2, DType> &src, | |||
mshadow::Tensor<gpu, 1, char>* workspace) { | |||
mshadow::Tensor<gpu, 1, char>* workspace = NULL) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: NULL
-> nullptr
. NULL
has more semantic meanings than nullptr
and should be deprecated in C++11.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will change.
@piiswrong Please give a review once you have a minute. |
* take forward for any axis with enhanced test * general take backward on gpu * backward of enhanced take op
* take forward for any axis with enhanced test * general take backward on gpu * backward of enhanced take op
* take forward for any axis with enhanced test * general take backward on gpu * backward of enhanced take op
Description
Previously our
take
operator only supports axis=0 and mode = 'clip' case, this PR adds support for axis in range [-r, r-1] and an additional mode 'wrap'.Checklist
Essentials
Changes
Comments
The legacy implementation for axis=0 and mode='clip' is still preserved to ensure there's no performance or accuracy regression after the enhancement.