Skip to content

Commit

Permalink
[metrics] change default behaviour of state dict (#4685)
Browse files Browse the repository at this point in the history
* fix state dict

* Update docs/source/metrics.rst

Co-authored-by: Rohit Gupta <[email protected]>

* changelog

Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: chaton <[email protected]>
  • Loading branch information
3 people authored Nov 16, 2020
1 parent be60efb commit 5109766
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))


- Metric states are no longer as default added to `state_dict` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/))


### Deprecated


Expand Down Expand Up @@ -81,7 +84,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))
- Added congratulations at the end of our notebooks ([#4555](https://github.com/PyTorchLightning/pytorch-lightning/pull/4555))
Expand Down
4 changes: 2 additions & 2 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us

.. note::

Metric states will as default add their internal state to the models ``state_dict``.
To change this after initializing the metric the method ``.persistent(mode)`` can
Metric states are **not** added to the models ``state_dict`` by default.
To change this, after initializing the metric, the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.

*********************
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self._reductions = {}

def add_state(
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False
):
"""
Adds metric state variable. Only used by subclasses.
Expand All @@ -100,6 +100,7 @@ def add_state(
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
function in this parameter.
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Default is ``False``.
Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
Expand Down Expand Up @@ -278,7 +279,7 @@ def _apply(self, fn):
f'or a list of torch.Tensor, but encountered {current_val}')
return self

def persistent(self, mode: bool = True):
def persistent(self, mode: bool = False):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
Expand Down
13 changes: 12 additions & 1 deletion tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pickle

from collections import OrderedDict
from distutils.version import LooseVersion

import cloudpickle
import numpy as np
import pytest
Expand Down Expand Up @@ -167,3 +168,13 @@ def test_pickle(tmpdir):
metric_loaded = cloudpickle.loads(metric_pickled)

assert metric_loaded.compute() == 1


def test_state_dict(tmpdir):
""" test that metric states can be removed and added to state dict """
metric = Dummy()
assert metric.state_dict() == OrderedDict()
metric.persistent(True)
assert metric.state_dict() == OrderedDict(x=0)
metric.persistent(False)
assert metric.state_dict() == OrderedDict()

0 comments on commit 5109766

Please sign in to comment.