Skip to content

Commit

Permalink
add test to cover crop end or right is zero or not zero
Browse files Browse the repository at this point in the history
  • Loading branch information
JiayingGaoo committed Jul 2, 2019
1 parent 30de251 commit 2fa6a5c
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2444,9 +2444,31 @@ def test_batch_to_spacend(self):
self._run_test_case([_OUTPUT], {_INPUT: input_val})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
def test_batch_to_spacend_with_dynamic_crop(self):
def test_batch_to_spacend_with_dynamic_crop_top_is_zero(self):
block_size = [2, 2]
crops_val = np.array([[0, 1], [2, 1]], dtype=np.int32)
crop_value = np.array([[0, 1], [2, 1]], dtype=np.int32)
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
crop = tf.placeholder(dtype=tf.int32, shape=[2, 2], name=_TFINPUT1)
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crop_value})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
def test_batch_to_spacend_with_dynamic_crop_right_is_zero(self):
block_size = [2, 2]
crops_val = np.array([[0, 1], [2, 0]], dtype=np.int32)
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
crops = tf.placeholder(dtype=tf.int32, shape=[2, 2], name=_TFINPUT1)
_ = tf.batch_to_space_nd(input_x, block_size, crops, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], feed_dict={_INPUT: input_val, _INPUT1: crops_val})

@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'starts' and 'ends'")
def test_batch_to_spacend_with_dynamic_crop_ends_are_zero(self):
block_size = [2, 2]
crops_val = np.array([[1, 0], [2, 0]], dtype=np.int32)
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)

input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT)
Expand Down

0 comments on commit 2fa6a5c

Please sign in to comment.