-
Notifications
You must be signed in to change notification settings - Fork 139
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into feat_correct
- Loading branch information
Showing
14 changed files
with
579 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from .base import SegmentAnythingBase | ||
from .exporter import SegmentAnythingExporter | ||
from .importer import SegmentAnythingImporter | ||
|
||
__all__ = ["SegmentAnythingBase", "SegmentAnythingImporter", "SegmentAnythingExporter"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import os.path as osp | ||
from glob import glob | ||
from inspect import isclass | ||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union | ||
|
||
from datumaro.components.annotation import Bbox, RleMask | ||
from datumaro.components.dataset_base import DatasetItem, SubsetBase | ||
from datumaro.components.errors import ( | ||
DatasetImportError, | ||
InvalidAnnotationError, | ||
InvalidFieldTypeError, | ||
MissingFieldError, | ||
) | ||
from datumaro.components.importer import ImportContext | ||
from datumaro.components.media import Image | ||
from datumaro.util import NOTSET, parse_json_file | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def parse_field( | ||
ann: Dict[str, Any], | ||
key: str, | ||
cls: Union[Type[T], Tuple[Type, ...]], | ||
default: Any = NOTSET, | ||
) -> Any: | ||
value = ann.get(key, NOTSET) | ||
if value is NOTSET: | ||
if default is not NOTSET: | ||
return default | ||
raise MissingFieldError(key) | ||
elif not isinstance(value, cls): | ||
cls = (cls,) if isclass(cls) else cls | ||
raise InvalidFieldTypeError( | ||
key, actual=str(type(value)), expected=tuple(str(t) for t in cls) | ||
) | ||
return value | ||
|
||
|
||
class SegmentAnythingBase(SubsetBase): | ||
def __init__( | ||
self, | ||
path: str, | ||
*, | ||
subset: Optional[str] = None, | ||
ctx: Optional[ImportContext] = None, | ||
): | ||
if not osp.isdir(path): | ||
raise DatasetImportError(f"path {path} must be directory.") | ||
self._path = path | ||
|
||
super().__init__(subset=subset, ctx=ctx) | ||
self._items = self._load_items() | ||
|
||
def _load_items(self): | ||
pbar = self._ctx.progress_reporter | ||
items = [] | ||
for annotation_file in pbar.iter( | ||
glob(osp.join(self._path, "*.json")), | ||
desc=f"Parsing data in {osp.basename(self._path)}", | ||
): | ||
image_id = None | ||
annotations = [] | ||
item_kwargs = { | ||
"id": None, | ||
"subset": self._subset, | ||
"media": None, | ||
"annotations": [], | ||
"attributes": {}, | ||
} | ||
|
||
try: | ||
contents = parse_json_file(annotation_file) | ||
image_info = contents["image"] | ||
annotations = contents["annotations"] | ||
|
||
image_id = parse_field(image_info, "image_id", int) | ||
item_kwargs["attributes"]["id"] = image_id | ||
|
||
image_size = ( | ||
parse_field(image_info, "height", int, default=None), | ||
parse_field(image_info, "width", int, default=None), | ||
) | ||
if any(i is None for i in image_size): | ||
image_size = None | ||
file_name = parse_field(image_info, "file_name", str) | ||
|
||
item_kwargs["id"] = osp.splitext(file_name)[0] | ||
item_kwargs["media"] = Image.from_file( | ||
path=osp.join(self._path, file_name), size=image_size | ||
) | ||
except Exception as e: | ||
self._ctx.error_policy.report_item_error(e, item_id=(image_id, self._subset)) | ||
|
||
try: | ||
for annotation in annotations: | ||
anno_id = parse_field(annotation, "id", int) | ||
attributes = { | ||
"predicted_iou": parse_field( | ||
annotation, | ||
"predicted_iou", | ||
float, | ||
0.0, | ||
), | ||
"stability_score": parse_field( | ||
annotation, | ||
"stability_score", | ||
float, | ||
0.0, | ||
), | ||
"point_coords": parse_field( | ||
annotation, | ||
"point_coords", | ||
list, | ||
[[]], | ||
), | ||
"crop_box": parse_field(annotation, "crop_box", list, []), | ||
} | ||
|
||
group = anno_id # make sure all tasks' annotations are merged | ||
|
||
segmentation = parse_field(annotation, "segmentation", dict, None) | ||
if segmentation is None: | ||
raise InvalidAnnotationError("'segmentation' label is not found.") | ||
item_kwargs["annotations"].append( | ||
RleMask( | ||
rle=segmentation, | ||
label=None, | ||
id=anno_id, | ||
attributes=attributes, | ||
group=group, | ||
) | ||
) | ||
|
||
bbox = parse_field(annotation, "bbox", list, None) | ||
if bbox is None: | ||
bbox = item_kwargs["annotations"][-1].get_bbox().tolist() | ||
|
||
if len(bbox) > 0: | ||
if len(bbox) != 4: | ||
raise InvalidAnnotationError( | ||
f"Bbox has wrong value count {len(bbox)}. Expected 4 values." | ||
) | ||
x, y, w, h = bbox | ||
item_kwargs["annotations"].append( | ||
Bbox( | ||
x, | ||
y, | ||
w, | ||
h, | ||
label=None, | ||
id=anno_id, | ||
attributes=attributes, | ||
group=group, | ||
) | ||
) | ||
except Exception as e: | ||
self._ctx.error_policy.report_annotation_error(e, item_id=(image_id, self._subset)) | ||
|
||
try: | ||
items.append(DatasetItem(**item_kwargs)) | ||
except Exception as e: | ||
self._ctx.error_policy.report_item_error(e, item_id=(image_id, self._subset)) | ||
|
||
return items |
163 changes: 163 additions & 0 deletions
163
datumaro/plugins/data_formats/segment_anything/exporter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
|
||
import logging as log | ||
import os | ||
import os.path as osp | ||
from itertools import chain | ||
from typing import List, Union | ||
|
||
from pycocotools import mask as mask_utils | ||
|
||
from datumaro.components.annotation import AnnotationType, Ellipse, Polygon | ||
from datumaro.components.errors import MediaTypeError | ||
from datumaro.components.exporter import Exporter | ||
from datumaro.components.media import Image | ||
from datumaro.util import annotation_util as anno_tools | ||
from datumaro.util import dump_json_file, mask_tools | ||
|
||
|
||
class SegmentAnythingExporter(Exporter): | ||
DEFAULT_IMAGE_EXT = ".jpg" | ||
|
||
_polygon_types = {AnnotationType.polygon, AnnotationType.ellipse} | ||
_allowed_types = { | ||
AnnotationType.bbox, | ||
AnnotationType.polygon, | ||
AnnotationType.mask, | ||
AnnotationType.ellipse, | ||
} | ||
|
||
def __init__( | ||
self, | ||
extractor, | ||
save_dir, | ||
**kwargs, | ||
): | ||
super().__init__(extractor, save_dir, **kwargs) | ||
|
||
@staticmethod | ||
def find_instance_anns(annotations): | ||
return [a for a in annotations if a.type in SegmentAnythingExporter._allowed_types] | ||
|
||
@classmethod | ||
def find_instances(cls, annotations): | ||
return anno_tools.find_instances(cls.find_instance_anns(annotations)) | ||
|
||
def get_annotation_info(self, group, img_width, img_height): | ||
boxes = [a for a in group if a.type == AnnotationType.bbox] | ||
polygons: List[Union[Polygon, Ellipse]] = [ | ||
a for a in group if a.type in self._polygon_types | ||
] | ||
masks = [a for a in group if a.type == AnnotationType.mask] | ||
|
||
anns = boxes + polygons + masks | ||
leader = anno_tools.find_group_leader(anns) | ||
if len(boxes) > 0: | ||
bbox = anno_tools.max_bbox(boxes) | ||
else: | ||
bbox = anno_tools.max_bbox(anns) | ||
polygons = [p.as_polygon() for p in polygons] | ||
|
||
mask = None | ||
if polygons: | ||
mask = mask_tools.rles_to_mask(polygons, img_width, img_height) | ||
if masks: | ||
masks = (m.image for m in masks) | ||
if mask is not None: | ||
masks = chain(masks, [mask]) | ||
mask = mask_tools.merge_masks(masks) | ||
if mask is None: | ||
return None | ||
mask = mask_tools.mask_to_rle(mask) | ||
|
||
segmentation = { | ||
"counts": list(int(c) for c in mask["counts"]), | ||
"size": list(int(c) for c in mask["size"]), | ||
} | ||
rles = mask_utils.frPyObjects(segmentation, img_height, img_width) | ||
if isinstance(rles["counts"], bytes): | ||
rles["counts"] = rles["counts"].decode() | ||
area = mask_utils.area(rles) | ||
|
||
annotation_data = { | ||
"id": leader.group, | ||
"segmentation": rles, | ||
"bbox": bbox, | ||
"area": area, | ||
"predicted_iou": max(ann.attributes.get("predicted_iou", 0.0) for ann in anns), | ||
"stability_score": max(ann.attributes.get("stability_score", 0.0) for ann in anns), | ||
"crop_box": anno_tools.max_bbox([ann.attributes.get("crop_box", []) for ann in anns]), | ||
"point_coords": list( | ||
set( | ||
tuple(point_coord) | ||
for ann in anns | ||
for point_coord in ann.attributes.get("point_coords", [[]]) | ||
) | ||
), | ||
} | ||
return annotation_data | ||
|
||
def apply(self): | ||
if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image): | ||
raise MediaTypeError("Media type is not an image") | ||
|
||
os.makedirs(self._save_dir, exist_ok=True) | ||
|
||
subsets = self._extractor.subsets() | ||
pbars = self._ctx.progress_reporter.split(len(subsets)) | ||
|
||
max_image_id = 1 | ||
for pbar, (subset_name, subset) in zip(pbars, subsets.items()): | ||
for item in pbar.iter(subset, desc=f"Exporting {subset_name}"): | ||
try: | ||
# make sure file_name is flat | ||
file_name = self._make_image_filename(item).replace("/", "__") | ||
try: | ||
image_id = int(item.attributes.get("id", max_image_id)) | ||
except ValueError: | ||
image_id = max_image_id | ||
max_image_id += 1 | ||
|
||
if not item.media or not item.media.size: | ||
log.warning( | ||
f"Item '{item.id}': skipping writing instances since no image info available" | ||
) | ||
continue | ||
|
||
height, width = item.media.size | ||
json_data = { | ||
"image": { | ||
"image_id": image_id, | ||
"file_name": file_name, | ||
"height": height, | ||
"width": width, | ||
}, | ||
"annotations": [], | ||
} | ||
|
||
instances = self.find_instances(item.annotations) | ||
annotations = [self.get_annotation_info(i, width, height) for i in instances] | ||
annotations = [i for i in annotations if i is not None] | ||
if not annotations: | ||
log.warning( | ||
f"Item '{item.id}': skipping writing instances since no annotation available" | ||
) | ||
continue | ||
json_data["annotations"] = annotations | ||
|
||
dump_json_file( | ||
os.path.join(self._save_dir, osp.splitext(file_name)[0] + ".json"), | ||
json_data, | ||
) | ||
|
||
if self._save_media: | ||
self._save_image( | ||
item, | ||
path=osp.abspath(osp.join(self._save_dir, file_name)), | ||
) | ||
|
||
except Exception as e: | ||
self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset)) |
Oops, something went wrong.