Skip to content

Commit

Permalink
Add half precision testing [2/n] (#135)
Browse files Browse the repository at this point in the history
* ssim

* explained_variance

* r2

* changelog

* minimum

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
SkafteNicki and Borda authored Mar 29, 2021
1 parent 53d5701 commit b7312a2
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 9 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))


- Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77))
- Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77),
[#135](https://github.com/PyTorchLightning/metrics/pull/135)
)


- Added `BootStrapper` to easely calculate confidence intervals for metrics ([#101](https://github.com/PyTorchLightning/metrics/pull/101))
Expand Down
1 change: 1 addition & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ the following limitations:
but they are also listed below:

- :ref:`references/modules:PSNR` and :ref:`references/functional:psnr [func]`
- :ref:`references/modules:SSIM` and :ref:`references/functional:ssim [func]`

******************
Metric Arithmetics
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def _assert_half_support(
y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device)
y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device)
metric_module = metric_module.to(device)
assert metric_module(y_hat, y)
assert metric_functional(y_hat, y)
_assert_tensor(metric_module(y_hat, y))
_assert_tensor(metric_functional(y_hat, y))


class MetricTester:
Expand Down
11 changes: 11 additions & 0 deletions tests/regression/test_explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import explained_variance
from torchmetrics.regression import ExplainedVariance
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

Expand Down Expand Up @@ -84,6 +85,16 @@ def test_explained_variance_functional(self, multioutput, preds, target, sk_metr
metric_args=dict(multioutput=multioutput),
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6'
)
def test_explained_variance_half_cpu(self, multioutput, preds, target, sk_metric):
self.run_precision_test_cpu(preds, target, ExplainedVariance, explained_variance)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_explained_variance_half_gpu(self, multioutput, preds, target, sk_metric):
self.run_precision_test_gpu(preds, target, ExplainedVariance, explained_variance)


def test_error_on_different_shape(metric_class=ExplainedVariance):
metric = metric_class()
Expand Down
13 changes: 13 additions & 0 deletions tests/regression/test_r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import r2score
from torchmetrics.regression import R2Score
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

Expand Down Expand Up @@ -92,6 +93,18 @@ def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, nu
metric_args=dict(adjusted=adjusted, multioutput=multioutput),
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6'
)
def test_r2_half_cpu(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
self.run_precision_test_cpu(preds, target, partial(R2Score, num_outputs=num_outputs), r2score,
{'adjusted': adjusted, 'multioutput': multioutput})

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_r2_half_gpu(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
self.run_precision_test_gpu(preds, target, partial(R2Score, num_outputs=num_outputs), r2score,
{'adjusted': adjusted, 'multioutput': multioutput})


def test_error_on_different_shape(metric_class=R2Score):
metric = metric_class()
Expand Down
9 changes: 9 additions & 0 deletions tests/regression/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def test_ssim_functional(self, preds, target, multichannel):
metric_args={"data_range": 1.0},
)

# SSIM half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="SSIM metric does not support cpu + half precision")
def test_ssim_half_cpu(self, preds, target, multichannel):
self.run_precision_test_cpu(preds, target, SSIM, ssim, {"data_range": 1.0})

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_ssim_half_gpu(self, preds, target, multichannel):
self.run_precision_test_gpu(preds, target, SSIM, ssim, {"data_range": 1.0})


@pytest.mark.parametrize(
['pred', 'target', 'kernel', 'sigma'],
Expand Down
9 changes: 5 additions & 4 deletions torchmetrics/functional/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ def _explained_variance_update(preds: Tensor, target: Tensor) -> Tuple[int, Tens

n_obs = preds.size(0)
sum_error = torch.sum(target - preds, dim=0)
sum_squared_error = torch.sum((target - preds)**2, dim=0)
diff = target - preds
sum_squared_error = torch.sum(diff * diff, dim=0)

sum_target = torch.sum(target, dim=0)
sum_squared_target = torch.sum(target**2, dim=0)
sum_squared_target = torch.sum(target * target, dim=0)

return n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target

Expand All @@ -41,10 +42,10 @@ def _explained_variance_compute(
multioutput: str = "uniform_average",
) -> Union[Tensor, Sequence[Tensor]]:
diff_avg = sum_error / n_obs
numerator = sum_squared_error / n_obs - diff_avg**2
numerator = sum_squared_error / n_obs - (diff_avg * diff_avg)

target_avg = sum_target / n_obs
denominator = sum_squared_target / n_obs - target_avg**2
denominator = sum_squared_target / n_obs - (target_avg * target_avg)

# Take care of division by zero
nonzero_numerator = numerator != 0
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/functional/regression/r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def _r2score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tens
raise ValueError('Needs at least two samples to calculate r2 score.')

sum_error = torch.sum(target, dim=0)
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
residual = torch.sum(torch.pow(target - preds, 2.0), dim=0)
sum_squared_error = torch.sum(target * target, dim=0)
diff = target - preds
residual = torch.sum(diff * diff, dim=0)
total = target.size(0)

return sum_squared_error, sum_error, residual, total
Expand Down

0 comments on commit b7312a2

Please sign in to comment.