-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
On my machine, before optimization:
After optimization:
Using below script: import time
import mxnet as mx
src_shape = [(1, ), (1, 128), (1, 256), (1, 512), (1, 1024), (1, 16383), (1, 65535), (1, 256, 256), (256, 1, 256)]
dst_shape = [(128,), (128, 128), (256, 256), (512, 512), (1024, 1024), (24, 16383), (24, 65535), (24, 256, 256), (256, 24, 256)]
axes = [0, 0, 0, 0, 0, 0, 0, 0, 1]
for idx, sh in enumerate(src_shape):
size = dst_shape[idx][axes[idx]]
# mxnet
x = mx.random.uniform(shape=sh)
y = mx.nd.broadcast_axis(x, axis=axes[idx], size=size)
y.wait_to_read()
tic = time.time()
for _ in range(3000):
y = mx.nd.broadcast_axis(x, axis=axes[idx], size=size)
y.wait_to_read()
mx_time = time.time() - tic
print("broadcast %s to %s, time: %.5f ms" %(sh, dst_shape[idx], mx_time*1000/3000.0)) |
@sxjscience Please take a look. Thanks! |
…to broadcast-axis-opt
const std::vector<TBlob>& outputs) { | ||
using namespace mshadow; | ||
const BroadcastAxesParam& param = nnvm::get<BroadcastAxesParam>(attrs.parsed); | ||
if (param.axis.ndim() == 1 && inputs[0].shape_[param.axis[0]] == 1 && req[0] == kWriteTo) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should deal with negative value in param.axis
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems broadcast_axis doesn't support negative axis even before this change. See https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/broadcast_reduce_op.h#L386.
import mxnet as mx
a = mx.random.uniform(shape=(4, 5, 1))
b = mx.nd.broadcast_axis(a, axis=-1, size=6)
print(b)
Will cause error:
Traceback (most recent call last):
File "test_bcast.py", line 4, in <module>
b = mx.nd.broadcast_axis(a, axis=-1, size=6)
File "<string>", line 60, in broadcast_axis
File "/home/lvtao/miniconda3/envs/mxnet/lib/python3.6/site-packages/mxnet/_ctypes/ndarray.py", line 107, in _imperative_invoke
ctypes.byref(out_stypes)))
File "/home/lvtao/miniconda3/envs/mxnet/lib/python3.6/site-packages/mxnet/base.py", line 278, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [18:53:34] include/mxnet/./tuple.h:206: Check failed: i >= 0 && i < ndim(): index = -1 must be in range [0, 3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, then LGTM.
Description
Improve the performance of single axis broadcasting.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments