diff --git a/CHANGELOG.md b/CHANGELOG.md index af120dfdaf2..f86a4e62a1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Renamed IoU -> Jaccard Index ([#662](https://github.com/PyTorchLightning/metrics/pull/662)) +- Renamed `WER` -> `WordErrorRate` and `wer` -> `word_error_rate` ([#714](https://github.com/PyTorchLightning/metrics/pull/714)) + - Renamed correlation coefficient classes: ([#710](https://github.com/PyTorchLightning/metrics/pull/710)) * `MatthewsCorrcoef` -> `MatthewsCorrCoef` diff --git a/tests/text/test_wer.py b/tests/text/test_wer.py index 23fea84224a..d75fecf4db3 100644 --- a/tests/text/test_wer.py +++ b/tests/text/test_wer.py @@ -11,8 +11,8 @@ else: compute_measures = Callable -from torchmetrics.functional.text.wer import wer -from torchmetrics.text.wer import WER +from torchmetrics.functional.text.wer import word_error_rate +from torchmetrics.text.wer import WordErrorRate def _compute_wer_metric_jiwer(prediction: Union[str, List[str]], reference: Union[str, List[str]]): @@ -36,7 +36,7 @@ def test_wer_class(self, ddp, dist_sync_on_step, preds, targets): ddp=ddp, preds=preds, targets=targets, - metric_class=WER, + metric_class=WordErrorRate, sk_metric=_compute_wer_metric_jiwer, dist_sync_on_step=dist_sync_on_step, ) @@ -46,7 +46,7 @@ def test_wer_functional(self, preds, targets): self.run_functional_metric_test( preds, targets, - metric_functional=wer, + metric_functional=word_error_rate, sk_metric=_compute_wer_metric_jiwer, ) @@ -55,6 +55,6 @@ def test_wer_differentiability(self, preds, targets): self.run_differentiability_test( preds=preds, targets=targets, - metric_module=WER, - metric_functional=wer, + metric_module=WordErrorRate, + metric_functional=word_error_rate, ) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 95ba4efd7ec..95199124583 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -78,6 +78,7 @@ SacreBLEUScore, SQuAD, TranslationEditRate, + WordErrorRate, WordInfoLost, WordInfoPreserved, ) @@ -154,6 +155,7 @@ "SymmetricMeanAbsolutePercentageError", "TranslationEditRate", "WER", + "WordErrorRate", "CharErrorRate", "MatchErrorRate", "WordInfoLost", diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index c5f75d64ace..c500d574cc8 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -73,7 +73,7 @@ from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score from torchmetrics.functional.text.squad import squad from torchmetrics.functional.text.ter import translation_edit_rate -from torchmetrics.functional.text.wer import wer +from torchmetrics.functional.text.wer import wer, word_error_rate from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.functional.text.wip import word_information_preserved @@ -140,6 +140,7 @@ "symmetric_mean_absolute_percentage_error", "translation_edit_rate", "wer", + "word_error_rate", "char_error_rate", "match_error_rate", "word_information_lost", diff --git a/torchmetrics/functional/text/__init__.py b/torchmetrics/functional/text/__init__.py index c08a3af9b3f..2430ec68e4c 100644 --- a/torchmetrics/functional/text/__init__.py +++ b/torchmetrics/functional/text/__init__.py @@ -19,6 +19,6 @@ from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401 from torchmetrics.functional.text.squad import squad # noqa: F401 from torchmetrics.functional.text.ter import translation_edit_rate # noqa: F401 -from torchmetrics.functional.text.wer import wer # noqa: F401 +from torchmetrics.functional.text.wer import wer, word_error_rate # noqa: F401 from torchmetrics.functional.text.wil import word_information_lost # noqa: F401 from torchmetrics.functional.text.wip import word_information_preserved # noqa: F401 diff --git a/torchmetrics/functional/text/wer.py b/torchmetrics/functional/text/wer.py index 8885adaaa82..25dafcb2b98 100644 --- a/torchmetrics/functional/text/wer.py +++ b/torchmetrics/functional/text/wer.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import List, Tuple, Union +from warnings import warn import torch from torch import Tensor, tensor @@ -61,7 +62,7 @@ def _wer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -def wer( +def word_error_rate( predictions: Union[str, List[str]], references: Union[str, List[str]], ) -> Tensor: @@ -79,8 +80,27 @@ def wer( Examples: >>> predictions = ["this is the prediction", "there is an other sample"] >>> references = ["this is the reference", "there is another one"] - >>> wer(predictions=predictions, references=references) + >>> word_error_rate(predictions=predictions, references=references) tensor(0.5000) """ errors, total = _wer_update(predictions, references) return _wer_compute(errors, total) + + +def wer( + predictions: Union[str, List[str]], + references: Union[str, List[str]], +) -> Tensor: + """Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. + + .. deprecated:: v0.7 + Use :func:`torchmetrics.fuctional.word_error_rate`. Will be removed in v0.8. + + Examples: + >>> predictions = ["this is the prediction", "there is an other sample"] + >>> references = ["this is the reference", "there is another one"] + >>> wer(predictions=predictions, references=references) + tensor(0.5000) + """ + warn("`wer` was renamed to `word_error_rate` in v0.7 and it will be removed in v0.8", DeprecationWarning) + return word_error_rate(predictions, references) diff --git a/torchmetrics/text/__init__.py b/torchmetrics/text/__init__.py index b45d5ef29bc..b3ca31681bd 100644 --- a/torchmetrics/text/__init__.py +++ b/torchmetrics/text/__init__.py @@ -18,6 +18,6 @@ from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401 from torchmetrics.text.squad import SQuAD # noqa: F401 from torchmetrics.text.ter import TranslationEditRate # noqa: F401 -from torchmetrics.text.wer import WER # noqa: F401 +from torchmetrics.text.wer import WER, WordErrorRate # noqa: F401 from torchmetrics.text.wil import WordInfoLost # noqa: F401 from torchmetrics.text.wip import WordInfoPreserved # noqa: F401 diff --git a/torchmetrics/text/wer.py b/torchmetrics/text/wer.py index c01bfeb7c0b..7eea21fbf68 100644 --- a/torchmetrics/text/wer.py +++ b/torchmetrics/text/wer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, List, Optional, Union +from warnings import warn import torch from torch import Tensor, tensor @@ -20,7 +21,7 @@ from torchmetrics.metric import Metric -class WER(Metric): +class WordErrorRate(Metric): r""" Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. This value indicates the percentage of words that were incorrectly predicted. @@ -57,7 +58,7 @@ class WER(Metric): Examples: >>> predictions = ["this is the prediction", "there is an other sample"] >>> references = ["this is the reference", "there is another one"] - >>> metric = WER() + >>> metric = WordErrorRate() >>> metric(predictions, references) tensor(0.5000) """ @@ -100,3 +101,29 @@ def compute(self) -> Tensor: Word error rate score """ return _wer_compute(self.errors, self.total) + + +class WER(WordErrorRate): + r""" + Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. + + .. deprecated:: v0.7 + Use :class:`torchmetrics.WordErrorRate`. Will be removed in v0.8. + + Examples: + >>> predictions = ["this is the prediction", "there is an other sample"] + >>> references = ["this is the reference", "there is another one"] + >>> metric = WER() + >>> metric(predictions, references) + tensor(0.5000) + """ + + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + warn("`WER` was renamed to `WordErrorRate` in v0.7 and it will be removed in v0.8", DeprecationWarning) + super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)