-
Notifications
You must be signed in to change notification settings - Fork 416
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add Extended Edit Distance (EED) metric (#668)
* add Extended Edit Distance (EED) metric * flake8, mypy, and doctest * fixed weird bug where parallelized metric was giving different answers to non-parallelized metric * update CHANGELOG.md Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Daniel Stancl <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
- Loading branch information
1 parent
3bd4fb0
commit 8ef281c
Showing
11 changed files
with
721 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from functools import partial | ||
|
||
import pytest | ||
from torch import Tensor, tensor | ||
|
||
from tests.text.helpers import TextTester | ||
from tests.text.inputs import _inputs_single_reference, _inputs_single_sentence_multiple_references | ||
from torchmetrics.functional.text.eed import extended_edit_distance | ||
from torchmetrics.text.eed import ExtendedEditDistance | ||
|
||
|
||
def rwth_manual_metric(preds, targets) -> Tensor: | ||
"""The results were obtained w.r.t. | ||
the examples defined in `tests.text.inputs` with the script from https://github.com/rwth-i6/ExtendedEditDistance. | ||
""" | ||
ans_1 = tensor(0.24248056001808083) | ||
ans_2 = tensor(0.19152276295133436) | ||
|
||
HYPOTHESIS_A = "It is a guide to action which ensures that the military always obeys the commands of the party" | ||
|
||
# If hypothesis A and B are in preds, the average of ans_1 and ans_2 is given | ||
if len(preds) == 4: | ||
return (ans_1 + ans_2) / 2 | ||
# If only hypothesis A or B are given, ans_1 and ans_2 are given, respectively | ||
if HYPOTHESIS_A in preds: | ||
return ans_1 | ||
return ans_2 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["preds", "targets"], | ||
[(_inputs_single_reference.preds, _inputs_single_reference.targets)], | ||
) | ||
class TestExtendedEditDistance(TextTester): | ||
@pytest.mark.parametrize("ddp", [False, True]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [False, True]) | ||
def test_eed_class(self, preds, targets, ddp, dist_sync_on_step): | ||
rwth_metric = partial(rwth_manual_metric) | ||
self.run_class_metric_test( | ||
ddp=ddp, | ||
preds=preds, | ||
targets=targets, | ||
metric_class=ExtendedEditDistance, | ||
sk_metric=rwth_metric, | ||
dist_sync_on_step=dist_sync_on_step, | ||
) | ||
|
||
def test_eed_functional(self, preds, targets): | ||
rwth_metric = partial(rwth_manual_metric) | ||
self.run_functional_metric_test( | ||
preds, | ||
targets, | ||
metric_functional=extended_edit_distance, | ||
sk_metric=rwth_metric, | ||
) | ||
|
||
def test_eed_differentiability(self, preds, targets): | ||
self.run_differentiability_test( | ||
preds=preds, | ||
targets=targets, | ||
metric_module=ExtendedEditDistance, | ||
metric_functional=extended_edit_distance, | ||
) | ||
|
||
|
||
# test blank edge cases | ||
def test_eed_empty_functional(): | ||
hyp = [] | ||
ref = [[]] | ||
assert extended_edit_distance(hyp, ref) == tensor(0.0) | ||
|
||
|
||
def test_eed_empty_class(): | ||
eed_metric = ExtendedEditDistance() | ||
hyp = [] | ||
ref = [[]] | ||
assert eed_metric(hyp, ref) == tensor(0.0) | ||
|
||
|
||
def test_eed_empty_with_non_empty_hyp_functional(): | ||
hyp = ["python"] | ||
ref = [[]] | ||
assert extended_edit_distance(hyp, ref) == tensor(0.0) | ||
|
||
|
||
def test_eed_empty_with_non_empty_hyp_class(): | ||
eed_metric = ExtendedEditDistance() | ||
hyp = ["python"] | ||
ref = [[]] | ||
assert eed_metric(hyp, ref) == tensor(0.0) | ||
|
||
|
||
def test_eed_return_sentence_level_score_functional(): | ||
hyp = _inputs_single_sentence_multiple_references.preds | ||
ref = _inputs_single_sentence_multiple_references.targets | ||
_, sentence_eed = extended_edit_distance(hyp, ref, return_sentence_level_score=True) | ||
isinstance(sentence_eed, Tensor) | ||
|
||
|
||
def test_eed_return_sentence_level_class(): | ||
metric = ExtendedEditDistance(return_sentence_level_score=True) | ||
hyp = _inputs_single_sentence_multiple_references.preds | ||
ref = _inputs_single_sentence_multiple_references.targets | ||
_, sentence_eed = metric(hyp, ref) | ||
isinstance(sentence_eed, Tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.