From a908cc65c3b60ad26a8542b51704c65b8ac03141 Mon Sep 17 00:00:00 2001 From: Jirka Date: Sun, 18 Apr 2021 09:48:59 +0200 Subject: [PATCH 1/8] allow MetricCollection with args --- torchmetrics/collections.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 3f0e0933c69..72e0dd62edc 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -13,11 +13,12 @@ # limitations under the License. from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Sequence from torch import nn from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn class MetricCollection(nn.ModuleDict): @@ -58,6 +59,12 @@ class MetricCollection(nn.ModuleDict): >>> metrics(preds, target) {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} + Example (input as arguments): + >>> metrics = MetricCollection(Accuracy(), Precision(num_classes=3, average='macro'), + ... Recall(num_classes=3, average='macro')) + >>> metrics(preds, target) + {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} + Example (input as dict): >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), ... 'macro_recall': Recall(num_classes=3, average='macro')}) @@ -72,10 +79,25 @@ class MetricCollection(nn.ModuleDict): def __init__( self, - metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]], + metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], + *metric: Metric, prefix: Optional[str] = None, ): super().__init__() + if isinstance(metrics, Metric): + # set compatible with original type expectations + metrics = [metrics] + elif isinstance(metrics, Sequence): + # prepare for optional additions + metrics = list(metrics) + if metric: + metrics += [m for m in metric if isinstance(m, Metric)] + remain = [m for m in metric if not isinstance(m, Metric)] + if remain: + rank_zero_warn( + f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." + ) + if isinstance(metrics, dict): # Check all values are metrics for name, metric in metrics.items(): @@ -85,7 +107,7 @@ def __init__( " is not an instance of `pl.metrics.Metric`" ) self[name] = metric - elif isinstance(metrics, (tuple, list)): + elif isinstance(metrics, Sequence): for metric in metrics: if not isinstance(metric, Metric): raise ValueError( From 432bf653adf96a550178583c626a2bfcca6681bc Mon Sep 17 00:00:00 2001 From: Jirka Date: Sun, 18 Apr 2021 09:58:16 +0200 Subject: [PATCH 2/8] format --- tests/bases/test_average.py | 3 +++ torchmetrics/average.py | 4 +--- .../classification/binned_precision_recall.py | 11 ++++------- torchmetrics/collections.py | 10 +++------- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/tests/bases/test_average.py b/tests/bases/test_average.py index 9c84caf8ddc..164e6845292 100644 --- a/tests/bases/test_average.py +++ b/tests/bases/test_average.py @@ -15,11 +15,13 @@ def average_ignore_weights(values, weights): class DefaultWeightWrapper(AverageMeter): + def update(self, values, weights): super().update(values) class ScalarWrapper(AverageMeter): + def update(self, values, weights): # torch.ravel is PyTorch 1.8 only, so use np.ravel instead values = values.cpu().numpy() @@ -37,6 +39,7 @@ def update(self, values, weights): ], ) class TestAverageMeter(MetricTester): + @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_average_fn(self, ddp, dist_sync_on_step, values, weights): diff --git a/torchmetrics/average.py b/torchmetrics/average.py index 98621fa87e8..c13cb60ac64 100644 --- a/torchmetrics/average.py +++ b/torchmetrics/average.py @@ -78,9 +78,7 @@ def __init__( # TODO: need to be strings because Unions are not pickleable in Python 3.6 def update( # type: ignore - self, - value: "Union[Tensor, float]", - weight: "Union[Tensor, float]" = 1.0 + self, value: "Union[Tensor, float]", weight: "Union[Tensor, float]" = 1.0 ) -> None: """Updates the average with. diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index eff2bd3998c..702e182fd56 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -157,13 +157,10 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) # Need to guarantee that last precision=1 and recall=0, similar to precision_recall_curve - precisions = torch.cat([ - precisions, torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device) - ], - dim=1) - recalls = torch.cat([recalls, - torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)], - dim=1) + t_ones = torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device) + precisions = torch.cat([precisions, t_ones], dim=1) + t_zeros = torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device) + recalls = torch.cat([recalls, t_zeros], dim=1) if self.num_classes == 1: return (precisions[0, :], recalls[0, :], self.thresholds) else: diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 72e0dd62edc..ab247edebda 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -13,7 +13,7 @@ # limitations under the License. from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Union, Sequence +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from torch import nn @@ -103,17 +103,13 @@ def __init__( for name, metric in metrics.items(): if not isinstance(metric, Metric): raise ValueError( - f"Value {metric} belonging to key {name}" - " is not an instance of `pl.metrics.Metric`" + f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`" ) self[name] = metric elif isinstance(metrics, Sequence): for metric in metrics: if not isinstance(metric, Metric): - raise ValueError( - f"Input {metric} to `MetricCollection` is not a instance" - " of `pl.metrics.Metric`" - ) + raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`") name = metric.__class__.__name__ if name in self: raise ValueError(f"Encountered two metrics both named {name}") From 4676d72ee3be2997b1c56fe639af2b66ef0b7c2c Mon Sep 17 00:00:00 2001 From: Jirka Date: Sun, 18 Apr 2021 10:02:05 +0200 Subject: [PATCH 3/8] chlog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a2be2c4936..9a75f8b7e46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed behaviour of `confusionmatrix` for multilabel data to better match `multilabel_confusion_matrix` from sklearn ([#134](https://github.com/PyTorchLightning/metrics/pull/134)) - Updated FBeta arguments ([#111](https://github.com/PyTorchLightning/metrics/pull/111)) - Changed `reset` method to use `detach.clone()` instead of `deepcopy` when resetting to default ([#163](https://github.com/PyTorchLightning/metrics/pull/163)) +- Allowed `MetricCollection` pass metrics as arguments ([#176](https://github.com/PyTorchLightning/metrics/pull/176)) ### Deprecated From 48e2707d438aa0a71393d562e05386380046d3d6 Mon Sep 17 00:00:00 2001 From: Jirka Date: Sun, 18 Apr 2021 10:15:44 +0200 Subject: [PATCH 4/8] imports --- torchmetrics/collections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index ab247edebda..1457621ac19 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -13,7 +13,7 @@ # limitations under the License. from copy import deepcopy -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Union from torch import nn From b35d660744b77ab794cbd1350306c1bbb1f682ee Mon Sep 17 00:00:00 2001 From: Jirka Date: Sun, 18 Apr 2021 23:31:12 +0200 Subject: [PATCH 5/8] tests --- tests/bases/test_collections.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 68206debd66..f3532838407 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -85,23 +85,23 @@ def test_device_and_dtype_transfer_metriccollection(tmpdir): def test_metric_collection_wrong_input(tmpdir): """ Check that errors are raised on wrong input """ - m1 = DummyMetricSum() + dms = DummyMetricSum() # Not all input are metrics (list) with pytest.raises(ValueError): - _ = MetricCollection([m1, 5]) + _ = MetricCollection([dms, 5]) # Not all input are metrics (dict) with pytest.raises(ValueError): - _ = MetricCollection({'metric1': m1, 'metric2': 5}) + _ = MetricCollection({'metric1': dms, 'metric2': 5}) # Same metric passed in multiple times with pytest.raises(ValueError, match='Encountered two metrics both named *.'): - _ = MetricCollection([m1, m1]) + _ = MetricCollection([dms, dms]) # Not a list or dict passed in - with pytest.raises(ValueError, match='Unknown input to MetricCollection.'): - _ = MetricCollection(m1) + with pytest.warns(Warning, match=' which are not `Metric` so they will be ignored.'): + _ = MetricCollection(dms, [dms]) def test_metric_collection_args_kwargs(tmpdir): From 8be3eb6f18dd57330919ad61c2d51aab70cfea47 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 19 Apr 2021 10:10:55 +0200 Subject: [PATCH 6/8] Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- torchmetrics/collections.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 1457621ac19..b463ec44bee 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -80,7 +80,7 @@ class MetricCollection(nn.ModuleDict): def __init__( self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], - *metric: Metric, + *additional_metrics: Metric, prefix: Optional[str] = None, ): super().__init__() @@ -90,13 +90,17 @@ def __init__( elif isinstance(metrics, Sequence): # prepare for optional additions metrics = list(metrics) - if metric: - metrics += [m for m in metric if isinstance(m, Metric)] - remain = [m for m in metric if not isinstance(m, Metric)] - if remain: - rank_zero_warn( - f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." - ) + remain = [] + for m in additional_metrics: + if isinstance(m, Metric): + metrics.append(m) + else: + remain.append(m) + + if remain: + rank_zero_warn( + f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." + ) if isinstance(metrics, dict): # Check all values are metrics From 5f61b90a185c7ccc135417dbd0efaa9db6c81b58 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 19 Apr 2021 10:23:30 +0200 Subject: [PATCH 7/8] docs --- torchmetrics/collections.py | 44 +++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index b463ec44bee..3da0c899fad 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -23,19 +23,18 @@ class MetricCollection(nn.ModuleDict): """ - MetricCollection class can be used to chain metrics that have the same - call pattern into one single class. + MetricCollection class can be used to chain metrics that have the same call pattern into one single class. Args: metrics: One of the following - * list or tuple: if metrics are passed in as a list, will use the - metrics class name as key for output dict. Therefore, two metrics - of the same class cannot be chained this way. + * list or tuple (sequence): if metrics are passed in as a list, will use the metrics class name + as key for output dict. Therefore, two metrics of the same class cannot be chained this way. - * dict: if metrics are passed in as a dict, will use each key in the - dict as key for output dict. Use this format if you want to chain - together multiple of the same metric with different parameters. + * dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict. + Use this format if you want to chain together multiple of the same metric with different parameters. + + additional_metrics: adding additiona metrics if the first argument was single or sequrnce of metrics. prefix: a string to append in front of the keys of the output dict @@ -46,6 +45,8 @@ class MetricCollection(nn.ModuleDict): If two elements in ``metrics`` have the same ``name``. ValueError: If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``. + ValueError: + If ``metrics`` is is ``dict`` and passed any additional_metrics. Example (input as list): >>> import torch @@ -87,19 +88,24 @@ def __init__( if isinstance(metrics, Metric): # set compatible with original type expectations metrics = [metrics] - elif isinstance(metrics, Sequence): + if isinstance(metrics, Sequence): # prepare for optional additions metrics = list(metrics) - remain = [] - for m in additional_metrics: - if isinstance(m, Metric): - metrics.append(m) - else: - remain.append(m) - - if remain: - rank_zero_warn( - f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." + remain = [] + for m in additional_metrics: + if isinstance(m, Metric): + metrics.append(m) + else: + remain.append(m) + + if remain: + rank_zero_warn( + f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." + ) + elif additional_metrics: + raise ValueError( + f"You have passes extra arguments {additional_metrics} which are not compatible" + f" with first passed dictionary {metrics} so they will be ignored." ) if isinstance(metrics, dict): From bfd4ae3bfbaf81643a45d3d8898578770f67f63f Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 19 Apr 2021 10:29:36 +0200 Subject: [PATCH 8/8] ... --- torchmetrics/collections.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 3da0c899fad..cbbe70ac358 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -93,10 +93,7 @@ def __init__( metrics = list(metrics) remain = [] for m in additional_metrics: - if isinstance(m, Metric): - metrics.append(m) - else: - remain.append(m) + (metrics if isinstance(m, Metric) else remain).append(m) if remain: rank_zero_warn(