Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add half precision testing [1/n] #77

Merged
merged 21 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.
val3 = self.metric3['accuracy'](preds, target)
val4 = self.metric4(preds, target)

****************************
Metrics and 16-bit precision
****************************

Most metrics in our collection can be used with 16-bit precision (``torch.half``) tensors. However, we have found
the following limitations:

* In general ``pytorch`` had better support for 16-bit precision much earlier on GPU than CPU. Therefore, we
recommend that anyone that want to use metrics with half precision on CPU, upgrade to atleast pytorch v1.6
where support for operations such as addition, subtraction, multiplication ect. was added.
* Some metrics does not work at all in half precision on CPU. We have explicitly stated this in their docstring,
but they are also listed below:

- :ref:`references/modules:PSNR` and :ref:`references/functional:psnr [func]`


******************
Metric Arithmetics
******************
Expand Down
64 changes: 64 additions & 0 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,30 @@ def _functional_test(
_assert_allclose(lightning_result, sk_result, atol=atol)


def _assert_half_support(
metric_module: Metric,
metric_functional: Callable,
preds: torch.Tensor,
target: torch.Tensor,
device: str = 'cpu'
):
"""
Test if an metric can be used with half precision tensors

Args:
metric_module: the metric module to test
metric_functional: the metric functional to test
preds: torch tensor with predictions
target: torch tensor with targets
device: determine device, either "cpu" or "cuda"
"""
p = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device)
t = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device)
metric_module = metric_module.to(device)
assert metric_module(p, t)
assert metric_functional(p, t)
Borda marked this conversation as resolved.
Show resolved Hide resolved


class MetricTester:
"""Class used for efficiently run alot of parametrized tests in ddp mode.
Makes sure that ddp is only setup once and that pool of processes are
Expand Down Expand Up @@ -283,6 +307,46 @@ def run_class_metric_test(
atol=self.atol,
)

def run_precision_test_cpu(
self, preds: torch.Tensor, target: torch.Tensor,
metric_module: Metric, metric_functional: Callable, metric_args: dict = {}
):
""" Test if an metric can be used with half precision tensors on cpu
Args:
preds: torch tensor with predictions
target: torch tensor with targets
metric_module: the metric module to test
metric_functional: the metric functional to test
metric_args: dict with additional arguments used for class initialization
"""
_assert_half_support(
metric_module(**metric_args),
partial(metric_functional, **metric_args),
preds,
target,
device='cpu'
)

def run_precision_test_gpu(
self, preds: torch.Tensor, target: torch.Tensor,
metric_module: Metric, metric_functional: Callable, metric_args: dict = {}
):
""" Test if an metric can be used with half precision tensors on gpu
Args:
preds: torch tensor with predictions
target: torch tensor with targets
metric_module: the metric module to test
metric_functional: the metric functional to test
metric_args: dict with additional arguments used for class initialization
"""
_assert_half_support(
metric_module(**metric_args),
partial(metric_functional, **metric_args),
preds,
target,
device='cuda'
)


class DummyMetric(Metric):
name = "Dummy"
Expand Down
15 changes: 15 additions & 0 deletions tests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError
from torchmetrics.utilities import _TORCH_GREATER_EQUAL_1_6
Borda marked this conversation as resolved.
Show resolved Hide resolved

torch.manual_seed(42)

Expand Down Expand Up @@ -92,6 +93,20 @@ def test_mean_error_functional(self, preds, target, sk_metric, metric_class, met
sk_metric=partial(sk_metric, sk_fn=sk_fn),
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we really should skip here or show a more meaningful error message.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justusschock you got any proposal?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justusschock more meaningful message in test or the class itself? wouldn't be easier to test for half precision tensors in the base metric class and check if torch version >= 1.6

reason='half support of core operations on not support before pytorch v1.6'
)
def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn):
if metric_class == MeanSquaredLogError:
# MeanSquaredLogError half + cpu does not work due to missing support in torch.log
pytest.xfail("MeanSquaredLogError metric does not support cpu + half precision")
self.run_precision_test_cpu(preds, target, metric_class, metric_functional)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn):
self.run_precision_test_gpu(preds, target, metric_class, metric_functional)


@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError])
def test_error_on_different_shape(metric_class):
Expand Down
11 changes: 11 additions & 0 deletions tests/regression/test_psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduc
metric_args=_args,
)

# PSNR half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision")
def test_psnr_half_cpu(self, preds, target, data_range, reduction, dim, base, sk_metric):
self.run_precision_test_cpu(preds, target, PSNR, psnr,
{"data_range": data_range, "base": base, "reduction": reduction, "dim": dim})

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_psnr_half_gpu(self, preds, target, data_range, reduction, dim, base, sk_metric):
self.run_precision_test_gpu(preds, target, PSNR, psnr,
{"data_range": data_range, "base": base, "reduction": reduction, "dim": dim})


@pytest.mark.parametrize("reduction", ["none", "sum"])
def test_reduction_for_dim_none(reduction):
Expand Down
3 changes: 2 additions & 1 deletion torchmetrics/functional/regression/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
_check_same_shape(preds, target)
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
diff = preds - target
sum_squared_error = torch.sum(diff * diff)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
n_obs = target.numel()
return sum_squared_error, n_obs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.T
>>> mean_squared_log_error(x, y)
tensor(0.0207)

.. note::
Half precision is only support on GPU for this metric

"""
sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target)
return _mean_squared_log_error_compute(sum_squared_log_error, n_obs)
6 changes: 5 additions & 1 deletion torchmetrics/functional/regression/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def _psnr_update(preds: torch.Tensor,
n_obs = torch.tensor(target.numel(), device=target.device)
return sum_squared_error, n_obs

sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim)
diff = preds - target
sum_squared_error = torch.sum(diff * diff, dim=dim)

if isinstance(dim, int):
dim_list = [dim]
Expand Down Expand Up @@ -90,6 +91,9 @@ def psnr(
>>> psnr(pred, target)
tensor(2.5527)

.. note::
Half precision is only support on GPU for this metric

"""
if dim is None and reduction != 'elementwise_mean':
rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')
Expand Down
3 changes: 3 additions & 0 deletions torchmetrics/regression/mean_squared_log_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class MeanSquaredLogError(Metric):
>>> mean_squared_log_error(preds, target)
tensor(0.0397)

.. note::
Half precision is only support on GPU for this metric

"""

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions torchmetrics/regression/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class PSNR(Metric):
>>> psnr(preds, target)
tensor(2.5527)

.. note::
Half precision is only support on GPU for this metric

"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from torchmetrics.utilities.data import apply_to_collection # noqa: F401
from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 # noqa: F401
Borda marked this conversation as resolved.
Show resolved Hide resolved
from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401
44 changes: 44 additions & 0 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import utilities"""
import importlib
import operator
from distutils.version import LooseVersion
from importlib.util import find_spec

from pkg_resources import DistributionNotFound


def _compare_version(package: str, op, version) -> bool:
"""
Compare package version with some requirements

>>> _compare_version("torch", operator.ge, "0.1")
True
"""
try:
pkg = importlib.import_module(package)
except (ModuleNotFoundError, DistributionNotFound):
return False
try:
pkg_version = LooseVersion(pkg.__version__)
except AttributeError:
return False
if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")):
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
return op(pkg_version, LooseVersion(version))


_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")