diff --git a/CHANGELOG.md b/CHANGELOG.md index 6200571f0ac..1fbf0b1c90d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed when `_stable_1d_sort` to work when n >= N ([PL^6177](https://github.com/PyTorchLightning/pytorch-lightning/pull/6177)) - Fixed `_computed` attribute not being correctly reset ([#147](https://github.com/PyTorchLightning/metrics/pull/147)) +- Fixed to blau score ([#165](https://github.com/PyTorchLightning/metrics/pull/165)) ## [0.2.0] - 2021-03-12 diff --git a/requirements/test.txt b/requirements/test.txt index 2f79ef04731..ff77d9bc3b3 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -14,4 +14,4 @@ yapf>=0.29.0 cloudpickle>=1.3 scikit-learn>0.22.1 scikit-image>0.17.1 -nltk>=3.3 +nltk>=3.6 diff --git a/torchmetrics/functional/nlp.py b/torchmetrics/functional/nlp.py index 53f5e47e40c..50e39e5762c 100644 --- a/torchmetrics/functional/nlp.py +++ b/torchmetrics/functional/nlp.py @@ -101,6 +101,7 @@ def bleu_score( if smooth: precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) + precision_scores[0] = numerator[0] / denominator[0] else: precision_scores = numerator / denominator