-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement MultioutputWrapper * Add wrapper to __all__ * Address deepsource flagged issues * Update docs & make squeeze_outputs actually work * Update tests to be randomized and parametrized * Apply suggestions from code review Co-authored-by: Stephen Malina <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
- Loading branch information
1 parent
9ef98b4
commit 5893793
Showing
6 changed files
with
322 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from collections import namedtuple | ||
from functools import partial | ||
from typing import Any, Callable, Optional | ||
|
||
import pytest | ||
import torch | ||
from sklearn.metrics import accuracy_score | ||
from sklearn.metrics import r2_score as sk_r2score | ||
|
||
from tests.helpers import seed_all | ||
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester | ||
from torchmetrics import Metric | ||
from torchmetrics.classification import Accuracy | ||
from torchmetrics.regression import R2Score | ||
from torchmetrics.wrappers.multioutput import MultioutputWrapper | ||
|
||
seed_all(42) | ||
|
||
|
||
class _MultioutputMetric(Metric): | ||
"""Test class that allows passing base metric as a class rather than its instantiation to the wrapper.""" | ||
|
||
def __init__( | ||
self, | ||
base_metric_class, | ||
num_outputs: int = 1, | ||
compute_on_step: bool = True, | ||
dist_sync_on_step: bool = False, | ||
process_group: Any = None, | ||
dist_sync_fn: Optional[Callable] = None, | ||
**base_metric_kwargs, | ||
) -> None: | ||
super().__init__( | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
process_group=process_group, | ||
dist_sync_fn=dist_sync_fn, | ||
) | ||
self.metric = MultioutputWrapper( | ||
base_metric_class( | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
process_group=process_group, | ||
dist_sync_fn=dist_sync_fn, | ||
**base_metric_kwargs, | ||
), | ||
num_outputs=num_outputs, | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
dist_sync_fn=dist_sync_fn, | ||
) | ||
|
||
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: | ||
"""Update the each pair of outputs and predictions.""" | ||
return self.metric.update(preds, target) | ||
|
||
def compute(self) -> torch.Tensor: | ||
"""Compute the R2 score between each pair of outputs and predictions.""" | ||
return self.metric.compute() | ||
|
||
@torch.jit.unused | ||
def forward(self, *args, **kwargs): | ||
"""Run forward on the underlying metric.""" | ||
return self.metric(*args, **kwargs) | ||
|
||
def reset(self) -> None: | ||
"""Reset the underlying metric state.""" | ||
self.metric.reset() | ||
|
||
|
||
num_targets = 2 | ||
|
||
Input = namedtuple("Input", ["preds", "target"]) | ||
|
||
_multi_target_regression_inputs = Input( | ||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), | ||
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), | ||
) | ||
_multi_target_classification_inputs = Input( | ||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, num_targets), | ||
target=torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, num_targets)), | ||
) | ||
|
||
|
||
def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values"): | ||
"""Compute R2 score over multiple outputs.""" | ||
sk_preds = preds.view(-1, num_targets).numpy() | ||
sk_target = target.view(-1, num_targets).numpy() | ||
r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) | ||
if adjusted != 0: | ||
r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) | ||
return r2_score | ||
|
||
|
||
def _multi_target_sk_accuracy(preds, target, num_outputs): | ||
"""Compute accuracy over multiple outputs.""" | ||
accs = [] | ||
for i in range(num_outputs): | ||
accs.append(accuracy_score(torch.argmax(preds[:, :, i], dim=1), target[:, i])) | ||
return accs | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs", | ||
[ | ||
( | ||
R2Score, | ||
_multi_target_sk_r2score, | ||
_multi_target_regression_inputs.preds, | ||
_multi_target_regression_inputs.target, | ||
num_targets, | ||
{}, | ||
), | ||
( | ||
Accuracy, | ||
partial(_multi_target_sk_accuracy, num_outputs=2), | ||
_multi_target_classification_inputs.preds, | ||
_multi_target_classification_inputs.target, | ||
num_targets, | ||
dict(num_classes=NUM_CLASSES), | ||
), | ||
], | ||
) | ||
class TestMultioutputWrapper(MetricTester): | ||
"""Test the MultioutputWrapper class with regression and classification inner metrics.""" | ||
|
||
@pytest.mark.parametrize("ddp", [True, False]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) | ||
def test_multioutput_wrapper( | ||
self, base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs, ddp, dist_sync_on_step | ||
): | ||
"""Test that the multioutput wrapper properly slices and computes outputs along the output dimension for | ||
both classification and regression metrics.""" | ||
self.run_class_metric_test( | ||
ddp, | ||
preds, | ||
target, | ||
_MultioutputMetric, | ||
compare_metric, | ||
dist_sync_on_step, | ||
metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class, **metric_kwargs), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
from copy import deepcopy | ||
from typing import Any, Callable, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from torchmetrics import Metric | ||
from torchmetrics.utilities import apply_to_collection | ||
|
||
|
||
def _get_nan_indices(*tensors: torch.Tensor) -> torch.Tensor: | ||
"""Get indices of rows along dim 0 which have NaN values.""" | ||
if len(tensors) == 0: | ||
raise ValueError("Must pass at least one tensor as argument") | ||
sentinel = tensors[0] | ||
nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool, device=sentinel.device) | ||
for tensor in tensors: | ||
permuted_tensor = tensor.flatten(start_dim=1) | ||
nan_idxs |= torch.any(torch.isnan(permuted_tensor), dim=1) | ||
return nan_idxs | ||
|
||
|
||
class MultioutputWrapper(Metric): | ||
"""Wrap a base metric to enable it to support multiple outputs. | ||
Several torchmetrics metrics, such as :class:`torchmetrics.regression.spearman.SpearmanCorrcoef` lack support for | ||
multioutput mode. This class wraps such metrics to support computing one metric per output. | ||
Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs. | ||
This means if you set `num_outputs` to 2, `compute()` will return a Tensor of dimension | ||
(2, ...) where ... represents the dimensions the metric returns when not wrapped. | ||
In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude | ||
fashion, dealing with missing labels (or other data). When ``remove_nans`` is passed, the class will remove the | ||
intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses | ||
`MultioutputWrapper` to wrap :class:`torchmetrics.regression.r2.R2Score` with 2 outputs, one of which occasionally | ||
has missing labels for classes like ``R2Score`` is that this class supports removing NaN values | ||
(parameter ``remove_nans``) on a per-output basis. When ``remove_nans`` is passed the wrapper will remove all rows | ||
Args: | ||
base_metric: | ||
Metric being wrapped. | ||
num_outputs: | ||
Expected dimensionality of the output dimension. This parameter is | ||
used to determine the number of distinct metrics we need to track. | ||
output_dim: | ||
Dimension on which output is expected. Note that while this provides some flexibility, the output dimension | ||
must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels | ||
can have a different number of dimensions than the predictions. This can be worked around if the output | ||
dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. | ||
remove_nans: | ||
Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying | ||
metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N | ||
represents the length of the batch or dataset being passed in. | ||
squeeze_outputs: | ||
If true, will squeeze the 1-item dimensions left after `index_select` is applied. | ||
This is sometimes unnecessary but harmless for metrics such as `R2Score` but useful | ||
for certain classification metrics that can't handle additional 1-item dimensions. | ||
compute_on_step: | ||
Whether to recompute the metric value on each update step. | ||
dist_sync_on_step: | ||
Required for distributed training support. | ||
process_group: | ||
Specify the process group on which synchronization is called. | ||
The default: None (which selects the entire world) | ||
dist_sync_fn: | ||
Required for distributed training support. | ||
Example: | ||
>>> # Mimic R2Score in `multioutput`, `raw_values` mode: | ||
>>> import torch | ||
>>> from torchmetrics import MultioutputWrapper, R2Score | ||
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) | ||
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) | ||
>>> r2score = MultioutputWrapper(R2Score(), 2) | ||
>>> r2score(preds, target) | ||
[tensor(0.9654), tensor(0.9082)] | ||
>>> # Classification metric where prediction and label tensors have different shapes. | ||
>>> from torchmetrics import BinnedAveragePrecision | ||
>>> target = torch.tensor([[1, 2], [2, 0], [1, 2]]) | ||
>>> preds = torch.tensor([ | ||
... [[.1, .8], [.8, .05], [.1, .15]], | ||
... [[.1, .1], [.2, .3], [.7, .6]], | ||
... [[.002, .4], [.95, .45], [.048, .15]] | ||
... ]) | ||
>>> binned_avg_precision = MultioutputWrapper(BinnedAveragePrecision(3, thresholds=5), 2) | ||
>>> binned_avg_precision(preds, target) | ||
[[tensor(-0.), tensor(1.0000), tensor(1.0000)], [tensor(0.3333), tensor(-0.), tensor(0.6667)]] | ||
""" | ||
|
||
def __init__( | ||
self, | ||
base_metric: Metric, | ||
num_outputs: int, | ||
output_dim: int = -1, | ||
remove_nans: bool = True, | ||
squeeze_outputs: bool = True, | ||
compute_on_step: bool = True, | ||
dist_sync_on_step: bool = False, | ||
process_group: Optional[Any] = None, | ||
dist_sync_fn: Callable = None, | ||
): | ||
super().__init__( | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
process_group=process_group, | ||
dist_sync_fn=dist_sync_fn, | ||
) | ||
self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) | ||
self.output_dim = output_dim | ||
self.remove_nans = remove_nans | ||
self.squeeze_outputs = squeeze_outputs | ||
|
||
def _get_args_kwargs_by_output( | ||
self, *args: torch.Tensor, **kwargs: torch.Tensor | ||
) -> List[Tuple[torch.Tensor, torch.Tensor]]: | ||
"""Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out.""" | ||
args_kwargs_by_output = [] | ||
for i in range(len(self.metrics)): | ||
selected_args = apply_to_collection( | ||
args, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) | ||
) | ||
selected_kwargs = apply_to_collection( | ||
kwargs, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) | ||
) | ||
if self.remove_nans: | ||
args_kwargs = selected_args + tuple(selected_kwargs.values()) | ||
nan_idxs = _get_nan_indices(*args_kwargs) | ||
selected_args = [arg[~nan_idxs] for arg in selected_args] | ||
selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()} | ||
|
||
if self.squeeze_outputs: | ||
selected_args = [arg.squeeze(self.output_dim) for arg in selected_args] | ||
args_kwargs_by_output.append((selected_args, selected_kwargs)) | ||
return args_kwargs_by_output | ||
|
||
def update(self, *args: Any, **kwargs: Any) -> None: | ||
"""Update each underlying metric with the corresponding output.""" | ||
reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) | ||
for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): | ||
metric.update(*selected_args, **selected_kwargs) | ||
|
||
def compute(self) -> List[torch.Tensor]: | ||
"""Compute metrics.""" | ||
return [m.compute() for m in self.metrics] | ||
|
||
@torch.jit.unused | ||
def forward(self, *args: Any, **kwargs: Any) -> Any: | ||
"""Call underlying forward methods and aggregate the results if they're non-null. | ||
We override this method to ensure that state variables get copied over on the underlying metrics. | ||
""" | ||
results = [] | ||
reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) | ||
for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): | ||
results.append(metric(*selected_args, **selected_kwargs)) | ||
if results[0] is None: | ||
return None | ||
return results | ||
|
||
@property | ||
def is_differentiable(self) -> bool: | ||
return False | ||
|
||
def reset(self) -> None: | ||
"""Reset all underlying metrics.""" | ||
for metric in self.metrics: | ||
metric.reset() |