Skip to content

Commit

Permalink
add more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ttxskk committed Feb 11, 2022
1 parent e94478a commit 119e76b
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions tests/test_evaluation/test_eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@


def tets_accel_error():
target = np.random.rand(2, 5, 3)
target = np.random.rand(10, 5, 3)
output = np.copy(target)
mask = np.ones((output.shape[0], output.shape[1]), dtype=bool)
_ = keypoint_accel_error(output, target, mask)
_ = keypoint_accel_error(output, target)
mask = np.ones((output.shape[0]), dtype=bool)

error = keypoint_accel_error(output, target, mask)
np.testing.assert_almost_equal(error, 0)

error = keypoint_accel_error(output, target)
np.testing.assert_almost_equal(error, 0)


def tets_keypoinyt_mpjpe():
Expand All @@ -25,20 +29,32 @@ def tets_keypoinyt_mpjpe():
with pytest.raises(ValueError):
_ = keypoint_mpjpe(output, target, mask, alignment='norm')

_ = keypoint_mpjpe(output, target, mask, alignment='none')
error = keypoint_mpjpe(output, target, mask, alignment='none')
np.testing.assert_almost_equal(error, 0)

error = keypoint_mpjpe(output, target, mask, alignment='scale')
np.testing.assert_almost_equal(error, 0)

_ = keypoint_mpjpe(output, target, mask, alignment='none')
error = keypoint_mpjpe(output, target, mask, alignment='procrustes')
np.testing.assert_almost_equal(error, 0)

_ = keypoint_mpjpe(output, target, mask, alignment='scale')
R = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
output = np.dot(target, R)
error = keypoint_mpjpe(output, target, mask, alignment='none')
assert error > 1e-10

_ = keypoint_mpjpe(output, target, mask, alignment='procrustes')
R = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
output = np.dot(target, R)
error = keypoint_mpjpe(output, target, mask, alignment='procrustes')
np.testing.assert_almost_equal(error, 0)


def tets_keypoinyt_pve():
target = np.random.rand(2, 6890, 3)
output = np.copy(target)

_ = vertice_pve(output, target)
error = vertice_pve(output, target)
np.testing.assert_almost_equal(error, 0)


def test_keypoint_3d_pck():
Expand Down

0 comments on commit 119e76b

Please sign in to comment.