Skip to content

Commit

Permalink
add MultiTaskMetrics
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed May 8, 2024
1 parent 48ef85b commit b1b8a21
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 10 deletions.
3 changes: 2 additions & 1 deletion danling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from danling import metrics, modules, optim, registry, runner, tensors, typing, utils

from .metrics import AverageMeter, MultiTaskAverageMeter
from .metrics import AverageMeter, Metrics, MultiTaskAverageMeter, MultiTaskMetrics
from .registry import GlobalRegistry, Registry
from .runner import AccelerateRunner, BaseRunner, TorchRunner
from .tensors import NestedTensor, PNTensor
Expand All @@ -21,6 +21,7 @@
"Registry",
"GlobalRegistry",
"Metrics",
"MultiTaskMetrics",
"AverageMeter",
"MultiTaskAverageMeter",
"NestedTensor",
Expand Down
3 changes: 2 additions & 1 deletion danling/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

with try_import():
from .functional import accuracy, auprc, auroc, matthews_corrcoef, pearson, r2_score, rmse, spearman
from .metrics import Metrics
from .metrics import Metrics, MultiTaskMetrics

__all__ = [
"Metrics",
"MultiTaskMetrics",
"AverageMeter",
"MultiTaskAverageMeter",
"regression_metrics",
Expand Down
102 changes: 94 additions & 8 deletions danling/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from danling.tensors import NestedTensor

from .multitask import MultiTaskDict
from .utils import flist, get_world_size

try:
Expand All @@ -26,10 +27,13 @@ class Metrics(Metric):
Metric class wraps around multiple metrics that share the same states.
Typically, there are many metrics that we want to compute for a single task.
For example, we usually needs to compute `accuracy`, `auroc`, `auprc` for a classification task.
Computing them one by one is inefficient, especially when evaluating in a distributed environment.
For example, we usually needs to compute `pearson` and `spearman` for a regression task.
Unlike `accuracy`, which can uses an average meter to compute the average accuracy,
`pearson` and `spearman` cannot be computed by averaging the results of multiple batches.
They need access to all the data to compute the correct results.
And saving all intermediate results for each tasks is quite inefficient.
To solve this problem, Metrics maintains a shared state for multiple metric functions.
`Metrics` solves this problem by maintaining a shared state for multiple metric functions.
Attributes:
metrics: A dictionary of metrics to be computed.
Expand Down Expand Up @@ -98,9 +102,8 @@ class Metrics(Metric):
('auroc'): 0.6666666666666666
('auprc'): 0.5555555820465088
)
>>> print(f"{metrics:.4f}")
auroc: 0.6667 (0.6667)
auprc: 0.5000 (0.5556)
>>> f"{metrics:.4f}"
'auroc: 0.6667 (0.6667)\tauprc: 0.5000 (0.5556)'
"""

metrics: FlatDict[str, Callable]
Expand Down Expand Up @@ -307,12 +310,12 @@ def __repr__(self):

def __format__(self, format_spec):
val, avg = self.value(), self.average()
return "\n".join(
return "\t".join(
[f"{key}: {val[key].__format__(format_spec)} ({avg[key].__format__(format_spec)})" for key in self.metrics]
)

def reset(self: Self) -> Self: # pragma: no cover
"""
r"""
Reset the metric state variables to their default value.
The tensors in the default values are also moved to the device of
the last ``self.to(device)`` call.
Expand Down Expand Up @@ -405,3 +408,86 @@ def score_name(self, name) -> None:
if name not in self.metrics:
raise ValueError(f"score_name must be in {self.metrics.keys()}, but got {name}")
self._score_name = name


class MultiTaskMetrics(MultiTaskDict):
r"""
Examples:
>>> from danling.metrics.functional import auroc, auprc, pearson, spearman, accuracy, matthews_corrcoef
>>> metrics = MultiTaskMetrics()
>>> metrics.dataset1.cls = Metrics(auroc=auroc, auprc=auprc)
>>> metrics.dataset1.reg = Metrics(pearson=pearson, spearman=spearman)
>>> metrics.dataset2 = Metrics(auroc=auroc, auprc=auprc)
>>> metrics
MultiTaskMetrics(<class 'danling.metrics.metrics.MultiTaskMetrics'>,
('dataset1'): MultiTaskMetrics(<class 'danling.metrics.metrics.MultiTaskMetrics'>,
('cls'): Metrics('auroc', 'auprc')
('reg'): Metrics('pearson', 'spearman')
)
('dataset2'): Metrics('auroc', 'auprc')
)
>>> metrics.update({"dataset1.cls": {"input": [0.2, 0.4, 0.5, 0.7], "target": [0, 1, 0, 1]}, "dataset1.reg": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0.2, 0.3, 0.5, 0.7]}, "dataset2": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 1, 0, 1]}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.7500 (0.7500)\tauprc: 0.8333 (0.8333)\ndataset1.reg: pearson: 0.9691 (0.9691)\tspearman: 1.0000 (1.0000)\ndataset2: auroc: 0.7500 (0.7500)\tauprc: 0.8333 (0.8333)'
>>> metrics.setattr("return_average", True)
>>> metrics.update({"dataset1.cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 0, 1, 0]}, "dataset1.reg": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0.2, 0.4, 0.6, 0.8]}, "dataset2": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0, 0, 1, 0]}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.6667 (0.7000)\tauprc: 0.5000 (0.5556)\ndataset1.reg: pearson: 0.9898 (0.9146)\tspearman: 1.0000 (0.9222)\ndataset2: auroc: 0.6667 (0.7333)\tauprc: 0.5000 (0.7000)'
""" # noqa: E501

def __init__(self, *args, **kwargs):
super().__init__(*args, default_factory=MultiTaskMetrics, **kwargs)

def update(self, values: Mapping[str, Mapping[str, Tensor]]) -> None: # pylint: disable=W0237
r"""
Updates the average and current value in all metrics.
Args:
values: Dict of values to be added to the average.
Raises:
ValueError: If the value is not an instance of (Mapping).
Examples:
>>> from danling.metrics.functional import auroc, auprc, pearson, spearman
>>> metrics = MultiTaskMetrics()
>>> metrics.dataset1.cls = Metrics(auroc=auroc, auprc=auprc)
>>> metrics.dataset1.reg = Metrics(pearson=pearson, spearman=spearman)
>>> metrics.dataset2 = Metrics(auroc=auroc, auprc=auprc)
>>> metrics.update({"dataset1.cls": {"input": [0.2, 0.4, 0.5, 0.7], "target": [0, 1, 0, 1]}, "dataset1.reg": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0.2, 0.3, 0.5, 0.7]}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.7500 (0.7500)\tauprc: 0.8333 (0.8333)\ndataset1.reg: pearson: 0.9691 (0.9691)\tspearman: 1.0000 (1.0000)\ndataset2: auroc: nan (nan)\tauprc: nan (nan)'
>>> metrics.update({"dataset2": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 1, 0, 1]}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.7500 (0.7500)\tauprc: 0.8333 (0.8333)\ndataset1.reg: pearson: 0.9691 (0.9691)\tspearman: 1.0000 (1.0000)\ndataset2: auroc: 0.7500 (0.7500)\tauprc: 0.8333 (0.8333)'
>>> metrics.update({"dataset1.cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 0, 1, 0]}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.6667 (0.7000)\tauprc: 0.5000 (0.5556)\ndataset1.reg: pearson: 0.9691 (0.9691)\tspearman: 1.0000 (1.0000)\ndataset2: auroc: 0.7500 (0.7500)\tauprc: 0.8333 (0.8333)'
>>> metrics.update({"dataset1.reg": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0.2, 0.4, 0.6, 0.8]}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.6667 (0.7000)\tauprc: 0.5000 (0.5556)\ndataset1.reg: pearson: 0.9898 (0.9146)\tspearman: 1.0000 (0.9222)\ndataset2: auroc: 0.7500 (0.7500)\tauprc: 0.8333 (0.8333)'
>>> metrics.update({"dataset1": {"cls": {"input": [0.1, 0.4, 0.6, 0.8]}}})
Traceback (most recent call last):
ValueError: Expected values to be a flat dictionary, but got <class 'dict'>
This is likely due to nested dictionary in the values.
Nested dictionaries cannot be processed due to the method's design, which uses Mapping to pass both input and target. Ensure your input is a flat dictionary or a single value.
>>> metrics.update(dict(loss=""))
Traceback (most recent call last):
ValueError: Expected values to be a flat dictionary, but got <class 'str'>
""" # noqa: E501

for metric, value in values.items():
if isinstance(value, Mapping):
if metric not in self:
raise ValueError(f"Metric {metric} not found in {self}")
try:
self[metric].update(**value)
except TypeError:
raise ValueError(
f"Expected values to be a flat dictionary, but got {type(value)}\n"
"This is likely due to nested dictionary in the values.\n"
"Nested dictionaries cannot be processed due to the method's design, which uses Mapping "
"to pass both input and target. Ensure your input is a flat dictionary or a single value."
) from None
else:
raise ValueError(f"Expected values to be a flat dictionary, but got {type(value)}")

0 comments on commit b1b8a21

Please sign in to comment.