Skip to content

Commit

Permalink
Fix aggregation metrics with zero tensors (#1070)
Browse files Browse the repository at this point in the history
* Apply suggestions from code review

Co-authored-by: Jirka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Jun 7, 2022
1 parent f0279bb commit 1cc90c6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed aggregation metrics when input only contains zero ([#1070](https://github.com/PyTorchLightning/metrics/pull/1070))

-

Expand Down
1 change: 1 addition & 0 deletions tests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def test_nan_error(value, nan_strategy, metric_class):
(CatMetric, 2.0, _case1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])),
(CatMetric, "ignore", _case2, torch.tensor([1.0, 2.0, 4.0, 5.0])),
(CatMetric, 2.0, _case2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])),
(CatMetric, "ignore", torch.zeros(5), torch.zeros(5)),
],
)
def test_nan_expected(metric_class, nan_strategy, value, expected):
Expand Down
24 changes: 16 additions & 8 deletions torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class BaseAggregator(Metric):
value: Tensor
is_differentiable = None
higher_is_better = None
full_state_update = False

def __init__(
self,
Expand Down Expand Up @@ -116,6 +117,8 @@ class MaxMetric(BaseAggregator):
tensor(3.)
"""

full_state_update = True

def __init__(
self,
nan_strategy: Union[str, float] = "warn",
Expand All @@ -136,7 +139,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore
dimensions will be flattened
"""
value = self._cast_and_nan_check_input(value)
if any(value.flatten()): # make sure tensor not empty
if value.numel(): # make sure tensor not empty
self.value = torch.max(self.value, torch.max(value))


Expand Down Expand Up @@ -165,6 +168,8 @@ class MinMetric(BaseAggregator):
tensor(1.)
"""

full_state_update = True

def __init__(
self,
nan_strategy: Union[str, float] = "warn",
Expand All @@ -185,7 +190,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore
dimensions will be flattened
"""
value = self._cast_and_nan_check_input(value)
if any(value.flatten()): # make sure tensor not empty
if value.numel(): # make sure tensor not empty
self.value = torch.min(self.value, torch.min(value))


Expand Down Expand Up @@ -234,7 +239,8 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore
dimensions will be flattened
"""
value = self._cast_and_nan_check_input(value)
self.value += value.sum()
if value.numel():
self.value += value.sum()


class CatMetric(BaseAggregator):
Expand Down Expand Up @@ -277,7 +283,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore
dimensions will be flattened
"""
value = self._cast_and_nan_check_input(value)
if any(value.flatten()):
if value.numel():
self.value.append(value)

def compute(self) -> Tensor:
Expand Down Expand Up @@ -339,14 +345,16 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0
value = self._cast_and_nan_check_input(value)
weight = self._cast_and_nan_check_input(weight)

# broadcast weight to values shape
if not hasattr(torch, "broadcast_to"):
if value.numel() == 0:
return
# broadcast weight to value shape
if hasattr(torch, "broadcast_to"):
weight = torch.broadcast_to(weight, value.shape)
else:
if weight.shape == ():
weight = torch.ones_like(value) * weight
if weight.shape != value.shape:
raise ValueError("Broadcasting not supported on PyTorch <1.8")
else:
weight = torch.broadcast_to(weight, value.shape)

self.value += (value * weight).sum()
self.weight += weight.sum()
Expand Down

0 comments on commit 1cc90c6

Please sign in to comment.