Skip to content

Commit

Permalink
add test to cover int64
Browse files Browse the repository at this point in the history
  • Loading branch information
JiayingGaoo committed Jul 2, 2019
1 parent 0673f31 commit e25ca13
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
11 changes: 11 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2476,6 +2476,17 @@ def test_batch_to_spacend_with_dynamic_crop_ends_are_zero(self):
_ = tf.batch_to_space_nd(input_x, block_size, crops, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crops_val})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
def test_batch_to_spacend_with_dynamic_crop_for_int64(self):
block_size = [2, 2]
crops_val = np.array([[1, 0], [2, 0]], dtype=np.int64)
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
crops = tf.placeholder(dtype=tf.int64, shape=[2, 2], name=_TFINPUT1)
_ = tf.batch_to_space_nd(input_x, block_size, crops, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crops_val})

def test_batch_to_space3d(self):
block_size = [2, 2]
crop = [[0, 1], [2, 1]]
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def version_1(cls, ctx, node, **kwargs):
ends = ctx.make_node("Add", [ends_for_zero.output[0], ends_for_nonzero.output[0]])

# Create slice_axis const node
slice_axis_value = np.array([1, 2]).astype(np.int32)
slice_axis_value = np.array([1, 2]).astype(crops_np_type)
slice_axis_const = ctx.make_const(utils.make_name("Const"), slice_axis_value)

attr = {"axes": slice_axis_const.output[0], "ends": ends.output[0], "starts": starts.output[0]}
Expand Down

0 comments on commit e25ca13

Please sign in to comment.