Skip to content

Commit

Permalink
refactor: ter (#713)
Browse files Browse the repository at this point in the history
* translation_edit_rate
* TranslationEditRate
* docs
  • Loading branch information
Borda authored Jan 6, 2022
1 parent 30e33c2 commit 5c18b6b
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `WordInfoLost` and `WordInfoPreserved` ([#630](https://github.com/PyTorchLightning/metrics/pull/630))
- `SQuAD` ([#623](https://github.com/PyTorchLightning/metrics/pull/623))
- `CHRFScore` ([#641](https://github.com/PyTorchLightning/metrics/pull/641))
- `TER` ([#646](https://github.com/PyTorchLightning/metrics/pull/646))
- `TranslationEditRate` ([#646](https://github.com/PyTorchLightning/metrics/pull/646))


- Added a default VSCode devcontainer configuration ([#621](https://github.com/PyTorchLightning/metrics/pull/621))
Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,10 @@ squad [func]
.. autofunction:: torchmetrics.functional.squad
:noindex:

ter [func]
~~~~~~~~~~
translation_edit_rate [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.ter
.. autofunction:: torchmetrics.functional.translation_edit_rate
:noindex:

wer [func]
Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,10 @@ SQuAD
.. autoclass:: torchmetrics.SQuAD
:noindex:

TER
~~~
TranslationEditRate
~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.TER
.. autoclass:: torchmetrics.TranslationEditRate
:noindex:

WER
Expand Down
24 changes: 12 additions & 12 deletions tests/text/test_ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references
from torchmetrics.functional.text.ter import ter
from torchmetrics.text.ter import TER
from torchmetrics.functional.text.ter import translation_edit_rate
from torchmetrics.text.ter import TranslationEditRate
from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE

if _SACREBLEU_AVAILABLE:
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_chrf_score_class(
ddp=ddp,
preds=preds,
targets=targets,
metric_class=TER,
metric_class=TranslationEditRate,
sk_metric=nltk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
Expand All @@ -95,7 +95,7 @@ def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, a
self.run_functional_metric_test(
preds,
targets,
metric_functional=ter,
metric_functional=translation_edit_rate,
sk_metric=nltk_metric,
metric_args=metric_args,
)
Expand All @@ -111,20 +111,20 @@ def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctu
self.run_differentiability_test(
preds=preds,
targets=targets,
metric_module=TER,
metric_functional=ter,
metric_module=TranslationEditRate,
metric_functional=translation_edit_rate,
metric_args=metric_args,
)


def test_ter_empty_functional():
hyp = []
ref = [[]]
assert ter(hyp, ref) == tensor(0.0)
assert translation_edit_rate(hyp, ref) == tensor(0.0)


def test_ter_empty_class():
ter_metric = TER()
ter_metric = TranslationEditRate()
hyp = []
ref = [[]]
assert ter_metric(hyp, ref) == tensor(0.0)
Expand All @@ -133,11 +133,11 @@ def test_ter_empty_class():
def test_ter_empty_with_non_empty_hyp_functional():
hyp = ["python"]
ref = [[]]
assert ter(hyp, ref) == tensor(0.0)
assert translation_edit_rate(hyp, ref) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_class():
ter_metric = TER()
ter_metric = TranslationEditRate()
hyp = ["python"]
ref = [[]]
assert ter_metric(hyp, ref) == tensor(0.0)
Expand All @@ -146,12 +146,12 @@ def test_ter_empty_with_non_empty_hyp_class():
def test_ter_return_sentence_level_score_functional():
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter(hyp, ref, return_sentence_level_score=True)
_, sentence_ter = translation_edit_rate(hyp, ref, return_sentence_level_score=True)
isinstance(sentence_ter, Tensor)


def test_ter_return_sentence_level_class():
ter_metric = TER(return_sentence_level_score=True)
ter_metric = TranslationEditRate(return_sentence_level_score=True)
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter_metric(hyp, ref)
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@
RetrievalRPrecision,
)
from torchmetrics.text import ( # noqa: E402
TER,
WER,
BLEUScore,
CharErrorRate,
CHRFScore,
MatchErrorRate,
SacreBLEUScore,
SQuAD,
TranslationEditRate,
WordInfoLost,
WordInfoPreserved,
)
Expand Down Expand Up @@ -152,7 +152,7 @@
"StatScores",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TER",
"TranslationEditRate",
"WER",
"CharErrorRate",
"MatchErrorRate",
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
from torchmetrics.functional.text.squad import squad
from torchmetrics.functional.text.ter import ter
from torchmetrics.functional.text.ter import translation_edit_rate
from torchmetrics.functional.text.wer import wer
from torchmetrics.functional.text.wil import word_information_lost
from torchmetrics.functional.text.wip import word_information_preserved
Expand Down Expand Up @@ -138,7 +138,7 @@
"ssim",
"stat_scores",
"symmetric_mean_absolute_percentage_error",
"ter",
"translation_edit_rate",
"wer",
"char_error_rate",
"match_error_rate",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchmetrics.functional.text.mer import match_error_rate # noqa: F401
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401
from torchmetrics.functional.text.squad import squad # noqa: F401
from torchmetrics.functional.text.ter import ter # noqa: F401
from torchmetrics.functional.text.ter import translation_edit_rate # noqa: F401
from torchmetrics.functional.text.wer import wer # noqa: F401
from torchmetrics.functional.text.wil import word_information_lost # noqa: F401
from torchmetrics.functional.text.wip import word_information_preserved # noqa: F401
4 changes: 2 additions & 2 deletions torchmetrics/functional/text/ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def _ter_compute(total_num_edits: Tensor, total_ref_length: Tensor) -> Tensor:
return _compute_ter_score_from_statistics(total_num_edits, total_ref_length)


def ter(
def translation_edit_rate(
hypothesis_corpus: Union[str, Sequence[str]],
reference_corpus: Sequence[Union[str, Sequence[str]]],
normalize: bool = False,
Expand Down Expand Up @@ -594,7 +594,7 @@ def ter(
Example:
>>> hypothesis_corpus = ['the cat is on the mat']
>>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']]
>>> ter(hypothesis_corpus, reference_corpus)
>>> translation_edit_rate(hypothesis_corpus, reference_corpus)
tensor(0.1538)
References:
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchmetrics.text.mer import MatchErrorRate # noqa: F401
from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401
from torchmetrics.text.squad import SQuAD # noqa: F401
from torchmetrics.text.ter import TER # noqa: F401
from torchmetrics.text.ter import TranslationEditRate # noqa: F401
from torchmetrics.text.wer import WER # noqa: F401
from torchmetrics.text.wil import WordInfoLost # noqa: F401
from torchmetrics.text.wip import WordInfoPreserved # noqa: F401
4 changes: 2 additions & 2 deletions torchmetrics/text/ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchmetrics.metric import Metric


class TER(Metric):
class TranslationEditRate(Metric):
"""Calculate Translation edit rate (`TER`_) of machine translated text with one or more references. This
implementation follows the implmenetaions from
https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/ter.py. The `sacrebleu` implmenetation is a
Expand Down Expand Up @@ -52,7 +52,7 @@ class TER(Metric):
Example:
>>> hypothesis_corpus = ['the cat is on the mat']
>>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric = TER()
>>> metric = TranslationEditRate()
>>> metric(hypothesis_corpus, reference_corpus)
tensor(0.1538)
Expand Down

0 comments on commit 5c18b6b

Please sign in to comment.