Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Segment Anything data format #1005

Merged
merged 16 commits into from
May 22, 2023
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/996>)
- Add VocInstanceSegmentationImporter and VocInstanceSegmentationExporter
(<https://github.com/openvinotoolkit/datumaro/pull/997>)
- Add Segment Anything data format support
(<https://github.com/openvinotoolkit/datumaro/pull/1005>)

### Enhancements
- Use autosummary for fully-automatic Python module docs generation
Expand Down
14 changes: 14 additions & 0 deletions datumaro/components/format_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,20 @@ def require_files(

return sorted(self._require_files_iter(pattern, exclude_fnames=exclude_fnames))

def require_files_iter(
self,
pattern: str,
*,
exclude_fnames: Union[str, Collection[str]] = (),
) -> Iterator[str]:
"""
Same as `require_files`, but returns a generator.
"""

self._start_requirement("require_files_iter")

return self._require_files_iter(pattern, exclude_fnames=exclude_fnames)

def _require_files_iter(
self,
pattern: str,
Expand Down
9 changes: 9 additions & 0 deletions datumaro/plugins/data_formats/segment_anything/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from .base import SegmentAnythingBase
from .exporter import SegmentAnythingExporter
from .importer import SegmentAnythingImporter

__all__ = ["SegmentAnythingBase", "SegmentAnythingImporter", "SegmentAnythingExporter"]
169 changes: 169 additions & 0 deletions datumaro/plugins/data_formats/segment_anything/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
from glob import glob
from inspect import isclass
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

from datumaro.components.annotation import Bbox, RleMask
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import (
DatasetImportError,
InvalidAnnotationError,
InvalidFieldTypeError,
MissingFieldError,
)
from datumaro.components.importer import ImportContext
from datumaro.components.media import Image
from datumaro.util import NOTSET, parse_json_file

T = TypeVar("T")


def parse_field(
ann: Dict[str, Any],
key: str,
cls: Union[Type[T], Tuple[Type, ...]],
default: Any = NOTSET,
) -> Any:
value = ann.get(key, NOTSET)
if value is NOTSET:
if default is not NOTSET:
return default
raise MissingFieldError(key)
elif not isinstance(value, cls):
cls = (cls,) if isclass(cls) else cls
raise InvalidFieldTypeError(
key, actual=str(type(value)), expected=tuple(str(t) for t in cls)
)
return value


class SegmentAnythingBase(SubsetBase):
def __init__(
self,
path: str,
*,
subset: Optional[str] = None,
ctx: Optional[ImportContext] = None,
):
if not osp.isdir(path):
raise DatasetImportError(f"path {path} must be directory.")
self._path = path

super().__init__(subset=subset, ctx=ctx)
self._items = self._load_items()

def _load_items(self):
pbar = self._ctx.progress_reporter
items = []
for annotation_file in pbar.iter(
glob(osp.join(self._path, "*.json")),
desc=f"Parsing data in {osp.basename(self._path)}",
):
image_id = None
annotations = []
item_kwargs = {
"id": None,
"subset": self._subset,
"media": None,
"annotations": [],
"attributes": {},
}

try:
contents = parse_json_file(annotation_file)
image_info = contents["image"]
annotations = contents["annotations"]

image_id = parse_field(image_info, "image_id", int)
item_kwargs["attributes"]["id"] = image_id

image_size = (
parse_field(image_info, "height", int, default=None),
parse_field(image_info, "width", int, default=None),
)
if any(i is None for i in image_size):
image_size = None
file_name = parse_field(image_info, "file_name", str)

item_kwargs["id"] = osp.splitext(file_name)[0]
item_kwargs["media"] = Image.from_file(
path=osp.join(self._path, file_name), size=image_size
)
except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(image_id, self._subset))

try:
for annotation in annotations:
anno_id = parse_field(annotation, "id", int)
attributes = {
"predicted_iou": parse_field(
annotation,
"predicted_iou",
float,
0.0,
),
"stability_score": parse_field(
annotation,
"stability_score",
float,
0.0,
),
"point_coords": parse_field(
annotation,
"point_coords",
list,
[[]],
),
"crop_box": parse_field(annotation, "crop_box", list, []),
}

group = anno_id # make sure all tasks' annotations are merged

segmentation = parse_field(annotation, "segmentation", dict, None)
if segmentation is None:
raise InvalidAnnotationError("'segmentation' label is not found.")
item_kwargs["annotations"].append(
RleMask(
rle=segmentation,
label=None,
id=anno_id,
attributes=attributes,
group=group,
)
)

bbox = parse_field(annotation, "bbox", list, None)
if bbox is None:
bbox = item_kwargs["annotations"][-1].get_bbox().tolist()

if len(bbox) > 0:
if len(bbox) != 4:
raise InvalidAnnotationError(
f"Bbox has wrong value count {len(bbox)}. Expected 4 values."
)
x, y, w, h = bbox
item_kwargs["annotations"].append(
Bbox(
x,
y,
w,
h,
label=None,
id=anno_id,
attributes=attributes,
group=group,
)
)
except Exception as e:
self._ctx.error_policy.report_annotation_error(e, item_id=(image_id, self._subset))

try:
items.append(DatasetItem(**item_kwargs))
except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(image_id, self._subset))

return items
173 changes: 173 additions & 0 deletions datumaro/plugins/data_formats/segment_anything/exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT


import logging as log
import os
import os.path as osp
from itertools import chain
from typing import List, Union

from pycocotools import mask as mask_utils

from datumaro.components.annotation import AnnotationType, Ellipse, Polygon
from datumaro.components.errors import DatumaroError, MediaTypeError
from datumaro.components.exporter import Exporter
from datumaro.components.media import Image
from datumaro.util import NOTSET
from datumaro.util import annotation_util as anno_tools
from datumaro.util import dump_json_file, mask_tools


def replace(json_data, key, value_new):
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
value_origin = json_data[key]
if value_origin is NOTSET:
json_data[key] = value_new
return
if value_origin != value_new:
raise DatumaroError(f"The value for '{key}' is not same for item {value_new}")


class SegmentAnythingExporter(Exporter):
DEFAULT_IMAGE_EXT = ".jpg"

_polygon_types = {AnnotationType.polygon, AnnotationType.ellipse}
_allowed_types = {
AnnotationType.bbox,
AnnotationType.polygon,
AnnotationType.mask,
AnnotationType.ellipse,
}

def __init__(
self,
extractor,
save_dir,
**kwargs,
):
super().__init__(extractor, save_dir, **kwargs)

@staticmethod
def find_instance_anns(annotations):
return [a for a in annotations if a.type in SegmentAnythingExporter._allowed_types]

@classmethod
def find_instances(cls, annotations):
return anno_tools.find_instances(cls.find_instance_anns(annotations))

def get_annotation_info(self, group, img_width, img_height):
boxes = [a for a in group if a.type == AnnotationType.bbox]
polygons: List[Union[Polygon, Ellipse]] = [
a for a in group if a.type in self._polygon_types
]
masks = [a for a in group if a.type == AnnotationType.mask]

anns = boxes + polygons + masks
leader = anno_tools.find_group_leader(anns)
if len(boxes) > 0:
bbox = anno_tools.max_bbox(boxes)
else:
bbox = anno_tools.max_bbox(anns)
polygons = [p.as_polygon() for p in polygons]

mask = None
if polygons:
mask = mask_tools.rles_to_mask(polygons, img_width, img_height)
if masks:
masks = (m.image for m in masks)
if mask is not None:
masks = chain(masks, [mask])
mask = mask_tools.merge_masks(masks)
if mask is None:
return None
mask = mask_tools.mask_to_rle(mask)

segmentation = {
"counts": list(int(c) for c in mask["counts"]),
"size": list(int(c) for c in mask["size"]),
}
rles = mask_utils.frPyObjects(segmentation, img_height, img_width)
if isinstance(rles["counts"], bytes):
rles["counts"] = rles["counts"].decode()
area = mask_utils.area(rles)

annotation_data = {
"id": leader.group,
"segmentation": rles,
"bbox": bbox,
"area": area,
"predicted_iou": max(ann.attributes.get("predicted_iou", 0.0) for ann in anns),
"stability_score": max(ann.attributes.get("stability_score", 0.0) for ann in anns),
"crop_box": anno_tools.max_bbox([ann.attributes.get("crop_box", []) for ann in anns]),
"point_coords": list(
set(
tuple(point_coord)
for ann in anns
for point_coord in ann.attributes.get("point_coords", [[]])
)
),
}
return annotation_data

def apply(self):
if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image):
raise MediaTypeError("Media type is not an image")

os.makedirs(self._save_dir, exist_ok=True)

subsets = self._extractor.subsets()
pbars = self._ctx.progress_reporter.split(len(subsets))

max_image_id = 1
for pbar, (subset_name, subset) in zip(pbars, subsets.items()):
for item in pbar.iter(subset, desc=f"Exporting {subset_name}"):
try:
# make sure file_name is flat
file_name = self._make_image_filename(item).replace("/", "__")
try:
image_id = int(item.attributes.get("id", max_image_id))
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
except ValueError:
image_id = max_image_id
max_image_id += 1

if not item.media or not item.media.size:
log.warning(
f"Item '{item.id}': skipping writing instances since no image info available"
)
continue

height, width = item.media.size
json_data = {
"image": {
"image_id": image_id,
"file_name": file_name,
"height": height,
"width": width,
},
"annotations": [],
}

instances = self.find_instances(item.annotations)
annotations = [self.get_annotation_info(i, width, height) for i in instances]
annotations = [i for i in annotations if i is not None]
if not annotations:
log.warning(
f"Item '{item.id}': skipping writing instances since no annotation available"
)
continue
json_data["annotations"] = annotations

dump_json_file(
os.path.join(self._save_dir, osp.splitext(file_name)[0] + ".json"),
json_data,
)

if self._save_media:
self._save_image(
item,
path=osp.abspath(osp.join(self._save_dir, file_name)),
)

except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset))
Loading