Skip to content

Commit

Permalink
Move Blue Score the respective folders (#360)
Browse files Browse the repository at this point in the history
* Added Blue Score the respective  folders

* File naming correction and moved existing tests

* Fixes from pre-commit

* Updated function definitions to be in sync with nltk style

* Added Blue Score the respective  folders

* File naming correction and moved existing tests

* Fixes from pre-commit

* Updated function definitions to be in sync with nltk style

* Updated docs references in the rst files to reflect in the HTML.

* Made naming changes for consistency, updated references in docs, added tests for class implementation

* Added functional/nlp.py back with Deprecation Warning for current support

* Fixed import error

* Updated docstring for deprecation and added tests for metric computation in batch

* deprecate

* chlog

* types

* Fixing doctests, updating test variables types

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
4 people authored Jul 19, 2021
1 parent 5c2069e commit 591bb2d
Show file tree
Hide file tree
Showing 11 changed files with 387 additions and 109 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `torchmetrics.functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))


- Moved `bleu_score` from `functional.nlp` to `functional.text.bleu` ([#360](https://github.com/PyTorchLightning/metrics/pull/360))


### Removed

- Removed restriction that `threshold` has to be in (0,1) range to support logit input ([#351](https://github.com/PyTorchLightning/metrics/pull/351))
Expand Down
21 changes: 10 additions & 11 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,6 @@ symmetric_mean_absolute_percentage_error [func]
.. autofunction:: torchmetrics.functional.symmetric_mean_absolute_percentage_error
:noindex:


***
NLP
***

bleu_score [func]
~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.bleu_score
:noindex:

********
Pairwise
********
Expand Down Expand Up @@ -355,3 +344,13 @@ retrieval_normalized_dcg [func]

.. autofunction:: torchmetrics.functional.retrieval_normalized_dcg
:noindex:

****
Text
****

bleu_score [func]
~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.bleu_score
:noindex:
10 changes: 10 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,16 @@ RetrievalNormalizedDCG
.. autoclass:: torchmetrics.RetrievalNormalizedDCG
:noindex:

****
Text
****

BLEUScore
~~~~~~~~~

.. autoclass:: torchmetrics.BLEUScore
:noindex:


********
Wrappers
Expand Down
88 changes: 77 additions & 11 deletions tests/functional/test_nlp.py → tests/text/test_blue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
from torch import tensor

from torchmetrics.functional import bleu_score
from torchmetrics.functional.text.bleu import bleu_score
from torchmetrics.text.bleu import BLEUScore

# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu
Expand All @@ -39,8 +41,13 @@
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
REF2A = "he was interested in world history because he read the book".split()

LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
HYPOTHESES = [HYP1, HYP2]
TUPLE_OF_REFERENCES = ((REF1A, REF1B, REF1C), tuple([REF2A]))
HYPOTHESES = (HYP1, HYP2)

BATCHES = [
dict(reference_corpus=[[REF1A, REF1B, REF1C]], translate_corpus=[HYP1]),
dict(reference_corpus=[[REF2A]], translate_corpus=[HYP2])
]

# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
smooth_func = SmoothingFunction().method2
Expand All @@ -55,28 +62,87 @@
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
],
)
def test_bleu_score(weights, n_gram, smooth_func, smooth):
def test_bleu_score_functional(weights, n_gram, smooth_func, smooth):
nltk_output = sentence_bleu(
[REFERENCE1, REFERENCE2, REFERENCE3],
HYPOTHESIS1,
weights=weights,
smoothing_function=smooth_func,
)
pl_output = bleu_score([[REFERENCE1, REFERENCE2, REFERENCE3]], [HYPOTHESIS1], n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, tensor(nltk_output))

nltk_output = corpus_bleu(TUPLE_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
pl_output = bleu_score(TUPLE_OF_REFERENCES, HYPOTHESES, n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, tensor(nltk_output))


def test_bleu_empty_functional():
hyp = [[]]
ref = [[[]]]
assert bleu_score(ref, hyp) == tensor(0.0)


def test_no_4_gram_functional():
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
assert bleu_score(refs, hyps) == tensor(0.0)


@pytest.mark.parametrize(
["weights", "n_gram", "smooth_func", "smooth"],
[
pytest.param([1], 1, None, False),
pytest.param([0.5, 0.5], 2, smooth_func, True),
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
],
)
def test_bleu_score_class(weights, n_gram, smooth_func, smooth):
bleu = BLEUScore(n_gram=n_gram, smooth=smooth)
nltk_output = sentence_bleu(
[REFERENCE1, REFERENCE2, REFERENCE3],
HYPOTHESIS1,
weights=weights,
smoothing_function=smooth_func,
)
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
pl_output = bleu([[REFERENCE1, REFERENCE2, REFERENCE3]], [HYPOTHESIS1])
assert torch.allclose(pl_output, tensor(nltk_output))

nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
nltk_output = corpus_bleu(TUPLE_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
pl_output = bleu(TUPLE_OF_REFERENCES, HYPOTHESES)
assert torch.allclose(pl_output, tensor(nltk_output))


@pytest.mark.parametrize(
["weights", "n_gram", "smooth_func", "smooth"],
[
pytest.param([1], 1, None, False),
pytest.param([0.5, 0.5], 2, smooth_func, True),
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
],
)
def test_bleu_score_class_batches(weights, n_gram, smooth_func, smooth):
bleu = BLEUScore(n_gram=n_gram, smooth=smooth)

nltk_output = corpus_bleu(TUPLE_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)

for batch in BATCHES:
bleu.update(batch['reference_corpus'], batch['translate_corpus'])
pl_output = bleu.compute()
assert torch.allclose(pl_output, tensor(nltk_output))


def test_bleu_empty():
def test_bleu_empty_class():
bleu = BLEUScore()
hyp = [[]]
ref = [[[]]]
assert bleu_score(hyp, ref) == tensor(0.0)
assert bleu(ref, hyp) == tensor(0.0)


def test_no_4_gram():
def test_no_4_gram_class():
bleu = BLEUScore()
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
assert bleu_score(hyps, refs) == tensor(0.0)
assert bleu(refs, hyps) == tensor(0.0)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@
RetrievalPrecision,
RetrievalRecall,
)
from torchmetrics.text import BLEUScore # noqa: F401 E402
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
2 changes: 1 addition & 1 deletion torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import psnr # noqa: F401
from torchmetrics.functional.image.ssim import ssim # noqa: F401
from torchmetrics.functional.nlp import bleu_score # noqa: F401
from torchmetrics.functional.regression.cosine_similarity import cosine_similarity # noqa: F401
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401
Expand All @@ -58,3 +57,4 @@
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401
from torchmetrics.functional.text.bleu import bleu_score # noqa: F401
100 changes: 14 additions & 86 deletions torchmetrics/functional/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,107 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# referenced from
# Library Name: torchtext
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from typing import Sequence
from warnings import warn

import torch
from torch import Tensor, tensor
from torch import Tensor


def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter:
"""
Counting how many times each word appears in a given text with ngram
Args:
ngram_input_list: A list of translated text or reference texts
n_gram: gram value ranged 1 to 4
Return:
ngram_counter: a collections.Counter object of ngram
"""

ngram_counter: Counter = Counter()

for i in range(1, n_gram + 1):
for j in range(len(ngram_input_list) - i + 1):
ngram_key = tuple(ngram_input_list[j:(i + j)])
ngram_counter[ngram_key] += 1

return ngram_counter
from torchmetrics.functional.text.bleu import bleu_score as _bleu_score


def bleu_score(
translate_corpus: Sequence[Sequence[str]],
reference_corpus: Sequence[Sequence[Sequence[str]]],
translate_corpus: Sequence[Sequence[str]],
n_gram: int = 4,
smooth: bool = False
) -> Tensor:
"""
Calculate BLEU score of machine translated text with one or more references
Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
n_gram: Gram value ranged from 1 to 4
smooth: Whether or not to apply smoothing – Lin et al. 2004
Return:
Tensor with BLEU Score
Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine translated text with one or more references
Example:
>>> from torchmetrics.functional import bleu_score
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> bleu_score(translate_corpus, reference_corpus)
>>> bleu_score(reference_corpus, translate_corpus)
tensor(0.7598)
"""

if len(translate_corpus) != len(reference_corpus):
raise ValueError(f"Corpus has different size {len(translate_corpus)} != {len(reference_corpus)}")
numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
c = 0.0
r = 0.0

for (translation, references) in zip(translate_corpus, reference_corpus):
c += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
translation_counter: Counter = _count_ngram(translation, n_gram)
reference_counter: Counter = Counter()

for ref in references:
reference_counter |= _count_ngram(ref, n_gram)

ngram_counter_clip = translation_counter & reference_counter

for counter_clip in ngram_counter_clip:
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
for counter in translation_counter:
denominator[len(counter) - 1] += translation_counter[counter]

trans_len = tensor(c)
ref_len = tensor(r)

if min(numerator) == 0.0:
return tensor(0.0)

if smooth:
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
precision_scores[0] = numerator[0] / denominator[0]
else:
precision_scores = numerator / denominator

log_precision_scores = tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
bleu = brevity_penalty * geometric_mean

return bleu
.. deprecated:: v0.5
Use :func:`torchmetrics.functional.text.bleu.bleu_score`. Will be removed in v0.6.
"""
warn(
"Function `functional.nlp.bleu_score` is deprecated in v0.5 and will be removed in v0.6."
" Use `functional.text.bleu.bleu_score` instead.", DeprecationWarning
)
return _bleu_score(reference_corpus, translate_corpus, n_gram, smooth)
15 changes: 15 additions & 0 deletions torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.text.bleu import bleu_score # noqa: F401
Loading

0 comments on commit 591bb2d

Please sign in to comment.