Skip to content

Commit

Permalink
Add TER metric (#104)
Browse files Browse the repository at this point in the history
* WIP: Adding TER. Multi pred & multi ref has not been completed yet.

* TER metric added.

* Typo fixed in docstring.

* Typo fixed in docstring.

* Docstring example corrected.
- requirement_message corrected.
  • Loading branch information
devrimcavusoglu authored Mar 15, 2022
1 parent ed369f4 commit 3c8c3dd
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 0 deletions.
1 change: 1 addition & 0 deletions jury/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@
from jury.metrics.rouge import Rouge
from jury.metrics.sacrebleu import Sacrebleu
from jury.metrics.squad import Squad
from jury.metrics.ter import TER
from jury.metrics.wer import WER
1 change: 1 addition & 0 deletions jury/metrics/ter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from jury.metrics.ter.ter import TER
8 changes: 8 additions & 0 deletions jury/metrics/ter/ter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from jury.metrics._core import MetricAlias
from jury.metrics.ter.ter_for_language_generation import TERForLanguageGeneration

__main_class__ = "TER"


class TER(MetricAlias):
_SUBCLASS = TERForLanguageGeneration
227 changes: 227 additions & 0 deletions jury/metrics/ter/ter_for_language_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# coding=utf-8
# Copyright 2020 Open Business Software Solutions, The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Translation Edit Rate (TER) metric. The part of this file is adapted from HuggingFace's
datasets package implementation of TER metric. See
https://github.com/huggingface/datasets/blob/master/metrics/ter/ter.py
"""
from typing import Callable, Dict, Sequence

import datasets

from jury.collator import Collator
from jury.metrics import LanguageGenerationInstance, MetricForLanguageGeneration
from jury.metrics._core.utils import PackagePlaceholder, requirement_message

# `import sacrebleu as scb` placeholder
scb = PackagePlaceholder(version="2.0.0")


_CITATION = """\
@inproceedings{snover-etal-2006-study,
title = "A Study of Translation Edit Rate with Targeted Human Annotation",
author = "Snover, Matthew and
Dorr, Bonnie and
Schwartz, Rich and
Micciulla, Linnea and
Makhoul, John",
booktitle = "Proceedings of the 7th Conference of the Association for Machine Translation in the Americas: Technical Papers",
month = aug # " 8-12",
year = "2006",
address = "Cambridge, Massachusetts, USA",
publisher = "Association for Machine Translation in the Americas",
url = "https://aclanthology.org/2006.amta-papers.25",
pages = "223--231",
}
@inproceedings{post-2018-call,
title = "A Call for Clarity in Reporting {BLEU} Scores",
author = "Post, Matt",
booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
month = oct,
year = "2018",
address = "Belgium, Brussels",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/W18-6319",
pages = "186--191",
}
"""

_DESCRIPTION = """\
TER (Translation Edit Rate, also called Translation Error Rate) is a metric to quantify the edit operations that a
hypothesis requires to match a reference translation. We use the implementation that is already present in sacrebleu
(https://github.com/mjpost/sacreBLEU#ter), which in turn is inspired by the TERCOM implementation, which can be found
here: https://github.com/jhclark/tercom.
The implementation here is slightly different from sacrebleu in terms of the required input format. The length of
the references and hypotheses lists need to be the same, so you may need to transpose your references compared to
sacrebleu's required input format. See https://github.com/huggingface/datasets/issues/3154#issuecomment-950746534
See the README.md file at https://github.com/mjpost/sacreBLEU#ter for more information.
"""

_KWARGS_DESCRIPTION = """
Produces TER scores alongside the number of edits and reference length.
Args:
predictions: The system stream (a sequence of segments).
references: A list of one or more reference streams (each a sequence of segments).
normalized: Whether to apply basic tokenization to sentences.
no_punct: Whether to remove punctuations from sentences.
asian_support: Whether to support Asian character processing.
case_sensitive: Whether to disable lowercasing.
Returns:
'score': TER score (num_edits / sum_ref_lengths),
'num_edits': The cumulative number of edits,
'ref_length': The cumulative average reference length.
Examples:
>>> predictions = [["hello there general kenobi"], ["foo bar foobar"]]
>>> references = [["hello there general kenobi", "hello there !"], ["foo bar foobar", "foo bar foobar"]]
>>> ter = jury.load_metric("ter")
>>> results = ter.compute(predictions=predictions, references=references)
>>> print(results)
{'score': 0.0, 'num_edits': 0, 'ref_length': 6.5}
"""


@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class TERForLanguageGeneration(MetricForLanguageGeneration):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
homepage="http://www.cs.umd.edu/~snover/tercom/",
inputs_description=_KWARGS_DESCRIPTION,
features=self._default_features,
codebase_urls=["https://github.com/mjpost/sacreBLEU#ter"],
reference_urls=[
"https://github.com/jhclark/tercom",
],
)

def _download_and_prepare(self, dl_manager):
global scb
global TERScorer

try:
import sacrebleu as scb
from sacrebleu import TER as TERScorer
except ModuleNotFoundError:
raise ModuleNotFoundError(requirement_message(path="TER", package_name="sacrebleu"))
else:
super(TERForLanguageGeneration, self)._download_and_prepare(dl_manager)

def _validate_references(self, references: Collator) -> None:
references_per_prediction = len(references[0])
if any(len(refs) != references_per_prediction for refs in references):
raise ValueError("Sacrebleu requires the same number of references for each prediction")

def _compute_ter_score(
self, predictions: Sequence[str], references: Sequence[Sequence[str]], sentence_level: bool = False, **kwargs
):
sb_ter = TERScorer(**kwargs)
if sentence_level:
output = sb_ter.sentence_score(predictions, references)
else:
output = sb_ter.corpus_score(predictions, references)
return {"score": float(output.score / 100), "num_edits": output.num_edits, "ref_length": output.ref_length}

def _compute_single_pred_single_ref(
self,
predictions: LanguageGenerationInstance,
references: LanguageGenerationInstance,
reduce_fn: Callable = None,
normalized: bool = False,
no_punct: bool = False,
asian_support: bool = False,
case_sensitive: bool = False,
):
return self._compute_ter_score(
predictions=predictions,
references=references,
normalized=normalized,
no_punct=no_punct,
asian_support=asian_support,
case_sensitive=case_sensitive,
)

def _compute_single_pred_multi_ref(
self,
predictions: LanguageGenerationInstance,
references: LanguageGenerationInstance,
reduce_fn: Callable = None,
normalized: bool = False,
no_punct: bool = False,
asian_support: bool = False,
case_sensitive: bool = False,
):
# SacreBleu inherently supports multiple references.
return self._compute_ter_score(
predictions=predictions,
references=references,
normalized=normalized,
no_punct=no_punct,
asian_support=asian_support,
case_sensitive=case_sensitive,
)

def _compute_multi_pred_multi_ref(
self,
predictions: LanguageGenerationInstance,
references: LanguageGenerationInstance,
reduce_fn: Callable = None,
normalized: bool = False,
no_punct: bool = False,
asian_support: bool = False,
case_sensitive: bool = False,
):
scores = []
avg_num_edits = 0
avg_ref_length = 0
for preds, refs in zip(predictions, references):
pred_scores = []
num_edits = []
ref_lengths = []
for pred in preds:
score = self._compute_ter_score(
predictions=pred,
references=refs,
sentence_level=True,
normalized=normalized,
no_punct=no_punct,
asian_support=asian_support,
case_sensitive=case_sensitive,
)
pred_scores.append(score["score"])
num_edits.append(score["num_edits"])
ref_lengths.append(score["ref_length"])
pred_score = reduce_fn(pred_scores).item()
avg_num_edits += sum(num_edits) / len(num_edits)
avg_ref_length += sum(ref_lengths) / len(ref_lengths)
scores.append(pred_score)
return {
"score": sum(scores) / len(scores),
"avg_num_edits": avg_num_edits / len(predictions),
"avg_ref_length": avg_ref_length / len(predictions),
}

def evaluate(
self, predictions: Collator, references: Collator, reduce_fn: Callable = None, **kwargs
) -> Dict[str, float]:
if predictions.can_collapse() and references.can_collapse():
predictions = predictions.collapse()
eval_fn = self._compute_single_pred_single_ref
elif predictions.can_collapse() and not references.can_collapse():
predictions = predictions.collapse()
eval_fn = self._compute_single_pred_multi_ref
else:
eval_fn = self._compute_multi_pred_multi_ref
self._validate_references(references)
return eval_fn(predictions=predictions, references=references, reduce_fn=reduce_fn, **kwargs)
47 changes: 47 additions & 0 deletions tests/jury/metrics/test_ter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from jury import Jury
from jury.metrics import AutoMetric
from tests.jury.conftest import get_expected_output
from tests.utils import assert_almost_equal_dict


@pytest.fixture(scope="module")
def jury_ter():
metric = AutoMetric.load("ter", compute_kwargs={"normalized": True})
return Jury(metrics=metric)


@pytest.fixture
@get_expected_output(prefix="metrics")
def output_basic():
return output_basic.output


@pytest.fixture
@get_expected_output(prefix="metrics")
def output_multiple_ref():
return output_multiple_ref.output


@pytest.fixture
@get_expected_output(prefix="metrics")
def output_multiple_pred_multiple_ref():
return output_multiple_pred_multiple_ref.output


def test_basic(predictions, references, jury_ter, output_basic):
scores = jury_ter(predictions=predictions, references=references)
assert_almost_equal_dict(actual=scores, desired=output_basic)


def test_multiple_ref(predictions, multiple_references, jury_ter, output_multiple_ref):
scores = jury_ter(predictions=predictions, references=multiple_references)
assert_almost_equal_dict(actual=scores, desired=output_multiple_ref)


def test_multiple_pred_multiple_ref(
multiple_predictions, multiple_references, jury_ter, output_multiple_pred_multiple_ref
):
scores = jury_ter(predictions=multiple_predictions, references=multiple_references)
assert_almost_equal_dict(actual=scores, desired=output_multiple_pred_multiple_ref)
29 changes: 29 additions & 0 deletions tests/test_data/expected_outputs/metrics/test_ter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"basic": {
"total_items": 2,
"empty_items": 0,
"ter": {
"score": 0.4615384615384615,
"num_edits": 3,
"ref_length": 6.5
}
},
"multiple_ref": {
"total_items": 2,
"empty_items": 0,
"ter": {
"score": 0.6153846153846154,
"num_edits": 8,
"ref_length": 13.0
}
},
"multiple_pred_multiple_ref": {
"total_items": 2,
"empty_items": 0,
"ter": {
"score": 0.812121212121212,
"avg_num_edits": 3.25,
"avg_ref_length": 6.5
}
}
}

0 comments on commit 3c8c3dd

Please sign in to comment.