Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added references in symbol.md and ndarray.md. Improved test cases and…
Browse files Browse the repository at this point in the history
… added block_size check
  • Loading branch information
access2rohit committed Jul 23, 2018
1 parent a14d7c5 commit ce29f04
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 35 deletions.
61 changes: 35 additions & 26 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2172,11 +2172,10 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
}

struct DepthToSpaceParam : public dmlc::Parameter<DepthToSpaceParam> {
int blockSize;
int block_size;
DMLC_DECLARE_PARAMETER(DepthToSpaceParam) {
DMLC_DECLARE_FIELD(blockSize)
.describe("The size of chunks that need to be taken from depth and spread across to the"
" shape dimension of the tensor and vice versa");
DMLC_DECLARE_FIELD(block_size)
.describe("Blocks of [block_size. block_size] are moved");
}
};

Expand All @@ -2191,7 +2190,8 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs,
TShape expected_out(4);

TShape& in_shape = in_attrs->at(0);
int block = param.blockSize;
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0";
CHECK_EQ(in_shape[1] % (block * block), 0)
<< "Cannot perform Depth To Space operation on the specified tensor."
Expand Down Expand Up @@ -2226,10 +2226,18 @@ inline bool DepthToSpaceOpType(const nnvm::NodeAttrs& attrs,
return out_attrs->at(0) != -1;
}

#define UPDATE_INDEX_USING_OFFSET(X) \
next_idx_val = idx / dim_size; \
inp_index += (idx - next_idx_val * dim_size) * offset_arr[X]; \
idx = next_idx_val;
MSHADOW_XINLINE void update_index(int index_position, int dim_size, int *idx, int *inp_index, const int* offset_arr){
int next_idx_val = *idx / dim_size;
*inp_index += (*idx - next_idx_val * dim_size) * offset_arr[index_position];
*idx = next_idx_val;
}

/*
* #define UPDATE_INDEX_USING_OFFSET(X) \
* next_idx_val = idx / dim_size; \
* inp_index += (idx - next_idx_val * dim_size) * offset_arr[X]; \
* idx = next_idx_val;
*/

/*!
* \brief This function preforms the tensor transpose (0, 1, 2, 3, 4, 5) ->
Expand All @@ -2247,19 +2255,19 @@ struct depth_to_space_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
const int block, const int* size, const int* offset_arr) {
int inp_index = 0, idx = i, next_idx_val, dim_size;
int inp_index = 0, idx = i, dim_size;
dim_size = block;
UPDATE_INDEX_USING_OFFSET(2)
update_index(2, dim_size, &idx, &inp_index, offset_arr);
dim_size = size[3];
UPDATE_INDEX_USING_OFFSET(5)
update_index(5, dim_size, &idx, &inp_index, offset_arr);
dim_size = block;
UPDATE_INDEX_USING_OFFSET(1)
update_index(1, dim_size, &idx, &inp_index, offset_arr);
dim_size = size[2];
UPDATE_INDEX_USING_OFFSET(4)
update_index(4, dim_size, &idx, &inp_index, offset_arr);
dim_size = size[1] / (block * block);
UPDATE_INDEX_USING_OFFSET(3)
update_index(3, dim_size, &idx, &inp_index, offset_arr);
dim_size = size[0];
UPDATE_INDEX_USING_OFFSET(0)
update_index(0, dim_size, &idx, &inp_index, offset_arr);
KERNEL_ASSIGN(out_data[i], req, in_data[inp_index]);
}
};
Expand Down Expand Up @@ -2311,7 +2319,7 @@ void DepthToSpaceOpForward(const nnvm::NodeAttrs& attrs,
const TBlob& out_data = outputs[0];
const DepthToSpaceParam& param = nnvm::get<DepthToSpaceParam>(attrs.parsed);
using namespace mxnet_op;
int block = param.blockSize;
int block = param.block_size;

mshadow::Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(mshadow::Shape1(sizeof(int32_t) * 10), s);
Expand Down Expand Up @@ -2343,7 +2351,8 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs,
TShape expected_out(in_attrs->at(0).ndim());

TShape& in_shape = in_attrs->at(0);
int block = param.blockSize;
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[0], 0)
<< "Operation requires a 4D tensor. Size of dimension:0 cannot be 0";
CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0";
Expand Down Expand Up @@ -2397,19 +2406,19 @@ struct space_to_depth_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const int block,
const int* size, const int* offset_arr) {
int inp_index = 0, idx = i, next_idx_val, dim_size;
int inp_index = 0, idx = i, dim_size;
dim_size = size[3] / block;
UPDATE_INDEX_USING_OFFSET(4)
update_index(4, dim_size, &idx, &inp_index, offset_arr);
dim_size = size[2] / block;
UPDATE_INDEX_USING_OFFSET(2)
update_index(2, dim_size, &idx, &inp_index, offset_arr);
dim_size = size[1];
UPDATE_INDEX_USING_OFFSET(1)
update_index(1, dim_size, &idx, &inp_index, offset_arr);
dim_size = block;
UPDATE_INDEX_USING_OFFSET(5)
update_index(5, dim_size, &idx, &inp_index, offset_arr);
dim_size = block;
UPDATE_INDEX_USING_OFFSET(3)
update_index(3, dim_size, &idx, &inp_index, offset_arr);
dim_size = size[0];
UPDATE_INDEX_USING_OFFSET(0)
update_index(0, dim_size, &idx, &inp_index, offset_arr);
KERNEL_ASSIGN(out_data[i], req, in_data[inp_index]);
}
};
Expand Down Expand Up @@ -2465,7 +2474,7 @@ void SpaceToDepthOpForward(const nnvm::NodeAttrs& attrs,
const TBlob& out_data = outputs[0];
const DepthToSpaceParam& param = nnvm::get<DepthToSpaceParam>(attrs.parsed);
using namespace mxnet_op;
int block = param.blockSize;
int block = param.block_size;

mshadow::Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(mshadow::Shape1(sizeof(int32_t) * 10), s);
Expand Down
28 changes: 21 additions & 7 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,13 +910,20 @@ NNVM_REGISTER_OP(_backward_squeeze)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>);

NNVM_REGISTER_OP(depth_to_space)
.describe(R"code(This operators implements the depthToSpace function:
.describe(R"code(Similar to ONNX DepthToSpace operator:
https://github.com/onnx/onnx/blob/master/docs/Operators.md#DepthToSpace.
Rearranges(permutes) data from depth into blocks of spatial data.
The output is a new tensor where the values from depth dimension are moved in spatial blocks
to height and width dimension. The reverse of this operation is ``space_to_depth``.
.. math::
f(x, block) = \tilde{x}
x \prime = reshape(x, [N, block_size, block_size, C / (block_size ^ 2), H * block_size, W * block_size]),
x \prime \prime = transpose(x \prime, [0, 3, 4, 1, 5, 2])
y = reshape(x \prime \prime, [N, C / (block ^ 2), H * block_size, W * block_size]\)
where :math:`x` is an input tensor of shape [N,C,H,W] and :math:`\tilde{x}` is the output tensor of shape :math:`[N, C/(block^2), H*block, W*block]`
where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width]
and :math:`y` is the output tensor of layout :math:`[N, C / (block_size ^ 2), H * block_size, W * block_size]`
Example::
Expand All @@ -928,7 +935,7 @@ Example::
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23]]]]
depth_to_space(x, 2) = [[[[0, 6, 1, 7, 2, 8],
[12, 18, 13, 19, 14, 20],
[3, 9, 4, 10, 5, 11],
Expand All @@ -953,13 +960,20 @@ Example::
.add_arguments(DepthToSpaceParam::__FIELDS__());

NNVM_REGISTER_OP(space_to_depth)
.describe(R"code(This operators implements the spacetodepth function:
.describe(R"code(Similar to ONNX SpaceToDepth operator:
https://github.com/onnx/onnx/blob/master/docs/Operators.md#SpaceToDepth
Rearranges(permutes) blocks of spatial data into depth.
The output is a new tensor where the values from height and width dimension are
moved to the depth dimension. The reverse of this operation is ``depth_to_space``.
.. math::
f(x, blockSize) = \tilde{x}
x \prime = reshape(x, [N, C, H / block_size, block_size, W / block_size, block_size]),
x \prime \prime = transpose(x \prime, [0, 3, 5, 1, 2, 4])
y = reshape(x \prime \prime, [N, C * (block ^ 2), H / block_size, W / block_size]\)
where :math:`x` is an input tensor of shape [N,C,H,W] and :math:`\tilde{x}` is the output tensor of shape :math:`[N, C*(block^2), H/block, W/block]`
where :math:`x` is an input tensor with default layout as :math:`[N, C, H, W]`: [batch, channels, height, width]
and :math:`y` is the output tensor of layout :math:`[N, C * (block ^ 2), H / block, W / block]`
Example::
Expand Down
40 changes: 38 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6696,7 +6696,7 @@ def f(x, blocksize):
data_np = data.asnumpy()
expected = f(data_np, block)
output = mx.nd.depth_to_space(data, block)
assert_almost_equal(output.asnumpy(), expected)
assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)

shape_out = (1,1,4,6)
data = mx.sym.Variable('data')
Expand All @@ -6706,6 +6706,24 @@ def f(x, blocksize):
check_symbolic_forward(dts_sym, [data_np], [expected])
check_symbolic_backward(dts_sym, [data_np], [np.ones(shape_out)], [np.ones(shape_inp)])

def test_invalid_depth_dim():
invalid_shape_inp = (1,3,2,3)
block = 2
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.depth_to_space, data, block)

def test_invalid_space_dim():
invalid_shape_inp = (1,4,2,3)
block = 2
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.depth_to_space, data, block)

def test_invalid_block_size():
invalid_shape_inp = (1,0,2,3)
block = 2
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.depth_to_space, data, block)

@with_seed()
def test_spacetodepth():
def f(x, blocksize):
Expand All @@ -6721,7 +6739,7 @@ def f(x, blocksize):
data_np = data.asnumpy()
expected = f(data_np, block)
output = mx.nd.space_to_depth(data, block)
assert_almost_equal(output.asnumpy(), expected)
assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)

shape_out = (1,4,2,3)
data = mx.sym.Variable('data')
Expand All @@ -6731,6 +6749,24 @@ def f(x, blocksize):
check_symbolic_forward(dts_sym, [data_np], [expected])
check_symbolic_backward(dts_sym, [data_np], [np.ones(shape_out)], [np.ones(shape_inp)])

def test_invalid_space_dim():
invalid_shape_inp = (1,1,2,3)
block = 2
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.space_to_depth, data, block)

def test_invalid_block_size():
invalid_shape_inp = (1,1,4,2)
block = 0
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.space_to_depth, data, block)

def test_invalid_depth_dim():
invalid_shape_inp = (1,0,4,2)
block = 2
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.space_to_depth, data, block)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit ce29f04

Please sign in to comment.