Skip to content

Commit

Permalink
Fix collision in LabelCategories._indices (#51)
Browse files Browse the repository at this point in the history
* Fix collision in LabelCategories._indices

* Fix formatting

* Fix LabelCategories.items type annotation

* Introduce conftest.ASSETS_DIR for consistent assets access

* Add tests for label categories

* Fix linting issues

* Update copyright headers

* Update CHANGELOG.md

* Add link to related PR

* Test KeyError message

* Add @mark_requirement(Requirements.DATUM_GENERAL_REQ) and parametrize tests

* Comply with isort

* Add licence header to test_label_categories.py

* Comply with isort

* Rename tests to comply with current naming convention

* Add meaningful test for _reindex, fix _reindex

* Remove alias for self._indices

Not entirely agree with prohibition of chained assignments, but in this case it just might be unnecessary.
  • Loading branch information
Bobronium authored Aug 5, 2024
1 parent 3cd8bc3 commit 04832fb
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 32 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- TBD

### Fixed
- Collision between parents and names in LabelCategories
- (<https://github.com/cvat-ai/datumaro/pull/51>)
- Detection for LFW format
(<https://github.com/openvinotoolkit/datumaro/pull/680>)
- Export of masks with background class with id != 0 in the VOC, KITTI and Cityscapes formats
Expand Down
29 changes: 18 additions & 11 deletions datumaro/components/annotation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (C) 2021-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

Expand Down Expand Up @@ -106,8 +106,8 @@ class Category:
parent: str = field(default="", validator=default_if_none(str))
attributes: Set[str] = field(factory=set, validator=default_if_none(set))

items: List[str] = field(factory=list, validator=default_if_none(list))
_indices: Dict[str, int] = field(factory=dict, init=False, eq=False)
items: List[Category] = field(factory=list, validator=default_if_none(list))
_indices: Dict[Tuple[str, str], int] = field(factory=dict, init=False, eq=False)

@classmethod
def from_iterable(
Expand Down Expand Up @@ -147,26 +147,33 @@ def __attrs_post_init__(self):
self._reindex()

def _reindex(self):
indices = {}
self._indices = {}
for index, item in enumerate(self.items):
assert (item.parent + item.name) not in self._indices
indices[item.parent + item.name] = index
self._indices = indices
key = (item.parent, item.name)
if key in self._indices:
raise KeyError(f"Item with duplicate label {item.parent!r} {item.name!r}")
self._indices[key] = index

@property
def labels(self):
return {label_index: parent + name for (parent, name), label_index in self._indices.items()}

def add(
self, name: str, parent: Optional[str] = "", attributes: Optional[Set[str]] = None
) -> int:
assert name
key = (parent or "") + name
assert key not in self._indices
if not name:
raise ValueError("Label name must not be empty")
key = (parent or "", name)
if key in self._indices:
raise KeyError(f"Label {parent!r} {name!r} already exists")

index = len(self.items)
self.items.append(self.Category(name, parent, attributes))
self._indices[key] = index
return index

def find(self, name: str, parent: str = "") -> Tuple[Optional[int], Optional[Category]]:
index = self._indices.get(parent + name)
index = self._indices.get((parent, name))
if index is not None:
return index, self.items[index]
return index, None
Expand Down
6 changes: 3 additions & 3 deletions datumaro/plugins/camvid_format.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (C) 2020-2021 Intel Corporation
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

Expand Down Expand Up @@ -198,8 +199,7 @@ def _load_categories(self, path):
def _load_items(self, path):
items = {}

labels = self._categories[AnnotationType.label]._indices
labels = {labels[label_name]: label_name for label_name in labels}
labels = self._categories[AnnotationType.label].labels

with open(path, encoding="utf-8") as f:
for line in f:
Expand Down
4 changes: 2 additions & 2 deletions datumaro/plugins/coco_format/extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (C) 2019-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

Expand Down Expand Up @@ -348,7 +348,7 @@ def _load_annotations(self, ann, image_info=None, parsed_annotations=None):
for x, y, v in take_by(keypoints, 3):
sublabel = None
if i < len(sublabels):
sublabel = label_cat.find(label_cat.items[label_id].name + sublabels[i])[0]
sublabel = label_cat.find(sublabels[i], label_cat.items[label_id].name)[0]
points.append(Points([x, y], [v], label=sublabel))
i += 1

Expand Down
4 changes: 3 additions & 1 deletion datumaro/util/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def compare_datasets(
ann_b_matches, lambda x: _compare_annotations(x, ann_a, ignored_attrs=ignored_attrs)
)
if ann_b is None:
test.fail("ann %s, candidates %s" % (ann_a, ann_b_matches))
test.fail(
"ann\n\t%s,\ncandidates\n\t%s" % (ann_a, "\n\t".join(map(str, ann_b_matches)))
)
item_b.annotations.remove(ann_b) # avoid repeats


Expand Down
13 changes: 7 additions & 6 deletions tests/cli/test_revpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datumaro.util.scope import scope_add, scoped
from datumaro.util.test_utils import TestDir

from ..conftest import ASSETS_DIR
from ..requirements import Requirements, mark_requirement


Expand Down Expand Up @@ -133,13 +134,13 @@ def test_ambiguous_format(self):
# create an ambiguous dataset by merging annotations from
# datasets in different formats
annotation_dir = osp.join(dataset_url, "training/street")
assets_dir = osp.join(osp.dirname(__file__), "../assets")
os.makedirs(annotation_dir)
for asset in [
"ade20k2017_dataset/dataset/training/street/1_atr.txt",
"ade20k2020_dataset/dataset/training/street/1.json",
]:
shutil.copy(osp.join(assets_dir, asset), annotation_dir)
for root, asset_name in (
("ade20k2017_dataset", "1_atr.txt"),
("ade20k2020_dataset", "1.json"),
):
asset = ASSETS_DIR / root / "dataset" / "training" / "street" / asset_name
shutil.copy(asset, annotation_dir)

with self.subTest("no context"):
with self.assertRaises(WrongRevpathError) as cm:
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from pathlib import Path

from datumaro.util.test_utils import TestDir

from .fixtures import *
from .utils.test_utils import TestCaseHelper

ASSETS_DIR = Path(__file__).parent / "assets"


def pytest_configure(config):
# register additional markers
Expand Down
169 changes: 169 additions & 0 deletions tests/test_label_categories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (C) 2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import pytest

from datumaro import LabelCategories

from tests.requirements import Requirements, mark_requirement


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_add_category():
categories = LabelCategories()
index = categories.add("cat")
assert index == 0
assert len(categories) == 1
assert categories[0].name == "cat"


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_add_category_with_parent():
categories = LabelCategories()
index = categories.add("cat", parent="animal")
assert index == 0
assert len(categories) == 1
assert categories[0].name == "cat"
assert categories[0].parent == "animal"


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_add_category_with_attributes():
categories = LabelCategories()
attributes = {"color", "size"}
index = categories.add("cat", attributes=attributes)
assert index == 0
assert len(categories) == 1
assert categories[0].attributes == attributes


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("name,parent", [("cat", "animal"), ("cat", "")])
def test_can_add_duplicate_category(name, parent):
categories = LabelCategories()
categories.add(name, parent=parent)
with pytest.raises(KeyError, match=f"Label '{parent}' '{name}' already exists"):
categories.add(name, parent=parent)


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_resolve_potential_collision():
"""
Previously indices were computed as (parent or "") + name
See https://github.com/cvat-ai/datumaro/pull/51
"""
categories = LabelCategories()
categories.add("22", parent="parent")
categories.add("2", parent="parent2")
assert categories.items[0].name == "22"
assert categories.items[0].parent == "parent"
assert categories.items[1].name == "2"
assert categories.items[1].parent == "parent2"


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_find_category():
categories = LabelCategories()
categories.add("cat")
index, category = categories.find("cat")
assert index == 0
assert category.name == "cat"


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_cant_find_non_existent_category():
categories = LabelCategories()
index, category = categories.find("dog")
assert index is None
assert category is None


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_construct_from_iterable():
categories = LabelCategories.from_iterable(["cat", "dog"])
assert len(categories) == 2
assert categories[0].name == "cat"
assert categories[1].name == "dog"


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_construct_from_iterable_with_parents():
categories = LabelCategories.from_iterable([("cat", "animal"), ("dog", "animal")])
assert len(categories) == 2
assert categories[0].name == "cat"
assert categories[0].parent == "animal"
assert categories[1].name == "dog"
assert categories[1].parent == "animal"


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_construct_from_iterable_with_attributes():
categories = LabelCategories.from_iterable([("cat", "animal", ["color", "size"])])
assert len(categories) == 1
assert categories[0].name == "cat"
assert categories[0].parent == "animal"
assert categories[0].attributes == {"color", "size"}


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_reindex_on_init():
categories = LabelCategories(
items=[
cat_category := LabelCategories.Category("cat"),
dog_category := LabelCategories.Category("dog", "animal"),
]
)
assert categories._indices == {("", "cat"): 0, ("animal", "dog"): 1}
assert categories.find("cat") == (0, cat_category)
assert categories.find("dog", parent="animal") == (1, dog_category)
assert categories.find("dog") == (None, None)


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_cant_reindex_on_init():
with pytest.raises(KeyError, match="Item with duplicate label '' 'cat'"):
LabelCategories(
items=[
cat_category := LabelCategories.Category("cat"),
cat_category,
]
)


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_has_labels_property():
categories = LabelCategories()
categories.add("cat")
categories.add("dog", parent="animal")
labels = categories.labels
assert labels[0] == "cat"
assert labels[1] == "animaldog"


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_has_len():
categories = LabelCategories()
assert len(categories) == 0
categories.add("cat")
assert len(categories) == 1


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_has_contains():
categories = LabelCategories()
categories.add("cat")
assert "cat" in categories
assert "dog" not in categories


@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_iter():
categories = LabelCategories()
categories.add("cat")
categories.add("dog")
items = list(categories)
assert len(items) == 2
assert items[0].name == "cat"
assert items[1].name == "dog"
5 changes: 2 additions & 3 deletions tests/test_mars_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: MIT

import os.path as osp
from unittest.case import TestCase

import numpy as np
Expand All @@ -14,10 +13,10 @@
from datumaro.plugins.mars_format import MarsImporter
from datumaro.util.test_utils import compare_datasets

from tests.conftest import ASSETS_DIR
from tests.requirements import Requirements, mark_requirement

ASSETS_DIR = osp.join(osp.dirname(__file__), "assets")
DUMMY_MARS_DATASET = osp.join(ASSETS_DIR, "mars_dataset")
DUMMY_MARS_DATASET = str(ASSETS_DIR / "mars_dataset")


class MarsImporterTest(TestCase):
Expand Down
7 changes: 3 additions & 4 deletions tests/test_open_images_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from datumaro.plugins.open_images_format import OpenImagesConverter, OpenImagesImporter
from datumaro.util.test_utils import TestDir, compare_datasets

from tests.conftest import ASSETS_DIR
from tests.requirements import Requirements, mark_requirement


Expand Down Expand Up @@ -308,10 +309,8 @@ def test_can_save_and_load_with_meta_file(self):
compare_datasets(self, dataset, parsed_dataset, require_media=True)


ASSETS_DIR = osp.join(osp.dirname(__file__), "assets")

DUMMY_DATASET_DIR_V6 = osp.join(ASSETS_DIR, "open_images_dataset/v6")
DUMMY_DATASET_DIR_V5 = osp.join(ASSETS_DIR, "open_images_dataset/v5")
DUMMY_DATASET_DIR_V6 = str(ASSETS_DIR / "open_images_dataset" / "v6")
DUMMY_DATASET_DIR_V5 = str(ASSETS_DIR / "open_images_dataset" / "v5")


class OpenImagesImporterTest(TestCase):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from datumaro.plugins.sampler.random_sampler import LabelRandomSampler, RandomSampler
from datumaro.util.test_utils import compare_datasets, compare_datasets_strict

from .conftest import ASSETS_DIR

try:
import pandas as pd

Expand All @@ -31,8 +33,8 @@ class TestRelevancySampler(TestCase):
@staticmethod
def _get_probs(out_range=False):
probs = []
inference_file = "tests/assets/sampler/inference.csv"
with open(inference_file) as csv_file:
inference_file = ASSETS_DIR / "sampler" / "inference.csv"
with inference_file.open() as csv_file:
csv_reader = csv.reader(csv_file)
col = 0
for row in csv_reader:
Expand Down

0 comments on commit 04832fb

Please sign in to comment.