Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Mar 26, 2021
1 parent bf78331 commit 569657e
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 84 deletions.
36 changes: 12 additions & 24 deletions tests/classification/test_hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@
torch.manual_seed(42)

_input_binary = Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
)

_input_binary_single = Input(
preds=torch.randn((NUM_BATCHES, 1)),
target=torch.randint(high=2, size=(NUM_BATCHES, 1))
)
_input_binary_single = Input(preds=torch.randn((NUM_BATCHES, 1)), target=torch.randint(high=2, size=(NUM_BATCHES, 1)))

_input_multiclass = Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES),
Expand Down Expand Up @@ -67,7 +63,7 @@ def _sk_hinge(preds, target, squared, multiclass_mode):
measures = np.clip(measures, 0, None)

if squared:
measures = measures ** 2
measures = measures**2
return measures.mean(axis=0)
else:
if multiclass_mode == MulticlassMode.ONE_VS_ALL:
Expand Down Expand Up @@ -119,36 +115,28 @@ def test_hinge_fn(self, preds, target, squared, multiclass_mode):
)


_input_multi_target = Input(
preds=torch.randn(BATCH_SIZE),
target=torch.randint(high=2, size=(BATCH_SIZE, 2))
)
_input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2)))

_input_binary_different_sizes = Input(
preds=torch.randn(BATCH_SIZE * 2),
target=torch.randint(high=2, size=(BATCH_SIZE,))
preds=torch.randn(BATCH_SIZE * 2), target=torch.randint(high=2, size=(BATCH_SIZE, ))
)

_input_multi_different_sizes = Input(
preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES),
target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,))
preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES), target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, ))
)

_input_extra_dim = Input(
preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2),
target=torch.randint(high=2, size=(BATCH_SIZE,))
preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2), target=torch.randint(high=2, size=(BATCH_SIZE, ))
)


@pytest.mark.parametrize(
"preds, target, multiclass_mode",
[
(_input_multi_target.preds, _input_multi_target.target, None),
(_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None),
(_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None),
(_input_extra_dim.preds, _input_extra_dim.target, None),
(_input_multiclass.preds[0], _input_multiclass.target[0], 'invalid_mode')
],
[(_input_multi_target.preds, _input_multi_target.target, None),
(_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None),
(_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None),
(_input_extra_dim.preds, _input_extra_dim.target, None),
(_input_multiclass.preds[0], _input_multiclass.target[0], 'invalid_mode')],
)
def test_bad_inputs_fn(preds, target, multiclass_mode):
with pytest.raises(ValueError):
Expand Down
30 changes: 15 additions & 15 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _assert_half_support(
metric_functional: Callable,
preds: torch.Tensor,
target: torch.Tensor,
device: str = 'cpu'
device: str = 'cpu',
):
"""
Test if an metric can be used with half precision tensors
Expand Down Expand Up @@ -313,8 +313,12 @@ def run_class_metric_test(
)

def run_precision_test_cpu(
self, preds: torch.Tensor, target: torch.Tensor,
metric_module: Metric, metric_functional: Callable, metric_args: dict = {}
self,
preds: torch.Tensor,
target: torch.Tensor,
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = {}
):
""" Test if an metric can be used with half precision tensors on cpu
Args:
Expand All @@ -325,16 +329,16 @@ def run_precision_test_cpu(
metric_args: dict with additional arguments used for class initialization
"""
_assert_half_support(
metric_module(**metric_args),
partial(metric_functional, **metric_args),
preds,
target,
device='cpu'
metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device='cpu'
)

def run_precision_test_gpu(
self, preds: torch.Tensor, target: torch.Tensor,
metric_module: Metric, metric_functional: Callable, metric_args: dict = {}
self,
preds: torch.Tensor,
target: torch.Tensor,
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = {}
):
""" Test if an metric can be used with half precision tensors on gpu
Args:
Expand All @@ -345,11 +349,7 @@ def run_precision_test_gpu(
metric_args: dict with additional arguments used for class initialization
"""
_assert_half_support(
metric_module(**metric_args),
partial(metric_functional, **metric_args),
preds,
target,
device='cuda'
metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device='cuda'
)


Expand Down
3 changes: 1 addition & 2 deletions tests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def test_mean_error_functional(self, preds, target, sk_metric, metric_class, met
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6,
reason='half support of core operations on not support before pytorch v1.6'
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6'
)
def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn):
if metric_class == MeanSquaredLogError:
Expand Down
20 changes: 16 additions & 4 deletions tests/regression/test_psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,25 @@ def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduc
# PSNR half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision")
def test_psnr_half_cpu(self, preds, target, data_range, reduction, dim, base, sk_metric):
self.run_precision_test_cpu(preds, target, PSNR, psnr,
{"data_range": data_range, "base": base, "reduction": reduction, "dim": dim})
self.run_precision_test_cpu(
preds, target, PSNR, psnr, {
"data_range": data_range,
"base": base,
"reduction": reduction,
"dim": dim
}
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_psnr_half_gpu(self, preds, target, data_range, reduction, dim, base, sk_metric):
self.run_precision_test_gpu(preds, target, PSNR, psnr,
{"data_range": data_range, "base": base, "reduction": reduction, "dim": dim})
self.run_precision_test_gpu(
preds, target, PSNR, psnr, {
"data_range": data_range,
"base": base,
"reduction": reduction,
"dim": dim
}
)


@pytest.mark.parametrize("reduction", ["none", "sum"])
Expand Down
6 changes: 1 addition & 5 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ def _compute_sklearn_metric(


def _test_retrieval_against_sklearn(
sklearn_metric,
torch_metric,
size,
n_documents,
query_without_relevant_docs_options
sklearn_metric, torch_metric, size, n_documents, query_without_relevant_docs_options
) -> None:
""" Compare PL metrics to standard version. """
metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options)
Expand Down
6 changes: 1 addition & 5 deletions tests/retrieval/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
def test_results(size, n_documents, query_without_relevant_docs_options):
""" Test metrics are computed correctly. """
_test_retrieval_against_sklearn(
sk_average_precision,
RetrievalMAP,
size,
n_documents,
query_without_relevant_docs_options
sk_average_precision, RetrievalMAP, size, n_documents, query_without_relevant_docs_options
)


Expand Down
6 changes: 1 addition & 5 deletions tests/retrieval/test_mrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ def _reciprocal_rank(target: np.array, preds: np.array):
def test_results(size, n_documents, query_without_relevant_docs_options):
""" Test metrics are computed correctly. """
_test_retrieval_against_sklearn(
_reciprocal_rank,
RetrievalMRR,
size,
n_documents,
query_without_relevant_docs_options
_reciprocal_rank, RetrievalMRR, size, n_documents, query_without_relevant_docs_options
)


Expand Down
30 changes: 13 additions & 17 deletions torchmetrics/functional/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ class MulticlassMode(EnumStr):


def _check_shape_and_type_consistency_hinge(
preds: Tensor,
target: Tensor,
preds: Tensor,
target: Tensor,
) -> DataType:
if target.ndim > 1:
raise ValueError(
f"The `target` should be one dimensional, got `target` with shape={target.shape}.",
)
raise ValueError(f"The `target` should be one dimensional, got `target` with shape={target.shape}.", )

if preds.ndim == 1:
if preds.shape != target.shape:
Expand All @@ -55,17 +53,15 @@ def _check_shape_and_type_consistency_hinge(
)
mode = DataType.MULTICLASS
else:
raise ValueError(
f"The `preds` should be one or two dimensional, got `preds` with shape={preds.shape}."
)
raise ValueError(f"The `preds` should be one or two dimensional, got `preds` with shape={preds.shape}.")
return mode


def _hinge_update(
preds: Tensor,
target: Tensor,
squared: bool = False,
multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
preds: Tensor,
target: Tensor,
squared: bool = False,
multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
) -> Tuple[Tensor, Tensor]:
if preds.shape[0] == 1:
preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0)
Expand All @@ -84,7 +80,7 @@ def _hinge_update(
target = target.bool()
margin = torch.zeros_like(preds)
margin[target] = preds[target]
margin[~target] = - preds[~target]
margin[~target] = -preds[~target]
else:
raise ValueError(
"The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER"
Expand All @@ -107,10 +103,10 @@ def _hinge_compute(measure: Tensor, total: Tensor) -> Tensor:


def hinge(
preds: Tensor,
target: Tensor,
squared: bool = False,
multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
preds: Tensor,
target: Tensor,
squared: bool = False,
multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
) -> Tensor:
r"""
Computes the mean `Hinge loss <https://en.wikipedia.org/wiki/Hinge_loss>`_, typically used for Support Vector
Expand Down
8 changes: 2 additions & 6 deletions torchmetrics/retrieval/retrieval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def __init__(

query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg')
if query_without_relevant_docs not in query_without_relevant_docs_options:
raise ValueError(
f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}."
)
raise ValueError(f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}.")

self.query_without_relevant_docs = query_without_relevant_docs
self.exclude = exclude
Expand Down Expand Up @@ -124,9 +122,7 @@ def compute(self) -> Tensor:

if not mini_target.sum():
if self.query_without_relevant_docs == 'error':
raise ValueError(
"`compute` method was provided with a query with no positive target."
)
raise ValueError("`compute` method was provided with a query with no positive target.")
if self.query_without_relevant_docs == 'pos':
res.append(tensor(1.0, **kwargs))
elif self.query_without_relevant_docs == 'neg':
Expand Down
5 changes: 4 additions & 1 deletion torchmetrics/utilities/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,10 @@ def _check_retrieval_functional_inputs(preds: Tensor, target: Tensor) -> None:


def _check_retrieval_inputs(
indexes: Tensor, preds: Tensor, target: Tensor, ignore: int = None
indexes: Tensor,
preds: Tensor,
target: Tensor,
ignore: int = None,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.
Expand Down

0 comments on commit 569657e

Please sign in to comment.