Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2 score metric #1274

Merged
merged 27 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
56c953f
r2_score added
asteyo Aug 6, 2021
767430c
catalyst-make-codestyle _r2_score.py
asteyo Aug 9, 2021
faadc56
Merge branch 'master' into r2_score
asteyo Aug 9, 2021
85d8da6
r2 score LoaderMetric API is added
asteyo Aug 30, 2021
622c01e
r2 score renamed to r2 squared
asteyo Sep 11, 2021
29da20f
functional r2 metric name fix to r2_squared
asteyo Sep 11, 2021
62d65f9
test for functional r2 squared is added
asteyo Sep 11, 2021
9991a0c
compute key-value fix
asteyo Sep 20, 2021
ea0d905
args order in update fixed
asteyo Sep 20, 2021
3dabd39
args order fix
asteyo Sep 20, 2021
c4622ac
r2squared import is added to functional metrics init
asteyo Sep 20, 2021
84b0435
r2squared callback is added
asteyo Sep 20, 2021
a4a4de1
r2squared callback is added to metrics callbacks init
asteyo Sep 20, 2021
80be323
r2squared metric is added to metrics init
asteyo Sep 20, 2021
9d03acb
tests for r2squared is added
asteyo Sep 20, 2021
eb343f3
regression test update
asteyo Sep 20, 2021
2759cda
metrics docs update
asteyo Sep 20, 2021
2824561
Merge branch 'master' into r2_score
asteyo Sep 21, 2021
5f9713f
codestyle fix
asteyo Sep 25, 2021
932816b
Merge branch 'master' into r2_score
asteyo Sep 25, 2021
7656b3b
Merge branch 'master' of https://github.com/catalyst-team/catalyst in…
asteyo Sep 25, 2021
8abe5ed
torch.square to torch.pow fix)
asteyo Sep 25, 2021
f75a2ef
Merge branch 'r2_score' of https://github.com/asteyo/catalyst into r2…
asteyo Sep 25, 2021
e7d7623
codestyle update
asteyo Sep 26, 2021
8693c6b
spaces codestyle fix
asteyo Sep 26, 2021
72fb072
codestyle fix
asteyo Sep 26, 2021
b60060b
Update _r2_squared.py
Scitator Sep 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions catalyst/callbacks/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from catalyst.callbacks.metrics.functional_metric import FunctionalMetricCallback

from catalyst.callbacks.metrics.r2_squared import R2SquaredCallback

from catalyst.callbacks.metrics.recsys import (
HitrateCallback,
MAPCallback,
Expand Down
75 changes: 75 additions & 0 deletions catalyst/callbacks/metrics/r2_squared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from catalyst.callbacks.metric import LoaderMetricCallback
from catalyst.metrics._r2_squared import R2Squared


class R2SquaredCallback(LoaderMetricCallback):
"""R2 Squared metric callback.

Args:
input_key: input key to use for r2squared calculation, specifies our ``y_true``.
target_key: output key to use for r2squared calculation, specifies our ``y_pred``.
prefix: metric prefix
suffix: metric suffix

Examples:

.. code-block:: python

import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl

# data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])

# model training
runner = dl.SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
logdir="./logdir",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
num_epochs=8,
verbose=True,
callbacks=[
dl.R2SquaredCallback(input_key="logits", target_key="targets")
]
)

.. note::
Please follow the `minimal examples`_ sections for more use cases.

.. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples
"""

def __init__(
self,
input_key: str,
target_key: str,
prefix: str = None,
suffix: str = None,
):
"""Init."""
super().__init__(
metric=R2Squared(prefix=prefix, suffix=suffix),
input_key=input_key,
target_key=target_key,
)


__all__ = ["R2SquaredCallback"]
1 change: 1 addition & 0 deletions catalyst/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from catalyst.metrics._map import MAPMetric
from catalyst.metrics._mrr import MRRMetric
from catalyst.metrics._ndcg import NDCGMetric
from catalyst.metrics._r2_squared import R2Squared
from catalyst.metrics._segmentation import (
RegionBasedMetric,
IOUMetric,
Expand Down
64 changes: 64 additions & 0 deletions catalyst/metrics/_r2_squared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Optional

import torch

from catalyst.metrics._metric import ICallbackLoaderMetric


class R2Squared(ICallbackLoaderMetric):
"""This metric accumulates r2 score along loader

Args:
compute_on_call: if True, allows compute metric's value on call
prefix: metric prefix
suffix: metric suffix
"""

def __init__(
self,
compute_on_call: bool = True,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
) -> None:
"""Init R2Squared"""
super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix)
self.metric_name = f"{self.prefix}r2squared{self.suffix}"
self.num_examples = 0
self.delta_sum = 0
self.y_sum = 0
self.y_sq_sum = 0

def reset(self, num_batches: int, num_samples: int) -> None:
"""
Reset metrics fields
"""
self.num_examples = 0
self.delta_sum = 0
self.y_sum = 0
self.y_sq_sum = 0

def update(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> None:
"""
Update accumulated data with new batch
"""
self.num_examples += len(y_true)
self.delta_sum += torch.sum(torch.pow(y_pred - y_true, 2))
self.y_sum += torch.sum(y_true)
self.y_sq_sum += torch.sum(torch.pow(y_true, 2))

def compute(self) -> torch.Tensor:
"""
Return accumulated metric
"""
return 1 - self.delta_sum / (self.y_sq_sum - (self.y_sum ** 2) / self.num_examples)

def compute_key_value(self) -> torch.Tensor:
Scitator marked this conversation as resolved.
Show resolved Hide resolved
"""
Return key-value
"""
r2squared = self.compute()
output = {self.metric_name: r2squared}
return output


__all__ = ["R2Squared"]
1 change: 1 addition & 0 deletions catalyst/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from catalyst.metrics.functional._mrr import reciprocal_rank, mrr
from catalyst.metrics.functional._ndcg import dcg, ndcg
from catalyst.metrics.functional._precision import precision
from catalyst.metrics.functional._r2_squared import r2_squared
from catalyst.metrics.functional._recall import recall
from catalyst.metrics.functional._segmentation import (
iou,
Expand Down
50 changes: 50 additions & 0 deletions catalyst/metrics/functional/_r2_squared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Sequence

import torch


def r2_squared(outputs: torch.Tensor, targets: torch.Tensor) -> Sequence[torch.Tensor]:
"""
Computes regression r2 squared.

Args:
outputs: model outputs
with shape [bs; 1]
targets: ground truth
with shape [bs; 1]

Returns:
float of computed r2 squared

Examples:

.. code-block:: python

import torch
from catalyst import metrics
metrics.r2_squared(
outputs=torch.tensor([0, 1, 2]),
targets=torch.tensor([0, 1, 2]),
)
# tensor([1.])


.. code-block:: python

import torch
from catalyst import metrics
metrics.r2_squared(
outputs=torch.tensor([2.5, 0.0, 2, 8]),
targets=torch.tensor([3, -0.5, 2, 7]),
)
# tensor([0.9486])
"""
total_sum_of_squares = torch.sum(
torch.pow(targets.float() - torch.mean(targets.float()), 2)
).view(-1)
residual_sum_of_squares = torch.sum(torch.pow(targets.float() - outputs.float(), 2)).view(-1)
output = 1 - residual_sum_of_squares / total_sum_of_squares
return output


__all__ = ["r2_squared"]
14 changes: 14 additions & 0 deletions docs/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ RecSys – NDCGMetric
:undoc-members:
:show-inheritance:

Regression – R2Squared
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.metrics._r2_squared.R2Squared
:exclude-members: __init__
:undoc-members:
:show-inheritance:

Segmentation – RegionBasedMetric
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.metrics._segmentation.RegionBasedMetric
Expand Down Expand Up @@ -272,6 +279,13 @@ Precision
:undoc-members:
:show-inheritance:

R2Squared
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: catalyst.metrics.functional._r2_squared
:members:
:undoc-members:
:show-inheritance:

Recall
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: catalyst.metrics.functional._recall
Expand Down
16 changes: 16 additions & 0 deletions tests/catalyst/metrics/functional/test_r2_squared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# flake8: noqa
import numpy as np

import torch

from catalyst.metrics.functional._r2_squared import r2_squared


def test_r2_squared():
"""
Tests for catalyst.metrics.r2_squared metric.
"""
y_true = torch.tensor([3, -0.5, 2, 7])
y_pred = torch.tensor([2.5, 0.0, 2, 8])
val = r2_squared(y_pred, y_true)
assert torch.isclose(val, torch.Tensor([0.9486]))
83 changes: 83 additions & 0 deletions tests/catalyst/metrics/test_r2squared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# flake8: noqa
from typing import Dict, Iterable, Union

import pytest

import torch

from catalyst.metrics._r2_squared import R2Squared


@pytest.mark.parametrize(
"outputs,targets,true_values",
(
(
torch.Tensor([2.5, 0.0, 2, 8]),
torch.Tensor([3, -0.5, 2, 7]),
{
"r2squared": torch.Tensor([0.9486]),
},
),
),
)
def test_r2_squared(
outputs: torch.Tensor,
targets: torch.Tensor,
true_values: Dict[str, torch.Tensor],
) -> None:
"""
Test r2 squared metric

Args:
outputs: tensor of outputs
targets: tensor of targets
true_values: true metric values
"""
metric = R2Squared()
metric.update(y_pred=outputs, y_true=targets)
metrics = metric.compute_key_value()
for key in true_values.keys():
assert torch.isclose(true_values[key], metrics[key])


@pytest.mark.parametrize(
"outputs_list,targets_list,true_values",
(
(
(
torch.Tensor([2.5, 0.0, 2, 8]),
torch.Tensor([2.5, 0.0, 2, 8]),
torch.Tensor([2.5, 0.0, 2, 8]),
torch.Tensor([2.5, 0.0, 2, 8]),
),
(
torch.Tensor([3, -0.5, 2, 7]),
torch.Tensor([3, -0.5, 2, 7]),
torch.Tensor([3, -0.5, 2, 7]),
torch.Tensor([3, -0.5, 2, 7]),
),
{
"r2squared": torch.Tensor([0.9486]),
},
),
),
)
def test_r2_squared_update(
outputs_list: Iterable[torch.Tensor],
targets_list: Iterable[torch.Tensor],
true_values: Dict[str, torch.Tensor],
):
"""
Test r2 squared metric computation

Args:
outputs_list: list of outputs
targets_list: list of targets
true_values: true metric values
"""
metric = R2Squared()
for outputs, targets in zip(outputs_list, targets_list):
metric.update(y_pred=outputs, y_true=targets)
metrics = metric.compute_key_value()
for key in true_values.keys():
assert torch.isclose(true_values[key], metrics[key])
1 change: 1 addition & 0 deletions tests/pipelines/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def train_experiment(device, engine=None):
minimize_valid_metric=True,
num_epochs=1,
verbose=False,
callbacks=[dl.R2SquaredCallback(input_key="logits", target_key="targets")],
)


Expand Down