Skip to content

Commit

Permalink
Making BatchToSpaceND support dynamic crops
Browse files Browse the repository at this point in the history
coding improvements

(cherry picked from commit 612bb7ea56636c4ee7328acf583cc05b804d141e)

temp: improve codes in tensor.py

temp: delete test_Gao.py
  • Loading branch information
JiayingGaoo committed Jul 10, 2019
1 parent 4aa7ecc commit bff38ea
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 46 deletions.
39 changes: 35 additions & 4 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,6 +2439,33 @@ def test_batch_to_spacend(self):
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: input_val})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
def test_batch_to_spacend_with_dynamic_crop(self):
block_size = [2, 2]
crop_value_1 = np.array([[0, 1], [2, 1]], dtype=np.int32) # top is zero
crop_value_2 = np.array([[0, 1], [2, 0]], dtype=np.int32) # right is zero
crop_value_3 = np.array([[1, 0], [2, 0]], dtype=np.int32) # ends are zero
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
crop = tf.placeholder(dtype=tf.int32, shape=[2, 2], name=_TFINPUT1)
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)

for crop_value in [crop_value_1, crop_value_2, crop_value_3]:
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crop_value})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
@check_target('rs6', 'batch_to_space_nd')
def test_batch_to_spacend_with_dynamic_crop_for_int64(self):
block_size = [2, 2]
crops_val = np.array([[1, 0], [2, 0]], dtype=np.int64)
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
crops = tf.placeholder(dtype=tf.int64, shape=[2, 2], name=_TFINPUT1)
_ = tf.batch_to_space_nd(input_x, block_size, crops, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crops_val})

def test_batch_to_space3d(self):
block_size = [2, 2]
crop = [[0, 1], [2, 1]]
Expand Down Expand Up @@ -2585,7 +2612,8 @@ def test_gemm_pattern0(self):
mul2 = tf.multiply(beta, c)
x_ = mul1 + mul2
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3}, graph_validator=lambda g: check_op_count(g, "Gemm", 1))
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
graph_validator=lambda g: check_op_count(g, "Gemm", 1))

# test for gemm pattern1: alpha*A*B + C
def test_gemm_pattern1(self):
Expand All @@ -2602,7 +2630,8 @@ def test_gemm_pattern1(self):
alpha = tf.constant(1.0)
x_ = tf.multiply(alpha, tf.matmul(a, b)) + c
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3}, graph_validator=lambda g: check_op_count(g, "Gemm", 1))
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
graph_validator=lambda g: check_op_count(g, "Gemm", 1))

# test for gemm pattern2: A*B + beta*C
def test_gemm_pattern2(self):
Expand All @@ -2619,7 +2648,8 @@ def test_gemm_pattern2(self):
beta = tf.constant(2.0)
x_ = tf.matmul(a, b) + tf.multiply(beta, c)
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3}, graph_validator=lambda g: check_op_count(g, "Gemm", 1))
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
graph_validator=lambda g: check_op_count(g, "Gemm", 1))

# test for gemm pattern3: A*B + C
def test_gemm_pattern3(self):
Expand All @@ -2635,7 +2665,8 @@ def test_gemm_pattern3(self):
c = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT2)
x_ = tf.matmul(a, b) + c
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3}, graph_validator=lambda g: check_op_count(g, "Gemm", 1))
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
graph_validator=lambda g: check_op_count(g, "Gemm", 1))

def test_graph_matcher(self):
shape = [2, 6]
Expand Down
89 changes: 64 additions & 25 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,12 +1080,13 @@ def version_1(cls, ctx, node, **kwargs):
input_tensor = node.inputs[0]
input_shape = ctx.get_shape(input_tensor.output[0])
blocksize = node.inputs[1].get_tensor_value()
crops = node.inputs[2].get_tensor_value()

utils.make_sure(len(input_shape) in (4, 3),
"only supports 3D and 4D for now")
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
"only support same blocksize at different dims")
utils.make_sure(node.inputs[2].output_shapes == [[2, 2]],
"only support the crops with shape [[2,2]]")

# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
if len(input_shape) == 3:
Expand All @@ -1097,34 +1098,72 @@ def version_1(cls, ctx, node, **kwargs):
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})

# implement crop logic, the data format is NHWC
slice_axis = [1, 2]
top, bottom = crops[0]
left, right = crops[1]
starts = [top, left]
ends = []
for end in [bottom, right]:
if end != 0:
ends.append(-end)
else:
ends.append(np.iinfo(np.int32).max)

attr = {"axes": slice_axis, "ends": ends, "starts": starts}
inputs_map = {"data": trans2.output[0], **attr}
dtypes = node.output_dtypes
shapes = node.output_shapes

if len(input_shape) == 3:
# add a squeeze op to convert output into 3d
kwargs = {**inputs_map}
ctx.remove_node(node.name)
slice1 = GraphBuilder(ctx).make_slice(kwargs)
ctx.make_node("Squeeze", [slice1], {"axes": [3]},
outputs=node.output, name=node.name, dtypes=dtypes, shapes=shapes)
if node.inputs[2].is_const():
crops = node.inputs[2].get_tensor_value()
# implement crop logic, the data format is NHWC
slice_axis = [1, 2]
top, bottom = crops[0]
left, right = crops[1]
starts = [top, left]
ends = []
for end in [bottom, right]:
if end != 0:
ends.append(-end)
else:
ends.append(np.iinfo(np.int32).max)

attr = {"axes": slice_axis, "ends": ends, "starts": starts}
inputs_map = {"data": trans2.output[0], **attr}

if len(input_shape) == 3:
# add a squeeze op to convert output into 3d
kwargs = {**inputs_map}
ctx.remove_node(node.name)
slice1 = GraphBuilder(ctx).make_slice(kwargs)
ctx.make_node("Squeeze", [slice1], {"axes": [3]},
outputs=node.output, name=node.name, dtypes=dtypes, shapes=shapes)
return

else:
kwargs = {**inputs_map, "outputs": node.output}
ctx.remove_node(node.name)
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)
# when node.inputs[2] is not const, we need to create nodes to get the attr for ends and starts.
# Make the node crops being transposed and respectively obtained its two rows by split op.
# Starts are fetched by the first row;
# Ends are obtained by the second row.
crops_type = ctx.get_dtype(node.input[2])
crops_np_type = utils.map_onnx_to_numpy_type(crops_type)

crops_trans = ctx.make_node("Transpose", [node.input[2]], {"perm": [1, 0]})
trans_split = ctx.make_node("Split", [crops_trans.output[0]], output_count=2, attr={"axis": 0})
starts = ctx.make_node("Squeeze", [trans_split.output[0]], attr={"axes": [0]})

ends_ori = ctx.make_node("Squeeze", [trans_split.output[1]], attr={"axes": [0]})
zero_value = np.array([0]).astype(crops_np_type)
zero_const = ctx.make_const(utils.make_name("Const"), zero_value)
ends_ori_equal_zero = ctx.make_node("Equal", [ends_ori.output[0], zero_const.output[0]])
ends_ori_equal_zero = ctx.make_node("Cast", [ends_ori_equal_zero.output[0]], attr={"to": crops_type})
int_max_value = np.array([utils.get_max_value(crops_np_type)])
int_max_const = ctx.make_const(utils.make_name("largest_int_val"), int_max_value)
ends_for_zero = ctx.make_node("Mul", [ends_ori_equal_zero.output[0], int_max_const.output[0]])

neg_one_value = np.array([-1]).astype(crops_np_type)
neg_one_const = ctx.make_const(utils.make_name("const"), neg_one_value)
ends_ori_equal_zero_inv = ctx.make_node("Add", [ends_ori_equal_zero.output[0], neg_one_const.output[0]])
ends_for_nonzero = ctx.make_node("Mul", [ends_ori_equal_zero_inv.output[0], ends_ori.output[0]])

ends = ctx.make_node("Add", [ends_for_zero.output[0], ends_for_nonzero.output[0]])

slice_axis_value = np.array([1, 2]).astype(crops_np_type)
slice_axis_const = ctx.make_const(utils.make_name("Const"), slice_axis_value)

attr = {"axes": slice_axis_const.output[0], "ends": ends.output[0], "starts": starts.output[0]}
inputs_map = {"data": trans2.output[0], **attr}

kwargs = {**inputs_map, "outputs": node.output}
ctx.remove_node(node.name)
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)


@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")
Expand Down
26 changes: 10 additions & 16 deletions tf2onnx/rewriter/gemm_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,16 @@
"""
tf2onnx.rewrite - rewrite tensorflow subgraph to onnx gemm op
"""

from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
from onnx import onnx_pb
import logging
from onnx import onnx_pb
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher

# pylint: disable=missing-docstring

def rewrite_gemm(g, ops):
if g.opset <= 6:
return ops

"""
4 Candidate patterns are listed as follow, i.e. pattern0, pattern1, pattern2 and pattern 3
Where, A,B and C represent the three inputs, alpha and beta represent the two attributes.
"""
# pattern0: alpha*A*B + beta*C
pattern0 = \
OpTypePattern('Add', name='add', inputs=[
Expand Down Expand Up @@ -64,7 +59,7 @@ def rewrite_gemm(g, ops):
for pattern_id, pattern in enumerate(pattern_list):
matcher = GraphMatcher(pattern, allow_reorder=True)
match_results = list(matcher.match_ops(ops))
if len(match_results) > 0:
if match_results:
for match in match_results:
add_node = match.get_op('add')
matmul_node = match.get_op("matmul")
Expand All @@ -73,37 +68,36 @@ def rewrite_gemm(g, ops):
if g.get_dtype(matmul_node.input[0]) != onnx_pb.TensorProto.FLOAT:
logging.warning(u"For now, onnxruntime only support float type for Gemm rewriter")
return ops
else:
a_edge_name = matmul_node.input[0]
b_edge_name = matmul_node.input[1]
c_edge_name = input_c_node.output[0]
a_edge_name = matmul_node.input[0]
b_edge_name = matmul_node.input[1]
c_edge_name = input_c_node.output[0]
attr = {}

# For each pattern, we must ensure that alpha and beta are both scalar, or return ops
if pattern_id == 0: # pattern 0: alpha*A*B + beta*C
alpha = match.get_op("alpha").get_tensor_value()
beta = match.get_op("beta").get_tensor_value()
if isinstance(alpha, float) or isinstance(alpha, int):
if isinstance(alpha, (float, int)):
alpha = float(alpha)
else:
return ops
if isinstance(beta, float) or isinstance(beta, int):
if isinstance(beta, (float, int)):
beta = float(beta)
else:
return ops
attr = {"alpha": alpha, "beta": beta}

if pattern_id == 1: # pattern1: alpha*A*B + C
alpha = match.get_op("alpha").get_tensor_value()
if isinstance(alpha, float) or isinstance(alpha, int):
if isinstance(alpha, (float, int)):
alpha = float(alpha)
else:
return ops
attr = {"alpha": alpha}

if pattern_id == 2: # pattern2: A*B + beta*C
beta = match.get_op("beta").get_tensor_value()
if isinstance(beta, float) or isinstance(beta, int):
if isinstance(beta, (float, int)):
beta = float(beta)
else:
return ops
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@

version = '1.6.0'
git_version = '7e40cd105db894df07a22dbf8494ff01c1d042a3'
git_version = '82f805f8fe7d2fa91e6ca9d39e153712f6887fec'

0 comments on commit bff38ea

Please sign in to comment.