Skip to content

Commit

Permalink
Nested metrics dictionaries now can be passed to the loggers (#1582)
Browse files Browse the repository at this point in the history
* now func merge_dicts works with nested dictionaries

* CHANGELOG.md upd
  • Loading branch information
alexeykarnachev authored Apr 23, 2020
1 parent 94e5344 commit edb8d7a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Added the possibility to pass nested metrics dictionaries to loggers ([#1582](https://github.com/PyTorchLightning/pytorch-lightning/pull/1582))

- Fixed memory leak from opt return ([#1528](https://github.com/PyTorchLightning/pytorch-lightning/pull/1528))

- Fixed saving checkpoint before deleting old ones ([#1453](https://github.com/PyTorchLightning/pytorch-lightning/pull/1453))
Expand Down
27 changes: 18 additions & 9 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class LoggerCollection(LightningLoggerBase):
Args:
logger_iterable: An iterable collection of loggers
"""

def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
super().__init__()
self._logger_iterable = logger_iterable
Expand Down Expand Up @@ -347,20 +348,28 @@ def merge_dicts(
Examples:
>>> import pprint
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1}
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1}
>>> d3 = {'a': 1.1, 'v': 2.3}
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1, 'd': {'d1': 1, 'd3': 3}}
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1, 'd': {'d1': 2, 'd2': 3}}
>>> d3 = {'a': 1.1, 'v': 2.3, 'd': {'d3': 3, 'd4': {'d5': 1}}}
>>> dflt_func = min
>>> agg_funcs = {'a': np.mean, 'v': max}
>>> agg_funcs = {'a': np.mean, 'v': max, 'd': {'d1': sum}}
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
{'a': 1.3, 'b': 2.0, 'c': 1, 'v': 2.3}
{'a': 1.3,
'b': 2.0,
'c': 1,
'd': {'d1': 3, 'd2': 3, 'd3': 3, 'd4': {'d5': 1}},
'v': 2.3}
"""

agg_key_funcs = agg_key_funcs or dict()
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
d_out = {}
for k in keys:
fn = agg_key_funcs.get(k, default_func) if agg_key_funcs else default_func
agg_val = fn([v for v in [d_in.get(k) for d_in in dicts] if v is not None])
d_out[k] = agg_val
fn = agg_key_funcs.get(k)
values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None]

if isinstance(values_to_agg[0], dict):
d_out[k] = merge_dicts(values_to_agg, fn, default_func)
else:
d_out[k] = (fn or default_func)(values_to_agg)

return d_out

0 comments on commit edb8d7a

Please sign in to comment.