Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 4, 2025
1 parent 1934dee commit 7fbbbf9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,11 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
return False

if isinstance(state1, Tensor) and isinstance(state2, Tensor):
if not(state1.shape == state2.shape and allclose(state1, state2)):
if not (state1.shape == state2.shape and allclose(state1, state2)):
return False

if isinstance(state1, list) and isinstance(state2, list):
if not(all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2))):
if not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2))):
return False

return True
Expand Down
18 changes: 7 additions & 11 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,22 +760,18 @@ def test_collection_update():
"bleu-4": BLEUScore(4),
})

preds = ['the cat is on the mat']
target = [['there is a cat on the mat', 'a cat is on the mat']]
preds = ["the cat is on the mat"]
target = [["there is a cat on the mat", "a cat is on the mat"]]

scores.update(preds, target)
actual = scores.compute()

expected = {
'bleu-1': torch.tensor(0.8333),
'bleu-2': torch.tensor(0.8165),
'bleu-3': torch.tensor(0.7937),
'bleu-4': torch.tensor(0.7598)
"bleu-1": torch.tensor(0.8333),
"bleu-2": torch.tensor(0.8165),
"bleu-3": torch.tensor(0.7937),
"bleu-4": torch.tensor(0.7598),
}

for k, v in expected.items():
torch.testing.assert_close(actual=actual.get(k),
expected=v,
rtol=1e-4,
atol=1e-4)

torch.testing.assert_close(actual=actual.get(k), expected=v, rtol=1e-4, atol=1e-4)

0 comments on commit 7fbbbf9

Please sign in to comment.