From 8c7c6ad8197e7a9ced234fae325486cc8a73922f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Nov 2022 14:42:50 +0100 Subject: [PATCH] replace assert torch.allclose with torch.testing.assert_allclose --- test/test_architecture_ops.py | 4 ++-- test/test_ops.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_architecture_ops.py b/test/test_architecture_ops.py index 9f254c7942b..32ad1a32f89 100644 --- a/test/test_architecture_ops.py +++ b/test/test_architecture_ops.py @@ -20,7 +20,7 @@ def test_maxvit_window_partition(self): x_hat = partition(x, partition_size) x_hat = departition(x_hat, partition_size, n_partitions, n_partitions) - assert torch.allclose(x, x_hat) + torch.testing.assert_close(x, x_hat) def test_maxvit_grid_partition(self): input_shape = (1, 3, 224, 224) @@ -39,7 +39,7 @@ def test_maxvit_grid_partition(self): x_hat = post_swap(x_hat) x_hat = departition(x_hat, n_partitions, partition_size, partition_size) - assert torch.allclose(x, x_hat) + torch.testing.assert_close(x, x_hat) if __name__ == "__main__": diff --git a/test/test_ops.py b/test/test_ops.py index d76e57faecf..99b58bb93a7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -630,7 +630,7 @@ def test_nms_ref(self, iou, seed): boxes, scores = self._create_tensors_with_iou(1000, iou) keep_ref = self._reference_nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou) - assert torch.allclose(keep, keep_ref), err_msg.format(iou) + torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou)) def test_nms_input_errors(self): with pytest.raises(RuntimeError): @@ -661,7 +661,7 @@ def test_qnms(self, iou, scale, zero_point): keep = ops.nms(boxes, scores, iou) qkeep = ops.nms(qboxes, qscores, iou) - assert torch.allclose(qkeep, keep), err_msg.format(iou) + torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou)) @needs_cuda @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @@ -1237,7 +1237,7 @@ def _run_cartesian_test(target_fn: Callable): boxes2 = gen_box(7) a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn) b = target_fn(boxes1, boxes2) - assert torch.allclose(a, b) + torch.testing.assert_close(a, b) class TestBoxIou(TestIouBase):