Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement own BERT score #473

Merged
merged 46 commits into from
Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f72661a
[WIP] Start adding own BERTScore implementation
stancld Aug 18, 2021
2620b36
[WIP] Prepare the basic backbone for own BERTScore
stancld Aug 18, 2021
a79b091
[WIP] Make working BERTScore with all_layers=True + init bert_score i…
stancld Aug 19, 2021
b145470
Fix cuda device placing
stancld Aug 19, 2021
f1c66b1
Use IDF only if asked
stancld Aug 20, 2021
3b741cb
Add data collators
stancld Aug 20, 2021
6a65526
Add some docs and clean code
stancld Aug 20, 2021
9a1b275
Update docs, write new tests, clean the code
stancld Aug 20, 2021
97f6796
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2021
f544c70
Fix IDF rescaling and add new tests + fix type
stancld Aug 21, 2021
aec66b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2021
4cd7ea3
Adjust code to work with the DDP plus some changes
stancld Aug 21, 2021
10a63a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2021
8f32421
Fix a bug with tokenizer and add hash_code
stancld Aug 21, 2021
8c68e01
Fix transformers import
stancld Aug 21, 2021
474baf7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2021
89b49d8
Fix some mypy, flake8 and logic errors
stancld Aug 21, 2021
d97153e
Add support for the user's own model
stancld Aug 21, 2021
31abcf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2021
af98261
[WIP] Fix error raised by default tokenizer
stancld Aug 21, 2021
3c07bf0
Add support for the rescale with baseline
stancld Aug 22, 2021
a42cdc7
black formatting
stancld Aug 22, 2021
6052525
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2021
77474ad
Updare requirements + clean a bit
stancld Aug 22, 2021
947fcca
Try to fix recursive-import issue + decrease tqdm version
stancld Aug 22, 2021
50158f1
Merge branch 'master' into own_bert-score
stancld Aug 24, 2021
e20307f
Apply Borda's suggestions
stancld Aug 24, 2021
61506a3
Clean some code + add some docstirngs
stancld Aug 24, 2021
54c492e
Do some refactoring
stancld Aug 24, 2021
1caea8a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
741f36d
Merge branch 'master' into own_bert-score
stancld Aug 24, 2021
5d198ed
Merge branch 'master' into own_bert-score
stancld Aug 25, 2021
5e24e17
Apply suggestions from code review
Borda Aug 25, 2021
aa3e6a8
Merge branch 'master' into own_bert-score
Borda Aug 26, 2021
cb66482
Apply some borda's suggestions
stancld Aug 26, 2021
1df2bca
Merge branch 'master' into own_bert-score
stancld Aug 26, 2021
e2262d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2021
88655b7
Merge verbose/tqdm_available if statements
stancld Aug 26, 2021
eb234c7
Add docstring
stancld Aug 26, 2021
4b12107
Change docstring
stancld Aug 26, 2021
4e5c1da
Merge branch 'master' into own_bert-score
stancld Aug 27, 2021
d0d1419
Run bert-ddp tests only if torch.distributed.is_available()
stancld Aug 27, 2021
9759298
Simplify a condition
stancld Aug 27, 2021
a8dda94
Merge branch 'master' into own_bert-score
mergify[bot] Aug 27, 2021
64e0863
Use smaller model, 'albert-base-v2', for testing because of OOM issues
stancld Aug 27, 2021
f23451e
Set join=False for mp_spawn in the ddp_test
stancld Aug 29, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed `jiwer` as dependency for text package ([#446](https://github.com/PyTorchLightning/metrics/pull/446))


- Removed `bert-score` as dependency for text package ([#473](https://github.com/PyTorchLightning/metrics/pull/473))

### Fixed

- Fixed ranking of samples in `SpearmanCorrCoef` metric ([#448](https://github.com/PyTorchLightning/metrics/pull/448))
Expand Down
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ include LICENSE
# Include marker file for PEP 561
include torchmetrics/py.typed

# Include examples
recursive-include tm_examples *.py

exclude *.sh
exclude *.toml
exclude *.svg
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ known_first_party = [
"torchmetrics",
"tests",
"integrations",
"tm_examples",
]
skip_glob = []
profile = "black"
Expand Down
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/mas
# text
jiwer>=2.2.0
rouge-score>=0.0.4
bert_score==0.3.10
transformers>=4.0
2 changes: 1 addition & 1 deletion requirements/text.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
nltk>=3.6
bert-score==0.3.10
tqdm>=4.41.0
288 changes: 272 additions & 16 deletions tests/text/test_bertscore.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from typing import Any
import os
from typing import Any, Dict, List

import numpy as np
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from torchmetrics.functional import bert_score
from torchmetrics.functional import bert_score as metrics_bert_score
from torchmetrics.text import BERTScore
from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE

if _BERTSCORE_AVAILABLE:
from bert_score import score as original_bert_score

os.environ["TOKENIZERS_PARALLELISM"] = "true"

# Examples and expected values taken from:
# https://github.com/Tiiiger/bert_score/blob/master/tests/test_scorer.py
preds = [
Expand All @@ -25,11 +34,20 @@
]


_METRICS = ["precision", "recall", "f1"]


def _assert_list(preds: Any, refs: Any, threshold: float = 1e-8):
"""Assert two lists are equal."""
assert np.allclose(preds, refs, atol=threshold, equal_nan=True)


def _parse_original_bert_score(score: torch.Tensor) -> Dict[str, List[float]]:
"""Parse the BERT score returned by the original `bert-score` package."""
score_dict = {metric: value.tolist() for metric, value in zip(_METRICS, score)}
return score_dict


preds_batched = [preds[0:2], preds[2:]]
refs_batched = [refs[0:2], refs[2:]]

Expand All @@ -41,10 +59,145 @@ def _assert_list(preds: Any, refs: Any, threshold: float = 1e-8):
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn(preds, refs):
"""Tests for functional."""
Score = bert_score(preds, refs, model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516])
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749])
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905])
original_score = original_bert_score(
preds, refs, model_type="bert-base-uncased", num_layers=8, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds, refs, model_name_or_path="bert-base-uncased", num_layers=8, idf=False, batch_size=3
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_with_idf(preds, refs):
"""Tests for functional with IDF rescaling."""
original_score = original_bert_score(
preds, refs, model_type="bert-base-uncased", num_layers=12, idf=True, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds, refs, model_name_or_path="bert-base-uncased", num_layers=12, idf=True, batch_size=3
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_all_layers(preds, refs):
"""Tests for functional and all layers."""
original_score = original_bert_score(
preds, refs, model_type="bert-base-uncased", all_layers=True, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds, refs, model_name_or_path="bert-base-uncased", all_layers=True, idf=False, batch_size=3
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_all_layers_with_idf(preds, refs):
"""Tests for functional and all layers with IDF rescaling."""
original_score = original_bert_score(
preds, refs, model_type="bert-base-uncased", all_layers=True, idf=True, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds, refs, model_name_or_path="bert-base-uncased", all_layers=True, idf=True, batch_size=3
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_all_layers_rescale_with_baseline(preds, refs):
"""Tests for functional with baseline rescaling."""
original_score = original_bert_score(
preds,
refs,
model_type="bert-base-uncased",
lang="en",
num_layers=8,
idf=False,
batch_size=3,
rescale_with_baseline=True,
)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds,
refs,
model_name_or_path="bert-base-uncased",
lang="en",
num_layers=8,
idf=False,
batch_size=3,
rescale_with_baseline=True,
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_rescale_with_baseline(preds, refs):
"""Tests for functional with baseline rescaling with all layers."""
original_score = original_bert_score(
preds,
refs,
model_type="bert-base-uncased",
lang="en",
all_layers=True,
idf=False,
batch_size=3,
rescale_with_baseline=True,
)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds,
refs,
model_name_or_path="bert-base-uncased",
lang="en",
all_layers=True,
idf=False,
batch_size=3,
rescale_with_baseline=True,
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
Expand All @@ -54,12 +207,77 @@ def test_score_fn(preds, refs):
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score(preds, refs):
"""Tests for metric."""
Scorer = BERTScore(model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
original_score = original_bert_score(
preds, refs, model_type="bert-base-uncased", num_layers=8, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path="bert-base-uncased", num_layers=8, idf=False, batch_size=3)
Scorer.update(predictions=preds, references=refs)
Score = Scorer.compute()
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516])
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749])
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905])
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_with_idf(preds, refs):
"""Tests for metric with IDF rescaling."""
original_score = original_bert_score(
preds, refs, model_type="bert-base-uncased", num_layers=8, idf=True, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path="bert-base-uncased", num_layers=8, idf=True, batch_size=3)
Scorer.update(predictions=preds, references=refs)
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_all_layers(preds, refs):
"""Tests for metric and all layers."""
original_score = original_bert_score(
preds, refs, model_type="bert-base-uncased", all_layers=True, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path="bert-base-uncased", all_layers=True, idf=False, batch_size=3)
Scorer.update(predictions=preds, references=refs)
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_all_layers_with_idf(preds, refs):
"""Tests for metric and all layers with IDF rescaling."""
original_score = original_bert_score(
preds, refs, model_type="bert-base-uncased", all_layers=True, idf=True, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path="bert-base-uncased", all_layers=True, idf=True, batch_size=3)
Scorer.update(predictions=preds, references=refs)
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
Expand All @@ -69,10 +287,48 @@ def test_score(preds, refs):
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_accumulation(preds, refs):
"""Tests for metric works with accumulation."""
Scorer = BERTScore(model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
original_score = original_bert_score(
sum(preds, []), sum(refs, []), model_type="bert-base-uncased", num_layers=8, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path="bert-base-uncased", num_layers=8, idf=False, batch_size=3)
for p, r in zip(preds, refs):
Scorer.update(predictions=p, references=r)
Score = Scorer.compute()
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516])
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749])
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905])
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


def _bert_score_ddp(rank, world_size, preds, refs, original_score):
"""Define a DDP process for BERTScore."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
Scorer = BERTScore(model_name_or_path="bert-base-uncased", num_layers=8, idf=False, batch_size=3)
Scorer.update(preds, refs)
metrics_score = Scorer.compute()
for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])
dist.destroy_process_group()


def _test_score_ddp_fn(rank, world_size, preds, refs):
"""Core functionality for the `test_score_ddp` test."""
original_score = original_bert_score(
preds, refs, model_type="bert-base-uncased", num_layers=8, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)
_bert_score_ddp(rank, world_size, preds, refs, original_score)


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_ddp(preds, refs):
"""Tests for metric using DDP."""
world_size = 2
mp.spawn(_test_score_ddp_fn, args=(world_size, preds, refs), nprocs=world_size, join=True)
Loading