Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Corrector transformation #1006

Merged
merged 12 commits into from
May 22, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/997>)
- Add Segment Anything data format support
(<https://github.com/openvinotoolkit/datumaro/pull/1005>)
- Add Correct transformation
(<https://github.com/openvinotoolkit/datumaro/pull/1006>)

### Enhancements
- Use autosummary for fully-automatic Python module docs generation
Expand Down
2 changes: 2 additions & 0 deletions datumaro/components/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


class Severity(Enum):
info = auto()
warning = auto()
error = auto()

Expand Down Expand Up @@ -51,6 +52,7 @@ def validate(self, dataset: IDataset) -> Dict:
summary = {
"errors": sum(map(lambda r: r["severity"] == "error", reports)),
"warnings": sum(map(lambda r: r["severity"] == "warning", reports)),
"infos": sum(map(lambda r: r["severity"] == "info", reports)),
}

validation_results["validation_reports"] = reports
Expand Down
128 changes: 127 additions & 1 deletion datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from datumaro.components.errors import DatumaroError
from datumaro.components.media import Image
from datumaro.components.transformer import ItemTransform, Transform
from datumaro.util import NOTSET, filter_dict, parse_str_enum_value, take_by
from datumaro.util import NOTSET, filter_dict, parse_json_file, parse_str_enum_value, take_by
from datumaro.util.annotation_util import find_group_leader, find_instances


Expand Down Expand Up @@ -1249,3 +1249,129 @@ def transform_item(self, item: DatasetItem):
attributes=self._filter_attrs(item.attributes), annotations=filtered_annotations
)
return item


class Correct(Transform, CliPlugin):
"""
Correct the dataset from a validation report.
A user can should feed into validation_reports.json from validator to correct the dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A user can should feed into validation_reports.json from validator to correct the dataset.
A user should feed `validation_reports.json` from validator to correct the dataset.

This helps to refine the dataset by rejecting undefined labels, missing annotations, and outliers.
"""

@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument(
"-r",
"--reports",
type=str,
default="validation_reports.json",
help="A validation report from a 'validate' CLI",
)
return parser

def __init__(
self,
extractor: IDataset,
reports: Union[str, Dict],
):
super().__init__(extractor)

if isinstance(reports, str):
reports = parse_json_file(reports)

self._reports = reports["validation_reports"]

self._categories = self._extractor.categories()

self._remove_items = set()
self._remove_anns = defaultdict(list)
self._add_attrs = defaultdict(list)

self._analyze_reports(report=self._reports)

def categories(self):
return self._categories

def _parse_ann_ids(self, desc: str):
return [int(s) for s in str.split(desc, "'") if s.isdigit()][0]

def _analyze_reports(self, report):
for rep in report:
if rep["anomaly_type"] == "MissingLabelCategories":
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved
unique_labels = sorted(
list({ann.label for item in self._extractor for ann in item.annotations})
)
label_categories = LabelCategories().from_iterable(
[str(label) for label in unique_labels]
)
for item in self._extractor:
for ann in item.annotations:
attrs = {attr for attr in ann.attributes}
label_categories[ann.label].attributes.update(attrs)
self._categories[AnnotationType.label] = label_categories

if rep["anomaly_type"] == "UndefinedLabel":
label_categories = self._categories[AnnotationType.label]
desc = [s for s in rep["description"].split("'")]
add_label_name = desc[1]
label_id, _ = label_categories.find(add_label_name)
if label_id is None:
label_categories.add(name=add_label_name)

if rep["anomaly_type"] == "UndefinedAttribute":
label_categories = self._categories[AnnotationType.label]
desc = [s for s in rep["description"].split("'")]
attr_name, label_name = desc[1], desc[3]
label_id = label_categories.find(label_name)[0]
if label_id is not None:
label_categories[label_id].attributes.add(attr_name)

# [TODO] Correct LabeleDefinedButNotFound: removing a label, reindexing, remapping others
# if rep["anomaly_type"] == "LabelDefinedButNotFound":
# remove_label_name = self._parse_label_cat(rep["description"])
# label_cat = self._extractor.categories()[AnnotationType.label]
# if remove_label_name in [labels.name for labels in label_cat.items]:
# label_cat.remove(remove_label_name)

if rep["anomaly_type"] in ["MissingAnnotation", "MultiLabelAnnotations"]:
self._remove_items.add((rep["item_id"], rep["subset"]))

if rep["anomaly_type"] in [
"NegativeLength",
"InvalidValue",
"FarFromLabelMean",
"FarFromAttrMean",
]:
ann_id = None or self._parse_ann_ids(rep["description"])
self._remove_anns[(rep["item_id"], rep["subset"])].append(ann_id)

if rep["anomaly_type"] == "MissingAttribute":
desc = [s for s in str.split(rep["description"], "'")]
attr_name, label_name = desc[1], desc[3]
label_id = self._extractor.categories()[AnnotationType.label].find(label_name)[0]
self._add_attrs[(rep["item_id"], rep["subset"])].append((label_id, attr_name))

def __iter__(self):
for item in self._extractor:
if (item.id, item.subset) in self._remove_items:
continue

ann_ids = self._remove_anns.get((item.id, item.subset), None)
if ann_ids:
updated_anns = [ann for ann in item.annotations if ann.id not in ann_ids]
yield item.wrap(annotations=updated_anns)
else:
updated_attrs = defaultdict(list)
for label_id, attr_name in self._add_attrs.get((item.id, item.subset), []):
updated_attrs[label_id].append(attr_name)

updated_anns = []
for ann in item.annotations:
new_ann = ann.wrap(attributes=deepcopy(ann.attributes))
if ann.label in updated_attrs:
new_ann.attributes.update(
{attr_name: "" for attr_name in updated_attrs[ann.label]}
)
updated_anns.append(new_ann)
yield item.wrap(annotations=updated_anns)
18 changes: 8 additions & 10 deletions datumaro/plugins/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _check_only_one_label(self, stats):

if len(labels_found) == 1:
validation_reports += self._generate_validation_report(
OnlyOneLabel, Severity.warning, labels_found[0]
OnlyOneLabel, Severity.info, labels_found[0]
)

return validation_reports
Expand All @@ -431,7 +431,7 @@ def _check_only_one_attribute(self, label_name, attr_name, attr_dets):
if len(values) == 1:
details = (label_name, attr_name, values[0])
validation_reports += self._generate_validation_report(
OnlyOneAttributeValue, Severity.warning, *details
OnlyOneAttributeValue, Severity.info, *details
)

return validation_reports
Expand All @@ -449,7 +449,7 @@ def _check_few_samples_in_label(self, stats):

for label_name, count in labels_with_few_samples:
validation_reports += self._generate_validation_report(
FewSamplesInLabel, Severity.warning, label_name, count
FewSamplesInLabel, Severity.info, label_name, count
)

return validation_reports
Expand All @@ -467,7 +467,7 @@ def _check_few_samples_in_attribute(self, label_name, attr_name, attr_dets):
for attr_value, count in attr_values_with_few_samples:
details = (label_name, attr_name, attr_value, count)
validation_reports += self._generate_validation_report(
FewSamplesInAttribute, Severity.warning, *details
FewSamplesInAttribute, Severity.info, *details
)

return validation_reports
Expand All @@ -486,9 +486,7 @@ def _check_imbalanced_labels(self, stats):
count_min = np.min(count_by_defined_labels)
balance = count_max / count_min if count_min > 0 else float("inf")
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedLabels, Severity.warning
)
validation_reports += self._generate_validation_report(ImbalancedLabels, Severity.info)

return validation_reports

Expand All @@ -505,7 +503,7 @@ def _check_imbalanced_attribute(self, label_name, attr_name, attr_dets):
balance = count_max / count_min if count_min > 0 else float("inf")
if balance >= thr:
validation_reports += self._generate_validation_report(
ImbalancedAttribute, Severity.warning, label_name, attr_name
ImbalancedAttribute, Severity.info, label_name, attr_name
)

return validation_reports
Expand Down Expand Up @@ -915,7 +913,7 @@ def _check_imbalanced_dist_in_label(self, label_name, label_stats):
if ratio >= thr:
details = (label_name, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
ImbalancedDistInLabel, Severity.warning, *details
ImbalancedDistInLabel, Severity.info, *details
)

return validation_reports
Expand All @@ -939,7 +937,7 @@ def _check_imbalanced_dist_in_attr(self, label_name, attr_name, attr_stats):
if ratio >= thr:
details = (label_name, attr_name, attr_value, f"{self.str_ann_type} {prop}")
validation_reports += self._generate_validation_report(
ImbalancedDistInAttribute, Severity.warning, *details
ImbalancedDistInAttribute, Severity.info, *details
)

return validation_reports
Expand Down
Loading