diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 68adb49c510..7f000374a2b 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -24,7 +24,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.data import _flatten_dict, allclose +from torchmetrics.utilities.data import _flatten, _flatten_dict, allclose from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val @@ -90,7 +90,9 @@ class name as key for the output dict. due to the internal logic of ``forward`` preventing this. Secondly, since we compute groups share metric states by reference, calling ``.items()``, ``.values()`` etc. on the metric collection will break this reference and a copy of states are instead returned in this case (reference will be reestablished on the next - call to ``update``). + call to ``update``). Do note that for the time being that if you are manually specifying compute groups in + nested collections, these are not compatible with the compute groups of the parent collection and will be + overridden. .. important:: Metric collections can be nested at initialization (see last example) but the output of the collection will @@ -192,7 +194,6 @@ class name of the metric: """ _modules: dict[str, Metric] # type: ignore[assignment] - _groups: Dict[int, List[str]] __jit_unused_properties__: ClassVar[list[str]] = ["metric_state"] def __init__( @@ -210,7 +211,7 @@ def __init__( self._enable_compute_groups = compute_groups self._groups_checked: bool = False self._state_is_copy: bool = False - + self._groups: Dict[int, list[str]] = {} self.add_metrics(metrics, *additional_metrics) @property @@ -338,7 +339,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None: of just passed by reference """ - if not self._state_is_copy: + if not self._state_is_copy and self._groups_checked: for cg in self._groups.values(): m0 = getattr(self, cg[0]) for i in range(1, len(cg)): @@ -495,7 +496,6 @@ def add_metrics( "Unknown input to MetricCollection. Expected, `Metric`, `MetricCollection` or `dict`/`sequence` of the" f" previous, but got {metrics}" ) - self._groups_checked = False if self._enable_compute_groups: self._init_compute_groups() @@ -518,9 +518,15 @@ def _init_compute_groups(self) -> None: f"Input {metric} in `compute_groups` argument does not match a metric in the collection." f" Please make sure that {self._enable_compute_groups} matches {self.keys(keep_base=True)}" ) + # add metrics not specified in compute groups as their own group + already_in_group = _flatten(self._groups.values()) # type: ignore + counter = len(self._groups) + for k in self.keys(keep_base=True): + if k not in already_in_group: + self._groups[counter] = [k] # type: ignore + counter += 1 self._groups_checked = True else: - # Initialize all metrics as their own compute group self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))} @property diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 312b0f45523..b2e6ee321f3 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -573,6 +573,28 @@ def test_compute_group_define_by_user(): assert m.compute() +def test_compute_group_define_by_user_outside_specs(): + """Check that user can provide compute groups with missing metrics in the specs.""" + m = MetricCollection( + MulticlassConfusionMatrix(3), + MulticlassRecall(3), + MulticlassPrecision(3), + MulticlassAccuracy(3), + compute_groups=[["MulticlassRecall", "MulticlassPrecision"]], + ) + assert m._groups_checked + assert m.compute_groups == { + 0: ["MulticlassRecall", "MulticlassPrecision"], + 1: ["MulticlassConfusionMatrix"], + 2: ["MulticlassAccuracy"], + } + + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + m.update(preds, target) + assert m.compute() + + def test_classwise_wrapper_compute_group(): """Check that user can provide compute groups.""" classwise_accuracy = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy")