diff --git a/tests/test_evaluation/test_eval_utils.py b/tests/test_evaluation/test_eval_utils.py index 75d93a7d..6f0dfbf6 100644 --- a/tests/test_evaluation/test_eval_utils.py +++ b/tests/test_evaluation/test_eval_utils.py @@ -11,11 +11,15 @@ def tets_accel_error(): - target = np.random.rand(2, 5, 3) + target = np.random.rand(10, 5, 3) output = np.copy(target) - mask = np.ones((output.shape[0], output.shape[1]), dtype=bool) - _ = keypoint_accel_error(output, target, mask) - _ = keypoint_accel_error(output, target) + mask = np.ones((output.shape[0]), dtype=bool) + + error = keypoint_accel_error(output, target, mask) + np.testing.assert_almost_equal(error, 0) + + error = keypoint_accel_error(output, target) + np.testing.assert_almost_equal(error, 0) def tets_keypoinyt_mpjpe(): @@ -25,20 +29,32 @@ def tets_keypoinyt_mpjpe(): with pytest.raises(ValueError): _ = keypoint_mpjpe(output, target, mask, alignment='norm') - _ = keypoint_mpjpe(output, target, mask, alignment='none') + error = keypoint_mpjpe(output, target, mask, alignment='none') + np.testing.assert_almost_equal(error, 0) + + error = keypoint_mpjpe(output, target, mask, alignment='scale') + np.testing.assert_almost_equal(error, 0) - _ = keypoint_mpjpe(output, target, mask, alignment='none') + error = keypoint_mpjpe(output, target, mask, alignment='procrustes') + np.testing.assert_almost_equal(error, 0) - _ = keypoint_mpjpe(output, target, mask, alignment='scale') + R = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + output = np.dot(target, R) + error = keypoint_mpjpe(output, target, mask, alignment='none') + assert error > 1e-10 - _ = keypoint_mpjpe(output, target, mask, alignment='procrustes') + R = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + output = np.dot(target, R) + error = keypoint_mpjpe(output, target, mask, alignment='procrustes') + np.testing.assert_almost_equal(error, 0) def tets_keypoinyt_pve(): target = np.random.rand(2, 6890, 3) output = np.copy(target) - _ = vertice_pve(output, target) + error = vertice_pve(output, target) + np.testing.assert_almost_equal(error, 0) def test_keypoint_3d_pck():