Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add half precision testing [2/n] #135

Merged
merged 7 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))


- Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77))
- Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77),
[#135](https://github.com/PyTorchLightning/metrics/pull/135)
)


- Added `BootStrapper` to easely calculate confidence intervals for metrics ([#101](https://github.com/PyTorchLightning/metrics/pull/101))
Expand Down
1 change: 1 addition & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ the following limitations:
but they are also listed below:

- :ref:`references/modules:PSNR` and :ref:`references/functional:psnr [func]`
- :ref:`references/modules:SSIM` and :ref:`references/functional:ssim [func]`

******************
Metric Arithmetics
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def _assert_half_support(
y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device)
y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device)
metric_module = metric_module.to(device)
assert metric_module(y_hat, y)
assert metric_functional(y_hat, y)
_assert_tensor(metric_module(y_hat, y))
_assert_tensor(metric_functional(y_hat, y))


class MetricTester:
Expand Down
11 changes: 11 additions & 0 deletions tests/regression/test_explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import explained_variance
from torchmetrics.regression import ExplainedVariance
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

Expand Down Expand Up @@ -84,6 +85,16 @@ def test_explained_variance_functional(self, multioutput, preds, target, sk_metr
metric_args=dict(multioutput=multioutput),
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6'
)
def test_explained_variance_half_cpu(self, multioutput, preds, target, sk_metric):
self.run_precision_test_cpu(preds, target, ExplainedVariance, explained_variance)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_explained_variance_half_gpu(self, multioutput, preds, target, sk_metric):
self.run_precision_test_gpu(preds, target, ExplainedVariance, explained_variance)


def test_error_on_different_shape(metric_class=ExplainedVariance):
metric = metric_class()
Expand Down
13 changes: 13 additions & 0 deletions tests/regression/test_r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import r2score
from torchmetrics.regression import R2Score
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

Expand Down Expand Up @@ -92,6 +93,18 @@ def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, nu
metric_args=dict(adjusted=adjusted, multioutput=multioutput),
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6'
)
def test_r2_half_cpu(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
self.run_precision_test_cpu(preds, target, partial(R2Score, num_outputs=num_outputs), r2score,
{'adjusted': adjusted, 'multioutput': multioutput})

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_r2_half_gpu(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
self.run_precision_test_gpu(preds, target, partial(R2Score, num_outputs=num_outputs), r2score,
{'adjusted': adjusted, 'multioutput': multioutput})


def test_error_on_different_shape(metric_class=R2Score):
metric = metric_class()
Expand Down
9 changes: 9 additions & 0 deletions tests/regression/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def test_ssim_functional(self, preds, target, multichannel):
metric_args={"data_range": 1.0},
)

# SSIM half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="SSIM metric does not support cpu + half precision")
def test_ssim_half_cpu(self, preds, target, multichannel):
self.run_precision_test_cpu(preds, target, SSIM, ssim, {"data_range": 1.0})

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_ssim_half_gpu(self, preds, target, multichannel):
self.run_precision_test_gpu(preds, target, SSIM, ssim, {"data_range": 1.0})


@pytest.mark.parametrize(
['pred', 'target', 'kernel', 'sigma'],
Expand Down
9 changes: 5 additions & 4 deletions torchmetrics/functional/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ def _explained_variance_update(preds: Tensor, target: Tensor) -> Tuple[int, Tens

n_obs = preds.size(0)
sum_error = torch.sum(target - preds, dim=0)
sum_squared_error = torch.sum((target - preds)**2, dim=0)
diff = target - preds
sum_squared_error = torch.sum(diff * diff, dim=0)

sum_target = torch.sum(target, dim=0)
sum_squared_target = torch.sum(target**2, dim=0)
sum_squared_target = torch.sum(target * target, dim=0)

return n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target

Expand All @@ -41,10 +42,10 @@ def _explained_variance_compute(
multioutput: str = "uniform_average",
) -> Union[Tensor, Sequence[Tensor]]:
diff_avg = sum_error / n_obs
numerator = sum_squared_error / n_obs - diff_avg**2
numerator = sum_squared_error / n_obs - (diff_avg * diff_avg)

target_avg = sum_target / n_obs
denominator = sum_squared_target / n_obs - target_avg**2
denominator = sum_squared_target / n_obs - (target_avg * target_avg)

# Take care of division by zero
nonzero_numerator = numerator != 0
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/functional/regression/r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def _r2score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tens
raise ValueError('Needs at least two samples to calculate r2 score.')

sum_error = torch.sum(target, dim=0)
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
residual = torch.sum(torch.pow(target - preds, 2.0), dim=0)
sum_squared_error = torch.sum(target * target, dim=0)
diff = target - preds
residual = torch.sum(diff * diff, dim=0)
total = target.size(0)

return sum_squared_error, sum_error, residual, total
Expand Down