Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Corrector transformation #1006

Merged
merged 12 commits into from
May 22, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/996>)
- Add VocInstanceSegmentationImporter and VocInstanceSegmentationExporter
(<https://github.com/openvinotoolkit/datumaro/pull/997>)
- Add Corrector transformation
(<https://github.com/openvinotoolkit/datumaro/pull/1006>)

### Enhancements
- Use autosummary for fully-automatic Python module docs generation
Expand Down
2 changes: 2 additions & 0 deletions datumaro/components/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


class Severity(Enum):
info = auto()
warning = auto()
error = auto()

Expand Down Expand Up @@ -51,6 +52,7 @@ def validate(self, dataset: IDataset) -> Dict:
summary = {
"errors": sum(map(lambda r: r["severity"] == "error", reports)),
"warnings": sum(map(lambda r: r["severity"] == "warning", reports)),
"infos": sum(map(lambda r: r["severity"] == "info", reports)),
}

validation_results["validation_reports"] = reports
Expand Down
135 changes: 133 additions & 2 deletions datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os.path as osp
import random
import re
from collections import Counter
from collections import Counter, defaultdict
from copy import deepcopy
from enum import Enum, auto
from itertools import chain
Expand Down Expand Up @@ -40,7 +40,7 @@
from datumaro.components.errors import DatumaroError
from datumaro.components.media import Image
from datumaro.components.transformer import ItemTransform, Transform
from datumaro.util import NOTSET, filter_dict, parse_str_enum_value, take_by
from datumaro.util import NOTSET, filter_dict, parse_json_file, parse_str_enum_value, take_by
from datumaro.util.annotation_util import find_group_leader, find_instances


Expand Down Expand Up @@ -1229,3 +1229,134 @@ def transform_item(self, item: DatasetItem):
attributes=self._filter_attrs(item.attributes), annotations=filtered_annotations
)
return item


class Correct(Transform, CliPlugin):
"""
Correct the dataset from a validation report.
A user can should feed into validation_reports.json from validator to correct the dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A user can should feed into validation_reports.json from validator to correct the dataset.
A user should feed `validation_reports.json` from validator to correct the dataset.

This helps to refine the dataset by rejecting undefined labels, missing annotations, and outliers.
"""

@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument(
"-r",
"--reports",
type=str,
default="validation_reports.json",
help="A validation report from a 'validate' CLI",
)
return parser

def __init__(
self,
extractor: IDataset,
reports: Union[str, Dict],
):
super().__init__(extractor)

if isinstance(reports, str):
reports = parse_json_file(reports)

self._reports = reports["validation_reports"]

self._remove_items = []
self._remove_anns = []
self._add_attrs = []
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved

self._analyze_reports(report=self._reports)

def _parse_ann_ids(self, desc: str):
return [int(s) for s in str.split(desc, "'") if s.isdigit()][0]

def _analyze_reports(self, report):
for rep in report:
if rep["anomaly_type"] == "MissingLabelCategories":
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved
label_categories = LabelCategories()
for item in self._extractor:
for ann in item.annotations:
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
attrs = set()
for attr in ann.attributes:
attrs.add(attr)
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
label_id = label_categories.find(str(ann.label))[0]
if label_id is None:
label_categories.add(name=str(ann.label), attributes=attrs)
else:
label_categories[label_id].attributes.add(attrs)
self._extractor.categories()[AnnotationType.label] = label_categories
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved

if rep["anomaly_type"] == "UndefinedLabel":
label_categories = self._extractor.categories().get(AnnotationType.label)
desc = [s for s in str.split(rep["description"], "'")]
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
add_label_name = desc[1]
label_id = label_categories.find(add_label_name)[0]
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
if label_id is None:
label_categories.add(name=add_label_name)

if rep["anomaly_type"] == "UndefinedAttribute":
label_categories = self._extractor.categories().get(AnnotationType.label)
desc = [s for s in str.split(rep["description"], "'")]
attr_name, label_name = desc[1], desc[3]
label_id = label_categories.find(label_name)[0]
if label_id is not None:
label_categories[label_id].attributes.add(attr_name)

# [TODO] Correct LabeleDefinedButNotFound: removing a label, reindexing, remapping others
# if rep["anomaly_type"] == "LabelDefinedButNotFound":
# remove_label_name = self._parse_label_cat(rep["description"])
# label_cat = self._extractor.categories()[AnnotationType.label]
# if remove_label_name in [labels.name for labels in label_cat.items]:
# label_cat.remove(remove_label_name)

if rep["anomaly_type"] in ["MissingAnnotation", "MultiLabelAnnotations"]:
self._remove_items.append((rep["item_id"], rep["subset"]))
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved

if rep["anomaly_type"] in [
"NegativeLength",
"InvalidValue",
"FarFromLabelMean",
"FarFromAttrMean",
]:
ann_id = None or self._parse_ann_ids(rep["description"])
self._remove_anns.append((rep["item_id"], rep["subset"], ann_id))
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved

if rep["anomaly_type"] == "MissingAttribute":
desc = [s for s in str.split(rep["description"], "'")]
attr_name, label_name = desc[1], desc[3]
label_id = self._extractor.categories()[AnnotationType.label].find(label_name)[0]
self._add_attrs.append((rep["item_id"], rep["subset"], label_id, attr_name))
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved

def _find_removing_anns_in_item(self, target: tuple[str, str]):
return [tup[2] for tup in self._remove_anns if tup[:2] == target]

def _find_adding_attrs_in_item(self, target: tuple[str, str]):
return [tup[2:] for tup in self._add_attrs if tup[:2] == target]
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self):
for item in self._extractor:
if (item.id, item.subset) in self._remove_items:
continue

ann_ids = self._find_removing_anns_in_item(target=(item.id, item.subset))
if ann_ids:
updated_anns = [ann for ann in item.annotations if ann.id not in ann_ids]
yield item.wrap(annotations=updated_anns)
else:
updated_attrs = defaultdict(list)
for label_id, attr_name in self._find_adding_attrs_in_item(
target=(item.id, item.subset)
):
if label_id in updated_attrs:
updated_attrs[label_id].append(attr_name)
else:
updated_attrs.update({label_id: [attr_name]})
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
updated_anns = []
for ann in item.annotations:
if ann.label in updated_attrs:
ann.attributes.update(
{attr_name: "" for attr_name in updated_attrs[ann.label]}
)
updated_anns.append(ann)
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
yield item.wrap(annotations=updated_anns)
18 changes: 8 additions & 10 deletions datumaro/plugins/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _check_only_one_label(self, stats):

if len(labels_found) == 1:
validation_reports += self._generate_validation_report(
OnlyOneLabel, Severity.warning, labels_found[0]
OnlyOneLabel, Severity.info, labels_found[0]
)

return validation_reports
Expand All @@ -431,7 +431,7 @@ def _check_only_one_attribute(self, label_name, attr_name, attr_dets):
if len(values) == 1:
details = (label_name, attr_name, values[0])
validation_reports += self._generate_validation_report(
OnlyOneAttributeValue, Severity.warning, *details
OnlyOneAttributeValue, Severity.info, *details
)

return validation_reports
Expand All @@ -449,7 +449,7 @@ def _check_few_samples_in_label(self, stats):

for label_name, count in labels_with_few_samples:
validation_reports += self._generate_validation_report(
FewSamplesInLabel, Severity.warning, label_name, count
FewSamplesInLabel, Severity.info, label_name, count
)

return validation_reports
Expand All @@ -467,7 +467,7 @@ def _check_few_samples_in_attribute(self, label_name, attr_name, attr_dets):
for attr_value, count in attr_values_with_few_samples:
details = (label_name, attr_name, attr_value, count)
validation_reports += self._generate_validation_report(
FewSamplesInAttribute, Severity.warning, *details
FewSamplesInAttribute, Severity.info, *details
)

return validation_reports
Expand All @@ -486,9 +486,7 @@ def _check_imbalanced_labels(self, stats):
count_min = np.min(count_by_defined_labels)
balance = count_max / count_min if count_min > 0 else float("inf")
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedLabels, Severity.warning
)
validation_reports += self._generate_validation_report(ImbalancedLabels, Severity.info)

return validation_reports

Expand All @@ -505,7 +503,7 @@ def _check_imbalanced_attribute(self, label_name, attr_name, attr_dets):
balance = count_max / count_min if count_min > 0 else float("inf")
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedAttribute, Severity.warning, label_name, attr_name
ImbalancedAttribute, Severity.info, label_name, attr_name
)

return validation_reports
Expand Down Expand Up @@ -915,7 +913,7 @@ def _check_imbalanced_dist_in_label(self, label_name, label_stats):
if ratio >= thr:
details = (label_name, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
ImbalancedDistInLabel, Severity.warning, *details
ImbalancedDistInLabel, Severity.info, *details
)

return validation_reports
Expand All @@ -939,7 +937,7 @@ def _check_imbalanced_dist_in_attr(self, label_name, attr_name, attr_stats):
if ratio >= thr:
details = (label_name, attr_name, attr_value, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
ImbalancedDistInAttribute, Severity.warning, *details
ImbalancedDistInAttribute, Severity.info, *details
)

return validation_reports
Expand Down
Loading