Skip to content

Commit

Permalink
Metric bootstrapper (#101)
Browse files Browse the repository at this point in the history
* add bootstrapping

* tests

* pep8

* move args to init

* fix tests

* fix tests

* mypy

* remove pdb

* add bootstrapping

* tests

* pep8

* move args to init

* fix tests

* fix tests

* mypy

* remove pdb

* versions

* versions

* Update docs/source/references/modules.rst

Co-authored-by: thomas chaton <[email protected]>

* isort

* Apply suggestions from code review

* Update torchmetrics/wrappers/bootstrapping.py

Co-authored-by: Jirka Borovec <[email protected]>

* update

* update

* add poisson

* pep8

* revert

* link

* isort

* roc changes remove

* fix

* fix tests

* pep8

* Apply suggestions from code review

* suggestions

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Mar 27, 2021
1 parent 386b7e4 commit 165cff0
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77))


- Added `BootStrapper` to easely calculate confidence intervals for metrics ([#101](https://github.com/PyTorchLightning/metrics/pull/101))


### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
11 changes: 11 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,14 @@ RetrievalMRR

.. autoclass:: torchmetrics.RetrievalMRR
:noindex:


********
Wrappers
********

Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic
of the base metric.

.. autoclass:: torchmetrics.BootStrapper
:noindex:
Empty file added tests/wrappers/__init__.py
Empty file.
108 changes: 108 additions & 0 deletions tests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator

import numpy as np
import pytest
import torch
from sklearn.metrics import precision_score, recall_score
from torch import Tensor

from torchmetrics.classification import Precision, Recall
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler

_preds = torch.randint(10, (10, 32))
_target = torch.randint(10, (10, 32))


class TestBootStrapper(BootStrapper):
""" For testing purpose, we subclass the bootstrapper class so we can get the exact permutation
the class is creating
"""
def update(self, *args) -> None:
self.out = []
for idx in range(self.num_bootstraps):
size = len(args[0])
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy)
new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx)
self.metrics[idx].update(*new_args)
self.out.append(new_args)


def _sample_checker(old_samples, new_samples, op: operator, threshold: int):
found_one = False
for os in old_samples:
cond = op(os, new_samples)
if cond.sum() > threshold:
found_one = True
break
return found_one


@pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial'])
def test_bootstrap_sampler(sampling_strategy):
""" make sure that the bootstrap sampler works as intended """
old_samples = torch.randn(10, 2)

# make sure that the new samples are only made up of old samples
idx = _bootstrap_sampler(10, sampling_strategy=sampling_strategy)
new_samples = old_samples[idx]
for ns in new_samples:
assert ns in old_samples

found_one = _sample_checker(old_samples, new_samples, operator.eq, 2)
assert found_one, "resampling did not work because no samples were sampled twice"

found_zero = _sample_checker(old_samples, new_samples, operator.ne, 0)
assert found_zero, "resampling did not work because all samples were atleast sampled once"


@pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial'])
@pytest.mark.parametrize(
"metric, sk_metric", [[Precision(average='micro'), precision_score], [Recall(average='micro'), recall_score]]
)
def test_bootstrap(sampling_strategy, metric, sk_metric):
""" Test that the different bootstraps gets updated as we expected and that the compute method works """
_kwargs = {'base_metric': metric, 'mean': True, 'std': True, 'raw': True, 'sampling_strategy': sampling_strategy}
if _TORCH_GREATER_EQUAL_1_7:
_kwargs.update(dict(quantile=torch.tensor([0.05, 0.95])))

bootstrapper = TestBootStrapper(**_kwargs)

collected_preds = [[] for _ in range(10)]
collected_target = [[] for _ in range(10)]
for p, t in zip(_preds, _target):
bootstrapper.update(p, t)

for i, o in enumerate(bootstrapper.out):

collected_preds[i].append(o[0])
collected_target[i].append(o[1])

collected_preds = [torch.cat(cp) for cp in collected_preds]
collected_target = [torch.cat(ct) for ct in collected_target]

sk_scores = [sk_metric(ct, cp, average='micro') for ct, cp in zip(collected_target, collected_preds)]

output = bootstrapper.compute()
# quantile only avaible for pytorch v1.7 and forward
if _TORCH_GREATER_EQUAL_1_7:
assert np.allclose(output['quantile'][0], np.quantile(sk_scores, 0.05))
assert np.allclose(output['quantile'][1], np.quantile(sk_scores, 0.95))

assert np.allclose(output['mean'], np.mean(sk_scores))
assert np.allclose(output['std'], np.std(sk_scores, ddof=1))
assert np.allclose(output['raw'], sk_scores)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@
R2Score,
)
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR # noqa: F401 E402
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
9 changes: 5 additions & 4 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class MetricCollection(nn.ModuleDict):
Example:
>>> # input as list
>>> import torch
>>> from pprint import pprint
>>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall
>>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
>>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
Expand All @@ -62,10 +63,10 @@ class MetricCollection(nn.ModuleDict):
>>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
... 'macro_recall': Recall(num_classes=3, average='macro')})
>>> same_metric = metrics.clone()
>>> metrics(preds, target)
{'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)}
>>> same_metric(preds, target)
{'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)}
>>> pprint(metrics(preds, target))
{'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
>>> pprint(same_metric(preds, target))
{'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
>>> metrics.persistent()
"""
Expand Down
12 changes: 7 additions & 5 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import utilities"""
import operator
from distutils.version import LooseVersion
from importlib import import_module
from importlib.util import find_spec

import torch
from pkg_resources import DistributionNotFound


Expand Down Expand Up @@ -60,7 +61,8 @@ def _compare_version(package: str, op, version) -> bool:
return op(pkg_version, LooseVersion(version))


_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")
_TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0")
_TORCH_GREATER_EQUAL_1_6 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
_TORCH_LOWER_1_4 = _compare_version("torch", operator.lt, "1.4.0")
_TORCH_LOWER_1_5 = _compare_version("torch", operator.lt, "1.5.0")
_TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
14 changes: 14 additions & 0 deletions torchmetrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401
174 changes: 174 additions & 0 deletions torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Union

import torch
from torch import Tensor, nn

from torchmetrics.metric import Metric
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7


def _bootstrap_sampler(
size: int,
sampling_strategy: str = 'poisson'
) -> Tensor:
""" Resample a tensor along its first dimension with replacement
Args:
size: number of samples
sampling_strategy: the strategy to use for sampling, either ``'poisson'`` or ``'multinomial'``
generator: a instance of ``torch.Generator`` that controls the sampling
Returns:
resampled tensor
"""
if sampling_strategy == 'poisson':
p = torch.distributions.Poisson(1)
n = p.sample((size,))
return torch.arange(size).repeat_interleave(n.long(), dim=0)
elif sampling_strategy == 'multinomial':
idx = torch.multinomial(
torch.ones(size),
num_samples=size,
replacement=True
)
return idx
raise ValueError('Unknown sampling strategy')


class BootStrapper(Metric):

def __init__(
self,
base_metric: Metric,
num_bootstraps: int = 10,
mean: bool = True,
std: bool = True,
quantile: Optional[Union[float, Tensor]] = None,
raw: bool = False,
sampling_strategy: str = 'poisson',
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None
) -> None:
r"""
Use to turn a metric into a `bootstrapped <https://en.wikipedia.org/wiki/Bootstrapping_(statistics)>`_
metric that can automate the process of getting confidence intervals for metric values. This wrapper
class basically keeps multiple copies of the same base metric in memory and whenever ``update`` or
``forward`` is called, all input tensors are resampled (with replacement) along the first dimension.
Args:
base_metric:
base metric class to wrap
num_bootstraps:
number of copies to make of the base metric for bootstrapping
mean:
if ``True`` return the mean of the bootstraps
std:
if ``True`` return the standard diviation of the bootstraps
quantile:
if given, returns the quantile of the bootstraps. Can only be used with
pytorch version 1.6 or higher
raw:
if ``True``, return all bootstrapped values
sampling_strategy:
Determines how to produce bootstrapped samplings. Either ``'poisson'`` or ``multinomial``.
If ``'possion'`` is chosen, the number of times each sample will be included in the bootstrap
will be given by :math:`n\sim Poisson(\lambda=1)`, which approximates the true bootstrap distribution
when the number of samples is large. If ``'multinomial'`` is chosen, we will apply true bootstrapping
at the batch level to approximate bootstrapping over the hole dataset.
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather.
Example::
>>> from pprint import pprint
>>> from torchmetrics import Accuracy, BootStrapper
>>> _ = torch.manual_seed(123)
>>> base_metric = Accuracy()
>>> bootstrap = BootStrapper(base_metric, num_bootstraps=20)
>>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,)))
>>> output = bootstrap.compute()
>>> pprint(output)
{'mean': tensor(0.2205), 'std': tensor(0.0859)}
"""
super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
if not isinstance(base_metric, Metric):
raise ValueError(
"Expected base metric to be an instance of torchmetrics.Metric"
f" but received {base_metric}"
)

self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)])
self.num_bootstraps = num_bootstraps

self.mean = mean
self.std = std
if quantile is not None and not _TORCH_GREATER_EQUAL_1_7:
raise ValueError('quantile argument can only be used with pytorch v1.7 or higher')
self.quantile = quantile
self.raw = raw

allowed_sampling = ('poisson', 'multinomial')
if sampling_strategy not in allowed_sampling:
raise ValueError(
f"Expected argument ``sampling_strategy`` to be one of {allowed_sampling}"
f" but recieved {sampling_strategy}"
)
self.sampling_strategy = sampling_strategy

def update(self, *args: Any, **kwargs: Any) -> None:
""" Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """
for idx in range(self.num_bootstraps):
args_sizes = apply_to_collection(args, Tensor, len)
kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len))
if len(args_sizes) > 0:
size = args_sizes[0]
elif len(kwargs_sizes) > 0:
size = kwargs_sizes[0]
else:
raise ValueError('None of the input contained tensors, so could not determine the sampling size')
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy)
new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx)
new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx)
self.metrics[idx].update(*new_args, **new_kwargs)

def compute(self) -> Dict[str, Tensor]:
""" Computes the bootstrapped metric values. Allways returns a dict of tensors, which can contain the
following keys: ``mean``, ``std``, ``quantile`` and ``raw`` depending on how the class was initialized
"""
computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0)
output_dict = {}
if self.mean:
output_dict['mean'] = computed_vals.mean(dim=0)
if self.std:
output_dict['std'] = computed_vals.std(dim=0)
if self.quantile is not None:
output_dict['quantile'] = torch.quantile(computed_vals, self.quantile)
if self.raw:
output_dict['raw'] = computed_vals
return output_dict

0 comments on commit 165cff0

Please sign in to comment.