From ea05dbd5d49347abbc2c7acdb197ed47366951db Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 25 May 2021 09:18:43 +0100 Subject: [PATCH 1/3] add reset_forward_cache --- torchmetrics/metric.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 108c9d52ced..61f04965fe0 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -59,6 +59,8 @@ class Metric(nn.Module, ABC): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + reset_forward_cache: + Wether to reset `forward_cache` after update when `dist_sync_on_step` is False. """ __jit_ignored_attributes__ = ["is_differentiable"] @@ -69,6 +71,7 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + reset_forward_cache: bool = True, ): super().__init__() @@ -82,6 +85,7 @@ def __init__( self.compute_on_step = compute_on_step self.process_group = process_group self.dist_sync_fn = dist_sync_fn + self.reset_forward_cache = reset_forward_cache self._to_sync = True self._update_signature = inspect.signature(self.update) @@ -166,7 +170,9 @@ def forward(self, *args, **kwargs): # add current step with torch.no_grad(): self.update(*args, **kwargs) - self._forward_cache = None + + if self.reset_forward_cache: + self._forward_cache = None if self.compute_on_step: self._to_sync = self.dist_sync_on_step From c993e20ac97d6ea8b9f1b7fe5ea9b3f0fb72bb6b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 25 May 2021 10:45:39 +0200 Subject: [PATCH 2/3] update + tests --- CHANGELOG.md | 1 + tests/bases/test_metric.py | 8 ++++++++ torchmetrics/metric.py | 7 +------ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d7f4377d51..1568c58abcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260)) ### Deprecated diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 26165e9de11..736740aab5d 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -303,3 +303,11 @@ def test_warning_on_compute_before_update(): def test_metric_scripts(): torch.jit.script(DummyMetric()) torch.jit.script(DummyMetricSum()) + + +def test_metric_forward_cache_reset(): + metric = DummyMetricSum() + _ = metric(2.0) + assert metric._forward_cache == 2.0 + metric.reset() + assert metric._forward_cache is None diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 61f04965fe0..f3fc6fe1676 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -59,8 +59,6 @@ class Metric(nn.Module, ABC): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. - reset_forward_cache: - Wether to reset `forward_cache` after update when `dist_sync_on_step` is False. """ __jit_ignored_attributes__ = ["is_differentiable"] @@ -85,7 +83,6 @@ def __init__( self.compute_on_step = compute_on_step self.process_group = process_group self.dist_sync_fn = dist_sync_fn - self.reset_forward_cache = reset_forward_cache self._to_sync = True self._update_signature = inspect.signature(self.update) @@ -171,9 +168,6 @@ def forward(self, *args, **kwargs): with torch.no_grad(): self.update(*args, **kwargs) - if self.reset_forward_cache: - self._forward_cache = None - if self.compute_on_step: self._to_sync = self.dist_sync_on_step @@ -288,6 +282,7 @@ def reset(self): This method automatically resets the metric state variables to their default value. """ self._update_called = False + self._forward_cache = None # lower lightning versions requires this implicitly to log metric objects correctly in self.log if not _LIGHTNING_AVAILABLE or self._LIGHTNING_GREATER_EQUAL_1_3: self._computed = None From 8880ceef539ab1d93df3fe96cf7f9a1ecbb5a520 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 25 May 2021 10:46:31 +0200 Subject: [PATCH 3/3] Update torchmetrics/metric.py --- torchmetrics/metric.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index f3fc6fe1676..b3dddaea3bd 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -69,7 +69,6 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - reset_forward_cache: bool = True, ): super().__init__()