Skip to content

Commit

Permalink
[metrics] change default behaviour of state dict (#4685)
Browse files Browse the repository at this point in the history
* fix state dict

* Update docs/source/metrics.rst

Co-authored-by: Rohit Gupta <[email protected]>

* changelog

Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: chaton <[email protected]>
  • Loading branch information
3 people committed Nov 17, 2020
1 parent 52a0781 commit f722018
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



- Metric states are no longer as default added to `state_dict` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/))


### Deprecated


Expand Down
4 changes: 2 additions & 2 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us

.. note::

Metric states will as default add their internal state to the models ``state_dict``.
To change this after initializing the metric the method ``.persistent(mode)`` can
Metric states are **not** added to the models ``state_dict`` by default.
To change this, after initializing the metric, the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.

*********************
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self._reductions = {}

def add_state(
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False
):
"""
Adds metric state variable. Only used by subclasses.
Expand All @@ -100,6 +100,7 @@ def add_state(
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
function in this parameter.
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Default is ``False``.
Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
Expand Down Expand Up @@ -278,7 +279,7 @@ def _apply(self, fn):
f'or a list of torch.Tensor, but encountered {current_val}')
return self

def persistent(self, mode: bool = True):
def persistent(self, mode: bool = False):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
Expand Down
13 changes: 12 additions & 1 deletion tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pickle

from collections import OrderedDict
from distutils.version import LooseVersion

import cloudpickle
import numpy as np
import pytest
Expand Down Expand Up @@ -167,3 +168,13 @@ def test_pickle(tmpdir):
metric_loaded = cloudpickle.loads(metric_pickled)

assert metric_loaded.compute() == 1


def test_state_dict(tmpdir):
""" test that metric states can be removed and added to state dict """
metric = Dummy()
assert metric.state_dict() == OrderedDict()
metric.persistent(True)
assert metric.state_dict() == OrderedDict(x=0)
metric.persistent(False)
assert metric.state_dict() == OrderedDict()

0 comments on commit f722018

Please sign in to comment.