diff --git a/CHANGELOG.md b/CHANGELOG.md index db69e788c12..d453910147d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `SQuAD` ([#623](https://github.com/PyTorchLightning/metrics/pull/623)) - `CHRFScore` ([#641](https://github.com/PyTorchLightning/metrics/pull/641)) - `TranslationEditRate` ([#646](https://github.com/PyTorchLightning/metrics/pull/646)) + - `ExtendedEditDistance` ([#668](https://github.com/PyTorchLightning/metrics/pull/668)) + - Added `MultiScaleSSIM` into image metrics ([#679](https://github.com/PyTorchLightning/metrics/pull/679)) diff --git a/docs/source/links.rst b/docs/source/links.rst index e2839e34fa2..eeaced1807c 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -76,4 +76,5 @@ .. _chrF score: https://aclanthology.org/W15-3049.pdf .. _chrF++ score: https://aclanthology.org/W17-4770.pdf .. _TER: https://aclanthology.org/2006.amta-papers.25.pdf +.. _ExtendedEditDistance: https://aclanthology.org/W19-5359.pdf .. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index b7a4817f682..4602cdeaed9 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -460,6 +460,12 @@ chrf_score [func] .. autofunction:: torchmetrics.functional.chrf_score :noindex: +extended_edit_distance [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.extended_edit_distance + :noindex: + match_error_rate [func] ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 5e7028a33c0..0134eb32ca9 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -642,6 +642,12 @@ CHRFScore .. autoclass:: torchmetrics.CHRFScore :noindex: +ExtendedEditDistance +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.ExtendedEditDistance + :noindex: + MatchErrorRate ~~~~~~~~~~~~~~ diff --git a/tests/text/inputs.py b/tests/text/inputs.py index dfa906abf31..a7536fb8bca 100644 --- a/tests/text/inputs.py +++ b/tests/text/inputs.py @@ -63,3 +63,7 @@ _inputs_error_rate_batch_size_1 = Input(**ERROR_RATES_BATCHES_1) _inputs_error_rate_batch_size_2 = Input(**ERROR_RATES_BATCHES_2) + +# single reference +TUPLE_OF_SINGLE_REFERENCES = (((REFERENCE_1A), (REFERENCE_1B)), ((REFERENCE_1B), (REFERENCE_1C))) +_inputs_single_reference = Input(preds=TUPLE_OF_HYPOTHESES, targets=TUPLE_OF_SINGLE_REFERENCES) diff --git a/tests/text/test_eed.py b/tests/text/test_eed.py new file mode 100644 index 00000000000..5abaee90d3f --- /dev/null +++ b/tests/text/test_eed.py @@ -0,0 +1,120 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import pytest +from torch import Tensor, tensor + +from tests.text.helpers import TextTester +from tests.text.inputs import _inputs_single_reference, _inputs_single_sentence_multiple_references +from torchmetrics.functional.text.eed import extended_edit_distance +from torchmetrics.text.eed import ExtendedEditDistance + + +def rwth_manual_metric(preds, targets) -> Tensor: + """The results were obtained w.r.t. + + the examples defined in `tests.text.inputs` with the script from https://github.com/rwth-i6/ExtendedEditDistance. + """ + ans_1 = tensor(0.24248056001808083) + ans_2 = tensor(0.19152276295133436) + + HYPOTHESIS_A = "It is a guide to action which ensures that the military always obeys the commands of the party" + + # If hypothesis A and B are in preds, the average of ans_1 and ans_2 is given + if len(preds) == 4: + return (ans_1 + ans_2) / 2 + # If only hypothesis A or B are given, ans_1 and ans_2 are given, respectively + if HYPOTHESIS_A in preds: + return ans_1 + return ans_2 + + +@pytest.mark.parametrize( + ["preds", "targets"], + [(_inputs_single_reference.preds, _inputs_single_reference.targets)], +) +class TestExtendedEditDistance(TextTester): + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_eed_class(self, preds, targets, ddp, dist_sync_on_step): + rwth_metric = partial(rwth_manual_metric) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + targets=targets, + metric_class=ExtendedEditDistance, + sk_metric=rwth_metric, + dist_sync_on_step=dist_sync_on_step, + ) + + def test_eed_functional(self, preds, targets): + rwth_metric = partial(rwth_manual_metric) + self.run_functional_metric_test( + preds, + targets, + metric_functional=extended_edit_distance, + sk_metric=rwth_metric, + ) + + def test_eed_differentiability(self, preds, targets): + self.run_differentiability_test( + preds=preds, + targets=targets, + metric_module=ExtendedEditDistance, + metric_functional=extended_edit_distance, + ) + + +# test blank edge cases +def test_eed_empty_functional(): + hyp = [] + ref = [[]] + assert extended_edit_distance(hyp, ref) == tensor(0.0) + + +def test_eed_empty_class(): + eed_metric = ExtendedEditDistance() + hyp = [] + ref = [[]] + assert eed_metric(hyp, ref) == tensor(0.0) + + +def test_eed_empty_with_non_empty_hyp_functional(): + hyp = ["python"] + ref = [[]] + assert extended_edit_distance(hyp, ref) == tensor(0.0) + + +def test_eed_empty_with_non_empty_hyp_class(): + eed_metric = ExtendedEditDistance() + hyp = ["python"] + ref = [[]] + assert eed_metric(hyp, ref) == tensor(0.0) + + +def test_eed_return_sentence_level_score_functional(): + hyp = _inputs_single_sentence_multiple_references.preds + ref = _inputs_single_sentence_multiple_references.targets + _, sentence_eed = extended_edit_distance(hyp, ref, return_sentence_level_score=True) + isinstance(sentence_eed, Tensor) + + +def test_eed_return_sentence_level_class(): + metric = ExtendedEditDistance(return_sentence_level_score=True) + hyp = _inputs_single_sentence_multiple_references.preds + ref = _inputs_single_sentence_multiple_references.targets + _, sentence_eed = metric(hyp, ref) + isinstance(sentence_eed, Tensor) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 56293457904..af536defd62 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -86,6 +86,7 @@ BLEUScore, CharErrorRate, CHRFScore, + ExtendedEditDistance, MatchErrorRate, SacreBLEUScore, SQuAD, @@ -115,6 +116,7 @@ "CosineSimilarity", "TweedieDevianceScore", "ExplainedVariance", + "ExtendedEditDistance", "F1", "F1Score", "FBeta", diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 7d6f9942bf2..2401979dd35 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -69,6 +69,7 @@ from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.functional.text.cer import char_error_rate from torchmetrics.functional.text.chrf import chrf_score +from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.functional.text.mer import match_error_rate from torchmetrics.functional.text.rouge import rouge_score from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score @@ -93,6 +94,7 @@ "tweedie_deviance_score", "dice_score", "explained_variance", + "extended_edit_distance", "f1", "f1_score", "fbeta", diff --git a/torchmetrics/functional/text/eed.py b/torchmetrics/functional/text/eed.py new file mode 100644 index 00000000000..eccaded44f9 --- /dev/null +++ b/torchmetrics/functional/text/eed.py @@ -0,0 +1,436 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# referenced from +# Library Name: torchtext +# Authors: torchtext authors +# Date: 2021-12-07 +# Link: + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +# The RWTH Extended Edit Distance (EED) License + +# Copyright (c) 2019, RWTH. +# All rights reserved. + +# This license is derived from the Q Public License v1.0 and the Qt Non-Commercial License v1.0 which are both Copyright +# by Trolltech AS, Norway. The aim of this license is to lay down the conditions enabling you to use, modify and +# circulate the SOFTWARE, use of third-party application programs based on the Software and publication of results +# obtained through the use of modified and unmodified versions of the SOFTWARE. However, RWTH remain the authors of the +# SOFTWARE and so retain property rights and the use of all ancillary rights. The SOFTWARE is defined as all successive +# versions of EED software and their documentation that have been developed by RWTH. +# +# When you access and use the SOFTWARE, you are presumed to be aware of and to have accepted all the rights and +# obligations of the present license: +# +# 1. You are granted the non-exclusive rights set forth in this license provided you agree to and comply with any all +# conditions in this license. Whole or partial distribution of the Software, or software items that link with the +# Software, in any form signifies acceptance of this license for non-commercial use only. +# 2. You may copy and distribute the Software in unmodified form provided that the entire package, including - but not +# restricted to - copyright, trademark notices and disclaimers, as released by the initial developer of the +# Software, is distributed. +# 3. You may make modifications to the Software and distribute your modifications, in a form that is separate from the +# Software, such as patches. The following restrictions apply to modifications: +# a. Modifications must not alter or remove any copyright notices in the Software. +# b When modifications to the Software are released under this license, a non-exclusive royalty-free right is +# granted to the initial developer of the Software to distribute your modification in future versions of the +# Software provided such versions remain available under these terms in addition to any other license(s) of the +# initial developer. +# 4. You may distribute machine-executable forms of the Software or machine-executable forms of modified versions of +# the Software, provided that you meet these restrictions: +# a. You must include this license document in the distribution. +# b. You must ensure that all recipients of the machine-executable forms are also able to receive the complete +# machine-readable source code to the distributed Software, including all modifications, without any charge +# beyond the costs of data transfer, and place prominent notices in the distribution explaining this. +# c. You must ensure that all modifications included in the machine-executable forms are available under the terms +# of this license. +# 5. You may use the original or modified versions of the Software to compile, link and run application programs +# legally developed by you or by others. +# 6. You may develop application programs, reusable components and other software items, in a non-commercial setting, +# that link with the original or modified versions of the Software. These items, when distributed, are subject to +# the following requirements: +# a. You must ensure that all recipients of machine-executable forms of these items are also able to receive and use +# the complete machine-readable source code to the items without any charge beyond the costs of data transfer. +# b. You must explicitly license all recipients of your items to use and re-distribute original and modified +# versions of the items in both machine-executable and source code forms. The recipients must be able to do so +# without any charges whatsoever, and they must be able to re-distribute to anyone they choose. +# c. If an application program gives you access to functionality of the Software for development of application +# programs, reusable components or other software components (e.g. an application that is a scripting wrapper), +# usage of the application program is considered to be usage of the Software and is thus bound by this license. +# d. If the items are not available to the general public, and the initial developer of the Software requests a copy +# of the items, then you must supply one. +# 7. Users must cite the authors of the Software upon publication of results obtained through the use of original or +# modified versions of the Software by referring to the following publication: +# P. Stanchev, W. Wang, and H. Ney, “EED: Extended Edit Distance Measure for Machine Translation”, submitted to WMT +# 2019. +# 8. In no event shall the initial developers or copyright holders be liable for any damages whatsoever, including - +# but not restricted to - lost revenue or profits or other direct, indirect, special, incidental or consequential +# damages, even if they have been advised of the possibility of such damages, except to the extent invariable law, +# if any, provides otherwise. +# 9. You assume all risks concerning the quality or the effects of the SOFTWARE and its use. If the SOFTWARE is +# defective, you will bear the costs of all required services, corrections or repairs. +# 10. This license has the binding value of a contract. +# 11. The present license and its effects are subject to German law and the competent German Courts. +# +# The Software and this license document are provided "AS IS" with NO EXPLICIT OR IMPLICIT WARRANTY OF ANY KIND, +# INCLUDING WARRANTY OF DESIGN, ADAPTION, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. + +import re +import unicodedata +from math import inf +from typing import List, Optional, Sequence, Tuple, Union + +from torch import Tensor, stack, tensor +from typing_extensions import Literal + +from torchmetrics.functional.text.helper import _validate_inputs + + +def _distance_between_words(preds_word: str, target_word: str) -> int: + """Distance measure used for substitutions/identity operation. Code adapted from + https://github.com/rwth-i6/ExtendedEditDistance/blob/master/EED.py. + + Args: + preds_word: hypothesis word string + target_word: reference word string + + Return: + 0 for match, 1 for no match + """ + return int(preds_word != target_word) + + +def _eed_function( + hyp: str, + ref: str, + alpha: float = 2.0, + rho: float = 0.3, + deletion: float = 0.2, + insertion: float = 1.0, +) -> float: + """Computes extended edit distance score for two lists of strings: hyp and ref. Code adapted from: + https://github.com/rwth-i6/ExtendedEditDistance/blob/master/EED.py. + + Args: + hyp: + A hypothesis string + ref: + A reference string + alpha: + optimal jump penalty, penalty for jumps between characters + rho: + coverage cost, penalty for repetition of characters + deletion: + penalty for deletion of character + insertion: + penalty for insertion or substitution of character + + Return: + Extended edit distance score as float + """ + number_of_visits = [-1] * (len(hyp) + 1) + + # row[i] stores cost of cheapest path from (0,0) to (i,l) in CDER aligment grid. + row = [1.0] * (len(hyp) + 1) + + row[0] = 0.0 # CDER initialisation 0,0 = 0.0, rest 1.0 + next_row = [inf] * (len(hyp) + 1) + + for w in range(1, len(ref) + 1): + for i in range(0, len(hyp) + 1): + + if i > 0: + next_row[i] = min( + next_row[i - 1] + deletion, + row[i - 1] + _distance_between_words(hyp[i - 1], ref[w - 1]), + row[i] + insertion, + ) + else: + next_row[i] = row[i] + 1.0 + + min_index = next_row.index(min(next_row)) + number_of_visits[min_index] += 1 + + # Long Jumps + if ref[w - 1] == " ": + jump = alpha + next_row[min_index] + next_row = [min(x, jump) for x in next_row] + + row = next_row + next_row = [inf] * (len(hyp) + 1) + + coverage = rho * sum(x if x >= 0 else 1 for x in number_of_visits) + + return min(1, (row[-1] + coverage) / (float(len(ref)) + coverage)) + + +def _preprocess_en(sentence: str) -> str: + """Copied from https://github.com/rwth-i6/ExtendedEditDistance/blob/master/util.py. + + Raises: + ValueError: If input sentence is not of a type `str`. + """ + if not isinstance(sentence, str): + raise ValueError(f"Only strings allowed during preprocessing step, found {type(sentence)} instead") + + sentence = sentence.rstrip() # trailing space, tab, or newline + + # Add space before interpunctions + rules_interpunction = [ + (".", " ."), + ("!", " !"), + ("?", " ?"), + (",", " ,"), + ] + for pattern, replacement in rules_interpunction: + sentence = sentence.replace(pattern, replacement) + + rules_re = [ + (r"\s+", r" "), # get rid of extra spaces + (r"(\d) ([.,]) (\d)", r"\1\2\3"), # 0 . 1 -> 0.1 + (r"(Dr|Jr|Prof|Rev|Gen|Mr|Mt|Mrs|Ms) .", r"\1."), # Mr . -> Mr. + ] + for pattern, replacement in rules_re: + sentence = re.sub(pattern, replacement, sentence) + + # Add space between abbreviations + rules_interpunction = [ + ("e . g .", "e.g."), + ("i . e .", "i.e."), + ("U . S .", "U.S."), + ] + for pattern, replacement in rules_interpunction: + sentence = sentence.replace(pattern, replacement) + + # add space to beginning and end of string + sentence = " " + sentence + " " + + return sentence + + +def _preprocess_ja(sentence: str) -> str: + """Copied from https://github.com/rwth-i6/ExtendedEditDistance/blob/master/util.py. + + Raises: + ValueError: If input sentence is not of a type `str`. + """ + if not isinstance(sentence, str): + raise ValueError(f"Only strings allowed during preprocessing step, found {type(sentence)} instead") + + sentence = sentence.rstrip() # trailing space, tab, newline + # characters which look identical actually are identical + sentence = unicodedata.normalize("NFKC", sentence) + return sentence + + +def _eed_compute(sentence_level_scores: List[Tensor]) -> Tensor: + """Final step in extended edit distance. + + Args: + sentence_level_scores: + list of sentence-level scores as floats + + Return: + average of scores as a tensor + """ + if len(sentence_level_scores) == 0: + return tensor(0.0) + + average = sum(sentence_level_scores) / tensor(len(sentence_level_scores)) + return average + + +def _preprocess_sentences( + preds: Union[str, Sequence[str]], + target: Sequence[Union[str, Sequence[str]]], + language: Union[Literal["en"], Literal["ja"]], +) -> Tuple[Union[str, Sequence[str]], Sequence[Union[str, Sequence[str]]]]: + """Proprocess strings according to language requirements. + + Args: + preds: An iterable of hypothesis corpus. + target: An iterable of iterables of reference corpus. + language: Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en + + Return: + Tuple of lists that contain the cleaned strings for target and preds + + Raises: + ValueError: If a different language than 'en" or 'ja' is used + ValueError: If length of target not equal to length of preds + ValueError: If objects in reference and hypothesis corpus are not strings + """ + # sanity checks + target, preds = _validate_inputs(hypothesis_corpus=preds, reference_corpus=target) + + # preprocess string + if language == "en": + preprocess_function = _preprocess_en + elif language == "ja": + preprocess_function = _preprocess_ja + else: + raise ValueError(f"Expected argument `language` to either be `en` or `ja` but got {language}") + + preds = [preprocess_function(pred) for pred in preds] + target = [[preprocess_function(ref) for ref in reference] for reference in target] + + return preds, target + + +def _compute_sentence_statistics( + preds_word: str, + target_words: Union[str, Sequence[str]], + alpha: float = 2.0, + rho: float = 0.3, + deletion: float = 0.2, + insertion: float = 1.0, +) -> Tensor: + """Compute scores for ExtendedEditDistance. + + Args: + target_words: + An iterable of reference words + preds_word: + A hypothesis word + alpha: + An optimal jump penalty, penalty for jumps between characters + rho: + coverage cost, penalty for repetition of characters + deletion: + penalty for deletion of character + insertion: + penalty for insertion or substitution of character + decide_score_fn: + decides which score + + Return: + best_score: + best (lowest) sentence-level score as a Tensor + """ + best_score = inf + + for reference in target_words: + score = _eed_function(preds_word, reference, alpha, rho, deletion, insertion) + if score < best_score: + best_score = score + + return tensor(best_score) + + +def _eed_update( + preds: Union[str, Sequence[str]], + target: Sequence[Union[str, Sequence[str]]], + language: Literal["en", "ja"] = "en", + alpha: float = 2.0, + rho: float = 0.3, + deletion: float = 0.2, + insertion: float = 1.0, + sentence_eed: Optional[List[Tensor]] = None, +) -> List[Tensor]: + """Compute scores for ExtendedEditDistance. + + Args: + preds: + An iterable of hypothesis corpus + target: + An iterable of iterables of reference corpus + language: + Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en + alpha: + optimal jump penalty, penalty for jumps between characters + rho: + coverage cost, penalty for repetition of characters + deletion: + penalty for deletion of character + insertion: + penalty for insertion or substitution of character + sentence_eed: + list of sentence-level scores + + Return: + individual sentence scores as a list of Tensors + """ + preds, target = _preprocess_sentences(preds, target, language) + + if sentence_eed is None: + sentence_eed = [] + + # return tensor(0.0) if target or preds is empty + if 0 in (len(preds), len(target[0])): + return sentence_eed + + for hypothesis, target_words in zip(preds, target): + score = _compute_sentence_statistics(hypothesis, target_words, alpha, rho, deletion, insertion) + sentence_eed.append(score) + + return sentence_eed + + +def extended_edit_distance( + preds: Union[str, Sequence[str]], + target: Sequence[Union[str, Sequence[str]]], + language: Literal["en", "ja"] = "en", + return_sentence_level_score: bool = False, + alpha: float = 2.0, + rho: float = 0.3, + deletion: float = 0.2, + insertion: float = 1.0, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Computes extended edit distance score (`ExtendedEditDistance`_) [1] for strings or list of strings. The + metric utilises the Levenshtein distance and extends it by adding an additional jump operation. + + Args: + preds: + An iterable of hypothesis corpus. + target: + An iterable of iterables of reference corpus. + language: + Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en + return_sentence_level_score: + An indication of whether sentence-level EED score is to be returned. + alpha: + optimal jump penalty, penalty for jumps between characters + rho: + coverage cost, penalty for repetition of characters + deletion: + penalty for deletion of character + insertion: + penalty for insertion or substitution of character + + Return: + Extended edit distance score as a tensor + + Example: + >>> from torchmetrics.functional import extended_edit_distance + >>> preds = ["this is the prediction", "here is an other sample"] + >>> target = ["this is the reference", "here is another one"] + >>> extended_edit_distance(preds=preds, target=target) + tensor(0.3078) + + References: + [1] P. Stanchev, W. Wang, and H. Ney, “EED: Extended Edit Distance Measure for Machine Translation”, + submitted to WMT 2019. `ExtendedEditDistance`_ + """ + # input validation for parameters + for param_name, param in zip(["alpha", "rho", "deletion", "insertion"], [alpha, rho, deletion, insertion]): + if not isinstance(param, float) or isinstance(param, float) and param < 0: + raise ValueError(f"Parameter `{param_name}` is expected to be a non-negative float.") + + sentence_level_scores = _eed_update(preds, target, language, alpha, rho, deletion, insertion) + + average = _eed_compute(sentence_level_scores) + + if return_sentence_level_score: + return average, stack(sentence_level_scores) + return average diff --git a/torchmetrics/text/__init__.py b/torchmetrics/text/__init__.py index b3ca31681bd..49f73b7ad71 100644 --- a/torchmetrics/text/__init__.py +++ b/torchmetrics/text/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.text.bleu import BLEUScore # noqa: F401 from torchmetrics.text.cer import CharErrorRate # noqa: F401 from torchmetrics.text.chrf import CHRFScore # noqa: F401 +from torchmetrics.text.eed import ExtendedEditDistance # noqa: F401 from torchmetrics.text.mer import MatchErrorRate # noqa: F401 from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401 from torchmetrics.text.squad import SQuAD # noqa: F401 diff --git a/torchmetrics/text/eed.py b/torchmetrics/text/eed.py new file mode 100644 index 00000000000..99c7b90e72a --- /dev/null +++ b/torchmetrics/text/eed.py @@ -0,0 +1,141 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union + +from torch import Tensor, stack +from typing_extensions import Literal + +from torchmetrics.functional.text.eed import _eed_compute, _eed_update +from torchmetrics.metric import Metric + + +class ExtendedEditDistance(Metric): + """Computes extended edit distance score (`ExtendedEditDistance`_) [1] for strings or list of strings. The + metric utilises the Levenshtein distance and extends it by adding an additional jump operation. + + Args: + language: + Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en + return_sentence_level_score: + An indication of whether sentence-level EED score is to be returned + alpha: + optimal jump penalty, penalty for jumps between characters + rho: + coverage cost, penalty for repetition of characters + deletion: + penalty for deletion of character + insertion: + penalty for insertion or substitution of character + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Return: + Extended edit distance score as a tensor + + Example: + >>> from torchmetrics import ExtendedEditDistance + >>> preds = ["this is the prediction", "here is an other sample"] + >>> target = ["this is the reference", "here is another one"] + >>> metric = ExtendedEditDistance() + >>> metric(preds=preds, target=target) + tensor(0.3078) + + References: + [1] P. Stanchev, W. Wang, and H. Ney, “EED: Extended Edit Distance Measure for Machine Translation”, submitted + to WMT 2019. `ExtendedEditDistance`_ + """ + + sentence_eed: List[Tensor] + higher_is_better = False + is_differentiable = False + + def __init__( + self, + language: Literal["en", "ja"] = "en", + return_sentence_level_score: bool = False, + alpha: float = 2.0, + rho: float = 0.3, + deletion: float = 0.2, + insertion: float = 1.0, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + if language not in ("en", "ja"): + raise ValueError(f"Expected argument `language` to either be `en` or `ja` but got {language}") + self.language: Literal["en", "ja"] = language + self.return_sentence_level_score = return_sentence_level_score + + # input validation for parameters + for param_name, param in zip(["alpha", "rho", "deletion", "insertion"], [alpha, rho, deletion, insertion]): + if not isinstance(param, float) or isinstance(param, float) and param < 0: + raise ValueError(f"Parameter `{param_name}` is expected to be a non-negative float.") + + self.alpha = alpha + self.rho = rho + self.deletion = deletion + self.insertion = insertion + + self.add_state("sentence_eed", [], dist_reduce_fx="cat") + + def update( # type: ignore + self, + preds: Union[str, Sequence[str]], + target: Sequence[Union[str, Sequence[str]]], + ) -> None: + """Update ExtendedEditDistance statistics. + + Args: + preds: An iterable of hypothesis corpus + target: An iterable of iterables of reference corpus + """ + self.sentence_eed = _eed_update( + preds, + target, + self.language, + self.alpha, + self.rho, + self.deletion, + self.insertion, + self.sentence_eed, + ) + + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Calculate extended edit distance score. + + Return: + Extended edit distance score as tensor + """ + average = _eed_compute(self.sentence_eed) + + if self.return_sentence_level_score: + return average, stack(self.sentence_eed) + return average