diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f2df75422..f42f1c712e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,6 +86,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - TBD ### Fixed +- Collision between parents and names in LabelCategories +- () - Detection for LFW format () - Export of masks with background class with id != 0 in the VOC, KITTI and Cityscapes formats diff --git a/datumaro/components/annotation.py b/datumaro/components/annotation.py index ee8bd73c1c..2808dbcf72 100644 --- a/datumaro/components/annotation.py +++ b/datumaro/components/annotation.py @@ -1,5 +1,5 @@ # Copyright (C) 2021-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -106,8 +106,8 @@ class Category: parent: str = field(default="", validator=default_if_none(str)) attributes: Set[str] = field(factory=set, validator=default_if_none(set)) - items: List[str] = field(factory=list, validator=default_if_none(list)) - _indices: Dict[str, int] = field(factory=dict, init=False, eq=False) + items: List[Category] = field(factory=list, validator=default_if_none(list)) + _indices: Dict[Tuple[str, str], int] = field(factory=dict, init=False, eq=False) @classmethod def from_iterable( @@ -147,18 +147,25 @@ def __attrs_post_init__(self): self._reindex() def _reindex(self): - indices = {} + self._indices = {} for index, item in enumerate(self.items): - assert (item.parent + item.name) not in self._indices - indices[item.parent + item.name] = index - self._indices = indices + key = (item.parent, item.name) + if key in self._indices: + raise KeyError(f"Item with duplicate label {item.parent!r} {item.name!r}") + self._indices[key] = index + + @property + def labels(self): + return {label_index: parent + name for (parent, name), label_index in self._indices.items()} def add( self, name: str, parent: Optional[str] = "", attributes: Optional[Set[str]] = None ) -> int: - assert name - key = (parent or "") + name - assert key not in self._indices + if not name: + raise ValueError("Label name must not be empty") + key = (parent or "", name) + if key in self._indices: + raise KeyError(f"Label {parent!r} {name!r} already exists") index = len(self.items) self.items.append(self.Category(name, parent, attributes)) @@ -166,7 +173,7 @@ def add( return index def find(self, name: str, parent: str = "") -> Tuple[Optional[int], Optional[Category]]: - index = self._indices.get(parent + name) + index = self._indices.get((parent, name)) if index is not None: return index, self.items[index] return index, None diff --git a/datumaro/plugins/camvid_format.py b/datumaro/plugins/camvid_format.py index 5e5230283a..92f1f2a7e2 100644 --- a/datumaro/plugins/camvid_format.py +++ b/datumaro/plugins/camvid_format.py @@ -1,4 +1,5 @@ -# Copyright (C) 2020-2021 Intel Corporation +# Copyright (C) 2020-2022 Intel Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -198,8 +199,7 @@ def _load_categories(self, path): def _load_items(self, path): items = {} - labels = self._categories[AnnotationType.label]._indices - labels = {labels[label_name]: label_name for label_name in labels} + labels = self._categories[AnnotationType.label].labels with open(path, encoding="utf-8") as f: for line in f: diff --git a/datumaro/plugins/coco_format/extractor.py b/datumaro/plugins/coco_format/extractor.py index a0b8e4c2e1..b00ff1f19f 100644 --- a/datumaro/plugins/coco_format/extractor.py +++ b/datumaro/plugins/coco_format/extractor.py @@ -1,5 +1,5 @@ # Copyright (C) 2019-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -348,7 +348,7 @@ def _load_annotations(self, ann, image_info=None, parsed_annotations=None): for x, y, v in take_by(keypoints, 3): sublabel = None if i < len(sublabels): - sublabel = label_cat.find(label_cat.items[label_id].name + sublabels[i])[0] + sublabel = label_cat.find(sublabels[i], label_cat.items[label_id].name)[0] points.append(Points([x, y], [v], label=sublabel)) i += 1 diff --git a/datumaro/util/test_utils.py b/datumaro/util/test_utils.py index 1041454f94..9518d47502 100644 --- a/datumaro/util/test_utils.py +++ b/datumaro/util/test_utils.py @@ -191,7 +191,9 @@ def compare_datasets( ann_b_matches, lambda x: _compare_annotations(x, ann_a, ignored_attrs=ignored_attrs) ) if ann_b is None: - test.fail("ann %s, candidates %s" % (ann_a, ann_b_matches)) + test.fail( + "ann\n\t%s,\ncandidates\n\t%s" % (ann_a, "\n\t".join(map(str, ann_b_matches))) + ) item_b.annotations.remove(ann_b) # avoid repeats diff --git a/tests/cli/test_revpath.py b/tests/cli/test_revpath.py index 336db86b9b..70f2d5143c 100644 --- a/tests/cli/test_revpath.py +++ b/tests/cli/test_revpath.py @@ -15,6 +15,7 @@ from datumaro.util.scope import scope_add, scoped from datumaro.util.test_utils import TestDir +from ..conftest import ASSETS_DIR from ..requirements import Requirements, mark_requirement @@ -133,13 +134,13 @@ def test_ambiguous_format(self): # create an ambiguous dataset by merging annotations from # datasets in different formats annotation_dir = osp.join(dataset_url, "training/street") - assets_dir = osp.join(osp.dirname(__file__), "../assets") os.makedirs(annotation_dir) - for asset in [ - "ade20k2017_dataset/dataset/training/street/1_atr.txt", - "ade20k2020_dataset/dataset/training/street/1.json", - ]: - shutil.copy(osp.join(assets_dir, asset), annotation_dir) + for root, asset_name in ( + ("ade20k2017_dataset", "1_atr.txt"), + ("ade20k2020_dataset", "1.json"), + ): + asset = ASSETS_DIR / root / "dataset" / "training" / "street" / asset_name + shutil.copy(asset, annotation_dir) with self.subTest("no context"): with self.assertRaises(WrongRevpathError) as cm: diff --git a/tests/conftest.py b/tests/conftest.py index 971ecb61e4..0196fe7bd2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,11 +2,15 @@ # Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT +from pathlib import Path + from datumaro.util.test_utils import TestDir from .fixtures import * from .utils.test_utils import TestCaseHelper +ASSETS_DIR = Path(__file__).parent / "assets" + def pytest_configure(config): # register additional markers diff --git a/tests/test_label_categories.py b/tests/test_label_categories.py new file mode 100644 index 0000000000..cc90048f19 --- /dev/null +++ b/tests/test_label_categories.py @@ -0,0 +1,169 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import pytest + +from datumaro import LabelCategories + +from tests.requirements import Requirements, mark_requirement + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_add_category(): + categories = LabelCategories() + index = categories.add("cat") + assert index == 0 + assert len(categories) == 1 + assert categories[0].name == "cat" + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_add_category_with_parent(): + categories = LabelCategories() + index = categories.add("cat", parent="animal") + assert index == 0 + assert len(categories) == 1 + assert categories[0].name == "cat" + assert categories[0].parent == "animal" + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_add_category_with_attributes(): + categories = LabelCategories() + attributes = {"color", "size"} + index = categories.add("cat", attributes=attributes) + assert index == 0 + assert len(categories) == 1 + assert categories[0].attributes == attributes + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +@pytest.mark.parametrize("name,parent", [("cat", "animal"), ("cat", "")]) +def test_can_add_duplicate_category(name, parent): + categories = LabelCategories() + categories.add(name, parent=parent) + with pytest.raises(KeyError, match=f"Label '{parent}' '{name}' already exists"): + categories.add(name, parent=parent) + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_resolve_potential_collision(): + """ + Previously indices were computed as (parent or "") + name + + See https://github.com/cvat-ai/datumaro/pull/51 + """ + categories = LabelCategories() + categories.add("22", parent="parent") + categories.add("2", parent="parent2") + assert categories.items[0].name == "22" + assert categories.items[0].parent == "parent" + assert categories.items[1].name == "2" + assert categories.items[1].parent == "parent2" + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_find_category(): + categories = LabelCategories() + categories.add("cat") + index, category = categories.find("cat") + assert index == 0 + assert category.name == "cat" + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_cant_find_non_existent_category(): + categories = LabelCategories() + index, category = categories.find("dog") + assert index is None + assert category is None + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_construct_from_iterable(): + categories = LabelCategories.from_iterable(["cat", "dog"]) + assert len(categories) == 2 + assert categories[0].name == "cat" + assert categories[1].name == "dog" + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_construct_from_iterable_with_parents(): + categories = LabelCategories.from_iterable([("cat", "animal"), ("dog", "animal")]) + assert len(categories) == 2 + assert categories[0].name == "cat" + assert categories[0].parent == "animal" + assert categories[1].name == "dog" + assert categories[1].parent == "animal" + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_construct_from_iterable_with_attributes(): + categories = LabelCategories.from_iterable([("cat", "animal", ["color", "size"])]) + assert len(categories) == 1 + assert categories[0].name == "cat" + assert categories[0].parent == "animal" + assert categories[0].attributes == {"color", "size"} + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_reindex_on_init(): + categories = LabelCategories( + items=[ + cat_category := LabelCategories.Category("cat"), + dog_category := LabelCategories.Category("dog", "animal"), + ] + ) + assert categories._indices == {("", "cat"): 0, ("animal", "dog"): 1} + assert categories.find("cat") == (0, cat_category) + assert categories.find("dog", parent="animal") == (1, dog_category) + assert categories.find("dog") == (None, None) + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_cant_reindex_on_init(): + with pytest.raises(KeyError, match="Item with duplicate label '' 'cat'"): + LabelCategories( + items=[ + cat_category := LabelCategories.Category("cat"), + cat_category, + ] + ) + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_has_labels_property(): + categories = LabelCategories() + categories.add("cat") + categories.add("dog", parent="animal") + labels = categories.labels + assert labels[0] == "cat" + assert labels[1] == "animaldog" + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_has_len(): + categories = LabelCategories() + assert len(categories) == 0 + categories.add("cat") + assert len(categories) == 1 + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_has_contains(): + categories = LabelCategories() + categories.add("cat") + assert "cat" in categories + assert "dog" not in categories + + +@mark_requirement(Requirements.DATUM_GENERAL_REQ) +def test_can_iter(): + categories = LabelCategories() + categories.add("cat") + categories.add("dog") + items = list(categories) + assert len(items) == 2 + assert items[0].name == "cat" + assert items[1].name == "dog" diff --git a/tests/test_mars_format.py b/tests/test_mars_format.py index a5152e7481..7ec2c68e46 100644 --- a/tests/test_mars_format.py +++ b/tests/test_mars_format.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: MIT -import os.path as osp from unittest.case import TestCase import numpy as np @@ -14,10 +13,10 @@ from datumaro.plugins.mars_format import MarsImporter from datumaro.util.test_utils import compare_datasets +from tests.conftest import ASSETS_DIR from tests.requirements import Requirements, mark_requirement -ASSETS_DIR = osp.join(osp.dirname(__file__), "assets") -DUMMY_MARS_DATASET = osp.join(ASSETS_DIR, "mars_dataset") +DUMMY_MARS_DATASET = str(ASSETS_DIR / "mars_dataset") class MarsImporterTest(TestCase): diff --git a/tests/test_open_images_format.py b/tests/test_open_images_format.py index 99d64a776c..9194bc756e 100644 --- a/tests/test_open_images_format.py +++ b/tests/test_open_images_format.py @@ -17,6 +17,7 @@ from datumaro.plugins.open_images_format import OpenImagesConverter, OpenImagesImporter from datumaro.util.test_utils import TestDir, compare_datasets +from tests.conftest import ASSETS_DIR from tests.requirements import Requirements, mark_requirement @@ -308,10 +309,8 @@ def test_can_save_and_load_with_meta_file(self): compare_datasets(self, dataset, parsed_dataset, require_media=True) -ASSETS_DIR = osp.join(osp.dirname(__file__), "assets") - -DUMMY_DATASET_DIR_V6 = osp.join(ASSETS_DIR, "open_images_dataset/v6") -DUMMY_DATASET_DIR_V5 = osp.join(ASSETS_DIR, "open_images_dataset/v5") +DUMMY_DATASET_DIR_V6 = str(ASSETS_DIR / "open_images_dataset" / "v6") +DUMMY_DATASET_DIR_V5 = str(ASSETS_DIR / "open_images_dataset" / "v5") class OpenImagesImporterTest(TestCase): diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 0409c79bb4..8bb5c0bfc9 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -13,6 +13,8 @@ from datumaro.plugins.sampler.random_sampler import LabelRandomSampler, RandomSampler from datumaro.util.test_utils import compare_datasets, compare_datasets_strict +from .conftest import ASSETS_DIR + try: import pandas as pd @@ -31,8 +33,8 @@ class TestRelevancySampler(TestCase): @staticmethod def _get_probs(out_range=False): probs = [] - inference_file = "tests/assets/sampler/inference.csv" - with open(inference_file) as csv_file: + inference_file = ASSETS_DIR / "sampler" / "inference.csv" + with inference_file.open() as csv_file: csv_reader = csv.reader(csv_file) col = 0 for row in csv_reader: