Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace thredshold argument to binned metrics #322

Merged
merged 11 commits into from
Jun 28, 2021
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299))
- Added KID metric to image module ([#301](https://github.com/PyTorchLightning/metrics/pull/301))
- Added `sync` and `sync_context` methods for manually controlling when metric states are synced ([#302](https://github.com/PyTorchLightning/metrics/pull/302))
- Added `thresholds` argument to binned metrics for manually controlling the thresholds ([]())

### Changed

Expand Down
27 changes: 24 additions & 3 deletions tests/classification/test_binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from tests.classification.inputs import _input_multilabel_prob_plausible as _input_mlb_prob_ok
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision
from torchmetrics.classification.binned_precision_recall import (
BinnedAveragePrecision,
BinnedPrecisionRecallCurve,
BinnedRecallAtFixedPrecision,
)

seed_all(42)

Expand Down Expand Up @@ -112,8 +116,10 @@ class TestBinnedAveragePrecision(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("num_thresholds", [101, 301])
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds):
@pytest.mark.parametrize(
"num_thresholds, thresholds", ([101, None], [301, None], [None, torch.linspace(0.0, 1.0, 101)])
)
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds, thresholds):
# rounding will simulate binning for both implementations
preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6

Expand All @@ -127,5 +133,20 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_o
metric_args={
"num_classes": num_classes,
"num_thresholds": num_thresholds,
"thresholds": thresholds
},
)


@pytest.mark.parametrize(
"metric_class", [BinnedAveragePrecision, BinnedRecallAtFixedPrecision, BinnedPrecisionRecallCurve]
)
def test_raises_errors(metric_class):
if metric_class == BinnedRecallAtFixedPrecision:
metric_class = partial(metric_class, min_precision=0.5)

with pytest.raises(ValueError):
metric_class(num_classes=10, num_thresholds=100, thresholds=[0.1, 0.5, 0.9])

with pytest.raises(ValueError):
metric_class(num_classes=10, num_thresholds=None, thresholds=1)
58 changes: 50 additions & 8 deletions torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ class BinnedPrecisionRecallCurve(Metric):
Args:
num_classes: integer with number of classes. For binary, set to 1.
num_thresholds: number of bins used for computation. More bins will lead to more detailed
curve and accurate estimates, but will be slower and consume more memory. Default 100
curve and accurate estimates, but will be slower and consume more memory. Default 100.
Mutually exclusive with `thresholds` argument.
thresholds: list or tensor with specific thresholds. Mutually exclusive with `num_thresholds`
argument.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -62,6 +65,12 @@ class BinnedPrecisionRecallCurve(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises:
ValueError:
If both ``num_thresholds`` and ``thresholds`` is not ``None``
Borda marked this conversation as resolved.
Show resolved Hide resolved
ValueError:
If ``thresholds`` is not a list or tensor
Borda marked this conversation as resolved.
Show resolved Hide resolved

Example (binary case):
>>> from torchmetrics import BinnedPrecisionRecallCurve
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4])
Expand Down Expand Up @@ -106,7 +115,8 @@ class BinnedPrecisionRecallCurve(Metric):
def __init__(
self,
num_classes: int,
num_thresholds: int = 100,
num_thresholds: Optional[int] = 100,
thresholds: Optional[Union[Tensor, List[float]]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -118,14 +128,26 @@ def __init__(
)

self.num_classes = num_classes
self.num_thresholds = num_thresholds
thresholds = torch.linspace(0, 1.0, num_thresholds)
self.register_buffer("thresholds", thresholds)
if num_thresholds is not None:
if thresholds is not None:
raise ValueError(
'Arguments `num_thresholds` and `thresholds` are mutually exclusive, but got'
f'{num_thresholds} and {thresholds}'
)
self.num_thresholds = num_thresholds
thresholds = torch.linspace(0, 1.0, num_thresholds)
self.register_buffer("thresholds", thresholds)
elif thresholds is not None:
Borda marked this conversation as resolved.
Show resolved Hide resolved
if not (isinstance(thresholds, list) or isinstance(thresholds, Tensor)):
raise ValueError('Expected argument `thresholds` to either be a list of floats or a tensor')
thresholds = torch.tensor(thresholds) if isinstance(thresholds, list) else thresholds
self.num_thresholds = thresholds.numel()
self.register_buffer("thresholds", thresholds)

for name in ("TPs", "FPs", "FNs"):
self.add_state(
name=name,
default=torch.zeros(num_classes, num_thresholds, dtype=torch.float32),
default=torch.zeros(num_classes, self.num_thresholds, dtype=torch.float32),
dist_reduce_fx="sum",
)

Expand Down Expand Up @@ -186,12 +208,21 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
num_thresholds: number of bins used for computation. More bins will lead to more detailed
curve and accurate estimates, but will be slower and consume more memory. Default 100
curve and accurate estimates, but will be slower and consume more memory. Default 100.
Mutually exclusive with `thresholds` argument.
thresholds: list or tensor with specific thresholds. Mutually exclusive with `num_thresholds`
Borda marked this conversation as resolved.
Show resolved Hide resolved
argument.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises:
ValueError:
If both ``num_thresholds`` and ``thresholds`` is not ``None``
Borda marked this conversation as resolved.
Show resolved Hide resolved
ValueError:
If ``thresholds`` is not a list or tensor

Example (binary case):
>>> from torchmetrics import BinnedAveragePrecision
>>> pred = torch.tensor([0, 1, 2, 3])
Expand Down Expand Up @@ -234,12 +265,21 @@ class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
num_classes: integer with number of classes. Provide 1 for for binary problems.
min_precision: float value specifying minimum precision threshold.
num_thresholds: number of bins used for computation. More bins will lead to more detailed
curve and accurate estimates, but will be slower and consume more memory. Default 100
curve and accurate estimates, but will be slower and consume more memory. Default 100.
Mutually exclusive with `thresholds` argument.
thresholds: list or tensor with specific thresholds. Mutually exclusive with `num_thresholds`
Borda marked this conversation as resolved.
Show resolved Hide resolved
argument.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises:
ValueError:
If both ``num_thresholds`` and ``thresholds`` is not ``None``
Borda marked this conversation as resolved.
Show resolved Hide resolved
ValueError:
If ``thresholds`` is not a list or tensor

Example (binary case):
>>> from torchmetrics import BinnedRecallAtFixedPrecision
>>> pred = torch.tensor([0, 0.2, 0.5, 0.8])
Expand All @@ -265,13 +305,15 @@ def __init__(
num_classes: int,
min_precision: float,
num_thresholds: int = 100,
thresholds: Optional[Union[Tensor, List[float]]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
num_classes=num_classes,
num_thresholds=num_thresholds,
thresholds=thresholds,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
Expand Down