Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TER metric #104

Merged
merged 5 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jury/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@
from jury.metrics.rouge import Rouge
from jury.metrics.sacrebleu import Sacrebleu
from jury.metrics.squad import Squad
from jury.metrics.ter import TER
from jury.metrics.wer import WER
1 change: 1 addition & 0 deletions jury/metrics/ter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from jury.metrics.ter.ter import TER
8 changes: 8 additions & 0 deletions jury/metrics/ter/ter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from jury.metrics._core import MetricAlias
from jury.metrics.ter.ter_for_language_generation import TERForLanguageGeneration

__main_class__ = "TER"


class TER(MetricAlias):
_SUBCLASS = TERForLanguageGeneration
227 changes: 227 additions & 0 deletions jury/metrics/ter/ter_for_language_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# coding=utf-8
# Copyright 2020 Open Business Software Solutions, The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Translation Edit Rate (TER) metric. The part of this file is adapted from HuggingFace's
datasets package implementation of TER metric. See
https://github.com/huggingface/datasets/blob/master/metrics/ter/ter.py
"""
from typing import Callable, Dict, Sequence

import datasets

from jury.collator import Collator
from jury.metrics import LanguageGenerationInstance, MetricForLanguageGeneration
from jury.metrics._core.utils import PackagePlaceholder, requirement_message

# `import sacrebleu as scb` placeholder
scb = PackagePlaceholder(version="2.0.0")


_CITATION = """\
@inproceedings{snover-etal-2006-study,
title = "A Study of Translation Edit Rate with Targeted Human Annotation",
author = "Snover, Matthew and
Dorr, Bonnie and
Schwartz, Rich and
Micciulla, Linnea and
Makhoul, John",
booktitle = "Proceedings of the 7th Conference of the Association for Machine Translation in the Americas: Technical Papers",
month = aug # " 8-12",
year = "2006",
address = "Cambridge, Massachusetts, USA",
publisher = "Association for Machine Translation in the Americas",
url = "https://aclanthology.org/2006.amta-papers.25",
pages = "223--231",
}
@inproceedings{post-2018-call,
title = "A Call for Clarity in Reporting {BLEU} Scores",
author = "Post, Matt",
booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
month = oct,
year = "2018",
address = "Belgium, Brussels",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/W18-6319",
pages = "186--191",
}
"""

_DESCRIPTION = """\
TER (Translation Edit Rate, also called Translation Error Rate) is a metric to quantify the edit operations that a
hypothesis requires to match a reference translation. We use the implementation that is already present in sacrebleu
(https://github.com/mjpost/sacreBLEU#ter), which in turn is inspired by the TERCOM implementation, which can be found
here: https://github.com/jhclark/tercom.
The implementation here is slightly different from sacrebleu in terms of the required input format. The length of
the references and hypotheses lists need to be the same, so you may need to transpose your references compared to
sacrebleu's required input format. See https://github.com/huggingface/datasets/issues/3154#issuecomment-950746534
See the README.md file at https://github.com/mjpost/sacreBLEU#ter for more information.
"""

_KWARGS_DESCRIPTION = """
Produces TER scores alongside the number of edits and reference length.
Args:
predictions: The system stream (a sequence of segments).
references: A list of one or more reference streams (each a sequence of segments).
normalized: Whether to apply basic tokenization to sentences.
no_punct: Whether to remove punctuations from sentences.
asian_support: Whether to support Asian character processing.
case_sensitive: Whether to disable lowercasing.
Returns:
'score': TER score (num_edits / sum_ref_lengths),
'num_edits': The cumulative number of edits,
'ref_length': The cumulative average reference length.
Examples:
Sophylax marked this conversation as resolved.
Show resolved Hide resolved
>>> predictions = [["hello there general kenobi"], ["foo bar foobar"]]
>>> references = [["hello there general kenobi", "hello there !"], ["foo bar foobar", "foo bar foobar"]]
>>> ter = jury.load_metric("ter")
>>> results = ter.compute(predictions=predictions, references=references)
>>> print(results)
{'score': 0.0, 'num_edits': 0, 'ref_length': 6.5}
"""


@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class TERForLanguageGeneration(MetricForLanguageGeneration):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
homepage="http://www.cs.umd.edu/~snover/tercom/",
inputs_description=_KWARGS_DESCRIPTION,
features=self._default_features,
codebase_urls=["https://github.com/mjpost/sacreBLEU#ter"],
reference_urls=[
"https://github.com/jhclark/tercom",
],
)

def _download_and_prepare(self, dl_manager):
global scb
global TERScorer

try:
import sacrebleu as scb
from sacrebleu import TER as TERScorer
except ModuleNotFoundError:
raise ModuleNotFoundError(requirement_message(path="TER", package_name="sacrebleu"))
else:
super(TERForLanguageGeneration, self)._download_and_prepare(dl_manager)

def _validate_references(self, references: Collator) -> None:
references_per_prediction = len(references[0])
if any(len(refs) != references_per_prediction for refs in references):
raise ValueError("Sacrebleu requires the same number of references for each prediction")

def _compute_ter_score(
self, predictions: Sequence[str], references: Sequence[Sequence[str]], sentence_level: bool = False, **kwargs
):
sb_ter = TERScorer(**kwargs)
if sentence_level:
output = sb_ter.sentence_score(predictions, references)
else:
output = sb_ter.corpus_score(predictions, references)
return {"score": float(output.score / 100), "num_edits": output.num_edits, "ref_length": output.ref_length}

def _compute_single_pred_single_ref(
self,
predictions: LanguageGenerationInstance,
references: LanguageGenerationInstance,
reduce_fn: Callable = None,
normalized: bool = False,
no_punct: bool = False,
asian_support: bool = False,
case_sensitive: bool = False,
):
return self._compute_ter_score(
predictions=predictions,
references=references,
normalized=normalized,
no_punct=no_punct,
asian_support=asian_support,
case_sensitive=case_sensitive,
)

def _compute_single_pred_multi_ref(
self,
predictions: LanguageGenerationInstance,
references: LanguageGenerationInstance,
reduce_fn: Callable = None,
normalized: bool = False,
no_punct: bool = False,
asian_support: bool = False,
case_sensitive: bool = False,
):
# SacreBleu inherently supports multiple references.
return self._compute_ter_score(
predictions=predictions,
references=references,
normalized=normalized,
no_punct=no_punct,
asian_support=asian_support,
case_sensitive=case_sensitive,
)

def _compute_multi_pred_multi_ref(
self,
predictions: LanguageGenerationInstance,
references: LanguageGenerationInstance,
reduce_fn: Callable = None,
normalized: bool = False,
no_punct: bool = False,
asian_support: bool = False,
case_sensitive: bool = False,
):
scores = []
avg_num_edits = 0
avg_ref_length = 0
for preds, refs in zip(predictions, references):
pred_scores = []
num_edits = []
ref_lengths = []
for pred in preds:
score = self._compute_ter_score(
predictions=pred,
references=refs,
sentence_level=True,
normalized=normalized,
no_punct=no_punct,
asian_support=asian_support,
case_sensitive=case_sensitive,
)
pred_scores.append(score["score"])
num_edits.append(score["num_edits"])
ref_lengths.append(score["ref_length"])
pred_score = reduce_fn(pred_scores).item()
avg_num_edits += sum(num_edits) / len(num_edits)
avg_ref_length += sum(ref_lengths) / len(ref_lengths)
scores.append(pred_score)
return {
"score": sum(scores) / len(scores),
Sophylax marked this conversation as resolved.
Show resolved Hide resolved
"avg_num_edits": avg_num_edits / len(predictions),
"avg_ref_length": avg_ref_length / len(predictions),
}

def evaluate(
self, predictions: Collator, references: Collator, reduce_fn: Callable = None, **kwargs
) -> Dict[str, float]:
if predictions.can_collapse() and references.can_collapse():
predictions = predictions.collapse()
eval_fn = self._compute_single_pred_single_ref
elif predictions.can_collapse() and not references.can_collapse():
predictions = predictions.collapse()
eval_fn = self._compute_single_pred_multi_ref
else:
eval_fn = self._compute_multi_pred_multi_ref
self._validate_references(references)
return eval_fn(predictions=predictions, references=references, reduce_fn=reduce_fn, **kwargs)
47 changes: 47 additions & 0 deletions tests/jury/metrics/test_ter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from jury import Jury
from jury.metrics import AutoMetric
from tests.jury.conftest import get_expected_output
from tests.utils import assert_almost_equal_dict


@pytest.fixture(scope="module")
def jury_ter():
metric = AutoMetric.load("ter", compute_kwargs={"normalized": True})
return Jury(metrics=metric)


@pytest.fixture
@get_expected_output(prefix="metrics")
def output_basic():
return output_basic.output


@pytest.fixture
@get_expected_output(prefix="metrics")
def output_multiple_ref():
return output_multiple_ref.output


@pytest.fixture
@get_expected_output(prefix="metrics")
def output_multiple_pred_multiple_ref():
return output_multiple_pred_multiple_ref.output


def test_basic(predictions, references, jury_ter, output_basic):
scores = jury_ter(predictions=predictions, references=references)
assert_almost_equal_dict(actual=scores, desired=output_basic)


def test_multiple_ref(predictions, multiple_references, jury_ter, output_multiple_ref):
scores = jury_ter(predictions=predictions, references=multiple_references)
assert_almost_equal_dict(actual=scores, desired=output_multiple_ref)


def test_multiple_pred_multiple_ref(
multiple_predictions, multiple_references, jury_ter, output_multiple_pred_multiple_ref
):
scores = jury_ter(predictions=multiple_predictions, references=multiple_references)
assert_almost_equal_dict(actual=scores, desired=output_multiple_pred_multiple_ref)
29 changes: 29 additions & 0 deletions tests/test_data/expected_outputs/metrics/test_ter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"basic": {
"total_items": 2,
"empty_items": 0,
"ter": {
"score": 0.4615384615384615,
"num_edits": 3,
"ref_length": 6.5
}
},
"multiple_ref": {
"total_items": 2,
"empty_items": 0,
"ter": {
"score": 0.6153846153846154,
"num_edits": 8,
"ref_length": 13.0
}
},
"multiple_pred_multiple_ref": {
"total_items": 2,
"empty_items": 0,
"ter": {
"score": 0.812121212121212,
"avg_num_edits": 3.25,
"avg_ref_length": 6.5
}
}
}