From 80575deba9bc32627532c6476dae48f2f3828898 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 17 Jun 2021 10:39:05 +0100 Subject: [PATCH 01/36] wip --- torchmetrics/metric.py | 46 ++++++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 82262af26d1..4a7828c5f77 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -221,6 +221,31 @@ def wrapped_func(*args, **kwargs): return wrapped_func + def _apply_sync(self, fn: Optional[Callable] = None, *args, **kwargs) -> Any: + dist_sync_fn = self.dist_sync_fn + if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): + # User provided a bool, so we assume DDP if available + dist_sync_fn = gather_all_tensors + + synced = False + cache = [] + if self._to_sync and dist_sync_fn is not None: + # cache prior to syncing + cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + + # sync + self._sync_dist(dist_sync_fn) + synced = True + + value = fn(*args, **kwargs) if fn else None + + if synced: + # if we synced, restore to cache so that we can continue to accumulate un-synced state + for attr, val in cache.items(): + setattr(self, attr, val) + + return value + def _wrap_compute(self, compute): @functools.wraps(compute) @@ -236,26 +261,7 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - dist_sync_fn = self.dist_sync_fn - if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): - # User provided a bool, so we assume DDP if available - dist_sync_fn = gather_all_tensors - - synced = False - cache = [] - if self._to_sync and dist_sync_fn is not None: - # cache prior to syncing - cache = {attr: getattr(self, attr) for attr in self._defaults} - - # sync - self._sync_dist(dist_sync_fn) - synced = True - - self._computed = compute(*args, **kwargs) - if synced: - # if we synced, restore to cache so that we can continue to accumulate un-synced state - for attr, val in cache.items(): - setattr(self, attr, val) + self._computed = self._apply_sync(fn=self.compute, *args, **kwargs) return self._computed From 904ecec7050bd736c30464753023973bb3ad54cf Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 17 Jun 2021 06:00:18 -0400 Subject: [PATCH 02/36] add _apply_sync to nn.Metric --- torchmetrics/metric.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 4a7828c5f77..1b7879bccd7 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -221,8 +221,23 @@ def wrapped_func(*args, **kwargs): return wrapped_func - def _apply_sync(self, fn: Optional[Callable] = None, *args, **kwargs) -> Any: - dist_sync_fn = self.dist_sync_fn + def _apply_sync( + self, + fn: Optional[Callable] = None, + dist_sync_fn: Optional[Callable] = None, + *args, + **kwargs + ) -> Any: + """ + Automatically perform synchronization when running in distributed setting, + apply a function, restore cache states and return the output of provided fn function. + + Args: + fn: Function to be applied after metric states synchronization + dist_sync_fn: Function to be used to perform metric states synchronization + args: Arguments to be passed to the fn function + kwargs: Keywords arguments to be passed to the fn function + """ if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors @@ -261,7 +276,7 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - self._computed = self._apply_sync(fn=self.compute, *args, **kwargs) + self._computed = self._apply_sync(fn=compute, dist_sync_fn=self.dist_sync_fn, *args, **kwargs) return self._computed From 71ca9be6b5bdf73efbbccffc3eee5776255666b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jun 2021 10:07:13 +0000 Subject: [PATCH 03/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/metric.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 1b7879bccd7..c847fbeed27 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -222,11 +222,7 @@ def wrapped_func(*args, **kwargs): return wrapped_func def _apply_sync( - self, - fn: Optional[Callable] = None, - dist_sync_fn: Optional[Callable] = None, - *args, - **kwargs + self, fn: Optional[Callable] = None, dist_sync_fn: Optional[Callable] = None, *args, **kwargs ) -> Any: """ Automatically perform synchronization when running in distributed setting, From e4d99d86e15f55b989f83922f3084b1e37d2d3b2 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 17 Jun 2021 06:13:07 -0400 Subject: [PATCH 04/36] move to context manager --- torchmetrics/metric.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 1b7879bccd7..7134d4f748c 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -21,7 +21,7 @@ import torch from torch import Tensor, nn - +from contextlib import contextmanager from torchmetrics.utilities import apply_to_collection, rank_zero_warn from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum from torchmetrics.utilities.distributed import gather_all_tensors @@ -221,22 +221,17 @@ def wrapped_func(*args, **kwargs): return wrapped_func + @contextmanager def _apply_sync( self, - fn: Optional[Callable] = None, dist_sync_fn: Optional[Callable] = None, - *args, - **kwargs - ) -> Any: + ) -> None: """ - Automatically perform synchronization when running in distributed setting, - apply a function, restore cache states and return the output of provided fn function. + Context manager to synchronize the states between processes when running in a distributed setting + and restore the local cache states after yielding. Args: - fn: Function to be applied after metric states synchronization - dist_sync_fn: Function to be used to perform metric states synchronization - args: Arguments to be passed to the fn function - kwargs: Keywords arguments to be passed to the fn function + dist_sync_fn: Function to be used to perform states synchronization """ if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): # User provided a bool, so we assume DDP if available @@ -252,15 +247,13 @@ def _apply_sync( self._sync_dist(dist_sync_fn) synced = True - value = fn(*args, **kwargs) if fn else None + yield if synced: # if we synced, restore to cache so that we can continue to accumulate un-synced state for attr, val in cache.items(): setattr(self, attr, val) - return value - def _wrap_compute(self, compute): @functools.wraps(compute) @@ -276,7 +269,8 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - self._computed = self._apply_sync(fn=compute, dist_sync_fn=self.dist_sync_fn, *args, **kwargs) + with self._apply_sync(dist_sync_fn=self.dist_sync_fn): + self._computed = compute(*args, **kwargs) return self._computed From bfcbc74ec585ec99c54bb00e5308e6b4df41de5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jun 2021 10:14:14 +0000 Subject: [PATCH 05/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/metric.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 7134d4f748c..89c4f4c9b7a 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -16,12 +16,13 @@ import operator from abc import ABC, abstractmethod from collections.abc import Sequence +from contextlib import contextmanager from copy import deepcopy from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor, nn -from contextlib import contextmanager + from torchmetrics.utilities import apply_to_collection, rank_zero_warn from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum from torchmetrics.utilities.distributed import gather_all_tensors @@ -227,8 +228,8 @@ def _apply_sync( dist_sync_fn: Optional[Callable] = None, ) -> None: """ - Context manager to synchronize the states between processes when running in a distributed setting - and restore the local cache states after yielding. + Context manager to synchronize the states between processes when running in a distributed setting + and restore the local cache states after yielding. Args: dist_sync_fn: Function to be used to perform states synchronization From 41a60e7a6e014023692faea975d6c9f412069155 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 17 Jun 2021 11:15:07 +0100 Subject: [PATCH 06/36] resolve flake8 --- torchmetrics/metric.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 7134d4f748c..89c4f4c9b7a 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -16,12 +16,13 @@ import operator from abc import ABC, abstractmethod from collections.abc import Sequence +from contextlib import contextmanager from copy import deepcopy from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor, nn -from contextlib import contextmanager + from torchmetrics.utilities import apply_to_collection, rank_zero_warn from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum from torchmetrics.utilities.distributed import gather_all_tensors @@ -227,8 +228,8 @@ def _apply_sync( dist_sync_fn: Optional[Callable] = None, ) -> None: """ - Context manager to synchronize the states between processes when running in a distributed setting - and restore the local cache states after yielding. + Context manager to synchronize the states between processes when running in a distributed setting + and restore the local cache states after yielding. Args: dist_sync_fn: Function to be used to perform states synchronization From 94fab1b0602d8f4c9c2d3e7cb29afa236216f4bb Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 17 Jun 2021 13:06:24 -0400 Subject: [PATCH 07/36] add sync --- torchmetrics/metric.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 89c4f4c9b7a..1371410cf34 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -19,7 +19,6 @@ from contextlib import contextmanager from copy import deepcopy from typing import Any, Callable, List, Optional, Union - import torch from torch import Tensor, nn @@ -202,7 +201,7 @@ def _sync_dist(self, dist_sync_fn=gather_all_tensors): for attr, reduction_fn in self._reductions.items(): # pre-processing ops (stack or flatten for inputs) - if isinstance(output_dict[attr][0], Tensor): + if isinstance(output_dict[attr][0], Tensor) and isinstance(output_dict[attr], Sequence): output_dict[attr] = torch.stack(output_dict[attr]) elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) @@ -234,7 +233,14 @@ def _apply_sync( Args: dist_sync_fn: Function to be used to perform states synchronization """ - if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): + + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + + if not is_distributed: + yield + return + + if dist_sync_fn is None: # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors @@ -315,7 +321,8 @@ def clone(self): def __getstate__(self): # ignore update and compute functions for pickling - return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} + with self._apply_sync(dist_sync_fn=self.dist_sync_fn): + return deepcopy({k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}) def __setstate__(self, state): # manually restore update and compute functions for pickling From 31498ef20a9985049d671fea69a0cafbc731b7b4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jun 2021 17:07:02 +0000 Subject: [PATCH 08/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/metric.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 1371410cf34..2431d5aa66b 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -19,6 +19,7 @@ from contextlib import contextmanager from copy import deepcopy from typing import Any, Callable, List, Optional, Union + import torch from torch import Tensor, nn @@ -233,13 +234,13 @@ def _apply_sync( Args: dist_sync_fn: Function to be used to perform states synchronization """ - + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() if not is_distributed: yield return - + if dist_sync_fn is None: # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors From 31563dc0cd231713e4a2a1b7ccfe0277122cbc39 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 17 Jun 2021 14:44:08 -0400 Subject: [PATCH 09/36] update --- torchmetrics/metric.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 1371410cf34..ea044191d03 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -321,8 +321,7 @@ def clone(self): def __getstate__(self): # ignore update and compute functions for pickling - with self._apply_sync(dist_sync_fn=self.dist_sync_fn): - return deepcopy({k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}) + return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} def __setstate__(self, state): # manually restore update and compute functions for pickling @@ -354,6 +353,20 @@ def _apply(self, fn): ) return this + @contextmanager + def _apply_persistent( + self, + mode: bool = False, + ) -> None: + """ + Context manager for post-init to change if metric states should be saved to + its state_dict + """ + persistent = self._persistent + self.persistent(mode) + yield + self._persistent = persistent + def persistent(self, mode: bool = False): """Method for post-init to change if metric states should be saved to its state_dict @@ -364,16 +377,17 @@ def persistent(self, mode: bool = False): def state_dict(self, destination=None, prefix="", keep_vars=False): destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) # Register metric states to be part of the state_dict - for key in self._defaults: - if self._persistent[key]: - current_val = getattr(self, key) - if not keep_vars: - if torch.is_tensor(current_val): - current_val = current_val.detach() - elif isinstance(current_val, list): - current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] - destination[prefix + key] = current_val - return destination + with self._apply_sync(dist_sync_fn=self.dist_sync_fn): + for key in self._defaults: + if self._persistent[key]: + current_val = getattr(self, key) + if not keep_vars: + if torch.is_tensor(current_val): + current_val = current_val.detach() + elif isinstance(current_val, list): + current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] + destination[prefix + key] = deepcopy(current_val) + return destination def _load_from_state_dict( self, From 15e6d9a61ec32412196ffcfbfd1dc05094d3f9f8 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 03:29:39 -0400 Subject: [PATCH 10/36] update on comments --- torchmetrics/metric.py | 71 ++++++++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 8192d4b75d9..dde42fbd01d 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -18,7 +18,7 @@ from collections.abc import Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, Dict import torch from torch import Tensor, nn @@ -187,17 +187,19 @@ def forward(self, *args, **kwargs): return self._forward_cache - def _sync_dist(self, dist_sync_fn=gather_all_tensors): + def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None): input_dict = {attr: getattr(self, attr) for attr in self._reductions} + for attr, reduction_fn in self._reductions.items(): # pre-concatenate metric states that are lists to reduce number of all_gather operations if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1: input_dict[attr] = [dim_zero_cat(input_dict[attr])] + output_dict = apply_to_collection( input_dict, Tensor, dist_sync_fn, - group=self.process_group, + group=process_group or self.process_group, ) for attr, reduction_fn in self._reductions.items(): @@ -222,42 +224,71 @@ def wrapped_func(*args, **kwargs): return wrapped_func - @contextmanager - def _apply_sync( + def sync( self, dist_sync_fn: Optional[Callable] = None, - ) -> None: + process_group: Optional[Any] = None, + should_sync: bool = True, + ) -> Dict[str, Tensor]: """ - Context manager to synchronize the states between processes when running in a distributed setting - and restore the local cache states after yielding. + Sync function to control Args: dist_sync_fn: Function to be used to perform states synchronization - """ + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + should_sync: Whether to apply to state synchronization. + Returns: + cache: A dictionarry containing the local metric states. The cache will be empty if sync didn't happen. + """ is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() - if not is_distributed: - yield - return - if dist_sync_fn is None: # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors - synced = False - cache = [] - if self._to_sync and dist_sync_fn is not None: + cache = {} + + if is_distributed and should_sync and dist_sync_fn is not None: # cache prior to syncing cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} # sync - self._sync_dist(dist_sync_fn) + self._sync_dist(dist_sync_fn, process_group=process_group) + + return cache + + @contextmanager + def sync_context( + self, + dist_sync_fn: Optional[Callable] = None, + process_group: Optional[Any] = None, + should_sync: bool = True, + restore_cache: bool = True, + ) -> None: + """ + Context manager to synchronize the states between processes when running in a distributed setting + and restore the local cache states after yielding. + + Args: + dist_sync_fn: Function to be used to perform states synchronization + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + should_sync: Whether to apply to state synchronization. + restore_cache: Whether to restore the cache state so that the metrics can + continue to be accumulated. + """ + synced = False + + if should_sync: + + cache = self.sync(dist_sync_fn=dist_sync_fn, process_group=process_group, should_sync=self._to_sync) synced = True yield - if synced: + if synced and restore_cache: # if we synced, restore to cache so that we can continue to accumulate un-synced state for attr, val in cache.items(): setattr(self, attr, val) @@ -277,7 +308,7 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - with self._apply_sync(dist_sync_fn=self.dist_sync_fn): + with self.sync_context(dist_sync_fn=self.dist_sync_fn): self._computed = compute(*args, **kwargs) return self._computed @@ -378,7 +409,7 @@ def persistent(self, mode: bool = False): def state_dict(self, destination=None, prefix="", keep_vars=False): destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) # Register metric states to be part of the state_dict - with self._apply_sync(dist_sync_fn=self.dist_sync_fn): + with self.sync_context(dist_sync_fn=self.dist_sync_fn): for key in self._defaults: if self._persistent[key]: current_val = getattr(self, key) From b3d5ec5126b6d25149a5e4a8ae64d4fea6698676 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 07:30:21 +0000 Subject: [PATCH 11/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/metric.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index dde42fbd01d..c2db7c9a07f 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -18,7 +18,7 @@ from collections.abc import Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, List, Optional, Union, Dict +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch import Tensor, nn @@ -189,12 +189,12 @@ def forward(self, *args, **kwargs): def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None): input_dict = {attr: getattr(self, attr) for attr in self._reductions} - + for attr, reduction_fn in self._reductions.items(): # pre-concatenate metric states that are lists to reduce number of all_gather operations if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1: input_dict[attr] = [dim_zero_cat(input_dict[attr])] - + output_dict = apply_to_collection( input_dict, Tensor, @@ -240,7 +240,7 @@ def sync( should_sync: Whether to apply to state synchronization. Returns: - cache: A dictionarry containing the local metric states. The cache will be empty if sync didn't happen. + cache: A dictionarry containing the local metric states. The cache will be empty if sync didn't happen. """ is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() @@ -249,7 +249,7 @@ def sync( dist_sync_fn = gather_all_tensors cache = {} - + if is_distributed and should_sync and dist_sync_fn is not None: # cache prior to syncing cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} @@ -276,7 +276,7 @@ def sync_context( process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) should_sync: Whether to apply to state synchronization. - restore_cache: Whether to restore the cache state so that the metrics can + restore_cache: Whether to restore the cache state so that the metrics can continue to be accumulated. """ synced = False From fc42bbeed0d214c90b473c5d827e4dad7e1fe8ac Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 18 Jun 2021 08:36:05 +0100 Subject: [PATCH 12/36] update --- torchmetrics/metric.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index c2db7c9a07f..ba93a5fa0f2 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -236,7 +236,8 @@ def sync( Args: dist_sync_fn: Function to be used to perform states synchronization process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) + Specify the process group on which synchronization is called. + default: None (which selects the entire world) should_sync: Whether to apply to state synchronization. Returns: @@ -274,7 +275,8 @@ def sync_context( Args: dist_sync_fn: Function to be used to perform states synchronization process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) + Specify the process group on which synchronization is called. + default: None (which selects the entire world) should_sync: Whether to apply to state synchronization. restore_cache: Whether to restore the cache state so that the metrics can continue to be accumulated. From 0dcb041cb925f59922eb841cc8c660411d029907 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 03:41:52 -0400 Subject: [PATCH 13/36] update --- torchmetrics/metric.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index dde42fbd01d..41b866ff17b 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -279,16 +279,13 @@ def sync_context( restore_cache: Whether to restore the cache state so that the metrics can continue to be accumulated. """ - synced = False - + cache = {} if should_sync: - cache = self.sync(dist_sync_fn=dist_sync_fn, process_group=process_group, should_sync=self._to_sync) - synced = True - + yield - - if synced and restore_cache: + + if cache and restore_cache: # if we synced, restore to cache so that we can continue to accumulate un-synced state for attr, val in cache.items(): setattr(self, attr, val) From 3b4e83882867240cd3a4832b7e5c4cc05ef6b675 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 07:43:32 +0000 Subject: [PATCH 14/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 97638d4689e..a3baf146fd1 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -284,9 +284,9 @@ def sync_context( cache = {} if should_sync: cache = self.sync(dist_sync_fn=dist_sync_fn, process_group=process_group, should_sync=self._to_sync) - + yield - + if cache and restore_cache: # if we synced, restore to cache so that we can continue to accumulate un-synced state for attr, val in cache.items(): From 140aeeb36206b98866e4406fa5511959ca37c143 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 04:04:28 -0400 Subject: [PATCH 15/36] add restore_cache --- torchmetrics/metric.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 97638d4689e..5eb181c5e9d 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -84,6 +84,7 @@ def __init__( self.process_group = process_group self.dist_sync_fn = dist_sync_fn self._to_sync = True + self._restore_cache = True self._update_signature = inspect.signature(self.update) self.update = self._wrap_update(self.update) @@ -170,6 +171,9 @@ def forward(self, *args, **kwargs): if self.compute_on_step: self._to_sync = self.dist_sync_on_step + # skip restore cache operation from compute + # as cache is stored below. + self._restore_cache = False # save context before switch cache = {attr: getattr(self, attr) for attr in self._defaults} @@ -182,6 +186,8 @@ def forward(self, *args, **kwargs): # restore context for attr, val in cache.items(): setattr(self, attr, val) + + self._restore_cache = True self._to_sync = True self._computed = None @@ -281,9 +287,7 @@ def sync_context( restore_cache: Whether to restore the cache state so that the metrics can continue to be accumulated. """ - cache = {} - if should_sync: - cache = self.sync(dist_sync_fn=dist_sync_fn, process_group=process_group, should_sync=self._to_sync) + cache = self.sync(dist_sync_fn=dist_sync_fn, process_group=process_group, should_sync=should_sync) yield @@ -307,7 +311,7 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - with self.sync_context(dist_sync_fn=self.dist_sync_fn): + with self.sync_context(dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, restore_cache=self._restore_cache): self._computed = compute(*args, **kwargs) return self._computed From 3ab11cd727e1a843fbe6a3085f8f6d3f9253fec9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 08:05:37 +0000 Subject: [PATCH 16/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/metric.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 91f15a26dce..f29a83b46e5 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -171,8 +171,8 @@ def forward(self, *args, **kwargs): if self.compute_on_step: self._to_sync = self.dist_sync_on_step - # skip restore cache operation from compute - # as cache is stored below. + # skip restore cache operation from compute + # as cache is stored below. self._restore_cache = False # save context before switch @@ -186,7 +186,7 @@ def forward(self, *args, **kwargs): # restore context for attr, val in cache.items(): setattr(self, attr, val) - + self._restore_cache = True self._to_sync = True self._computed = None @@ -288,7 +288,7 @@ def sync_context( continue to be accumulated. """ cache = self.sync(dist_sync_fn=dist_sync_fn, process_group=process_group, should_sync=should_sync) - + yield if cache and restore_cache: @@ -311,7 +311,9 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - with self.sync_context(dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, restore_cache=self._restore_cache): + with self.sync_context( + dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, restore_cache=self._restore_cache + ): self._computed = compute(*args, **kwargs) return self._computed From ca04cfbc8ceb9e2afdaacd3820d54038a4272b21 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 04:19:43 -0400 Subject: [PATCH 17/36] add a sync test --- tests/bases/test_ddp.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index 28aca7b1173..fcdcac510f1 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys +from typing import OrderedDict import pytest import torch @@ -116,3 +117,40 @@ def compute(self): def test_non_contiguous_tensors(): """ Test that gather_all operation works for non contiguous tensors """ torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2, ), nprocs=2) + + +def _test_state_dict_is_synced(rank, worldsize): + setup_ddp(rank, worldsize) + + class DummyCatMetric(Metric): + + def __init__(self): + super().__init__() + self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum) + self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum) + + def update(self, x): + self.x += x + self.c += 1 + + def compute(self): + return self.x / self.c + + metric = DummyCatMetric() + metric.persistent(True) + + for i in range(5): + metric(i) + state_dict = metric.state_dict() + + assert state_dict["x"] == sum(range(i + 1)) * 2 + assert metric.x == sum(range(i + 1)) + + assert state_dict["c"] == (i + 1) * 2 + assert metric.c == (i + 1) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_state_dict_is_synced(): + """ This test asserts taht metric are synced while creating the state dict but restored after to continue accumulation. """ + torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, ), nprocs=2) \ No newline at end of file From 45b1b1fea52004b9be8f70c582105ea84e5b1992 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 08:20:40 +0000 Subject: [PATCH 18/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index fcdcac510f1..1062500df79 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -153,4 +153,4 @@ def compute(self): @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_state_dict_is_synced(): """ This test asserts taht metric are synced while creating the state dict but restored after to continue accumulation. """ - torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, ), nprocs=2) \ No newline at end of file + torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, ), nprocs=2) From 8aa2a74fc68b959a5ec3aa535d05257de55fec75 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 18 Jun 2021 09:22:11 +0100 Subject: [PATCH 19/36] resolve flake8 --- tests/bases/test_ddp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index 1062500df79..d7d59e7e54f 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from typing import OrderedDict import pytest import torch @@ -152,5 +151,8 @@ def compute(self): @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_state_dict_is_synced(): - """ This test asserts taht metric are synced while creating the state dict but restored after to continue accumulation. """ + """ + This test asserts taht metric are synced while creating the state + dict but restored after to continue accumulation. + """ torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, ), nprocs=2) From 0f2ed93d8dded2870e40216363dab5ef244d9e50 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 05:53:24 -0400 Subject: [PATCH 20/36] resolve loading --- tests/bases/test_ddp.py | 23 ++++++++++++++++++----- torchmetrics/metric.py | 9 +++++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index fcdcac510f1..f8b2adae9ee 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys +import os from typing import OrderedDict - import pytest import torch from torch import tensor - +from unittest import mock from tests.helpers import seed_all from tests.helpers.testers import DummyMetric, setup_ddp from torchmetrics import Metric from torchmetrics.utilities.distributed import gather_all_tensors +from copy import deepcopy seed_all(42) @@ -119,7 +120,7 @@ def test_non_contiguous_tensors(): torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2, ), nprocs=2) -def _test_state_dict_is_synced(rank, worldsize): +def _test_state_dict_is_synced(rank, worldsize, tmpdir): setup_ddp(rank, worldsize) class DummyCatMetric(Metric): @@ -149,8 +150,20 @@ def compute(self): assert state_dict["c"] == (i + 1) * 2 assert metric.c == (i + 1) + def reload_state_dict(state_dict, expected_x, expected_c): + metric = DummyCatMetric() + #metric.persistent(True) + metric.load_state_dict(state_dict) + assert metric.x == expected_x + assert metric.c == expected_c + + with mock.patch.dict(os.environ, {"GLOBAL_RANK": str(rank)}): + reload_state_dict(deepcopy(state_dict), 20 if not rank else 0, 10 if not rank else 0 ) + + reload_state_dict(deepcopy(state_dict), 20, 10) + @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -def test_state_dict_is_synced(): +def test_state_dict_is_synced(tmpdir): """ This test asserts taht metric are synced while creating the state dict but restored after to continue accumulation. """ - torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, ), nprocs=2) \ No newline at end of file + torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2) \ No newline at end of file diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index f29a83b46e5..9b3014a263f 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union - +import os import torch from torch import Tensor, nn @@ -437,10 +437,15 @@ def _load_from_state_dict( error_msgs: List[str], ) -> None: """ Loads metric states from state_dict """ + + # only global rank 0 should be reloading the values present in the ``state_dict`` + # as the state contains synced values across all progress_group for key in self._defaults: name = prefix + key if name in state_dict: - setattr(self, key, state_dict.pop(name)) + value = state_dict.pop(name) + if os.getenv("GLOBAL_RANK", "0") == "0": + setattr(self, key, value) super()._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs ) From 25bfbf35749e86aa07e4d00f7fe6b217228796e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 09:55:52 +0000 Subject: [PATCH 21/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_ddp.py | 10 ++++++---- torchmetrics/metric.py | 7 ++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index d3c02e51d9c..85bee2c8c9a 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -11,17 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys import os +import sys +from copy import deepcopy +from unittest import mock + import pytest import torch from torch import tensor -from unittest import mock + from tests.helpers import seed_all from tests.helpers.testers import DummyMetric, setup_ddp from torchmetrics import Metric from torchmetrics.utilities.distributed import gather_all_tensors -from copy import deepcopy seed_all(42) @@ -157,7 +159,7 @@ def reload_state_dict(state_dict, expected_x, expected_c): assert metric.c == expected_c with mock.patch.dict(os.environ, {"GLOBAL_RANK": str(rank)}): - reload_state_dict(deepcopy(state_dict), 20 if not rank else 0, 10 if not rank else 0 ) + reload_state_dict(deepcopy(state_dict), 20 if not rank else 0, 10 if not rank else 0) reload_state_dict(deepcopy(state_dict), 20, 10) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 9b3014a263f..0aeff66cbd8 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -14,12 +14,13 @@ import functools import inspect import operator +import os from abc import ABC, abstractmethod from collections.abc import Sequence from contextlib import contextmanager from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union -import os + import torch from torch import Tensor, nn @@ -437,8 +438,8 @@ def _load_from_state_dict( error_msgs: List[str], ) -> None: """ Loads metric states from state_dict """ - - # only global rank 0 should be reloading the values present in the ``state_dict`` + + # only global rank 0 should be reloading the values present in the ``state_dict`` # as the state contains synced values across all progress_group for key in self._defaults: name = prefix + key From 7222aa3aeeea981b2b0c05ce45f006e6b243f21f Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 18 Jun 2021 10:57:14 +0100 Subject: [PATCH 22/36] resolve flake8 --- tests/bases/test_ddp.py | 1 - torchmetrics/metric.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index 85bee2c8c9a..ba3844bbe95 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -153,7 +153,6 @@ def compute(self): def reload_state_dict(state_dict, expected_x, expected_c): metric = DummyCatMetric() - #metric.persistent(True) metric.load_state_dict(state_dict) assert metric.x == expected_x assert metric.c == expected_c diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 0aeff66cbd8..75c7cf41adc 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -439,7 +439,7 @@ def _load_from_state_dict( ) -> None: """ Loads metric states from state_dict """ - # only global rank 0 should be reloading the values present in the ``state_dict`` + # only global rank 0 should be reloading the values present in the ``state_dict`` # as the state contains synced values across all progress_group for key in self._defaults: name = prefix + key From 6dd77052e954e717582478acdd04faf8081660fa Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 18 Jun 2021 11:39:22 +0100 Subject: [PATCH 23/36] Update torchmetrics/metric.py Co-authored-by: Nicki Skafte --- torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 75c7cf41adc..7e9f66fe521 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -238,7 +238,7 @@ def sync( should_sync: bool = True, ) -> Dict[str, Tensor]: """ - Sync function to control + Sync function for manually controlling when metrics states should be synced across processes Args: dist_sync_fn: Function to be used to perform states synchronization From da215adbb3f80511d69830b8695ba7ac38c6a08a Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 06:43:44 -0400 Subject: [PATCH 24/36] remove _update_signature --- torchmetrics/metric.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 9b3014a263f..29c4ffa9643 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -358,11 +358,12 @@ def clone(self): def __getstate__(self): # ignore update and compute functions for pickling - return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} + return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute", "_update_signature"]} def __setstate__(self, state): # manually restore update and compute functions for pickling self.__dict__.update(state) + self._update_signature = inspect.signature(self.update) self.update = self._wrap_update(self.update) self.compute = self._wrap_compute(self.compute) From fe456f2478103e818d5326a28c1ecbc113c1a5d1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 18 Jun 2021 13:20:41 +0200 Subject: [PATCH 25/36] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- tests/bases/test_ddp.py | 9 +++++---- torchmetrics/metric.py | 8 ++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index ba3844bbe95..39e25b02f0f 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -141,15 +141,16 @@ def compute(self): metric = DummyCatMetric() metric.persistent(True) - for i in range(5): + steps = 5 + for i in range(steps): metric(i) state_dict = metric.state_dict() - assert state_dict["x"] == sum(range(i + 1)) * 2 - assert metric.x == sum(range(i + 1)) + assert metric.x == steps * (steps + 1) / 2 + assert state_dict["x"] == metric.x * 2 - assert state_dict["c"] == (i + 1) * 2 assert metric.c == (i + 1) + assert state_dict["c"] == metric.c * 2 def reload_state_dict(state_dict, expected_x, expected_c): metric = DummyCatMetric() diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 290ff8e76a8..2c7d7c38d37 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -211,7 +211,7 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: for attr, reduction_fn in self._reductions.items(): # pre-processing ops (stack or flatten for inputs) - if isinstance(output_dict[attr][0], Tensor) and isinstance(output_dict[attr], Sequence): + if isinstance(output_dict[attr], Sequence) and isinstance(output_dict[attr][0], Tensor): output_dict[attr] = torch.stack(output_dict[attr]) elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) @@ -258,7 +258,7 @@ def sync( cache = {} - if is_distributed and should_sync and dist_sync_fn is not None: + if is_distributed and should_sync: # cache prior to syncing cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} @@ -421,10 +421,10 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): if self._persistent[key]: current_val = getattr(self, key) if not keep_vars: - if torch.is_tensor(current_val): + if isinstance(current_val, torch.Tensor): current_val = current_val.detach() elif isinstance(current_val, list): - current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] + current_val = [cur_v.detach() if isinstance(cur_v, torch.Tensor) else cur_v for cur_v in current_val] destination[prefix + key] = deepcopy(current_val) return destination From 303e8298ab25770e1f2a8eb72d6a15c33cb237b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 11:21:38 +0000 Subject: [PATCH 26/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/metric.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 2c7d7c38d37..3c153f1b879 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -424,7 +424,9 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): if isinstance(current_val, torch.Tensor): current_val = current_val.detach() elif isinstance(current_val, list): - current_val = [cur_v.detach() if isinstance(cur_v, torch.Tensor) else cur_v for cur_v in current_val] + current_val = [ + cur_v.detach() if isinstance(cur_v, torch.Tensor) else cur_v for cur_v in current_val + ] destination[prefix + key] = deepcopy(current_val) return destination From 586ae75a24c0ea0db876394e80eef10dad3e8a8e Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 07:42:30 -0400 Subject: [PATCH 27/36] update on comments --- tests/bases/test_ddp.py | 12 +++++++----- torchmetrics/metric.py | 43 +++++++++++++++++++---------------------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index ba3844bbe95..f926b372a99 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -141,15 +141,17 @@ def compute(self): metric = DummyCatMetric() metric.persistent(True) - for i in range(5): + steps = 5 + for i in range(steps): metric(i) state_dict = metric.state_dict() + print(state_dict) - assert state_dict["x"] == sum(range(i + 1)) * 2 - assert metric.x == sum(range(i + 1)) - - assert state_dict["c"] == (i + 1) * 2 + sum = i * (i + 1) / 2 + assert state_dict["x"] == sum * worldsize + assert metric.x == sum assert metric.c == (i + 1) + assert state_dict["c"] == metric.c * worldsize def reload_state_dict(state_dict, expected_x, expected_c): metric = DummyCatMetric() diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 290ff8e76a8..18e20ed746d 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -29,6 +29,8 @@ from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version +def is_distributed_fn() -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() class Metric(nn.Module, ABC): """ @@ -211,7 +213,7 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: for attr, reduction_fn in self._reductions.items(): # pre-processing ops (stack or flatten for inputs) - if isinstance(output_dict[attr][0], Tensor) and isinstance(output_dict[attr], Sequence): + if isinstance(output_dict[attr], Sequence) and isinstance(output_dict[attr][0], Tensor): output_dict[attr] = torch.stack(output_dict[attr]) elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) @@ -236,6 +238,7 @@ def sync( dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None, should_sync: bool = True, + is_distributed_fn: Optional[Callable] = is_distributed_fn, ) -> Dict[str, Tensor]: """ Sync function for manually controlling when metrics states should be synced across processes @@ -250,10 +253,9 @@ def sync( Returns: cache: A dictionarry containing the local metric states. The cache will be empty if sync didn't happen. """ - is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + is_distributed = is_distributed_fn() if dist_sync_fn is None: - # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors cache = {} @@ -274,6 +276,7 @@ def sync_context( process_group: Optional[Any] = None, should_sync: bool = True, restore_cache: bool = True, + is_distributed_fn: Optional[Callable] = is_distributed_fn, ) -> None: """ Context manager to synchronize the states between processes when running in a distributed setting @@ -288,7 +291,12 @@ def sync_context( restore_cache: Whether to restore the cache state so that the metrics can continue to be accumulated. """ - cache = self.sync(dist_sync_fn=dist_sync_fn, process_group=process_group, should_sync=should_sync) + cache = self.sync( + dist_sync_fn=dist_sync_fn, + process_group=process_group, + should_sync=should_sync, + is_distributed_fn=is_distributed_fn + ) yield @@ -392,20 +400,6 @@ def _apply(self, fn): ) return this - @contextmanager - def _apply_persistent( - self, - mode: bool = False, - ) -> None: - """ - Context manager for post-init to change if metric states should be saved to - its state_dict - """ - persistent = self._persistent - self.persistent(mode) - yield - self._persistent = persistent - def persistent(self, mode: bool = False): """Method for post-init to change if metric states should be saved to its state_dict @@ -421,13 +415,18 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): if self._persistent[key]: current_val = getattr(self, key) if not keep_vars: - if torch.is_tensor(current_val): + if isinstance(current_val, torch.Tensor): current_val = current_val.detach() elif isinstance(current_val, list): - current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] + current_val = [cur_v.detach() if isinstance(cur_v, torch.Tensor) else cur_v for cur_v in current_val] destination[prefix + key] = deepcopy(current_val) return destination + def _on_load_from_state_dict(self, state_dict, key, name) -> None: + value = state_dict.pop(name) + if os.getenv("GLOBAL_RANK", "0") == "0": + setattr(self, key, value) + def _load_from_state_dict( self, state_dict: dict, @@ -445,9 +444,7 @@ def _load_from_state_dict( for key in self._defaults: name = prefix + key if name in state_dict: - value = state_dict.pop(name) - if os.getenv("GLOBAL_RANK", "0") == "0": - setattr(self, key, value) + self._on_load_from_state_dict(state_dict, key, name) super()._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs ) From 409e20a94ddcf9c41e898d906c05e3302ca8eb8e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 11:45:16 +0000 Subject: [PATCH 28/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_ddp.py | 2 +- torchmetrics/metric.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index f926b372a99..241145f412d 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -148,7 +148,7 @@ def compute(self): print(state_dict) sum = i * (i + 1) / 2 - assert state_dict["x"] == sum * worldsize + assert state_dict["x"] == sum * worldsize assert metric.x == sum assert metric.c == (i + 1) assert state_dict["c"] == metric.c * worldsize diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index e1613ad8026..0789259777a 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -29,9 +29,11 @@ from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version + def is_distributed_fn() -> bool: return torch.distributed.is_available() and torch.distributed.is_initialized() + class Metric(nn.Module, ABC): """ Base class for all metrics present in the Metrics API. @@ -424,7 +426,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): destination[prefix + key] = deepcopy(current_val) return destination - def _on_load_from_state_dict(self, state_dict, key, name) -> None: + def _on_load_from_state_dict(self, state_dict, key, name) -> None: value = state_dict.pop(name) if os.getenv("GLOBAL_RANK", "0") == "0": setattr(self, key, value) From 71fad5220aed4dd6ed39af732b33c8f615bc7c06 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 07:50:31 -0400 Subject: [PATCH 29/36] add missing is_distributed_fn --- torchmetrics/metric.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index e1613ad8026..ecb4651f91c 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -249,6 +249,7 @@ def sync( Specify the process group on which synchronization is called. default: None (which selects the entire world) should_sync: Whether to apply to state synchronization. + is_distributed_fn: Function to determine if we are running inside a distributed setting Returns: cache: A dictionarry containing the local metric states. The cache will be empty if sync didn't happen. @@ -290,6 +291,7 @@ def sync_context( should_sync: Whether to apply to state synchronization. restore_cache: Whether to restore the cache state so that the metrics can continue to be accumulated. + is_distributed_fn: Function to determine if we are running inside a distributed setting """ cache = self.sync( dist_sync_fn=dist_sync_fn, From e72de7d2cb274ef31737ae7b55ca0660cb682440 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 08:03:26 -0400 Subject: [PATCH 30/36] update on comments --- tests/bases/test_ddp.py | 1 - torchmetrics/metric.py | 42 +++++++++++++++++++---------------------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index 241145f412d..c6bdae43d11 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -145,7 +145,6 @@ def compute(self): for i in range(steps): metric(i) state_dict = metric.state_dict() - print(state_dict) sum = i * (i + 1) / 2 assert state_dict["x"] == sum * worldsize diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index fbbe5d4b4d1..50513ff5af4 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -30,7 +30,7 @@ from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version -def is_distributed_fn() -> bool: +def distributed_available() -> bool: return torch.distributed.is_available() and torch.distributed.is_initialized() @@ -240,7 +240,7 @@ def sync( dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None, should_sync: bool = True, - is_distributed_fn: Optional[Callable] = is_distributed_fn, + distributed_available: Optional[Callable] = distributed_available, ) -> Dict[str, Tensor]: """ Sync function for manually controlling when metrics states should be synced across processes @@ -251,24 +251,20 @@ def sync( Specify the process group on which synchronization is called. default: None (which selects the entire world) should_sync: Whether to apply to state synchronization. - is_distributed_fn: Function to determine if we are running inside a distributed setting + distributed_available: Function to determine if we are running inside a distributed setting Returns: - cache: A dictionarry containing the local metric states. The cache will be empty if sync didn't happen. + cache: A dictionary containing the local metric states. The cache will be empty if sync didn't happen. """ - is_distributed = is_distributed_fn() - + is_distributed = distributed_available() + if not should_sync or not is_distributed: + return {} if dist_sync_fn is None: dist_sync_fn = gather_all_tensors - - cache = {} - - if is_distributed and should_sync: - # cache prior to syncing - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # sync - self._sync_dist(dist_sync_fn, process_group=process_group) + # cache prior to syncing + cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + # sync + self._sync_dist(dist_sync_fn, process_group=process_group) return cache @@ -279,7 +275,7 @@ def sync_context( process_group: Optional[Any] = None, should_sync: bool = True, restore_cache: bool = True, - is_distributed_fn: Optional[Callable] = is_distributed_fn, + distributed_available: Optional[Callable] = distributed_available, ) -> None: """ Context manager to synchronize the states between processes when running in a distributed setting @@ -293,13 +289,13 @@ def sync_context( should_sync: Whether to apply to state synchronization. restore_cache: Whether to restore the cache state so that the metrics can continue to be accumulated. - is_distributed_fn: Function to determine if we are running inside a distributed setting + distributed_available: Function to determine if we are running inside a distributed setting """ cache = self.sync( dist_sync_fn=dist_sync_fn, process_group=process_group, should_sync=should_sync, - is_distributed_fn=is_distributed_fn + distributed_available=distributed_available ) yield @@ -428,10 +424,8 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): destination[prefix + key] = deepcopy(current_val) return destination - def _on_load_from_state_dict(self, state_dict, key, name) -> None: - value = state_dict.pop(name) - if os.getenv("GLOBAL_RANK", "0") == "0": - setattr(self, key, value) + def _should_load_from_state_dict(self) -> bool: + return os.getenv("GLOBAL_RANK", "0") == "0" def _load_from_state_dict( self, @@ -450,7 +444,9 @@ def _load_from_state_dict( for key in self._defaults: name = prefix + key if name in state_dict: - self._on_load_from_state_dict(state_dict, key, name) + value = state_dict.pop(name) + if self._should_load_from_state_dict(): + setattr(self, key, value) super()._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs ) From 11a3ab8287341c3640f64982c648a797ad035eb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 18 Jun 2021 14:09:42 +0200 Subject: [PATCH 31/36] Update torchmetrics/metric.py --- torchmetrics/metric.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 50513ff5af4..89c4a7f038a 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -421,6 +421,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): current_val = [ cur_v.detach() if isinstance(cur_v, torch.Tensor) else cur_v for cur_v in current_val ] + # the tensors will be synced across processes so deepcopy to drop the references destination[prefix + key] = deepcopy(current_val) return destination From d9c0a53b93ce540bb23b7dc7e1cb6a4fdcab7d53 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 09:28:49 -0400 Subject: [PATCH 32/36] resolve failing test --- tests/bases/test_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index c6bdae43d11..30e9eb56c63 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -136,7 +136,7 @@ def update(self, x): self.c += 1 def compute(self): - return self.x / self.c + return self.x // self.c metric = DummyCatMetric() metric.persistent(True) From 6e7e3a82bd8597fbcaeca2bc79c6206f2b4e5e77 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Jun 2021 16:23:46 +0200 Subject: [PATCH 33/36] Deepsource smells --- torchmetrics/metric.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 89c4a7f038a..0e8ff0654f6 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -198,7 +198,7 @@ def forward(self, *args, **kwargs): return self._forward_cache - def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None): + def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None: input_dict = {attr: getattr(self, attr) for attr in self._reductions} for attr, reduction_fn in self._reductions.items(): @@ -262,10 +262,9 @@ def sync( if dist_sync_fn is None: dist_sync_fn = gather_all_tensors # cache prior to syncing - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + cache = {attr: getattr(self, attr) for attr in self._defaults} # sync self._sync_dist(dist_sync_fn, process_group=process_group) - return cache @contextmanager From 7ee31d0fbf0253bd84ffb724ed57dfb8e3824a27 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 21 Jun 2021 10:59:14 +0200 Subject: [PATCH 34/36] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- tests/bases/test_ddp.py | 2 +- torchmetrics/metric.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/bases/test_ddp.py b/tests/bases/test_ddp.py index 30e9eb56c63..20eeee517fc 100644 --- a/tests/bases/test_ddp.py +++ b/tests/bases/test_ddp.py @@ -167,7 +167,7 @@ def reload_state_dict(state_dict, expected_x, expected_c): @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_state_dict_is_synced(tmpdir): """ - This test asserts taht metric are synced while creating the state + This test asserts that metrics are synced while creating the state dict but restored after to continue accumulation. """ torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 0e8ff0654f6..abdb0bbfa99 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -176,8 +176,7 @@ def forward(self, *args, **kwargs): if self.compute_on_step: self._to_sync = self.dist_sync_on_step - # skip restore cache operation from compute - # as cache is stored below. + # skip restore cache operation from compute as cache is stored below. self._restore_cache = False # save context before switch From 99d99e162a110d468661f48b5dda7dbb045f38cb Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 21 Jun 2021 05:40:38 -0400 Subject: [PATCH 35/36] update --- torchmetrics/metric.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 0e8ff0654f6..5ae94197079 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -250,7 +250,8 @@ def sync( process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - should_sync: Whether to apply to state synchronization. + should_sync: Whether to apply to state synchronization. This will have an impact + only when running in a distributed setting. distributed_available: Function to determine if we are running inside a distributed setting Returns: @@ -285,7 +286,8 @@ def sync_context( process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - should_sync: Whether to apply to state synchronization. + should_sync: Whether to apply to state synchronization. This will have an impact + only when running in a distributed setting. restore_cache: Whether to restore the cache state so that the metrics can continue to be accumulated. distributed_available: Function to determine if we are running inside a distributed setting From 9872ddd9bc53e2cf79624ca3b11e41a49107b4a3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Jun 2021 09:41:20 +0000 Subject: [PATCH 36/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 978610368b6..e250608bfd5 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -249,7 +249,7 @@ def sync( process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - should_sync: Whether to apply to state synchronization. This will have an impact + should_sync: Whether to apply to state synchronization. This will have an impact only when running in a distributed setting. distributed_available: Function to determine if we are running inside a distributed setting @@ -285,7 +285,7 @@ def sync_context( process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - should_sync: Whether to apply to state synchronization. This will have an impact + should_sync: Whether to apply to state synchronization. This will have an impact only when running in a distributed setting. restore_cache: Whether to restore the cache state so that the metrics can continue to be accumulated.