diff --git a/tests/cyclops/evaluate/metrics/experimental/inputs.py b/tests/cyclops/evaluate/metrics/experimental/inputs.py index 92af7b9e6..d38d1d852 100644 --- a/tests/cyclops/evaluate/metrics/experimental/inputs.py +++ b/tests/cyclops/evaluate/metrics/experimental/inputs.py @@ -1,4 +1,5 @@ """Input data for tests of metrics in cyclops/evaluate/metrics/experimental.""" + import random from collections import namedtuple from types import ModuleType @@ -296,43 +297,46 @@ def _multilabel_cases(*, xp: Any): return ( pytest.param( InputSpec( - target=xp.asarray(_multilabel_labels), - preds=xp.asarray(_multilabel_preds), + target=xp.asarray(_multilabel_labels, dtype=xp.int32), + preds=xp.asarray(_multilabel_preds, dtype=xp.int32), ), id="input[2d-labels]", ), pytest.param( InputSpec( - target=xp.asarray(_multilabel_labels_multidim), - preds=xp.asarray(_multilabel_preds_multidim), + target=xp.asarray(_multilabel_labels_multidim, dtype=xp.int32), + preds=xp.asarray(_multilabel_preds_multidim, dtype=xp.int32), ), id="input[multidim-labels]", ), pytest.param( InputSpec( - target=xp.asarray(_multilabel_labels), - preds=xp.asarray(_multilabel_probs), + target=xp.asarray(_multilabel_labels, dtype=xp.int32), + preds=xp.asarray(_multilabel_probs, dtype=xp.float32), ), id="input[2d-probs]", ), pytest.param( InputSpec( - target=xp.asarray(_multilabel_labels), - preds=xp.asarray(_inv_sigmoid(_multilabel_probs)), + target=xp.asarray(_multilabel_labels, dtype=xp.int32), + preds=xp.asarray(_inv_sigmoid(_multilabel_probs), dtype=xp.float32), ), id="input[2d-logits]", ), pytest.param( InputSpec( - target=xp.asarray(_multilabel_labels_multidim), - preds=xp.asarray(_multilabel_probs_multidim), + target=xp.asarray(_multilabel_labels_multidim, dtype=xp.int32), + preds=xp.asarray(_multilabel_probs_multidim, dtype=xp.float32), ), id="input[multidim-probs]", ), pytest.param( InputSpec( - target=xp.asarray(_multilabel_labels_multidim), - preds=xp.asarray(_inv_sigmoid(_multilabel_probs_multidim)), + target=xp.asarray(_multilabel_labels_multidim, dtype=xp.int32), + preds=xp.asarray( + _inv_sigmoid(_multilabel_probs_multidim), + dtype=xp.float32, + ), ), id="input[multidim-logits]", ), diff --git a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py index 4dc5989fd..081ebd1e9 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py @@ -1,4 +1,5 @@ """Test precision-recall curve metric.""" + from functools import partial from typing import List, Tuple, Union @@ -45,9 +46,11 @@ def _binary_precision_recall_curve_reference( return tm_binary_precision_recall_curve( torch.utils.dlpack.from_dlpack(preds), torch.utils.dlpack.from_dlpack(target), - thresholds=torch.utils.dlpack.from_dlpack(thresholds) - if apc.is_array_api_obj(thresholds) - else thresholds, + thresholds=( + torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds + ), ignore_index=ignore_index, ) @@ -215,9 +218,11 @@ def _multiclass_precision_recall_curve_reference( torch.utils.dlpack.from_dlpack(preds), torch.utils.dlpack.from_dlpack(target), num_classes, - thresholds=torch.utils.dlpack.from_dlpack(thresholds) - if apc.is_array_api_obj(thresholds) - else thresholds, + thresholds=( + torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds + ), ignore_index=ignore_index, ) @@ -371,9 +376,11 @@ def _multilabel_precision_recall_curve_reference( torch.utils.dlpack.from_dlpack(preds), torch.utils.dlpack.from_dlpack(target), num_labels, - thresholds=torch.utils.dlpack.from_dlpack(thresholds) - if apc.is_array_api_obj(thresholds) - else thresholds, + thresholds=( + torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds + ), ignore_index=ignore_index, ) @@ -381,6 +388,8 @@ def _multilabel_precision_recall_curve_reference( class TestMultilabelPrecisionRecallCurve(MetricTester): """Test multilabel precision-recall curve function and class.""" + atol: float = 2e-7 + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_roc.py b/tests/cyclops/evaluate/metrics/experimental/test_roc.py index ddc4f9556..17a4fff5a 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_roc.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_roc.py @@ -1,4 +1,5 @@ """Test roc curve metric.""" + from functools import partial from typing import List, Tuple, Union @@ -45,9 +46,11 @@ def _binary_roc_reference( return tm_binary_roc( torch.utils.dlpack.from_dlpack(preds), torch.utils.dlpack.from_dlpack(target), - thresholds=torch.utils.dlpack.from_dlpack(thresholds) - if apc.is_array_api_obj(thresholds) - else thresholds, + thresholds=( + torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds + ), ignore_index=ignore_index, ) @@ -215,9 +218,11 @@ def _multiclass_roc_reference( torch.utils.dlpack.from_dlpack(preds), torch.utils.dlpack.from_dlpack(target), num_classes, - thresholds=torch.utils.dlpack.from_dlpack(thresholds) - if apc.is_array_api_obj(thresholds) - else thresholds, + thresholds=( + torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds + ), ignore_index=ignore_index, ) @@ -371,9 +376,11 @@ def _multilabel_roc_reference( torch.utils.dlpack.from_dlpack(preds), torch.utils.dlpack.from_dlpack(target), num_labels, - thresholds=torch.utils.dlpack.from_dlpack(thresholds) - if apc.is_array_api_obj(thresholds) - else thresholds, + thresholds=( + torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds + ), ignore_index=ignore_index, ) @@ -381,6 +388,8 @@ def _multilabel_roc_reference( class TestMultilabelROC(MetricTester): """Test multilabel roc curve function and class.""" + atol: float = 9e-8 + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1])