Skip to content

Commit

Permalink
Fix corner case in manually specifying compute_groups in `MetricCol…
Browse files Browse the repository at this point in the history
…lection` (#2979)

* implementation
* Apply suggestions from code review

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka B <[email protected]>
(cherry picked from commit a8de07d)
  • Loading branch information
SkafteNicki authored and Borda committed Feb 28, 2025
1 parent 2b12132 commit 969a98d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import _flatten_dict, allclose
from torchmetrics.utilities.data import _flatten, _flatten_dict, allclose
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

Expand Down Expand Up @@ -90,7 +90,9 @@ class name as key for the output dict.
due to the internal logic of ``forward`` preventing this. Secondly, since we compute groups share metric
states by reference, calling ``.items()``, ``.values()`` etc. on the metric collection will break this
reference and a copy of states are instead returned in this case (reference will be reestablished on the next
call to ``update``).
call to ``update``). Do note that for the time being that if you are manually specifying compute groups in
nested collections, these are not compatible with the compute groups of the parent collection and will be
overridden.
.. important::
Metric collections can be nested at initialization (see last example) but the output of the collection will
Expand Down Expand Up @@ -192,7 +194,6 @@ class name of the metric:
"""

_modules: dict[str, Metric] # type: ignore[assignment]
_groups: Dict[int, List[str]]
__jit_unused_properties__: ClassVar[list[str]] = ["metric_state"]

def __init__(
Expand All @@ -210,7 +211,7 @@ def __init__(
self._enable_compute_groups = compute_groups
self._groups_checked: bool = False
self._state_is_copy: bool = False

self._groups: Dict[int, list[str]] = {}
self.add_metrics(metrics, *additional_metrics)

@property
Expand Down Expand Up @@ -338,7 +339,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
of just passed by reference
"""
if not self._state_is_copy:
if not self._state_is_copy and self._groups_checked:
for cg in self._groups.values():
m0 = getattr(self, cg[0])
for i in range(1, len(cg)):
Expand Down Expand Up @@ -495,7 +496,6 @@ def add_metrics(
"Unknown input to MetricCollection. Expected, `Metric`, `MetricCollection` or `dict`/`sequence` of the"
f" previous, but got {metrics}"
)

self._groups_checked = False
if self._enable_compute_groups:
self._init_compute_groups()
Expand All @@ -518,9 +518,15 @@ def _init_compute_groups(self) -> None:
f"Input {metric} in `compute_groups` argument does not match a metric in the collection."
f" Please make sure that {self._enable_compute_groups} matches {self.keys(keep_base=True)}"
)
# add metrics not specified in compute groups as their own group
already_in_group = _flatten(self._groups.values()) # type: ignore
counter = len(self._groups)
for k in self.keys(keep_base=True):
if k not in already_in_group:
self._groups[counter] = [k] # type: ignore
counter += 1
self._groups_checked = True
else:
# Initialize all metrics as their own compute group
self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))}

@property
Expand Down
22 changes: 22 additions & 0 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,28 @@ def test_compute_group_define_by_user():
assert m.compute()


def test_compute_group_define_by_user_outside_specs():
"""Check that user can provide compute groups with missing metrics in the specs."""
m = MetricCollection(
MulticlassConfusionMatrix(3),
MulticlassRecall(3),
MulticlassPrecision(3),
MulticlassAccuracy(3),
compute_groups=[["MulticlassRecall", "MulticlassPrecision"]],
)
assert m._groups_checked
assert m.compute_groups == {
0: ["MulticlassRecall", "MulticlassPrecision"],
1: ["MulticlassConfusionMatrix"],
2: ["MulticlassAccuracy"],
}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
assert m.compute()


def test_classwise_wrapper_compute_group():
"""Check that user can provide compute groups."""
classwise_accuracy = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy")
Expand Down

0 comments on commit 969a98d

Please sign in to comment.