Skip to content

Commit

Permalink
Merge branch 'develop' into feat_correct
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjuleee authored May 22, 2023
2 parents d4780fc + 5c3e21a commit c034d11
Show file tree
Hide file tree
Showing 14 changed files with 579 additions and 14 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/996>)
- Add VocInstanceSegmentationImporter and VocInstanceSegmentationExporter
(<https://github.com/openvinotoolkit/datumaro/pull/997>)
- Add Corrector transformation
- Add Segment Anything data format support
(<https://github.com/openvinotoolkit/datumaro/pull/1005>)
- Add Correct transformation
(<https://github.com/openvinotoolkit/datumaro/pull/1006>)

### Enhancements
Expand All @@ -27,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/981>)
- Add MOT and MOTS data format docs
(<https://github.com/openvinotoolkit/datumaro/pull/999>)
- Improve RemoveAnnotations to remove specific annotations with ids
(<https://github.com/openvinotoolkit/datumaro/pull/1004>)

### Bug fixes
- Fix Mapillary Vistas data format
Expand Down
14 changes: 14 additions & 0 deletions datumaro/components/format_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,20 @@ def require_files(

return sorted(self._require_files_iter(pattern, exclude_fnames=exclude_fnames))

def require_files_iter(
self,
pattern: str,
*,
exclude_fnames: Union[str, Collection[str]] = (),
) -> Iterator[str]:
"""
Same as `require_files`, but returns a generator.
"""

self._start_requirement("require_files_iter")

return self._require_files_iter(pattern, exclude_fnames=exclude_fnames)

def _require_files_iter(
self,
pattern: str,
Expand Down
9 changes: 9 additions & 0 deletions datumaro/plugins/data_formats/segment_anything/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from .base import SegmentAnythingBase
from .exporter import SegmentAnythingExporter
from .importer import SegmentAnythingImporter

__all__ = ["SegmentAnythingBase", "SegmentAnythingImporter", "SegmentAnythingExporter"]
169 changes: 169 additions & 0 deletions datumaro/plugins/data_formats/segment_anything/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
from glob import glob
from inspect import isclass
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

from datumaro.components.annotation import Bbox, RleMask
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import (
DatasetImportError,
InvalidAnnotationError,
InvalidFieldTypeError,
MissingFieldError,
)
from datumaro.components.importer import ImportContext
from datumaro.components.media import Image
from datumaro.util import NOTSET, parse_json_file

T = TypeVar("T")


def parse_field(
ann: Dict[str, Any],
key: str,
cls: Union[Type[T], Tuple[Type, ...]],
default: Any = NOTSET,
) -> Any:
value = ann.get(key, NOTSET)
if value is NOTSET:
if default is not NOTSET:
return default
raise MissingFieldError(key)
elif not isinstance(value, cls):
cls = (cls,) if isclass(cls) else cls
raise InvalidFieldTypeError(
key, actual=str(type(value)), expected=tuple(str(t) for t in cls)
)
return value


class SegmentAnythingBase(SubsetBase):
def __init__(
self,
path: str,
*,
subset: Optional[str] = None,
ctx: Optional[ImportContext] = None,
):
if not osp.isdir(path):
raise DatasetImportError(f"path {path} must be directory.")
self._path = path

super().__init__(subset=subset, ctx=ctx)
self._items = self._load_items()

def _load_items(self):
pbar = self._ctx.progress_reporter
items = []
for annotation_file in pbar.iter(
glob(osp.join(self._path, "*.json")),
desc=f"Parsing data in {osp.basename(self._path)}",
):
image_id = None
annotations = []
item_kwargs = {
"id": None,
"subset": self._subset,
"media": None,
"annotations": [],
"attributes": {},
}

try:
contents = parse_json_file(annotation_file)
image_info = contents["image"]
annotations = contents["annotations"]

image_id = parse_field(image_info, "image_id", int)
item_kwargs["attributes"]["id"] = image_id

image_size = (
parse_field(image_info, "height", int, default=None),
parse_field(image_info, "width", int, default=None),
)
if any(i is None for i in image_size):
image_size = None
file_name = parse_field(image_info, "file_name", str)

item_kwargs["id"] = osp.splitext(file_name)[0]
item_kwargs["media"] = Image.from_file(
path=osp.join(self._path, file_name), size=image_size
)
except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(image_id, self._subset))

try:
for annotation in annotations:
anno_id = parse_field(annotation, "id", int)
attributes = {
"predicted_iou": parse_field(
annotation,
"predicted_iou",
float,
0.0,
),
"stability_score": parse_field(
annotation,
"stability_score",
float,
0.0,
),
"point_coords": parse_field(
annotation,
"point_coords",
list,
[[]],
),
"crop_box": parse_field(annotation, "crop_box", list, []),
}

group = anno_id # make sure all tasks' annotations are merged

segmentation = parse_field(annotation, "segmentation", dict, None)
if segmentation is None:
raise InvalidAnnotationError("'segmentation' label is not found.")
item_kwargs["annotations"].append(
RleMask(
rle=segmentation,
label=None,
id=anno_id,
attributes=attributes,
group=group,
)
)

bbox = parse_field(annotation, "bbox", list, None)
if bbox is None:
bbox = item_kwargs["annotations"][-1].get_bbox().tolist()

if len(bbox) > 0:
if len(bbox) != 4:
raise InvalidAnnotationError(
f"Bbox has wrong value count {len(bbox)}. Expected 4 values."
)
x, y, w, h = bbox
item_kwargs["annotations"].append(
Bbox(
x,
y,
w,
h,
label=None,
id=anno_id,
attributes=attributes,
group=group,
)
)
except Exception as e:
self._ctx.error_policy.report_annotation_error(e, item_id=(image_id, self._subset))

try:
items.append(DatasetItem(**item_kwargs))
except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(image_id, self._subset))

return items
163 changes: 163 additions & 0 deletions datumaro/plugins/data_formats/segment_anything/exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT


import logging as log
import os
import os.path as osp
from itertools import chain
from typing import List, Union

from pycocotools import mask as mask_utils

from datumaro.components.annotation import AnnotationType, Ellipse, Polygon
from datumaro.components.errors import MediaTypeError
from datumaro.components.exporter import Exporter
from datumaro.components.media import Image
from datumaro.util import annotation_util as anno_tools
from datumaro.util import dump_json_file, mask_tools


class SegmentAnythingExporter(Exporter):
DEFAULT_IMAGE_EXT = ".jpg"

_polygon_types = {AnnotationType.polygon, AnnotationType.ellipse}
_allowed_types = {
AnnotationType.bbox,
AnnotationType.polygon,
AnnotationType.mask,
AnnotationType.ellipse,
}

def __init__(
self,
extractor,
save_dir,
**kwargs,
):
super().__init__(extractor, save_dir, **kwargs)

@staticmethod
def find_instance_anns(annotations):
return [a for a in annotations if a.type in SegmentAnythingExporter._allowed_types]

@classmethod
def find_instances(cls, annotations):
return anno_tools.find_instances(cls.find_instance_anns(annotations))

def get_annotation_info(self, group, img_width, img_height):
boxes = [a for a in group if a.type == AnnotationType.bbox]
polygons: List[Union[Polygon, Ellipse]] = [
a for a in group if a.type in self._polygon_types
]
masks = [a for a in group if a.type == AnnotationType.mask]

anns = boxes + polygons + masks
leader = anno_tools.find_group_leader(anns)
if len(boxes) > 0:
bbox = anno_tools.max_bbox(boxes)
else:
bbox = anno_tools.max_bbox(anns)
polygons = [p.as_polygon() for p in polygons]

mask = None
if polygons:
mask = mask_tools.rles_to_mask(polygons, img_width, img_height)
if masks:
masks = (m.image for m in masks)
if mask is not None:
masks = chain(masks, [mask])
mask = mask_tools.merge_masks(masks)
if mask is None:
return None
mask = mask_tools.mask_to_rle(mask)

segmentation = {
"counts": list(int(c) for c in mask["counts"]),
"size": list(int(c) for c in mask["size"]),
}
rles = mask_utils.frPyObjects(segmentation, img_height, img_width)
if isinstance(rles["counts"], bytes):
rles["counts"] = rles["counts"].decode()
area = mask_utils.area(rles)

annotation_data = {
"id": leader.group,
"segmentation": rles,
"bbox": bbox,
"area": area,
"predicted_iou": max(ann.attributes.get("predicted_iou", 0.0) for ann in anns),
"stability_score": max(ann.attributes.get("stability_score", 0.0) for ann in anns),
"crop_box": anno_tools.max_bbox([ann.attributes.get("crop_box", []) for ann in anns]),
"point_coords": list(
set(
tuple(point_coord)
for ann in anns
for point_coord in ann.attributes.get("point_coords", [[]])
)
),
}
return annotation_data

def apply(self):
if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image):
raise MediaTypeError("Media type is not an image")

os.makedirs(self._save_dir, exist_ok=True)

subsets = self._extractor.subsets()
pbars = self._ctx.progress_reporter.split(len(subsets))

max_image_id = 1
for pbar, (subset_name, subset) in zip(pbars, subsets.items()):
for item in pbar.iter(subset, desc=f"Exporting {subset_name}"):
try:
# make sure file_name is flat
file_name = self._make_image_filename(item).replace("/", "__")
try:
image_id = int(item.attributes.get("id", max_image_id))
except ValueError:
image_id = max_image_id
max_image_id += 1

if not item.media or not item.media.size:
log.warning(
f"Item '{item.id}': skipping writing instances since no image info available"
)
continue

height, width = item.media.size
json_data = {
"image": {
"image_id": image_id,
"file_name": file_name,
"height": height,
"width": width,
},
"annotations": [],
}

instances = self.find_instances(item.annotations)
annotations = [self.get_annotation_info(i, width, height) for i in instances]
annotations = [i for i in annotations if i is not None]
if not annotations:
log.warning(
f"Item '{item.id}': skipping writing instances since no annotation available"
)
continue
json_data["annotations"] = annotations

dump_json_file(
os.path.join(self._save_dir, osp.splitext(file_name)[0] + ".json"),
json_data,
)

if self._save_media:
self._save_image(
item,
path=osp.abspath(osp.join(self._save_dir, file_name)),
)

except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset))
Loading

0 comments on commit c034d11

Please sign in to comment.