Skip to content

Commit

Permalink
Making BatchToSpaceND support dynamic crops
Browse files Browse the repository at this point in the history
add test to cover crop end or right is zero or not zero

(cherry picked from commit 2fa6a5c)

add test to cover int64

(cherry picked from commit e25ca13)

remove useless empty lines and commands

(cherry picked from commit fe82719)

add rs6 for int64 test

(cherry picked from commit 27b193c)
  • Loading branch information
JiayingGaoo committed Jul 2, 2019
1 parent 7fbab3e commit 663241d
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 23 deletions.
45 changes: 45 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,6 +2443,51 @@ def test_batch_to_spacend(self):
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: input_val})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
def test_batch_to_spacend_with_dynamic_crop_top_is_zero(self):
block_size = [2, 2]
crop_value = np.array([[0, 1], [2, 1]], dtype=np.int32)
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
crop = tf.placeholder(dtype=tf.int32, shape=[2, 2], name=_TFINPUT1)
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crop_value})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
def test_batch_to_spacend_with_dynamic_crop_right_is_zero(self):
block_size = [2, 2]
crops_val = np.array([[0, 1], [2, 0]], dtype=np.int32)
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
crops = tf.placeholder(dtype=tf.int32, shape=[2, 2], name=_TFINPUT1)
_ = tf.batch_to_space_nd(input_x, block_size, crops, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crops_val})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
def test_batch_to_spacend_with_dynamic_crop_ends_are_zero(self):
block_size = [2, 2]
crops_val = np.array([[1, 0], [2, 0]], dtype=np.int32)
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
crops = tf.placeholder(dtype=tf.int32, shape=[2, 2], name=_TFINPUT1)
_ = tf.batch_to_space_nd(input_x, block_size, crops, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crops_val})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
@check_target('rs6', 'batch_to_space_nd')
def test_batch_to_spacend_with_dynamic_crop_for_int64(self):
block_size = [2, 2]
crops_val = np.array([[1, 0], [2, 0]], dtype=np.int64)
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
crops = tf.placeholder(dtype=tf.int64, shape=[2, 2], name=_TFINPUT1)
_ = tf.batch_to_space_nd(input_x, block_size, crops, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crops_val})

def test_batch_to_space3d(self):
block_size = [2, 2]
crop = [[0, 1], [2, 1]]
Expand Down
88 changes: 65 additions & 23 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,12 +1080,13 @@ def version_1(cls, ctx, node, **kwargs):
input_tensor = node.inputs[0]
input_shape = ctx.get_shape(input_tensor.output[0])
blocksize = node.inputs[1].get_tensor_value()
crops = node.inputs[2].get_tensor_value()

utils.make_sure(len(input_shape) in (4, 3),
"only supports 3D and 4D for now")
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
"only support same blocksize at different dims")
utils.make_sure(node.inputs[2].output_shapes == [[2, 2]],
"only support the crops with shape [[2,2]]")

# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
if len(input_shape) == 3:
Expand All @@ -1097,36 +1098,77 @@ def version_1(cls, ctx, node, **kwargs):
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})

# implement crop logic, the data format is NHWC
slice_axis = [1, 2]
top, bottom = crops[0]
left, right = crops[1]
starts = [top, left]
ends = []
for end in [bottom, right]:
if end != 0:
ends.append(-end)
else:
ends.append(np.iinfo(np.int32).max)

attr = {"axes": slice_axis, "ends": ends, "starts": starts}
inputs_map = {"data": trans2.output[0], **attr}
dtypes = node.output_dtypes
shapes = node.output_shapes

if len(input_shape) == 3:
# add a squeeze op to convert output into 3d
kwargs = {**inputs_map}
ctx.remove_node(node.name)
slice1 = GraphBuilder(ctx).make_slice(kwargs)
ctx.make_node("Squeeze", [slice1], {"axes": [3]},
outputs=node.output, name=node.name, dtypes=dtypes, shapes=shapes)
if ctx.get_node_by_output(node.input[2]).is_const():
crops = node.inputs[2].get_tensor_value()
# implement crop logic, the data format is NHWC
slice_axis = [1, 2]
top, bottom = crops[0]
left, right = crops[1]
starts = [top, left]
ends = []
for end in [bottom, right]:
if end != 0:
ends.append(-end)
else:
ends.append(np.iinfo(np.int32).max)

attr = {"axes": slice_axis, "ends": ends, "starts": starts}
inputs_map = {"data": trans2.output[0], **attr}

if len(input_shape) == 3:
# add a squeeze op to convert output into 3d
kwargs = {**inputs_map}
ctx.remove_node(node.name)
slice1 = GraphBuilder(ctx).make_slice(kwargs)
ctx.make_node("Squeeze", [slice1], {"axes": [3]},
outputs=node.output, name=node.name, dtypes=dtypes, shapes=shapes)
else:
kwargs = {**inputs_map, "outputs": node.output}
ctx.remove_node(node.name)
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)

else:
crops_type = ctx.get_dtype(node.input[2])
crops_np_type = utils.map_onnx_to_numpy_type(crops_type)
# Make the node crops being transposed and respectively obtained its two rows by split op
crops_trans = ctx.make_node("Transpose", node.inputs[2].output, {"perm": [1, 0]})
trans_split = ctx.make_node("Split", [crops_trans.output[0]], output_count=2, attr={"axis": 0})
starts = ctx.make_node("Squeeze", [trans_split.output[0]], attr={"axes": [0]})
# Create the original node for ends
ends_ori = ctx.make_node("Squeeze", [trans_split.output[1]], attr={"axes": [0]})
# create zero const node
zero_value = np.array([0]).astype(crops_np_type)
zero_const = ctx.make_const(utils.make_name("Const"), zero_value)
# Obtain a node by comparing the ends_ori with zero
ends_ori_equal_zero = ctx.make_node("Equal", [ends_ori.output[0], zero_const.output[0]])
ends_ori_equal_zero = ctx.make_node("Cast", [ends_ori_equal_zero.output[0]], attr={"to": crops_type})
# Create int max const node
int_max_value = np.array([utils.get_max_value(crops_np_type)]).astype(crops_np_type)
int_max_const = ctx.make_const(utils.make_name("largest_int_val"), int_max_value)

# Then, ends_ori_equal_zero are copied into two parts, i.e. ends_for_zero and ends_for_nonzero
# (1). Create ends_for_zero node
ends_for_zero = ctx.make_node("Mul", [ends_ori_equal_zero.output[0], int_max_const.output[0]])
# (2). Create ends_for_nonzero node
neg_one_value = np.array([-1]).astype(crops_np_type)
neg_one_const = ctx.make_const(utils.make_name("const"), neg_one_value)
ends_ori_equal_zero_inv = ctx.make_node("Add", [ends_ori_equal_zero.output[0], neg_one_const.output[0]])
ends_for_nonzero = ctx.make_node("Mul", [ends_ori_equal_zero_inv.output[0], ends_ori.output[0]])
# Add (1) and (2) to obtain ends
ends = ctx.make_node("Add", [ends_for_zero.output[0], ends_for_nonzero.output[0]])
# Create slice_axis const node
slice_axis_value = np.array([1, 2]).astype(crops_np_type)
slice_axis_const = ctx.make_const(utils.make_name("Const"), slice_axis_value)

attr = {"axes": slice_axis_const.output[0], "ends": ends.output[0], "starts": starts.output[0]}
inputs_map = {"data": trans2.output[0], **attr}
kwargs = {**inputs_map, "outputs": node.output}
ctx.remove_node(node.name)
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)


@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")
class SpaceToBatch:
@classmethod
Expand Down

0 comments on commit 663241d

Please sign in to comment.