Skip to content

Commit

Permalink
Fixed a small issue
Browse files Browse the repository at this point in the history
  • Loading branch information
arv-77 committed Apr 16, 2021
1 parent ad2d4ca commit 7457f54
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import THRESHOLD, MetricTester
from tests.helpers.testers import NUM_CLASSES
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import Accuracy
from torchmetrics.functional import accuracy
from torchmetrics.utilities.checks import _input_format_classification
Expand Down Expand Up @@ -255,15 +254,15 @@ def test_wrong_params(
(
tensor([0, 1, 1, 1]),
tensor([2, 2, 1, 1]),
tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]]),
tensor([[1, 0, 1, 1], [0, 0, 1, 0]]),
tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]])
)
],
)
def test_different_modes(preds_mc, target_mc, preds_ml, target_ml):
acc = Accuracy()
acc(preds_mc, target_mc)
with pytest.raises(ValueError, match="The `average` has to be one of"):
with pytest.raises(ValueError, match="^[You cannot use]"):
acc(preds_ml, target_ml)


Expand Down

0 comments on commit 7457f54

Please sign in to comment.