-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CVAT][Recording Oracle] Add first set of agreement measures (#757)
* add first set of agreement measures. add unit tests. * align import with project convention. * remove requirements.txt. move agreement to modules. * add input validation to agreement functions.
- Loading branch information
1 parent
4c675dc
commit 025a29d
Showing
6 changed files
with
757 additions
and
660 deletions.
There are no files selected for viewing
1,236 changes: 578 additions & 658 deletions
1,236
packages/examples/cvat/recording-oracle/poetry.lock
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
name = "recording-oracle" | ||
version = "0.1.0" | ||
description = "An example of recording with with CVAT as an annotation instrument" | ||
authors = ["Sergey Dzeranov <[email protected]>"] | ||
authors = ["Sergey Dzeranov <[email protected]>", "Marius Hamacher <[email protected]>"] | ||
readme = "README.md" | ||
packages = [{include = "recording_oracle"}] | ||
|
||
|
@@ -18,7 +18,7 @@ pytest = "^7.4.0" | |
human-protocol-sdk = "^1.1.5" | ||
alembic = "^1.11.1" | ||
httpx = "^0.24.1" | ||
|
||
numpy = "^1.25.2" | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
black = "^23.3.0" | ||
|
1 change: 1 addition & 0 deletions
1
packages/examples/cvat/recording-oracle/src/modules/agreement/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .measures import percent_agreement, cohens_kappa, fleiss_kappa |
104 changes: 104 additions & 0 deletions
104
packages/examples/cvat/recording-oracle/src/modules/agreement/measures.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import numpy as np | ||
|
||
|
||
def _validate_nd(M: np.ndarray, n=2): | ||
"""Validates that M has n dimensions.""" | ||
if M.ndim != n: | ||
raise ValueError(f"Input must be a n-dimensional array-like.") | ||
|
||
|
||
def _validate_dtype_is_subtype_of(M: np.ndarray, supertype: np.dtype): | ||
"""Validates the data type of M is a subtype of supertype.""" | ||
if not issubclass(M.dtype.type, supertype): | ||
raise ValueError( | ||
f"Input must have a data type that is a subtype of " f"{supertype}" | ||
) | ||
|
||
|
||
def _validate_all_positive(M: np.ndarray): | ||
""" | ||
Validates that all entries in M are positive (including 0). | ||
Raises a ValueError if not. | ||
""" | ||
if np.any(M < 0): | ||
raise ValueError("Inputs must all be positive") | ||
|
||
|
||
def _validate_sufficient_annotations(M, n=1): | ||
"""Validates that M contains enough annotations.""" | ||
if M.sum() <= n: | ||
raise ValueError(f"Input must have more than {1} annotation.") | ||
|
||
|
||
def _validate_incidence_matrix(M): | ||
"""Validates that M is an incidence matrix.""" | ||
_validate_nd(M, n=2) | ||
_validate_dtype_is_subtype_of(M, np.integer) | ||
_validate_all_positive(M) | ||
_validate_sufficient_annotations(M, n=1) | ||
|
||
|
||
def _validate_confusion_matrix(M): | ||
"""Validates that M is a confusion Matrix.""" | ||
_validate_incidence_matrix(M) | ||
|
||
if M.shape[0] != M.shape[1]: | ||
raise ValueError("Input must be a square matrix.") | ||
|
||
|
||
def percent_agreement(data: np.ndarray, data_format="im") -> float: | ||
""" | ||
Returns the overall agreement percentage observed across the data. | ||
Args: | ||
data: Annotation data. | ||
data_format: The format of data. Options are 'im' for an incidence | ||
matrix and 'cm' for a confusion matrix. Defaults to 'im'. | ||
""" | ||
data = np.asarray(data) | ||
|
||
if data_format == "cm": | ||
_validate_confusion_matrix(data) | ||
return np.diag(data).sum() / data.sum() | ||
|
||
# implicitly assumes incidence matrix | ||
_validate_incidence_matrix(data) | ||
|
||
n_raters = np.max(data) | ||
item_agreements = np.sum(data * data, 1) - n_raters | ||
max_item_agreements = n_raters * (n_raters - 1) | ||
return (item_agreements / max_item_agreements).mean() | ||
|
||
|
||
def cohens_kappa(data: np.ndarray) -> float: | ||
""" | ||
Returns Cohen's Kappa for the provided annotations. | ||
Args: | ||
data: Annotation data, provided as K x K confusion matrix, with K = | ||
number of labels. | ||
""" | ||
data = np.asarray(data) | ||
|
||
agreement_observed = percent_agreement(data, "cm") | ||
agreement_expected = np.matmul(data.sum(0), data.sum(1)) / data.sum() ** 2 | ||
|
||
return (agreement_observed - agreement_expected) / (1 - agreement_expected) | ||
|
||
|
||
def fleiss_kappa(data: np.ndarray) -> float: | ||
""" | ||
Returns Fleisss' Kappa for the provided annotations. | ||
Args: | ||
data: Annotation data, provided as I x K incidence matrix, with | ||
I = number of items and K = number of labels. | ||
""" | ||
data = np.asarray(data) | ||
|
||
agreement_observed = percent_agreement(data, "im") | ||
|
||
class_probabilities = data.sum(0) / data.sum() | ||
agreement_expected = np.power(class_probabilities, 2).sum() | ||
|
||
return (agreement_observed - agreement_expected) / (1 - agreement_expected) |
39 changes: 39 additions & 0 deletions
39
packages/examples/cvat/recording-oracle/tests/unit/modules/agreement/conftest.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
import numpy as np | ||
|
||
|
||
@pytest.fixture | ||
def bin_2r_cm() -> np.ndarray: | ||
""" | ||
Returns a confusion matrix (rater_a x rater_b) for a binary classification | ||
problem with two raters. | ||
""" | ||
return np.asarray([[2, 2], [1, 5]]) | ||
|
||
|
||
@pytest.fixture | ||
def bin_2r_im() -> np.ndarray: | ||
""" | ||
Returns an incidence matrix (item x class) for a binary classification | ||
problem with two raters. | ||
""" | ||
return np.asarray( | ||
[[2, 0], [2, 0], [1, 1], [1, 1], [1, 1], [0, 2], [0, 2], [0, 2], [0, 2], [0, 2]] | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def bin_mr_im() -> np.ndarray: | ||
return np.asarray( | ||
[[3, 0], [2, 1], [2, 1], [2, 1], [1, 2], [0, 3], [0, 3], [1, 2], [1, 2], [1, 2]] | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def single_anno_cm() -> np.ndarray: | ||
return np.asarray([[1, 0], [0, 0]]) | ||
|
||
|
||
@pytest.fixture | ||
def wrong_dtype_cm() -> np.ndarray: | ||
return np.asarray([[1.0, 2.0], [3.0, 4.0]]) |
33 changes: 33 additions & 0 deletions
33
packages/examples/cvat/recording-oracle/tests/unit/modules/agreement/test_measures.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from src.modules.agreement.measures import percent_agreement, cohens_kappa, fleiss_kappa | ||
import pytest | ||
|
||
|
||
def _eq_rounded(a, b, n_digits=3): | ||
return round(a, n_digits) == round(b, n_digits) | ||
|
||
|
||
def test_percent_agreement(bin_2r_cm, bin_2r_im, single_anno_cm, wrong_dtype_cm): | ||
percentage = percent_agreement(bin_2r_cm, "cm") | ||
assert _eq_rounded(percentage, 0.7) | ||
|
||
percentage_incidence = percent_agreement(bin_2r_im, "im") | ||
assert _eq_rounded(percentage, percentage_incidence) | ||
|
||
with pytest.raises(ValueError, match="have more than 1 annotation"): | ||
percent_agreement(single_anno_cm, "cm") | ||
|
||
with pytest.raises(ValueError, match="must be a square"): | ||
percent_agreement(bin_2r_im, "cm") | ||
|
||
with pytest.raises(ValueError, match="is a subtype of"): | ||
percent_agreement(wrong_dtype_cm) | ||
|
||
|
||
def test_cohens_kappa(bin_2r_cm): | ||
kappa = cohens_kappa(bin_2r_cm) | ||
assert _eq_rounded(kappa, 0.348) | ||
|
||
|
||
def test_fleiss_kappa(bin_mr_im): | ||
kappa = fleiss_kappa(bin_mr_im) | ||
assert _eq_rounded(kappa, 0.05) |