Skip to content

Commit

Permalink
set dtype for multilabel test inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Feb 7, 2024
1 parent b6651f4 commit 2834deb
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 30 deletions.
28 changes: 16 additions & 12 deletions tests/cyclops/evaluate/metrics/experimental/inputs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Input data for tests of metrics in cyclops/evaluate/metrics/experimental."""

import random
from collections import namedtuple
from types import ModuleType
Expand Down Expand Up @@ -296,43 +297,46 @@ def _multilabel_cases(*, xp: Any):
return (
pytest.param(
InputSpec(
target=xp.asarray(_multilabel_labels),
preds=xp.asarray(_multilabel_preds),
target=xp.asarray(_multilabel_labels, dtype=xp.int32),
preds=xp.asarray(_multilabel_preds, dtype=xp.int32),
),
id="input[2d-labels]",
),
pytest.param(
InputSpec(
target=xp.asarray(_multilabel_labels_multidim),
preds=xp.asarray(_multilabel_preds_multidim),
target=xp.asarray(_multilabel_labels_multidim, dtype=xp.int32),
preds=xp.asarray(_multilabel_preds_multidim, dtype=xp.int32),
),
id="input[multidim-labels]",
),
pytest.param(
InputSpec(
target=xp.asarray(_multilabel_labels),
preds=xp.asarray(_multilabel_probs),
target=xp.asarray(_multilabel_labels, dtype=xp.int32),
preds=xp.asarray(_multilabel_probs, dtype=xp.float32),
),
id="input[2d-probs]",
),
pytest.param(
InputSpec(
target=xp.asarray(_multilabel_labels),
preds=xp.asarray(_inv_sigmoid(_multilabel_probs)),
target=xp.asarray(_multilabel_labels, dtype=xp.int32),
preds=xp.asarray(_inv_sigmoid(_multilabel_probs), dtype=xp.float32),
),
id="input[2d-logits]",
),
pytest.param(
InputSpec(
target=xp.asarray(_multilabel_labels_multidim),
preds=xp.asarray(_multilabel_probs_multidim),
target=xp.asarray(_multilabel_labels_multidim, dtype=xp.int32),
preds=xp.asarray(_multilabel_probs_multidim, dtype=xp.float32),
),
id="input[multidim-probs]",
),
pytest.param(
InputSpec(
target=xp.asarray(_multilabel_labels_multidim),
preds=xp.asarray(_inv_sigmoid(_multilabel_probs_multidim)),
target=xp.asarray(_multilabel_labels_multidim, dtype=xp.int32),
preds=xp.asarray(
_inv_sigmoid(_multilabel_probs_multidim),
dtype=xp.float32,
),
),
id="input[multidim-logits]",
),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test precision-recall curve metric."""

from functools import partial
from typing import List, Tuple, Union

Expand Down Expand Up @@ -45,9 +46,11 @@ def _binary_precision_recall_curve_reference(
return tm_binary_precision_recall_curve(
torch.utils.dlpack.from_dlpack(preds),
torch.utils.dlpack.from_dlpack(target),
thresholds=torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds,
thresholds=(
torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds
),
ignore_index=ignore_index,
)

Expand Down Expand Up @@ -215,9 +218,11 @@ def _multiclass_precision_recall_curve_reference(
torch.utils.dlpack.from_dlpack(preds),
torch.utils.dlpack.from_dlpack(target),
num_classes,
thresholds=torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds,
thresholds=(
torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds
),
ignore_index=ignore_index,
)

Expand Down Expand Up @@ -371,16 +376,20 @@ def _multilabel_precision_recall_curve_reference(
torch.utils.dlpack.from_dlpack(preds),
torch.utils.dlpack.from_dlpack(target),
num_labels,
thresholds=torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds,
thresholds=(
torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds
),
ignore_index=ignore_index,
)


class TestMultilabelPrecisionRecallCurve(MetricTester):
"""Test multilabel precision-recall curve function and class."""

atol: float = 2e-7

@pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:])
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
Expand Down
27 changes: 18 additions & 9 deletions tests/cyclops/evaluate/metrics/experimental/test_roc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test roc curve metric."""

from functools import partial
from typing import List, Tuple, Union

Expand Down Expand Up @@ -45,9 +46,11 @@ def _binary_roc_reference(
return tm_binary_roc(
torch.utils.dlpack.from_dlpack(preds),
torch.utils.dlpack.from_dlpack(target),
thresholds=torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds,
thresholds=(
torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds
),
ignore_index=ignore_index,
)

Expand Down Expand Up @@ -215,9 +218,11 @@ def _multiclass_roc_reference(
torch.utils.dlpack.from_dlpack(preds),
torch.utils.dlpack.from_dlpack(target),
num_classes,
thresholds=torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds,
thresholds=(
torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds
),
ignore_index=ignore_index,
)

Expand Down Expand Up @@ -371,16 +376,20 @@ def _multilabel_roc_reference(
torch.utils.dlpack.from_dlpack(preds),
torch.utils.dlpack.from_dlpack(target),
num_labels,
thresholds=torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds,
thresholds=(
torch.utils.dlpack.from_dlpack(thresholds)
if apc.is_array_api_obj(thresholds)
else thresholds
),
ignore_index=ignore_index,
)


class TestMultilabelROC(MetricTester):
"""Test multilabel roc curve function and class."""

atol: float = 9e-8

@pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:])
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
Expand Down

0 comments on commit 2834deb

Please sign in to comment.