From e4375afa1024b878b3555d5dae44eaab7c3f6748 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 16:46:44 +0400 Subject: [PATCH 01/25] sync components/annotation.py --- src/datumaro/components/annotation.py | 1478 +++++++++++++++-- src/datumaro/components/operations.py | 11 + .../plugins/data_formats/datumaro/exporter.py | 4 +- src/datumaro/util/annotation_util.py | 2 +- tests/unit/test_ops.py | 16 + 5 files changed, 1412 insertions(+), 99 deletions(-) diff --git a/src/datumaro/components/annotation.py b/src/datumaro/components/annotation.py index 2808dbcf72..66dac1e0d6 100644 --- a/src/datumaro/components/annotation.py +++ b/src/datumaro/components/annotation.py @@ -1,17 +1,33 @@ -# Copyright (C) 2021-2022 Intel Corporation +# Copyright (C) 2021-2024 Intel Corporation # Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from __future__ import annotations -from enum import Enum, auto +import math +from enum import IntEnum from functools import partial from itertools import zip_longest -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) import attr +import cv2 import numpy as np +import shapely.geometry as sg from attr import asdict, attrs, field from typing_extensions import Literal @@ -19,23 +35,31 @@ from datumaro.util.attrs_util import default_if_none, not_empty -class AnnotationType(Enum): - label = auto() - mask = auto() - points = auto() - polygon = auto() - polyline = auto() - bbox = auto() - caption = auto() - cuboid_3d = auto() - super_resolution_annotation = auto() - depth_annotation = auto() - skeleton = auto() +class AnnotationType(IntEnum): + unknown = 0 + label = 1 + mask = 2 + points = 3 + polygon = 4 + polyline = 5 + bbox = 6 + caption = 7 + cuboid_3d = 8 + super_resolution_annotation = 9 + depth_annotation = 10 + ellipse = 11 + hash_key = 12 + feature_vector = 13 + tabular = 14 + rotated_bbox = 15 + cuboid_2d = 16 + skeleton = 17 COORDINATE_ROUNDING_DIGITS = 2 - +CHECK_POLYGON_EQ_EPSILONE = 1e-7 NO_GROUP = 0 +NO_OBJECT_ID = -1 @attrs(slots=True, kw_only=True, order=False) @@ -66,6 +90,13 @@ class Annotation: # single object. The value of 0 means there is no group. group: int = field(default=NO_GROUP, validator=default_if_none(int)) + # object identifier over the multiple items + # e.g.) in a video, person 'A' could be annotated on the multiple frame images + # the user could assign >=0 value as id of person 'A'. + object_id: int = field(default=NO_OBJECT_ID, validator=default_if_none(int)) + + _type = AnnotationType.unknown + @property def type(self) -> AnnotationType: return self._type # must be set in subclasses @@ -92,6 +123,22 @@ class Categories: attributes: Set[str] = field(factory=set, validator=default_if_none(set), eq=False) +class GroupType(IntEnum): + EXCLUSIVE = 0 + INCLUSIVE = 1 + RESTRICTED = 2 + + def to_str(self) -> str: + return self.name.lower() + + @classmethod + def from_str(cls, text: str) -> GroupType: + try: + return cls[text.upper()] + except KeyError: + raise ValueError(f"Invalid GroupType: {text}") + + @attrs(slots=True, order=False) class LabelCategories(Categories): """ @@ -106,7 +153,16 @@ class Category: parent: str = field(default="", validator=default_if_none(str)) attributes: Set[str] = field(factory=set, validator=default_if_none(set)) + @attrs(slots=True, order=False) + class LabelGroup: + name: str = field(converter=str, validator=not_empty) + labels: List[str] = field(default=[], validator=default_if_none(list)) + group_type: GroupType = field( + default=GroupType.EXCLUSIVE, validator=default_if_none(GroupType) + ) + items: List[Category] = field(factory=list, validator=default_if_none(list)) + label_groups: List[LabelGroup] = field(factory=list, validator=default_if_none(list)) _indices: Dict[Tuple[str, str], int] = field(factory=dict, init=False, eq=False) @classmethod @@ -159,7 +215,10 @@ def labels(self): return {label_index: parent + name for (parent, name), label_index in self._indices.items()} def add( - self, name: str, parent: Optional[str] = "", attributes: Optional[Set[str]] = None + self, + name: str, + parent: Optional[str] = None, + attributes: Optional[Set[str]] = None, ) -> int: if not name: raise ValueError("Label name must not be empty") @@ -172,6 +231,18 @@ def add( self._indices[key] = index return index + def add_label_group( + self, + name: str, + labels: List[str], + group_type: GroupType, + ) -> int: + assert name + + index = len(self.label_groups) + self.label_groups.append(self.LabelGroup(name, labels, group_type)) + return index + def find(self, name: str, parent: str = "") -> Tuple[Optional[int], Optional[Category]]: index = self._indices.get((parent, name)) if index is not None: @@ -200,6 +271,38 @@ class Label(Annotation): label: int = field(converter=int) +@attrs(slots=True, eq=False, order=False) +class HashKey(Annotation): + _type = AnnotationType.hash_key + hash_key: np.ndarray = field(validator=attr.validators.instance_of(np.ndarray)) + + @hash_key.validator + def _validate(self, attribute, value: np.ndarray): + """Check whether value is a 1D Numpy array having 96 np.uint8 values""" + if value.ndim != 1 or value.shape[0] != 96 or value.dtype != np.uint8: + raise ValueError(value) + + def __eq__(self, other): + if not super().__eq__(other): + return False + if not isinstance(other, __class__): + return False + return np.array_equal(self.hash_key, other.hash_key) + + +@attrs(eq=False, order=False) +class FeatureVector(Annotation): + _type = AnnotationType.feature_vector + vector: np.ndarray = field(validator=attr.validators.instance_of(np.ndarray)) + + def __eq__(self, other): + if not super().__eq__(other): + return False + if not isinstance(other, __class__): + return False + return np.array_equal(self.hash_key, other.hash_key) + + RgbColor = Tuple[int, int, int] Colormap = Dict[int, RgbColor] @@ -260,7 +363,9 @@ def __eq__(self, other): BinaryMaskImage = np.ndarray # 2d array of type bool +BinaryMaskImageCallable = Callable[[], BinaryMaskImage] IndexMaskImage = np.ndarray # 2d array of type int +IndexMaskImageCallable = Callable[[], IndexMaskImage] @attrs(slots=True, eq=False, order=False) @@ -270,7 +375,7 @@ class Mask(Annotation): """ _type = AnnotationType.mask - _image = field() + _image: Union[BinaryMaskImage, BinaryMaskImageCallable] = field() label: Optional[int] = field( converter=attr.converters.optional(int), default=None, kw_only=True ) @@ -287,23 +392,58 @@ def image(self) -> BinaryMaskImage: image = image() return image - def as_class_mask(self, label_id: Optional[int] = None) -> IndexMaskImage: - """ - Produces a class index mask. Mask label id can be changed. + def as_class_mask( + self, + label_id: Optional[int] = None, + ignore_index: int = 0, + dtype: Optional[np.dtype] = None, + ) -> IndexMaskImage: + """Produces a class index mask based on the binary mask. + + Args: + label_id: Scalar value to represent the class index of the mask. + If not specified, `self.label` will be used. Defaults to None. + ignore_index: Scalar value to fill in the zeros in the binary mask. + Defaults to 0. + dtype: Data type for the resulting mask. If not specified, + it will be inferred from the provided `label_id` to hold its value. + For example, if `label_id=255`, the inferred dtype will be `np.uint8`. + Defaults to None. + + Returns: + IndexMaskImage: Class index mask generated from the binary mask. """ if label_id is None: label_id = self.label from datumaro.util.mask_tools import make_index_mask - return make_index_mask(self.image, label_id) + return make_index_mask(self.image, index=label_id, ignore_index=ignore_index, dtype=dtype) - def as_instance_mask(self, instance_id: int) -> IndexMaskImage: - """ - Produces a instance index mask. + def as_instance_mask( + self, + instance_id: int, + ignore_index: int = 0, + dtype: Optional[np.dtype] = None, + ) -> IndexMaskImage: + """Produces an instance index mask based on the binary mask. + + Args: + instance_id: Scalar value to represent the instance id. + ignore_index: Scalar value to fill in the zeros in the binary mask. + Defaults to 0. + dtype: Data type for the resulting mask. If not specified, + it will be inferred from the provided `label_id` to hold its value. + For example, if `label_id=255`, the inferred dtype will be `np.uint8`. + Defaults to None. + + Returns: + IndexMaskImage: Instance index mask generated from the binary mask. """ from datumaro.util.mask_tools import make_index_mask - return make_index_mask(self.image, instance_id) + return make_index_mask( + self.image, index=instance_id, ignore_index=ignore_index, dtype=dtype + ) def get_area(self) -> int: return np.count_nonzero(self.image) @@ -327,16 +467,12 @@ def paint(self, colormap: Colormap) -> np.ndarray: return paint_mask(self.as_class_mask(), colormap) def __eq__(self, other): + if not super().__eq__(other): + return False if not isinstance(other, __class__): return False - - parent_keys = [f.name for f in attr.fields(Annotation)] - self_parent_fields = {k: v for k, v in self.as_dict().items() if k in parent_keys} - other_parent_fields = {k: v for k, v in other.as_dict().items() if k in parent_keys} - return ( - (self_parent_fields == other_parent_fields) - and (self.label == other.label) + (self.label == other.label) and (self.z_order == other.z_order) and (np.array_equal(self.image, other.image)) ) @@ -385,6 +521,47 @@ def __eq__(self, other): return self.rle == other.rle +@attrs(slots=True, eq=False, order=False) +class ExtractedMask(Mask): + """Mask annotation (binary mask) extracted from an index mask (integer 2D Numpy array). + + This class can extract a binary mask with given index mask and index value. + The advantage of this class is that we can create multiple binary mask + but they share a single index mask source. + + Attributes: + index_mask: Integer 2D Numpy array. Its pixel can indicate a label id (class) + or an instance id. + index: Integer value to extract a binary mask from the given index mask. + + Examples: + This example demonstrates how to create an `ExtractedMask` from a synthetic index mask, + which denotes a semantic segmentation mask with binary values such as 0 for background + and 1 for foreground. + + >>> import numpy as np + >>> from datumaro.components.annotation import ExtractedMask + >>> + >>> index_mask = np.random.randint(low=0, high=2, size=(10, 10), dtype=np.uint8) + >>> mask1 = ExtractedMask(index_mask=index_mask, index=0, label=0) # 0 for background + >>> mask2 = ExtractedMask(index_mask=index_mask, index=1, label=1) # 1 for foreground + >>> np.unique(mask1.image).tolist() # `image` property create a binary mask + np.array([0, 1]) + >>> mask1.index_mask == mask2.index_mask # They share the same source + True + """ + + index_mask: Union[IndexMaskImage, IndexMaskImageCallable] = field() + index: int = field() + + _image: None = field(init=False, default=None) + + @property + def image(self) -> BinaryMaskImage: + index_mask = self.index_mask() if callable(self.index_mask) else self.index_mask + return index_mask == self.index + + CompiledMaskImage = np.ndarray # 2d of integers (of different precision) @@ -532,8 +709,23 @@ def lazy_extract(self, instance_id: int) -> Callable[[], IndexMaskImage]: @attrs(slots=True, order=False) -class _Shape(Annotation): - # Flattened list of point coordinates +class Shape(Annotation): + """ + Base class for shape annotations. This class defines the common attributes and methods + for different types of shape annotations. + + Attributes: + points (List[float]): List of float values representing the coordinates of the shape. + label (Optional[int]): Optional label ID for the shape. Default is None. + z_order (int): Z-order of the shape, used to determine the rendering order. Default is 0. + + Methods: + get_area: Abstract method to calculate the area of the shape. + as_polygon: Abstract method to convert the shape into a polygon representation. + get_bbox: Returns the bounding box of the shape as [x, y, w, h]. + get_points: Returns the points of the shape as a list of (x, y) tuples. + """ + points: List[float] = field( converter=lambda x: np.around(x, COORDINATE_ROUNDING_DIGITS).tolist(), factory=list ) @@ -545,10 +737,24 @@ class _Shape(Annotation): z_order: int = field(default=0, validator=default_if_none(int), kw_only=True) def get_area(self): + """ + Calculate the area of the shape. + """ + raise NotImplementedError() + + def as_polygon(self) -> List[float]: + """ + Convert the shape into a polygon representation. + """ raise NotImplementedError() def get_bbox(self) -> Tuple[float, float, float, float]: - "Returns [x, y, w, h]" + """ + Calculate and return the bounding box of the shape. + + Returns: + Tuple[float, float, float, float]: The bounding box as [x, y, w, h]. + """ points = self.points if not points: @@ -562,9 +768,40 @@ def get_bbox(self) -> Tuple[float, float, float, float]: y1 = max(ys) return [x0, y0, x1 - x0, y1 - y0] + def get_points(self) -> Optional[List[Tuple[float, float]]]: + """ + Convert and return the points of the shape as a list of (x, y) tuples. + + Returns: + Optional[List[Tuple[float, float]]]: List of points as (x, y) tuples, + or None if no points. + """ + points = self.points + if not points: + return None + + assert len(points) % 2 == 0, "points should have (2 x points) number of float values." + + xs = [p for p in points[0::2]] + ys = [p for p in points[1::2]] + + return [(x, y) for x, y in zip(xs, ys)] + @attrs(slots=True, order=False) -class PolyLine(_Shape): +class PolyLine(Shape): + """ + PolyLine annotation class. + This class represents a polyline shape, which is a series of connected line segments. + + Attributes: + _type (AnnotationType): The type of annotation, set to `AnnotationType.polyline`. + + Methods: + as_polygon: Returns the points of the polyline as a polygon. + get_area: Returns the area of the polyline, which is always 0. + """ + _type = AnnotationType.polyline def as_polygon(self): @@ -576,6 +813,23 @@ def get_area(self): @attrs(slots=True, init=False, order=False) class Cuboid3d(Annotation): + """ + Cuboid3d annotation class. + This class represents a 3D cuboid annotation with position, rotation, and scale. + + Attributes: + _type (AnnotationType): The type of annotation, set to `AnnotationType.cuboid_3d`. + _points (List[float]): List of float values representing the position, + rotation, and scale of the cuboid. + label (Optional[int]): Optional label ID for the cuboid. Default is None. + + Methods: + __init__: Initializes the Cuboid3d with position, rotation, and scale. + position: Property to get and set the position of the cuboid. + rotation: Property to get and set the rotation of the cuboid. + scale: Property to get and set the scale of the cuboid. + """ + _type = AnnotationType.cuboid_3d _points: List[float] = field(default=None) label: Optional[int] = field( @@ -584,14 +838,31 @@ class Cuboid3d(Annotation): @_points.validator def _points_validator(self, attribute, points): + """ + Validate and round the points representing the cuboid's position, rotation, and scale. + + Args: + attribute: The attribute being validated. + points: The list of float values to validate. + """ if points is None: - points = [0, 0, 0, 0, 0, 0, 1, 1, 1] + points = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0] else: assert len(points) == 3 + 3 + 3, points points = np.around(points, COORDINATE_ROUNDING_DIGITS).tolist() self._points = points def __init__(self, position, rotation=None, scale=None, **kwargs): + """ + Initialize the Cuboid3d with position, rotation, and scale. + + Args: + position (List[float]): List of 3 float values representing the position [x, y, z]. + rotation (List[float], optional): List of 3 float values + representing the rotation [rx, ry, rz]. + scale (List[float], optional): List of 3 float values + representing the scale [sx, sy, sz]. + """ assert len(position) == 3, position if not rotation: rotation = [0] * 3 @@ -602,11 +873,22 @@ def __init__(self, position, rotation=None, scale=None, **kwargs): @property def position(self): - """[x, y, z]""" + """ + Get the position of the cuboid. + + Returns: + List[float]: The position [x, y, z] of the cuboid. + """ return self._points[0:3] @position.setter def _set_poistion(self, value): + """ + Set the position of the cuboid. + + Args: + value (List[float]): The new position [x, y, z] of the cuboid. + """ # TODO: fix the issue with separate coordinate rounding: # self.position[0] = 12.345676 # - the number assigned won't be rounded. @@ -614,142 +896,783 @@ def _set_poistion(self, value): @property def rotation(self): - """[rx, ry, rz]""" + """ + Get the rotation of the cuboid. + + Returns: + List[float]: The rotation [rx, ry, rz] of the cuboid. + """ return self._points[3:6] @rotation.setter def _set_rotation(self, value): + """ + Set the rotation of the cuboid. + + Args: + value (List[float]): The new rotation [rx, ry, rz] of the cuboid. + """ self.rotation[:] = np.around(value, COORDINATE_ROUNDING_DIGITS).tolist() @property def scale(self): - """[sx, sy, sz]""" + """ + Get the scale of the cuboid. + + Returns: + List[float]: The scale [sx, sy, sz] of the cuboid. + """ return self._points[6:9] @scale.setter def _set_scale(self, value): + """ + Set the scale of the cuboid. + + Args: + value (List[float]): The new scale [sx, sy, sz] of the cuboid. + """ self.scale[:] = np.around(value, COORDINATE_ROUNDING_DIGITS).tolist() -@attrs(slots=True, order=False) -class Polygon(_Shape): +@attrs(slots=True, order=False, eq=False) +class Polygon(Shape): + """ + Polygon annotation class. This class represents a polygon shape defined by a series of points. + + Attributes: + _type (AnnotationType): The type of annotation, set to `AnnotationType.polygon`. + + Methods: + __attrs_post_init__: Validates the points to ensure they form a valid polygon. + get_area: Calculates the area of the polygon using the shoelace formula. + as_polygon: Returns the points of the polygon. + __eq__: Compares this polygon with another for equality. + _get_shoelace_area: Helper method to calculate the area of the polygon + using the shoelace formula. + """ + _type = AnnotationType.polygon def __attrs_post_init__(self): + """ + Validate the points to ensure they form a valid polygon. + + Raises: + AssertionError: If the number of points is not even or less than 3 pairs of coordinates. + """ # keep the message on a single line to produce informative output assert len(self.points) % 2 == 0 and 3 <= len(self.points) // 2, ( "Wrong polygon points: %s" % self.points ) def get_area(self): - import pycocotools.mask as mask_utils + """ + Calculate the area of the polygon using the shoelace formula. - x, y, w, h = self.get_bbox() - rle = mask_utils.frPyObjects([self.points], y + h, x + w) - area = mask_utils.area(rle)[0] + Returns: + float: The area of the polygon. + """ + # import pycocotools.mask as mask_utils + + # x, y, w, h = self.get_bbox() + # rle = mask_utils.frPyObjects([self.points], y + h, x + w) + # area = mask_utils.area(rle)[0] + area = self._get_shoelace_area() return area + def as_polygon(self) -> List[float]: + """ + Return the points of the polygon. + + Returns: + List[float]: The points of the polygon. + """ + return self.points + + def __eq__(self, other): + """ + Compare this polygon with another for equality. + + Args: + other: The other polygon to compare with. + + Returns: + bool: True if the polygons are equal, False otherwise. + """ + if not isinstance(other, __class__): + return False + if ( + not Annotation.__eq__(self, other) + or self.label != other.label + or self.z_order != other.z_order + ): + return False + + self_points = self.get_points() + other_points = other.get_points() + self_polygon = sg.Polygon(self_points) + other_polygon = sg.Polygon(other_points) + # if polygon is not valid, compare points + if not (self_polygon.is_valid and other_polygon.is_valid): + return self_points == other_points + inter_area = self_polygon.intersection(other_polygon).area + return abs(self_polygon.area - inter_area) < CHECK_POLYGON_EQ_EPSILONE + + def _get_shoelace_area(self): + """ + Calculate the area of the polygon using the shoelace formula. + + Returns: + float: The area of the polygon. + """ + points = self.get_points() + n = len(points) + # Not a polygon + if n < 3: + return 0 + + area = 0.0 + for i in range(n): + x1, y1 = points[i] + x2, y2 = points[(i + 1) % n] # Next vertex, wrapping around using modulo + area += x1 * y2 - y1 * x2 + + return abs(area) / 2.0 + @attrs(slots=True, init=False, order=False) -class Bbox(_Shape): +class Bbox(Shape): + """ + Bbox annotation class. This class represents a bounding box + defined by its top-left corner (x, y) and its width and height (w, h). + + Attributes: + _type (AnnotationType): The type of annotation, set to `AnnotationType.bbox`. + + Methods: + __init__: Initializes the Bbox with its coordinates and dimensions. + x: Property to get the x-coordinate of the bounding box. + y: Property to get the y-coordinate of the bounding box. + w: Property to get the width of the bounding box. + h: Property to get the height of the bounding box. + get_area: Calculates the area of the bounding box. + get_bbox: Returns the bounding box coordinates and dimensions. + as_polygon: Returns the bounding box as a list of points forming a polygon. + iou: Calculates the Intersection over Union (IoU) with another shape. + wrap: Creates a new Bbox instance with updated attributes. + """ + _type = AnnotationType.bbox def __init__(self, x, y, w, h, *args, **kwargs): + """ + Initialize the Bbox with its top-left corner (x, y) and its width and height (w, h). + + Args: + x (float): The x-coordinate of the top-left corner. + y (float): The y-coordinate of the top-left corner. + w (float): The width of the bounding box. + h (float): The height of the bounding box. + """ kwargs.pop("points", None) # comes from wrap() self.__attrs_init__([x, y, x + w, y + h], *args, **kwargs) @property def x(self): + """ + Get the x-coordinate of the top-left corner of the bounding box. + + Returns: + float: The x-coordinate of the bounding box. + """ return self.points[0] @property def y(self): + """ + Get the y-coordinate of the top-left corner of the bounding box. + + Returns: + float: The y-coordinate of the bounding box. + """ return self.points[1] @property def w(self): + """ + Get the width of the bounding box. + + Returns: + float: The width of the bounding box. + """ return self.points[2] - self.points[0] @property def h(self): + """ + Get the height of the bounding box. + + Returns: + float: The height of the bounding box. + """ return self.points[3] - self.points[1] def get_area(self): + """ + Calculate the area of the bounding box. + + Returns: + float: The area of the bounding box. + """ return self.w * self.h def get_bbox(self): + """ + Get the bounding box coordinates and dimensions. + + Returns: + List[float]: The bounding box as [x, y, w, h]. + """ return [self.x, self.y, self.w, self.h] - def as_polygon(self): + def as_polygon(self) -> List[float]: + """ + Convert the bounding box into a polygon representation. + + Returns: + List[float]: The bounding box as a polygon. + """ x, y, w, h = self.get_bbox() return [x, y, x + w, y, x + w, y + h, x, y + h] - def iou(self, other: _Shape) -> Union[float, Literal[-1]]: + def iou(self, other: Shape) -> Union[float, Literal[-1]]: + """ + Calculate the Intersection over Union (IoU) with another shape. + + Args: + other (Shape): The other shape to compare with. + + Returns: + Union[float, Literal[-1]]: The IoU value or -1 if not applicable. + """ from datumaro.util.annotation_util import bbox_iou return bbox_iou(self.get_bbox(), other.get_bbox()) def wrap(item, **kwargs): + """ + Create a new Bbox instance with updated attributes. + + Args: + item (Bbox): The original Bbox instance. + kwargs: Additional attributes to update. + + Returns: + Bbox: A new Bbox instance with updated attributes. + """ d = {"x": item.x, "y": item.y, "w": item.w, "h": item.h} d.update(kwargs) return attr.evolve(item, **d) -@attrs(slots=True, order=False) -class PointsCategories(Categories): +@attrs(slots=True, init=False, order=False) +class RotatedBbox(Shape): """ - Describes (key-)point metainfo such as point names and joints. + RotatedBbox annotation class. This class represents a rotated bounding box defined + by its center (cx, cy), width (w), height (h), and rotation angle (r). + + Attributes: + _type (AnnotationType): The type of annotation, set to `AnnotationType.rotated_bbox`. + + Methods: + __init__: Initializes the RotatedBbox with its center, dimensions, and rotation angle. + from_rectangle: Creates a RotatedBbox from a list of four corner points. + cx: Property to get the x-coordinate of the center of the bounding box. + cy: Property to get the y-coordinate of the center of the bounding box. + w: Property to get the width of the bounding box. + h: Property to get the height of the bounding box. + r: Property to get the rotation angle of the bounding box. + get_area: Calculates the area of the bounding box. + get_bbox: Returns the bounding box coordinates and dimensions. + get_rotated_bbox: Returns the rotated bounding box parameters. + as_polygon: Converts the rotated bounding box into a list of corner points. + iou: Calculates the Intersection over Union (IoU) with another shape. + wrap: Creates a new RotatedBbox instance with updated attributes. """ - @attrs(slots=True, order=False) - class Category: - # Names for specific points, e.g. eye, hose, mouth etc. - # These labels are not required to be in LabelCategories - labels: List[str] = field(factory=list, validator=default_if_none(list)) + _type = AnnotationType.rotated_bbox - # Pairs of connected point indices - joints: Set[Tuple[int, int]] = field(factory=set, validator=default_if_none(set)) + def __init__(self, cx, cy, w, h, r, *args, **kwargs): + """ + Initialize the RotatedBbox with its center (cx, cy), width (w), height (h), + and rotation angle (r). - items: Dict[int, Category] = field(factory=dict, validator=default_if_none(dict)) + Args: + cx (float): The x-coordinate of the center. + cy (float): The y-coordinate of the center. + w (float): The width of the bounding box. + h (float): The height of the bounding box. + r (float): The rotation angle of the bounding box in degrees. + """ + kwargs.pop("points", None) # comes from wrap() + self.__attrs_init__([cx, cy, w, h, r], *args, **kwargs) @classmethod - def from_iterable( - cls, - iterable: Iterable[ - Union[ - Tuple[int, List[str]], - Tuple[int, List[str], Set[Tuple[int, int]]], - ], - ], - ) -> PointsCategories: + def from_rectangle(cls, points: List[Tuple[float, float]], *args, **kwargs): """ - Create PointsCategories from an iterable. + Create a RotatedBbox from a list of four corner points. Args: - iterable: An Iterable with the following elements: - - - a label id - - a list of positional arguments for Categories + points (List[Tuple[float, float]]): A list of four points defining the rectangle. + args: Additional arguments. + kwargs: Additional keyword arguments. Returns: - PointsCategories: PointsCategories object + RotatedBbox: A new RotatedBbox instance. """ - temp_categories = cls() + assert len(points) == 4, "polygon for a rotated bbox should have only 4 coordinates." - for args in iterable: - temp_categories.add(*args) - return temp_categories + # Calculate rotation angle + rot = math.atan2(points[1][1] - points[0][1], points[1][0] - points[0][0]) - def add( - self, - label_id: int, - labels: Optional[Iterable[str]] = None, - joints: Iterable[Tuple[int, int]] = None, - ): - if joints is None: - joints = [] - joints = set(map(tuple, joints)) - self.items[label_id] = self.Category(labels, joints) + # Calculate the center of the bounding box + cx = (points[0][0] + points[2][0]) / 2 + cy = (points[0][1] + points[2][1]) / 2 + + # Calculate the width and height + width = math.sqrt((points[1][0] - points[0][0]) ** 2 + (points[1][1] - points[0][1]) ** 2) + height = math.sqrt((points[2][0] - points[1][0]) ** 2 + (points[2][1] - points[1][1]) ** 2) + + return cls(cx=cx, cy=cy, w=width, h=height, r=math.degrees(rot), *args, **kwargs) + + @property + def cx(self): + """ + Get the x-coordinate of the center of the bounding box. + + Returns: + float: The x-coordinate of the center. + """ + return self.points[0] + + @property + def cy(self): + """ + Get the y-coordinate of the center of the bounding box. + + Returns: + float: The y-coordinate of the center. + """ + return self.points[1] + + @property + def w(self): + """ + Get the width of the bounding box. + + Returns: + float: The width of the bounding box. + """ + return self.points[2] + + @property + def h(self): + """ + Get the height of the bounding box. + + Returns: + float: The height of the bounding box. + """ + return self.points[3] + + @property + def r(self): + """ + Get the rotation angle of the bounding box in degrees. + + Returns: + float: The rotation angle of the bounding box. + """ + return self.points[4] + + def get_area(self): + """ + Calculate the area of the bounding box. + + Returns: + float: The area of the bounding box. + """ + return self.w * self.h + + def get_bbox(self): + """ + Get the bounding box coordinates and dimensions. + + Returns: + List[float]: The bounding box as [x, y, w, h]. + """ + polygon = self.as_polygon() + xs = [pt[0] for pt in polygon] + ys = [pt[1] for pt in polygon] + + return [min(xs), min(ys), max(xs) - min(xs), max(ys) - min(ys)] + + def get_rotated_bbox(self): + """ + Get the rotated bounding box parameters. + + Returns: + List[float]: The rotated bounding box as [cx, cy, w, h, r]. + """ + return [self.cx, self.cy, self.w, self.h, self.r] + + def as_polygon(self) -> List[Tuple[float, float]]: + """ + Convert the rotated bounding box into a list of corner points. + + Returns: + List[Tuple[float, float]]: The bounding box as a list of four corner points. + """ + + def _rotate_point(x, y, angle): + """ + Rotate a point around another point. + + Args: + x (float): The x-coordinate of the point. + y (float): The y-coordinate of the point. + angle (float): The rotation angle in degrees. + + Returns: + Tuple[float, float]: The rotated point coordinates. + """ + angle_rad = math.radians(angle) + cos_theta = math.cos(angle_rad) + sin_theta = math.sin(angle_rad) + nx = cos_theta * x - sin_theta * y + ny = sin_theta * x + cos_theta * y + return nx, ny + + # Calculate corner points of the rectangle + corners = [ + (-self.w / 2, -self.h / 2), + (self.w / 2, -self.h / 2), + (self.w / 2, self.h / 2), + (-self.w / 2, self.h / 2), + ] + + # Rotate each corner point + rotated_corners = [_rotate_point(p[0], p[1], self.r) for p in corners] + + # Translate the rotated points to the original position + return [(p[0] + self.cx, p[1] + self.cy) for p in rotated_corners] + + def iou(self, other: Shape) -> Union[float, Literal[-1]]: + """ + Calculate the Intersection over Union (IoU) with another shape. + + Args: + other (Shape): The other shape to compare with. + + Returns: + Union[float, Literal[-1]]: The IoU value or -1 if not applicable. + """ + from datumaro.util.annotation_util import bbox_iou + + return bbox_iou(self.get_bbox(), other.get_bbox()) + + def wrap(item, **kwargs): + """ + Create a new RotatedBbox instance with updated attributes. + + Args: + item (RotatedBbox): The original RotatedBbox instance. + kwargs: Additional attributes to update. + + Returns: + RotatedBbox: A new RotatedBbox instance with updated attributes. + """ + d = {"x": item.x, "y": item.y, "w": item.w, "h": item.h, "r": item.r} + d.update(kwargs) + return attr.evolve(item, **d) + + +@attrs(slots=True, init=False, order=False) +class Cuboid2D(Annotation): + """ + Cuboid2D annotation class. This class represents a 3D bounding box + defined by its point coordinates in the following way: + [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5), (x6, y6), (x7, y7), (x8, y8)]. + + + 2---3 + /| /| + 1-+-4 | + | 5 + 6 + |/ |/ + 8---7 + + Attributes: + _type (AnnotationType): The type of annotation, set to `AnnotationType.cuboid_2d`. + + Methods: + __init__: Initializes the Cuboid2D with its coordinates. + wrap: Creates a new Cuboid2D instance with updated attributes. + """ + + _type = AnnotationType.cuboid_2d + points = field(default=None) + label: Optional[int] = field( + converter=attr.converters.optional(int), default=None, kw_only=True + ) + z_order: int = field(default=0, validator=default_if_none(int), kw_only=True) + y_3d: float = field(default=None, validator=default_if_none(float), kw_only=True) + + def __init__( + self, + _points: Iterable[Tuple[float, float]], + *args, + **kwargs, + ): + kwargs.pop("points", None) # comes from wrap() + self.__attrs_init__(points=_points, *args, **kwargs) + + @staticmethod + def _get_plane_equation(points): + """Calculates coefficients of the plane equation from three points.""" + x1, y1, z1 = points[0, 0], points[0, 1], points[0, 2] + x2, y2, z2 = points[1, 0], points[1, 1], points[1, 2] + x3, y3, z3 = points[2, 0], points[2, 1], points[2, 2] + a1 = x2 - x1 + b1 = y2 - y1 + c1 = z2 - z1 + a2 = x3 - x1 + b2 = y3 - y1 + c2 = z3 - z1 + a = b1 * c2 - b2 * c1 + b = a2 * c1 - a1 * c2 + c = a1 * b2 - b1 * a2 + d = -a * x1 - b * y1 - c * z1 + return np.array([a, b, c, d]) + + @staticmethod + def _get_denorm(Tr_velo_to_cam_homo): + """Calculates the denormalized vector perpendicular to the image plane. + Args: + Tr_velo_to_cam_homo (np.ndarray): + Homogeneous (4x4) LiDAR-to-camera transformation matrix + Returns: + np.ndarray: vector""" + ground_points_lidar = np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) + ground_points_lidar = np.concatenate( + (ground_points_lidar, np.ones((ground_points_lidar.shape[0], 1))), axis=1 + ) + ground_points_cam = np.matmul(Tr_velo_to_cam_homo, ground_points_lidar.T).T + denorm = -1 * Cuboid2D._get_plane_equation(ground_points_cam) + return denorm + + @staticmethod + def _get_3d_points(dim, location, rotation_y, denorm): + """Get corner points according to the 3D bounding box parameters. + + Args: + dim (List[float]): The dimensions of the 3D bounding box as [l, w, h]. + location (List[float]): The location of the 3D bounding box as [x, y, z]. + rotation_y (float): The rotation angle around the y-axis. + + Returns: + np.ndarray: The corner points of the 3D bounding box. + """ + + c, s = np.cos(rotation_y), np.sin(rotation_y) + R = np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=np.float32) + l, w, h = dim[2], dim[1], dim[0] + x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2] + y_corners = [0, 0, 0, 0, -h, -h, -h, -h] + z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2] + + corners = np.array([x_corners, y_corners, z_corners], dtype=np.float32) + corners_3d = np.dot(R, corners) + + denorm = denorm[:3] + denorm_norm = denorm / np.sqrt(denorm[0] ** 2 + denorm[1] ** 2 + denorm[2] ** 2) + ori_denorm = np.array([0.0, -1.0, 0.0]) + theta = -1 * math.acos(np.dot(denorm_norm, ori_denorm)) + n_vector = np.cross(denorm, ori_denorm) + n_vector_norm = n_vector / np.sqrt(n_vector[0] ** 2 + n_vector[1] ** 2 + n_vector[2] ** 2) + rotation_matrix, j = cv2.Rodrigues(theta * n_vector_norm) + corners_3d = np.dot(rotation_matrix, corners_3d) + corners_3d = corners_3d + np.array(location, dtype=np.float32).reshape(3, 1) + return corners_3d.transpose(1, 0) + + @staticmethod + def _project_to_2d(pts_3d, P): + """Project 3D points to 2D image plane. + + Args: + pts_3d (np.ndarray): The 3D points to project. + P (np.ndarray): The projection matrix. + + Returns: + np.ndarray: The 2D points projected to the image + """ + # Convert to homogeneous coordinates + pts_3d = pts_3d.T + pts_3d_homo = np.vstack((pts_3d, np.ones(pts_3d.shape[1]))) + pts_2d = P @ pts_3d_homo + pts_2d[0, :] = np.divide(pts_2d[0, :], pts_2d[2, :]) + pts_2d[1, :] = np.divide(pts_2d[1, :], pts_2d[2, :]) + pts_2d = pts_2d[:2, :].T + + return pts_2d + + @classmethod + def from_3d( + cls, + dim: np.ndarray, + location: np.ndarray, + rotation_y: float, + P: np.ndarray, + Tr_velo_to_cam: np.ndarray, + ) -> Cuboid2D: + """Creates an instance of Cuboid2D class from 3D bounding box parameters. + + Args: + dim (np.ndarray): 3 scalars describing length, width and height of a 3D bounding box + location (np.ndarray): (x, y, z) coordinates of the middle of the top face. + rotation_y (np.ndarray): rotation along the Y-axis (from -pi to pi) + P (np.ndarray): Camera-to-Image transformation matrix (3x4) + Tr_velo_to_cam (np.ndarray): LiDAR-to-Camera transformation matrix (3x4) + + Returns: + Cuboid2D: Projection points for the given bounding box + """ + Tr_velo_to_cam_homo = np.eye(4) + Tr_velo_to_cam_homo[:3, :4] = Tr_velo_to_cam + denorm = cls._get_denorm(Tr_velo_to_cam_homo) + pts_3d = cls._get_3d_points(dim, location, rotation_y, denorm) + y_3d = np.mean(pts_3d[:4, 1]) + pts_2d = cls._project_to_2d(pts_3d, P) + + return cls(list(map(tuple, pts_2d)), y_3d=y_3d) + + def to_3d(self, P_inv: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]: + """Reconstructs 3D object Velodyne coordinates + (dimensions, location and rotation along the Y-axis) + from the given Cuboid2D instance. + + Args: + P_inv (np.ndarray): Pseudo-inverse of Camera-to-Image projection matrix + Returns: + tuple: dimensions, location and rotation along the Y-axis + """ + recon_3d = [] + for idx, coord_2d in enumerate(self.points): + coord_2d = np.append(coord_2d, 1) + coord_3d = P_inv @ coord_2d + if idx < 4: + coord_3d = coord_3d * self.y_3d / coord_3d[1] + else: + coord_3d = coord_3d * recon_3d[idx - 4][0] / coord_3d[0] + recon_3d.append(coord_3d[:3]) + recon_3d = np.array(recon_3d) + + x = np.mean(recon_3d[:, 0]) + z = np.mean(recon_3d[:, 2]) + + yaws = [] + pairs = [(0, 1), (3, 2), (4, 5), (7, 6)] + for p in pairs: + delta_x = recon_3d[p[0]][0] - recon_3d[p[1]][0] + delta_z = recon_3d[p[0]][2] - recon_3d[p[1]][2] + yaws.append(np.arctan2(delta_x, delta_z)) + yaw = np.mean(yaws) + + widths = [] + pairs = [(0, 1), (2, 3), (4, 5), (6, 7)] + for p in pairs: + delta_x = np.sqrt( + (recon_3d[p[0]][0] - recon_3d[p[1]][0]) ** 2 + + (recon_3d[p[0]][2] - recon_3d[p[1]][2]) ** 2 + ) + widths.append(delta_x) + w = np.mean(widths) + + lengths = [] + pairs = [(1, 2), (0, 3), (5, 6), (4, 7)] + for p in pairs: + delta_z = np.sqrt( + (recon_3d[p[0]][0] - recon_3d[p[1]][0]) ** 2 + + (recon_3d[p[0]][2] - recon_3d[p[1]][2]) ** 2 + ) + lengths.append(delta_z) + l = np.mean(lengths) + + heights = [] + pairs = [(0, 4), (1, 5), (2, 6), (3, 7)] + for p in pairs: + delta_y = np.abs(recon_3d[p[0]][1] - recon_3d[p[1]][1]) + heights.append(delta_y) + h = np.mean(heights) + return np.array([h, w, l]), np.array([x, self.y_3d, z]), yaw + + +@attrs(slots=True, order=False) +class PointsCategories(Categories): + """ + Describes (key-)point metainfo such as point names and joints. + """ + + @attrs(slots=True, order=False) + class Category: + # Names for specific points, e.g. eye, hose, mouth etc. + # These labels are not required to be in LabelCategories + labels: List[str] = field(factory=list, validator=default_if_none(list)) + + # Pairs of connected point indices + joints: Set[Tuple[int, int]] = field(factory=set, validator=default_if_none(set)) + + items: Dict[int, Category] = field(factory=dict, validator=default_if_none(dict)) + + @classmethod + def from_iterable( + cls, + iterable: Iterable[ + Union[ + Tuple[int, List[str]], + Tuple[int, List[str], Set[Tuple[int, int]]], + ], + ], + ) -> PointsCategories: + """ + Create PointsCategories from an iterable. + + Args: + iterable: An Iterable with the following elements: + + - a label id + - a list of positional arguments for Categories + + Returns: + PointsCategories: PointsCategories object + """ + temp_categories = cls() + + for args in iterable: + temp_categories.add(*args) + return temp_categories + + def add( + self, + label_id: int, + labels: Optional[Iterable[str]] = None, + joints: Iterable[Tuple[int, int]] = None, + ): + if joints is None: + joints = [] + joints = set(map(tuple, joints)) + self.items[label_id] = self.Category(labels, joints) def __contains__(self, idx: int) -> bool: return idx in self.items @@ -762,22 +1685,57 @@ def __len__(self) -> int: @attrs(slots=True, order=False) -class Points(_Shape): +class Points(Shape): """ Represents an ordered set of points. + + Attributes: + _type (AnnotationType): The type of annotation, set to `AnnotationType.points`. + visibility (List[IntEnum]): A list indicating the visibility status of each point. + + Nested Class: + Visibility (IntEnum): Enum representing the visibility state of points. It has three states: + - absent: Point is absent (0). + - hidden: Point is hidden (1). + - visible: Point is visible (2). + + Methods: + __attrs_post_init__: Validates that the number of points is even. + get_area: Returns the area covered by the points, always zero. + get_bbox: Returns the bounding box containing all visible or hidden points. """ - class Visibility(Enum): + class Visibility(IntEnum): + """ + Enum representing the visibility state of points. + + Attributes: + absent (int): Point is absent (0). + hidden (int): Point is hidden (1). + visible (int): Point is visible (2). + """ + absent = 0 hidden = 1 visible = 2 _type = AnnotationType.points - visibility: List[bool] = field(default=None) + visibility: List[IntEnum] = field(default=None) @visibility.validator def _visibility_validator(self, attribute, visibility): + """ + Validates and initializes the visibility list. + + Args: + attribute: The attribute being validated. + visibility (List[IntEnum]): A list indicating the visibility status of each point. + + Raises: + AssertionError: If the length of the visibility list + does not match half the length of the points list. + """ if visibility is None: visibility = [self.Visibility.visible] * (len(self.points) // 2) else: @@ -788,12 +1746,30 @@ def _visibility_validator(self, attribute, visibility): self.visibility = visibility def __attrs_post_init__(self): + """ + Validates that the number of points is even after initialization. + + Raises: + AssertionError: If the number of points is not even. + """ assert len(self.points) % 2 == 0, self.points def get_area(self): + """ + Returns the area covered by the points. + + Returns: + int: Always returns 0. + """ return 0 def get_bbox(self): + """ + Returns the bounding box containing all visible or hidden points. + + Returns: + List[float]: The bounding box as [x0, y0, width, height]. + """ xs = [ p for p, v in zip(self.points[0::2], self.visibility) @@ -821,12 +1797,19 @@ class Caption(Annotation): caption: str = field(converter=str) -@attrs(slots=True, order=False) +@attrs(slots=True, order=False, eq=False) class _ImageAnnotation(Annotation): image: Image = field() + def __eq__(self, other): + if not super().__eq__(other): + return False + if not isinstance(other, __class__): + return False + return np.array_equal(self.image, other.image) -@attrs(slots=True, order=False) + +@attrs(slots=True, order=False, eq=False) class SuperResolutionAnnotation(_ImageAnnotation): """ Represents high resolution images. @@ -835,7 +1818,7 @@ class SuperResolutionAnnotation(_ImageAnnotation): _type = AnnotationType.super_resolution_annotation -@attrs(slots=True, order=False) +@attrs(slots=True, order=False, eq=False) class DepthAnnotation(_ImageAnnotation): """ Represents depth images. @@ -844,6 +1827,309 @@ class DepthAnnotation(_ImageAnnotation): _type = AnnotationType.depth_annotation +@attrs(slots=True, init=False, order=False) +class Ellipse(Shape): + """ + Ellipse represents an ellipse that is encapsulated by a rectangle. + + - x1 and y1 represent the top-left coordinate of the encapsulating rectangle + - x2 and y2 representing the bottom-right coordinate of the encapsulating rectangle + + Parameters + ---------- + + x1: float + left x coordinate of encapsulating rectangle + y1: float + top y coordinate of encapsulating rectangle + x2: float + right x coordinate of encapsulating rectangle + y2: float + bottom y coordinate of encapsulating rectangle + """ + + _type = AnnotationType.ellipse + + def __init__(self, x1: float, y1: float, x2: float, y2: float, *args, **kwargs): + kwargs.pop("points", None) # comes from wrap() + self.__attrs_init__([x1, y1, x2, y2], *args, **kwargs) + + @property + def x1(self): + return self.points[0] + + @property + def y1(self): + return self.points[1] + + @property + def x2(self): + return self.points[2] + + @property + def y2(self): + return self.points[3] + + @property + def w(self): + return self.points[2] - self.points[0] + + @property + def h(self): + return self.points[3] - self.points[1] + + @property + def c_x(self): + return 0.5 * (self.points[0] + self.points[2]) + + @property + def c_y(self): + return 0.5 * (self.points[1] + self.points[3]) + + def get_area(self): + return 0.25 * np.pi * self.w * self.h + + def get_bbox(self): + return [self.x1, self.y1, self.w, self.h] + + def get_points(self, num_points: int = 720) -> List[Tuple[float, float]]: + """ + Return points as a list of tuples, e.g. [(x0, y0), (x1, y1), ...]. + + Parameters + ---------- + num_points: int + The number of boundary points of the ellipse. + By default, one point is created for every 1 degree of interior angle (num_points=360). + """ + points = self.as_polygon(num_points) + + return [(x, y) for x, y in zip(points[0::2], points[1::2])] + + def as_polygon(self, num_points: int = 720) -> List[float]: + """ + Return a polygon as a list of tuples, e.g. [x0, y0, x1, y1, ...]. + + Parameters + ---------- + num_points: int + The number of boundary points of the ellipse. + By default, one point is created for every 1 degree of interior angle (num_points=360). + """ + theta = np.linspace(0, 2 * np.pi, num=num_points) + + l1 = 0.5 * self.w + l2 = 0.5 * self.h + x_points = self.c_x + l1 * np.cos(theta) + y_points = self.c_y + l2 * np.sin(theta) + + points = [] + for x, y in zip(x_points, y_points): + points += [x, y] + + return points + + def iou(self, other: Shape) -> Union[float, Literal[-1]]: + from datumaro.util.annotation_util import bbox_iou + + return bbox_iou(self.get_bbox(), other.get_bbox()) + + def wrap(item: Ellipse, **kwargs) -> Ellipse: + d = {"x1": item.x1, "y1": item.y1, "x2": item.x2, "y2": item.y2} + d.update(kwargs) + return attr.evolve(item, **d) + + +TableDtype = TypeVar("TableDtype", str, int, float) + + +@attrs(slots=True, order=False, eq=False) +class TabularCategories(Categories): + """ + Describes tabular data metainfo such as column names and types. + """ + + @attrs(slots=True, order=False, eq=False) + class Category: + name: str = field(converter=str, validator=not_empty) + dtype: Type[TableDtype] = field() + labels: Set[Union[str, int]] = field(factory=set, validator=default_if_none(set)) + + def __eq__(self, other): + same_name = self.name == other.name + same_dtype = self.dtype.__name__ == other.dtype.__name__ + same_labels = self.labels == other.labels + return same_name and same_dtype and same_labels + + def __repr__(self): + return f"name: {self.name}, dtype: {self.dtype.__name__}, labels: {self.labels}" + + items: List[Category] = field(factory=list, validator=default_if_none(list)) + _indices_by_name: Dict[str, int] = field(factory=dict, init=False, eq=False) + + @classmethod + def from_iterable( + cls, + iterable: Iterable[ + Union[Tuple[str, Type[TableDtype]], Tuple[str, Type[TableDtype], Set[str]]] + ], + ) -> TabularCategories: + """ + Creates a TabularCategories from iterable. + + Args: + iterable: a list of (Category name, type) or (Category name, type, set of labels) + + Returns: a TabularCategories object + """ + + temp_categories = cls() + + for category in iterable: + temp_categories.add(*category) + + return temp_categories + + def add( + self, + name: str, + dtype: Type[TableDtype], + labels: Optional[Set[str]] = None, + ) -> int: + """ + Add a Tabular Category. + + Args: + name (str): Column name + dtype (type): Type of the corresponding column. (str, int, or float) + labels (optional, set(str)): Label values where the column can have. + + Returns: + int: A index of added category. + """ + assert name + assert name not in self._indices_by_name + assert dtype + + index = len(self.items) + self.items.append(self.Category(name, dtype, labels)) + self._indices_by_name[name] = index + + return index + + def find(self, name: str) -> Tuple[Optional[int], Optional[Category]]: + """ + Find Category information for the given column name. + + Args: + name (str): Column name + + Returns: + tuple(int, Category): A index and Category information. + """ + index = self._indices_by_name.get(name) + return index, self.items[index] if index is not None else None + + def __getitem__(self, index: int) -> Category: + return self.items[index] + + def __contains__(self, name: str) -> bool: + return self.find(name)[1] is not None + + def __len__(self) -> int: + return len(self.items) + + def __iter__(self) -> Iterator[Category]: + return iter(self.items) + + def __eq__(self, other) -> bool: + if not super().__eq__(other): + return False + if not isinstance(other, __class__): + return False + return self.items == other.items + + +@attrs(slots=True, order=False) +class Tabular(Annotation): + """ + Represents values of target columns in a tabular dataset. + """ + + _type = AnnotationType.tabular + values: Dict[str, TableDtype] = field(converter=dict) + + +class Annotations(List[Annotation]): + """List of `Annotation` equipped with additional utility functions.""" + + def get_semantic_seg_mask( + self, ignore_index: int = 0, dtype: np.dtype = np.uint8 + ) -> np.ndarray: + """Extract semantic segmentation mask from a collection of Datumaro `Mask`s. + + Args: + ignore_index: Scalar value to fill in the zeros in each binary mask + before merging into a semantic segmentation mask. This value is usually used + to represent a pixel denoting a not-interested region. Defaults to 0. + dtype: Data type for the resulting mask. Defaults to np.uint8. + + Returns: + Semantic segmentation mask generated by merging Datumaro `Mask`s. + + Raises: + ValueError: If there are no mask annotations or + if there is an inconsistency in mask sizes. + """ + + masks = [ann for ann in self if isinstance(ann, Mask)] + # Mask with a lower z_order value will come first + masks.sort(key=lambda mask: mask.z_order) + + if not masks: + msg = "There is no mask annotations." + raise ValueError(msg) + + # Dispatching for better performance + # If all masks are `ExtractedMask`, share a same source `index_mask`, and + # there is no label remapping. + if ( + all(isinstance(mask, ExtractedMask) for mask in masks) + # and set(id(mask.index_mask) for mask in masks) == 1 + and all(mask.index_mask == next(iter(masks)).index_mask for mask in masks) + and all(mask.index == mask.label for mask in masks) + ): + index_mask = next(iter(masks)).index_mask + semantic_seg_mask: np.ndarray = index_mask() if callable(index_mask) else index_mask + if semantic_seg_mask.dtype != dtype: + semantic_seg_mask = semantic_seg_mask.astype(dtype) + + labels = np.unique(np.array([mask.label for mask in masks])) + ignore_index_mask = np.isin(semantic_seg_mask, labels, invert=True) + + return np.where(ignore_index_mask, ignore_index, semantic_seg_mask) + + class_masks = [mask.as_class_mask(ignore_index=ignore_index, dtype=dtype) for mask in masks] + + max_h = max([mask.shape[0] for mask in class_masks]) + max_w = max([mask.shape[1] for mask in class_masks]) + + semantic_seg_mask = np.full(shape=(max_h, max_w), fill_value=ignore_index, dtype=dtype) + + for class_mask in class_masks: + if class_mask.shape != semantic_seg_mask.shape: + msg = ( + f"There is inconsistency in mask size: " + f"{class_mask.shape}!={semantic_seg_mask.shape}." + ) + raise ValueError(msg, class_mask.shape, semantic_seg_mask.shape) + + ignore_index_mask = class_mask == ignore_index + semantic_seg_mask = np.where(ignore_index_mask, semantic_seg_mask, class_mask) + + return semantic_seg_mask + + @attrs(slots=True, order=False) class Skeleton(Annotation): """ diff --git a/src/datumaro/components/operations.py b/src/datumaro/components/operations.py index f972b315fa..c1b70aecda 100644 --- a/src/datumaro/components/operations.py +++ b/src/datumaro/components/operations.py @@ -724,6 +724,17 @@ def _for_type(t, **kwargs): elif t is AnnotationType.skeleton: # to do: add skeletons merge return _make(ImageAnnotationMerger, **kwargs) + # TODO: remove later + elif ( + t is AnnotationType.unknown + or t is AnnotationType.ellipse + or t is AnnotationType.hash_key + or t is AnnotationType.feature_vector + or t is AnnotationType.tabular + or t is AnnotationType.rotated_bbox + or t is AnnotationType.cuboid_2d + ): + return None else: raise NotImplementedError("Type %s is not supported" % t) diff --git a/src/datumaro/plugins/data_formats/datumaro/exporter.py b/src/datumaro/plugins/data_formats/datumaro/exporter.py index e18bb65f81..afa07fe28f 100644 --- a/src/datumaro/plugins/data_formats/datumaro/exporter.py +++ b/src/datumaro/plugins/data_formats/datumaro/exporter.py @@ -29,8 +29,8 @@ Polygon, PolyLine, RleMask, + Shape, Skeleton, - _Shape, ) from datumaro.components.dataset_base import CategoriesInfo, DatasetItem from datumaro.components.dataset_item_storage import ItemStatus @@ -215,7 +215,7 @@ def _convert_mask_object(self, obj): return converted def _convert_shape_object(self, obj): - assert isinstance(obj, _Shape) + assert isinstance(obj, Shape) converted = self._convert_annotation(obj) converted.update( diff --git a/src/datumaro/util/annotation_util.py b/src/datumaro/util/annotation_util.py index 64a1d28efc..fa32129ec2 100644 --- a/src/datumaro/util/annotation_util.py +++ b/src/datumaro/util/annotation_util.py @@ -15,8 +15,8 @@ LabelCategories, Mask, RleMask, + Shape, ) -from datumaro.components.annotation import _Shape as Shape from datumaro.util.mask_tools import mask_to_rle BboxCoords = Tuple[float, float, float, float] diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index 487012f474..9458b5ee90 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -246,6 +246,13 @@ def test_stats(self): "super_resolution_annotation": {"count": 0}, "depth_annotation": {"count": 0}, "skeleton": {"count": 0}, + "cuboid_2d": {"count": 0}, + "ellipse": {"count": 0}, + "feature_vector": {"count": 0}, + "hash_key": {"count": 0}, + "rotated_bbox": {"count": 0}, + "tabular": {"count": 0}, + "unknown": {"count": 0}, }, "annotations": { "labels": { @@ -302,6 +309,7 @@ def test_stats(self): } actual = compute_ann_statistics(dataset) + self.maxDiff = None self.assertEqual(expected, actual) @@ -346,6 +354,13 @@ def test_stats_with_empty_dataset(self): "super_resolution_annotation": {"count": 0}, "depth_annotation": {"count": 0}, "skeleton": {"count": 0}, + "cuboid_2d": {"count": 0}, + "ellipse": {"count": 0}, + "feature_vector": {"count": 0}, + "hash_key": {"count": 0}, + "rotated_bbox": {"count": 0}, + "tabular": {"count": 0}, + "unknown": {"count": 0}, }, "annotations": { "labels": { @@ -372,6 +387,7 @@ def test_stats_with_empty_dataset(self): } actual = compute_ann_statistics(dataset) + self.maxDiff = None self.assertEqual(expected, actual) From 7ae972d8dd482501aecd8cb3929f280ddb814d4c Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 17:00:54 +0400 Subject: [PATCH 02/25] sync components/shift_analyzer.py --- src/datumaro/components/shift_analyzer.py | 269 ++++++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 src/datumaro/components/shift_analyzer.py diff --git a/src/datumaro/components/shift_analyzer.py b/src/datumaro/components/shift_analyzer.py new file mode 100644 index 0000000000..332f5d7bec --- /dev/null +++ b/src/datumaro/components/shift_analyzer.py @@ -0,0 +1,269 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +# ruff: noqa: E501 + +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Optional + +import numpy as np + +from datumaro.components.annotation import FeatureVector +from datumaro.components.dataset import IDataset +from datumaro.components.launcher import LauncherWithModelInterpreter +from datumaro.util import take_by + +if TYPE_CHECKING: + import pyemd + from scipy import linalg, stats + + from datumaro.plugins.openvino_plugin import shift_launcher +else: + from datumaro.util.import_util import lazy_import + + pyemd = lazy_import("pyemd") + linalg = lazy_import("scipy.linalg") + stats = lazy_import("scipy.stats") + shift_launcher = lazy_import("datumaro.plugins.openvino_plugin.shift_launcher") + + +class RunningStats1D: + def __init__(self): + self.running_mean = None + self.running_sq_mean = None + self.num: int = 0 + + def add(self, feats: List[FeatureVector]) -> None: + arr = np.stack([feat.vector for feat in feats], axis=0) + assert arr.ndim == 2 + + batch_size, _ = arr.shape + mean = arr.mean(0) + arr = np.expand_dims(arr, axis=-1) # B x D x 1 + sq_mean = np.mean(np.matmul(arr, np.transpose(arr, axes=(0, 2, 1))), axis=0) # D x D + + self.num += batch_size + + if self.running_mean is not None: + self.running_mean = self.running_mean + batch_size / float(self.num) * ( + mean - self.running_mean + ) + else: + self.running_mean = mean + + if self.running_sq_mean is not None: + self.running_sq_mean = self.running_sq_mean + batch_size / float(self.num) * ( + sq_mean - self.running_sq_mean + ) + else: + self.running_sq_mean = sq_mean + + @property + def mean(self) -> np.ndarray: + return self.running_mean + + @property + def cov(self) -> np.ndarray: + mean = np.expand_dims(self.running_mean, axis=-1) # D x 1 + return self.running_sq_mean - np.matmul(mean, mean.transpose()) + + +class FeatureAccumulator: + def __init__(self, model: LauncherWithModelInterpreter): + self.model = model + self._batch_size = 1 + + def get_activation_stats(self, dataset: IDataset) -> RunningStats1D: + running_stats = RunningStats1D() + + for batch in take_by(dataset, self._batch_size): + outputs = self.model.launch(batch)[0] + features = [outputs[-1]] # extracted feature vector of googlenet-v4 + running_stats.add(features) + + return running_stats + + +class FeatureAccumulatorByLabel(FeatureAccumulator): + def __init__(self, model): + super().__init__(model) + + def get_activation_stats(self, dataset: IDataset) -> Dict[int, RunningStats1D]: + running_stats = defaultdict(RunningStats1D) + + for batch in take_by(dataset, self._batch_size): + inputs, targets = [], [] + for item in batch: + for ann in item.annotations: + inputs.append(np.atleast_3d(item.media.data)) + targets.append(ann.label) + + outputs = self.model.launch(batch)[0] + features = [outputs[-1]] # extracted feature vector of googlenet-v4 + + for target in targets: + running_stats[target].add(features) + + return running_stats + + +class ShiftAnalyzer: + def __init__(self) -> None: + """ + Searcher for Datumaro dataitems + + Parameters + ---------- + dataset: + Datumaro dataset to search similar dataitem. + topk: + Number of images. + """ + self._model = shift_launcher.ShiftLauncher( + model_name="googlenet-v4-tf", + output_layers="InceptionV4/Logits/PreLogitsFlatten/flatten_1/Reshape:0", + ) + + def compute_covariate_shift(self, sources: List[IDataset], method: Optional[str] = "fid"): + assert ( + len(sources) == 2 + ), "Shift analyzer should get two datasets to compute shifts between them." + + if method == "fid": + _feat_aggregator = FeatureAccumulator(model=self._model) + + src_stats = _feat_aggregator.get_activation_stats(sources[0]) + tgt_stats = _feat_aggregator.get_activation_stats(sources[1]) + + src_mu, src_sigma = src_stats.mean, src_stats.cov + tgt_mu, tgt_sigma = tgt_stats.mean, tgt_stats.cov + + return self._frechet_distance(src_mu, src_sigma, tgt_mu, tgt_sigma, atol=1e-3) + + elif method == "emd": + _feat_aggregator = FeatureAccumulatorByLabel(model=self._model) + + src_stats = _feat_aggregator.get_activation_stats(sources[0]) + tgt_stats = _feat_aggregator.get_activation_stats(sources[1]) + + w_s = np.array([stats.num for stats in src_stats.values()]) + w_t = np.array([stats.num for stats in tgt_stats.values()]) + + f_s = np.stack([stats.mean for stats in src_stats.values()], axis=0) + f_t = np.stack([stats.mean for stats in tgt_stats.values()], axis=0) + + # earth_mover_distance returns the similarity score in [0, 1]. + # We return the dissimilarity score by 1 - similarity score. + return 1.0 - self._earth_mover_distance(w_s, f_s, w_t, f_t, gamma=0.01) + + def compute_label_shift(self, sources: List[IDataset]): + assert ( + len(sources) == 2 + ), "Shift analyzer should get two datasets to compute shifts between them." + + labels = defaultdict(list) + for idx, source in enumerate(sources): + for item in source: + for ann in item.annotations: + labels[idx].append(ann.label) + + _, _, pv = stats.anderson_ksamp([labels[0], labels[1]]) + + return 1 - pv + + def _frechet_distance( + self, + mu1: np.ndarray, + sigma1: np.ndarray, + mu2: np.ndarray, + sigma2: np.ndarray, + eps: float = 1e-6, + atol: float = 1e-3, + ): + """ + Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + We borrowed the implementation of [1]_ (Apache 2.0 license). + Our implementation forces 64-bit floating-type calculations to avoid numerical instability. + Parameters + ---------- + mu1 + Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + mu2 + The sample mean over activations, precalculated on an representative data set. + sigma1 + The covariance matrix over activations for generated samples. + sigma2 + The covariance matrix over activations, precalculated on an representative data set. + eps + Epsilone term to the diagonal part of sigma covariance matrix. + atol + Threshold value to check whether the covariance matrix is real valued. + If any imagenary diagonal part of the covariance matrix is greather than `atol`, + raise `ValueError`. + Returns + ------- + Distance + Frechet distance + References + ---------- + .. [1] https://github.com/mseitzer/pytorch-fid/blob/3d604a25516746c3a4a5548c8610e99010b2c819/src/pytorch_fid/fid_score.py#L150 + """ + mu1 = np.atleast_1d(mu1).astype(np.float64) + mu2 = np.atleast_1d(mu2).astype(np.float64) + + sigma1 = np.atleast_2d(sigma1).astype(np.float64) + sigma2 = np.atleast_2d(sigma2).astype(np.float64) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths." + assert ( + sigma1.shape == sigma2.shape + ), "Training and test covariances have different dimensions." + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates." + ) % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=atol): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + def _earth_mover_distance( + self, + w_s: np.ndarray, + f_s: np.ndarray, + w_t: np.ndarray, + f_t: np.ndarray, + gamma: float, + ) -> float: + w_1 = np.zeros((len(w_s) + len(w_t),), np.float64) + w_2 = np.zeros((len(w_s) + len(w_t),), np.float64) + w_1[: len(w_s)] = w_s / np.sum(w_s) + w_2[len(w_s) :] = w_t / np.sum(w_t) + + f_concat = np.concatenate([f_s, f_t], axis=0) + distances = np.linalg.norm(f_concat[:, None] - f_concat[None, :], axis=2).astype(np.float64) + + emd = pyemd.emd(w_1, w_2, distances) + return np.exp(-gamma * emd).item() From 3b2bd5647b6b4fdbd235f80c39d0766c32c02683 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 17:26:01 +0400 Subject: [PATCH 03/25] sync components/dataset_base.py --- src/datumaro/components/dataset_base.py | 167 +++++++----------- .../plugins/data_formats/widerface.py | 2 +- 2 files changed, 62 insertions(+), 107 deletions(-) diff --git a/src/datumaro/components/dataset_base.py b/src/datumaro/components/dataset_base.py index b54c0c8a84..8ae2caa332 100644 --- a/src/datumaro/components/dataset_base.py +++ b/src/datumaro/components/dataset_base.py @@ -1,21 +1,19 @@ -# Copyright (C) 2019-2022 Intel Corporation +# Copyright (C) 2019-2023 Intel Corporation # Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from __future__ import annotations -import warnings from typing import Any, Dict, Iterator, List, Optional, Sequence, Type, TypeVar, Union, cast import attr -import numpy as np from attr import attrs, field -from datumaro.components.annotation import Annotation, AnnotationType, Categories +from datumaro.components.annotation import Annotation, Annotations, AnnotationType, Categories from datumaro.components.cli_plugin import CliPlugin from datumaro.components.contexts.importer import ImportContext, NullImportContext -from datumaro.components.media import Image, MediaElement, PointCloud +from datumaro.components.media import Image, MediaElement from datumaro.util.attrs_util import default_if_none, not_empty from datumaro.util.definitions import DEFAULT_SUBSET_NAME @@ -32,7 +30,7 @@ class DatasetItem: default=None, validator=attr.validators.optional(attr.validators.instance_of(MediaElement)) ) - annotations: List[Annotation] = field(factory=list, validator=default_if_none(list)) + annotations: Annotations = field(factory=Annotations, validator=default_if_none(Annotations)) attributes: Dict[str, Any] = field(factory=dict, validator=default_if_none(dict)) @@ -51,107 +49,13 @@ def __init__( media: Union[str, MediaElement, None] = None, annotations: Optional[List[Annotation]] = None, attributes: Dict[str, Any] = None, - image=None, - point_cloud=None, - related_images=None, ): - if image is not None: - warnings.warn( - "'image' is deprecated and will be " "removed in future. Use 'media' instead.", - DeprecationWarning, - stacklevel=2, - ) - if isinstance(image, str): - image = Image(path=image) - elif isinstance(image, np.ndarray) or callable(image): - image = Image(data=image) - assert isinstance(image, Image) - media = image - elif point_cloud is not None: - warnings.warn( - "'point_cloud' is deprecated and will be " - "removed in future. Use 'media' instead.", - DeprecationWarning, - stacklevel=2, - ) - if related_images is not None: - warnings.warn( - "'related_images' is deprecated and will be " - "removed in future. Use 'media' instead.", - DeprecationWarning, - stacklevel=2, - ) - if isinstance(point_cloud, str): - point_cloud = PointCloud(path=point_cloud, extra_images=related_images) - assert isinstance(point_cloud, PointCloud) - media = point_cloud - self.__attrs_init__( id=id, subset=subset, media=media, annotations=annotations, attributes=attributes ) - # Deprecated. Provided for backward compatibility. - @property - def image(self) -> Optional[Image]: - warnings.warn( - "'DatasetItem.image' is deprecated and will be " - "removed in future. Use '.media' and '.media_as()' instead.", - DeprecationWarning, - stacklevel=2, - ) - if not isinstance(self.media, Image): - return None - return self.media_as(Image) - - # Deprecated. Provided for backward compatibility. - @property - def point_cloud(self) -> Optional[str]: - warnings.warn( - "'DatasetItem.point_cloud' is deprecated and will be " - "removed in future. Use '.media' and '.media_as()' instead.", - DeprecationWarning, - stacklevel=2, - ) - if not isinstance(self.media, PointCloud): - return None - return self.media_as(PointCloud).path - - # Deprecated. Provided for backward compatibility. - @property - def related_images(self) -> List[Image]: - warnings.warn( - "'DatasetItem.related_images' is deprecated and will be " - "removed in future. Use '.media' and '.media_as()' instead.", - DeprecationWarning, - stacklevel=2, - ) - if not isinstance(self.media, PointCloud): - return [] - return self.media_as(PointCloud).extra_images - - # Deprecated. Provided for backward compatibility. - @property - def has_image(self): - warnings.warn( - "'DatasetItem.has_image' is deprecated and will be " - "removed in future. Use '.media' and '.media_as()' instead.", - DeprecationWarning, - stacklevel=2, - ) - return isinstance(self.media, Image) - - # Deprecated. Provided for backward compatibility. - @property - def has_point_cloud(self): - warnings.warn( - "'DatasetItem.has_point_cloud' is deprecated and will be " - "removed in future. Use '.media' and '.media_as()' instead.", - DeprecationWarning, - stacklevel=2, - ) - return isinstance(self.media, PointCloud) - +DatasetInfo = Dict[str, Any] CategoriesInfo = Dict[AnnotationType, Categories] @@ -177,6 +81,12 @@ def subsets(self) -> Dict[str, IDataset]: def get_subset(self, name) -> IDataset: raise NotImplementedError() + def infos(self) -> DatasetInfo: + """ + Returns meta-info of dataset. + """ + raise NotImplementedError() + def categories(self) -> CategoriesInfo: """ Returns metainfo about dataset labels. @@ -199,11 +109,26 @@ def media_type(self) -> Type[MediaElement]: """ raise NotImplementedError() + def ann_types(self) -> List[AnnotationType]: + """ + Returns available task type from dataset annotation types. + """ + raise NotImplementedError() + + @property + def is_stream(self) -> bool: + """Boolean indicating whether the dataset is a stream + + If the dataset is a stream, the dataset item is generated on demand from its iterator. + """ + return False + class _DatasetBase(IDataset): def __init__(self, *, length: Optional[int] = None, subsets: Optional[Sequence[str]] = None): self._length = length self._subsets = subsets + self._ann_types = set() def _init_cache(self): subsets = set() @@ -227,7 +152,7 @@ def subsets(self) -> Dict[str, IDataset]: self._init_cache() return {name or DEFAULT_SUBSET_NAME: self.get_subset(name) for name in self._subsets} - def get_subset(self, name): + def get_subset(self, name: str) -> IDataset: if self._subsets is None: self._init_cache() if name in self._subsets: @@ -250,18 +175,27 @@ class _DatasetFilter(_DatasetBase): def __iter__(_): return filter(pred, iter(self)) + def infos(_): + return self.infos() + def categories(_): return self.categories() def media_type(_): return self.media_type() + def ann_types(_): + return self.ann_types() + return _DatasetFilter() - def categories(self): + def infos(self) -> DatasetInfo: return {} - def get(self, id, subset=None): + def categories(self) -> CategoriesInfo: + return {} + + def get(self, id, subset=None) -> Optional[DatasetItem]: subset = subset or DEFAULT_SUBSET_NAME for item in self: if item.id == id and item.subset == subset: @@ -272,7 +206,7 @@ def get(self, id, subset=None): class DatasetBase(_DatasetBase, CliPlugin): """ A base class for user-defined and built-in extractors. - Should be used in cases, where SourceExtractor is not enough, + Should be used in cases, where SubsetBase is not enough, or its use makes problems with performance, implementation etc. """ @@ -282,16 +216,21 @@ def __init__( length: Optional[int] = None, subsets: Optional[Sequence[str]] = None, media_type: Type[MediaElement] = Image, + ann_types: Optional[List[AnnotationType]] = None, ctx: Optional[ImportContext] = None, ): super().__init__(length=length, subsets=subsets) self._ctx: ImportContext = ctx or NullImportContext() self._media_type = media_type + self._ann_types = ann_types if ann_types else set() def media_type(self): return self._media_type + def ann_types(self): + return self._ann_types + class SubsetBase(DatasetBase): """ @@ -305,14 +244,25 @@ def __init__( length: Optional[int] = None, subset: Optional[str] = None, media_type: Type[MediaElement] = Image, + ann_types: List[AnnotationType] = None, ctx: Optional[ImportContext] = None, ): self._subset = subset or DEFAULT_SUBSET_NAME - super().__init__(length=length, subsets=[self._subset], media_type=media_type, ctx=ctx) + super().__init__( + length=length, + subsets=[self._subset], + media_type=media_type, + ann_types=ann_types, + ctx=ctx, + ) + self._infos = {} self._categories = {} self._items = [] + def infos(self): + return self._infos + def categories(self): return self._categories @@ -325,3 +275,8 @@ def __len__(self): def get(self, id, subset=None): assert subset == self._subset, "%s != %s" % (subset, self._subset) return super().get(id, subset or self._subset) + + @property + def subset(self) -> str: + """Subset name of this instance.""" + return self._subset diff --git a/src/datumaro/plugins/data_formats/widerface.py b/src/datumaro/plugins/data_formats/widerface.py index 97cee58a06..89604a368b 100644 --- a/src/datumaro/plugins/data_formats/widerface.py +++ b/src/datumaro/plugins/data_formats/widerface.py @@ -147,7 +147,7 @@ def _load_items(self, path): attributes[attr] = bbox_list[i] i += 1 - annotations.append( + items[item_id].annotations.append( Bbox( float(bbox_list[0]), float(bbox_list[1]), From f0c3d7af7dbf7d6df724615ceeec44ee7c61b2c9 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 18:02:17 +0400 Subject: [PATCH 04/25] sync components/annotations (extracting from components/operations.py) --- .../components/annotations/__init__.py | 6 + .../components/annotations/matcher.py | 386 ++++++++++++++ src/datumaro/components/annotations/merger.py | 218 ++++++++ src/datumaro/components/operations.py | 469 +----------------- tests/unit/test_ops.py | 6 +- 5 files changed, 631 insertions(+), 454 deletions(-) create mode 100644 src/datumaro/components/annotations/__init__.py create mode 100644 src/datumaro/components/annotations/matcher.py create mode 100644 src/datumaro/components/annotations/merger.py diff --git a/src/datumaro/components/annotations/__init__.py b/src/datumaro/components/annotations/__init__.py new file mode 100644 index 0000000000..66933d3065 --- /dev/null +++ b/src/datumaro/components/annotations/__init__.py @@ -0,0 +1,6 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from .matcher import * +from .merger import * diff --git a/src/datumaro/components/annotations/matcher.py b/src/datumaro/components/annotations/matcher.py new file mode 100644 index 0000000000..eb7c874cc4 --- /dev/null +++ b/src/datumaro/components/annotations/matcher.py @@ -0,0 +1,386 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import List, Optional, Union + +import numpy as np +from attr import attrib, attrs + +from datumaro.components.abstracts import IMergerContext +from datumaro.components.abstracts.merger import IMatcherContext +from datumaro.components.annotation import Annotation, Points +from datumaro.util.annotation_util import ( + OKS, + approximate_line, + bbox_iou, + max_bbox, + mean_bbox, + segment_iou, +) + +__all__ = [ + "match_segments_pair", + "match_segments_more_than_pair", + "AnnotationMatcher", + "LabelMatcher", + "ShapeMatcher", + "BboxMatcher", + "PolygonMatcher", + "MaskMatcher", + "PointsMatcher", + "LineMatcher", + "CaptionsMatcher", + "Cuboid3dMatcher", + "ImageAnnotationMatcher", + "HashKeyMatcher", + "FeatureVectorMatcher", + "Cuboid2DMatcher", +] + + +def match_segments_pair( + a_segms, + b_segms, + distance=segment_iou, + dist_thresh=1.0, + label_matcher=lambda a, b: a.label == b.label, +): + """Match segments and return pairs of the two matched segments""" + + assert callable(distance), distance + assert callable(label_matcher), label_matcher + + a_segms.sort(key=lambda ann: 1 - ann.attributes.get("score", 1)) + b_segms.sort(key=lambda ann: 1 - ann.attributes.get("score", 1)) + + # a_matches: indices of b_segms matched to a bboxes + # b_matches: indices of a_segms matched to b bboxes + a_matches = -np.ones(len(a_segms), dtype=int) + b_matches = -np.ones(len(b_segms), dtype=int) + + distances = np.array([[distance(a, b) for b in b_segms] for a in a_segms]) + + # matches: boxes we succeeded to match completely + # mispred: boxes we succeeded to match, having label mismatch + matches = [] + mispred = [] + + # It needs len(a_segms) > 0 and len(b_segms) > 0 + if len(b_segms) > 0: + for a_idx, a_segm in enumerate(a_segms): + matched_b = -1 + max_dist = -1 + b_indices = np.argsort( + [not label_matcher(a_segm, b_segm) for b_segm in b_segms], kind="stable" + ) # prioritize those with same label, keep score order + for b_idx in b_indices: + if 0 <= b_matches[b_idx]: # assign a_segm with max conf + continue + d = distances[a_idx, b_idx] + if d < dist_thresh or d <= max_dist: + continue + max_dist = d + matched_b = b_idx + + if matched_b < 0: + continue + a_matches[a_idx] = matched_b + b_matches[matched_b] = a_idx + + b_segm = b_segms[matched_b] + + if label_matcher(a_segm, b_segm): + matches.append((a_segm, b_segm)) + else: + mispred.append((a_segm, b_segm)) + + # *_umatched: boxes of (*) we failed to match + a_unmatched = [a_segms[i] for i, m in enumerate(a_matches) if m < 0] + b_unmatched = [b_segms[i] for i, m in enumerate(b_matches) if m < 0] + + return matches, mispred, a_unmatched, b_unmatched + + +def match_segments_more_than_pair( + a_segms, + b_segms, + distance=segment_iou, + dist_thresh=1.0, + label_matcher=lambda a, b: a.label == b.label, +): + """Match segments and return sets of the matched segments which can be more than two""" + + assert callable(distance), distance + assert callable(label_matcher), label_matcher + + # a_matches: indices of b_segms matched to a bboxes + # b_matches: indices of a_segms matched to b bboxes + a_matches = -np.ones(len(a_segms), dtype=int) + b_matches = -np.ones(len(b_segms), dtype=int) + + distances = np.array([[distance(a, b) for b in b_segms] for a in a_segms]) + + # matches: boxes we succeeded to match completely + # mispred: boxes we succeeded to match, having label mismatch + matches = [] + mispred = [] + + # It needs len(a_segms) > 0 and len(b_segms) > 0 + if len(b_segms) > 0: + for a_idx, a_segm in enumerate(a_segms): + b_indices = np.argsort( + [not label_matcher(a_segm, b_segm) for b_segm in b_segms], kind="stable" + ) # prioritize those with same label, keep score order + for b_idx in b_indices: + d = distances[a_idx, b_idx] + if d < dist_thresh: + continue + + a_matches[a_idx] = b_idx + b_matches[b_idx] = a_idx + + b_segm = b_segms[b_idx] + + if label_matcher(a_segm, b_segm): + matches.append((a_segm, b_segm)) + else: + mispred.append((a_segm, b_segm)) + + # *_umatched: boxes of (*) we failed to match + a_unmatched = [a_segms[i] for i, m in enumerate(a_matches) if m < 0] + b_unmatched = [b_segms[i] for i, m in enumerate(b_matches) if m < 0] + + return matches, mispred, a_unmatched, b_unmatched + + +@attrs(kw_only=True) +class AnnotationMatcher: + _context: Optional[Union[IMatcherContext, IMergerContext]] = attrib(default=None) + + def match_annotations(self, sources): + raise NotImplementedError() + + +@attrs +class LabelMatcher(AnnotationMatcher): + def distance(self, a, b): + a_label = self._context.get_any_label_name(a, a.label) + b_label = self._context.get_any_label_name(b, b.label) + return a_label == b_label + + def match_annotations(self, sources): + return [sum(sources, [])] + + +@attrs(kw_only=True) +class ShapeMatcher(AnnotationMatcher): + pairwise_dist = attrib(converter=float, default=0.9) + cluster_dist = attrib(converter=float, default=-1.0) + _match_segments = attrib(default=match_segments_pair) + + def match_annotations(self, sources: List[List[Annotation]]) -> List[List[Annotation]]: + distance = self.distance + label_matcher = self.label_matcher + pairwise_dist = self.pairwise_dist + cluster_dist = self.cluster_dist + + if cluster_dist < 0: + cluster_dist = pairwise_dist + + id_segm = {id(a): (a, id(s)) for s in sources for a in s} + + def _is_close_enough(cluster, extra_id): + # check if whole cluster IoU will not be broken + # when this segment is added + b = id_segm[extra_id][0] + for a_id in cluster: + a = id_segm[a_id][0] + if distance(a, b) < cluster_dist: + return False + return True + + def _has_same_source(cluster, extra_id): + b = id_segm[extra_id][1] + for a_id in cluster: + a = id_segm[a_id][1] + if a == b: + return True + return False + + # match segments in sources, pairwise + adjacent = {i: [] for i in id_segm} # id(sgm) -> [id(adj_sgm1), ...] + for a_idx, src_a in enumerate(sources): + for src_b in sources[a_idx + 1 :]: + matches, _, _, _ = self._match_segments( + src_a, + src_b, + dist_thresh=pairwise_dist, + distance=distance, + label_matcher=label_matcher, + ) + for a, b in matches: + adjacent[id(a)].append(id(b)) + + # join all segments into matching clusters + clusters = [] + visited = set() + for cluster_idx in adjacent: + if cluster_idx in visited: + continue + + cluster = set() + to_visit = {cluster_idx} + while to_visit: + c = to_visit.pop() + cluster.add(c) + visited.add(c) + + for i in adjacent[c]: + if i in visited: + continue + if 0 < cluster_dist and not _is_close_enough(cluster, i): + continue + if _has_same_source(cluster, i): + continue + + to_visit.add(i) + + clusters.append([id_segm[i][0] for i in cluster]) + + return clusters + + def distance(self, a, b): + return segment_iou(a, b) + + def label_matcher(self, a, b): + a_label = self._context.get_any_label_name(a, a.label) + b_label = self._context.get_any_label_name(b, b.label) + return a_label == b_label + + +@attrs +class BboxMatcher(ShapeMatcher): + pass + + +@attrs +class PolygonMatcher(ShapeMatcher): + pass + + +@attrs +class MaskMatcher(ShapeMatcher): + pass + + +@attrs(kw_only=True) +class PointsMatcher(ShapeMatcher): + sigma: Optional[list] = attrib(default=None) + instance_map = attrib(converter=dict) + + def distance(self, a, b): + a_bbox = self.instance_map[id(a)][1] + b_bbox = self.instance_map[id(b)][1] + if bbox_iou(a_bbox, b_bbox) <= 0: + return 0 + bbox = mean_bbox([a_bbox, b_bbox]) + return OKS(a, b, sigma=self.sigma, bbox=bbox) + + +@attrs +class LineMatcher(ShapeMatcher): + def distance(self, a, b): + # Compute inter-line area by using the Trapezoid formulae + # https://en.wikipedia.org/wiki/Trapezoidal_rule + # Normalize by common bbox and get the bbox fill ratio + # Call this ratio the "distance" + + # The box area is an early-exit filter for non-intersected figures + bbox = max_bbox([a, b]) + box_area = bbox[2] * bbox[3] + if not box_area: + return 1 + + def _approx(line, segments): + if len(line) // 2 != segments + 1: + line = approximate_line(line, segments=segments) + return np.reshape(line, (-1, 2)) + + segments = max(len(a.points) // 2, len(b.points) // 2, 5) - 1 + + a = _approx(a.points, segments) + b = _approx(b.points, segments) + dists = np.linalg.norm(a - b, axis=1) + dists = dists[:-1] + dists[1:] + a_steps = np.linalg.norm(a[1:] - a[:-1], axis=1) + b_steps = np.linalg.norm(b[1:] - b[:-1], axis=1) + + # For the common bbox we can't use + # - the AABB (axis-alinged bbox) of a point set + # - the exterior of a point set + # - the convex hull of a point set + # because these soultions won't be correctly normalized. + # The lines can have multiple self-intersections, which can give + # the inter-line area more than internal area of the options above, + # producing the value of the distance outside of the [0; 1] range. + # + # Instead, we can compute the upper boundary for the inter-line + # area based on the maximum point distance and line length. + max_area = np.max(dists) * max(np.sum(a_steps), np.sum(b_steps)) + + area = np.dot(dists, a_steps + b_steps) * 0.5 * 0.5 / max(max_area, 1.0) + + return abs(1 - area) + + +@attrs +class CaptionsMatcher(AnnotationMatcher): + def match_annotations(self, sources): + raise NotImplementedError() + + +@attrs +class Cuboid3dMatcher(ShapeMatcher): + def distance(self, a, b): + raise NotImplementedError() + + +@attrs +class ImageAnnotationMatcher(AnnotationMatcher): + def match_annotations(self, sources): + raise NotImplementedError() + + +@attrs +class HashKeyMatcher(AnnotationMatcher): + def match_annotations(self, sources): + raise NotImplementedError() + + +@attrs +class FeatureVectorMatcher(AnnotationMatcher): + def match_annotations(self, sources): + raise NotImplementedError() + + +@attrs +class TabularMatcher(AnnotationMatcher): + def match_annotations(self, sources): + raise NotImplementedError() + + +@attrs +class RotatedBboxMatcher(ShapeMatcher): + sigma: Optional[list] = attrib(default=None) + + def distance(self, a, b): + a = Points([p for pt in a.as_polygon() for p in pt]) + b = Points([p for pt in b.as_polygon() for p in pt]) + + return OKS(a, b, sigma=self.sigma) + + +@attrs +class Cuboid2DMatcher(ShapeMatcher): + pass diff --git a/src/datumaro/components/annotations/merger.py b/src/datumaro/components/annotations/merger.py new file mode 100644 index 0000000000..8ff7593a61 --- /dev/null +++ b/src/datumaro/components/annotations/merger.py @@ -0,0 +1,218 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from attr import attrib, attrs + +from datumaro.components.annotation import Bbox, Label +from datumaro.components.errors import FailedLabelVotingError +from datumaro.util.annotation_util import mean_bbox, segment_iou + +from .matcher import ( + AnnotationMatcher, + BboxMatcher, + CaptionsMatcher, + Cuboid2DMatcher, + Cuboid3dMatcher, + FeatureVectorMatcher, + HashKeyMatcher, + ImageAnnotationMatcher, + LabelMatcher, + LineMatcher, + MaskMatcher, + PointsMatcher, + PolygonMatcher, + RotatedBboxMatcher, + ShapeMatcher, + TabularMatcher, +) + +__all__ = [ + "AnnotationMerger", + "LabelMerger", + "BboxMerger", + "RotatedBboxMerger", + "PolygonMerger", + "MaskMerger", + "PointsMerger", + "LineMerger", + "CaptionsMerger", + "Cuboid3dMerger", + "ImageAnnotationMerger", + "EllipseMerger", + "HashKeyMerger", + "FeatureVectorMerger", +] + + +@attrs(kw_only=True) +class AnnotationMerger(AnnotationMatcher): + def merge_clusters(self, clusters): + raise NotImplementedError() + + +@attrs(kw_only=True) +class LabelMerger(AnnotationMerger, LabelMatcher): + quorum = attrib(converter=int, default=0) + + def merge_clusters(self, clusters): + assert len(clusters) <= 1 + if len(clusters) == 0: + return [] + + votes = {} # label -> score + for ann in clusters[0]: + label = self._context._get_src_label_name(ann, ann.label) + votes[label] = 1 + votes.get(label, 0) + + merged = [] + for label, count in votes.items(): + if count < self.quorum: + sources = set( + self.get_ann_source(id(a)) + for a in clusters[0] + if label not in [self._context._get_src_label_name(l, l.label) for l in a] + ) + sources = [self._context._dataset_map[s][1] for s in sources] + self._context.add_item_error(FailedLabelVotingError, votes, sources=sources) + continue + + merged.append( + Label( + self._context._get_label_id(label), + attributes={"score": count / len(self._context._dataset_map)}, + ) + ) + + return merged + + +@attrs(kw_only=True) +class _ShapeMerger(AnnotationMerger, ShapeMatcher): + quorum = attrib(converter=int, default=0) + + def merge_clusters(self, clusters): + return list(map(self.merge_cluster, clusters)) + + def find_cluster_label(self, cluster): + votes = {} + for s in cluster: + label = self._context._get_src_label_name(s, s.label) + state = votes.setdefault(label, [0, 0]) + state[0] += s.attributes.get("score", 1.0) + state[1] += 1 + + label, (score, count) = max(votes.items(), key=lambda e: e[1][0]) + if count < self.quorum: + self._context.add_item_error(FailedLabelVotingError, votes) + label = None + score = score / len(self._context._dataset_map) + label = self._context._get_label_id(label) + return label, score + + @staticmethod + def _merge_cluster_shape_mean_box_nearest(cluster): + mbbox = Bbox(*mean_bbox(cluster)) + dist = (segment_iou(mbbox, s) for s in cluster) + nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1]) + return cluster[nearest_pos] + + def merge_cluster_shape(self, cluster): + shape = self._merge_cluster_shape_mean_box_nearest(cluster) + shape_score = sum(max(0, self.distance(shape, s)) for s in cluster) / len(cluster) + return shape, shape_score + + def merge_cluster(self, cluster): + label, label_score = self.find_cluster_label(cluster) + shape, shape_score = self.merge_cluster_shape(cluster) + + shape.z_order = max(cluster, key=lambda a: a.z_order).z_order + shape.label = label + shape.attributes["score"] = label_score * shape_score if label is not None else shape_score + + return shape + + +@attrs +class BboxMerger(_ShapeMerger, BboxMatcher): + pass + + +@attrs +class PolygonMerger(_ShapeMerger, PolygonMatcher): + pass + + +@attrs +class MaskMerger(_ShapeMerger, MaskMatcher): + pass + + +@attrs +class PointsMerger(_ShapeMerger, PointsMatcher): + pass + + +@attrs +class LineMerger(_ShapeMerger, LineMatcher): + pass + + +@attrs +class CaptionsMerger(AnnotationMerger, CaptionsMatcher): + pass + + +@attrs +class Cuboid3dMerger(_ShapeMerger, Cuboid3dMatcher): + @staticmethod + def _merge_cluster_shape_mean_box_nearest(cluster): + raise NotImplementedError() + # mbbox = Bbox(*mean_cuboid(cluster)) + # dist = (segment_iou(mbbox, s) for s in cluster) + # nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1]) + # return cluster[nearest_pos] + + def merge_cluster(self, cluster): + label, label_score = self.find_cluster_label(cluster) + shape, shape_score = self.merge_cluster_shape(cluster) + + shape.label = label + shape.attributes["score"] = label_score * shape_score if label is not None else shape_score + + return shape + + +@attrs +class ImageAnnotationMerger(AnnotationMerger, ImageAnnotationMatcher): + pass + + +@attrs +class EllipseMerger(_ShapeMerger, ShapeMatcher): + pass + + +@attrs +class HashKeyMerger(AnnotationMerger, HashKeyMatcher): + pass + + +@attrs +class FeatureVectorMerger(AnnotationMerger, FeatureVectorMatcher): + pass + + +@attrs +class TabularMerger(AnnotationMerger, TabularMatcher): + pass + + +@attrs +class RotatedBboxMerger(_ShapeMerger, RotatedBboxMatcher): + pass + + +@attrs +class Cuboid2DMerger(_ShapeMerger, Cuboid2DMatcher): + pass diff --git a/src/datumaro/components/operations.py b/src/datumaro/components/operations.py index c1b70aecda..d48e9ca324 100644 --- a/src/datumaro/components/operations.py +++ b/src/datumaro/components/operations.py @@ -4,42 +4,37 @@ # SPDX-License-Identifier: MIT import hashlib -import itertools import logging as log from collections import OrderedDict from copy import deepcopy -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from unittest import TestCase import attr import cv2 import numpy as np from attr import attrib, attrs -from scipy.optimize import linear_sum_assignment from datumaro.components.annotation import ( Annotation, AnnotationType, - Bbox, - Label, LabelCategories, MaskCategories, Points, PointsCategories, ) +from datumaro.components.annotations import LineMatcher, PointsMatcher, match_segments_pair +from datumaro.components.annotations.merger import ( + BboxMerger, + CaptionsMerger, + Cuboid3dMerger, + ImageAnnotationMerger, + LabelMerger, + LineMerger, + MaskMerger, + PointsMerger, + PolygonMerger, +) from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset import Dataset, IDataset from datumaro.components.dataset_base import CategoriesInfo, DatasetItem @@ -49,7 +44,6 @@ ConflictingCategoriesError, DatasetMergeError, FailedAttrVotingError, - FailedLabelVotingError, MediaTypeError, MismatchingAttributesError, MismatchingImageInfoError, @@ -62,15 +56,7 @@ ) from datumaro.components.media import Image, MediaElement, MultiframeImage, PointCloud, Video from datumaro.util import filter_dict, find -from datumaro.util.annotation_util import ( - OKS, - approximate_line, - bbox_iou, - find_instances, - max_bbox, - mean_bbox, - segment_iou, -) +from datumaro.util.annotation_util import find_instances, max_bbox from datumaro.util.attrs_util import default_if_none, ensure_cls @@ -909,7 +895,7 @@ def _get_src_label_name(self, ann, label_id): self._dataset_map[dataset_id][0].categories()[AnnotationType.label].items[label_id].name ) - def _get_any_label_name(self, ann, label_id): + def get_any_label_name(self, ann, label_id): if label_id is None: return None try: @@ -929,425 +915,6 @@ def _check_groups_definition(self): ) -@attrs(kw_only=True) -class AnnotationMatcher: - _context: Optional[IntersectMerge] = attrib(default=None) - - def match_annotations(self, sources): - raise NotImplementedError() - - -@attrs -class LabelMatcher(AnnotationMatcher): - def distance(self, a, b): - a_label = self._context._get_any_label_name(a, a.label) - b_label = self._context._get_any_label_name(b, b.label) - return a_label == b_label - - def match_annotations(self, sources): - return [sum(sources, [])] - - -@attrs(kw_only=True) -class _ShapeMatcher(AnnotationMatcher): - pairwise_dist = attrib(converter=float, default=0.9) - cluster_dist = attrib(converter=float, default=-1.0) - - def match_annotations(self, sources): - distance = self.distance - label_matcher = self.label_matcher - pairwise_dist = self.pairwise_dist - cluster_dist = self.cluster_dist - - if cluster_dist < 0: - cluster_dist = pairwise_dist - - id_segm = {id(a): (a, id(s)) for s in sources for a in s} - - def _is_close_enough(cluster, extra_id): - # check if whole cluster IoU will not be broken - # when this segment is added - b = id_segm[extra_id][0] - for a_id in cluster: - a = id_segm[a_id][0] - if distance(a, b) < cluster_dist: - return False - return True - - def _has_same_source(cluster, extra_id): - b = id_segm[extra_id][1] - for a_id in cluster: - a = id_segm[a_id][1] - if a == b: - return True - return False - - # match segments in sources, pairwise - adjacent = {i: [] for i in id_segm} # id(sgm) -> [id(adj_sgm1), ...] - for a_idx, src_a in enumerate(sources): - for src_b in sources[a_idx + 1 :]: - matches, _, _, _ = match_segments( - src_a, - src_b, - dist_thresh=pairwise_dist, - distance=distance, - label_matcher=label_matcher, - ) - for a, b in matches: - adjacent[id(a)].append(id(b)) - - # join all segments into matching clusters - clusters = [] - visited = set() - for cluster_idx in adjacent: - if cluster_idx in visited: - continue - - cluster = set() - to_visit = {cluster_idx} - while to_visit: - c = to_visit.pop() - cluster.add(c) - visited.add(c) - - for i in adjacent[c]: - if i in visited: - continue - if 0 < cluster_dist and not _is_close_enough(cluster, i): - continue - if _has_same_source(cluster, i): - continue - - to_visit.add(i) - - clusters.append([id_segm[i][0] for i in cluster]) - - return clusters - - def distance(self, a, b): - return segment_iou(a, b) - - def label_matcher(self, a, b): - a_label = self._context._get_any_label_name(a, a.label) - b_label = self._context._get_any_label_name(b, b.label) - return a_label == b_label - - -@attrs -class BboxMatcher(_ShapeMatcher): - pass - - -@attrs -class PolygonMatcher(_ShapeMatcher): - pass - - -@attrs -class MaskMatcher(_ShapeMatcher): - pass - - -@attrs(kw_only=True) -class PointsMatcher(_ShapeMatcher): - sigma: Optional[list] = attrib(default=None) - instance_map = attrib(converter=dict) - - def distance(self, a, b): - a_bbox = self.instance_map[id(a)][1] - b_bbox = self.instance_map[id(b)][1] - if bbox_iou(a_bbox, b_bbox) <= 0: - return 0 - bbox = mean_bbox([a_bbox, b_bbox]) - return OKS(a, b, sigma=self.sigma, bbox=bbox) - - -@attrs -class LineMatcher(_ShapeMatcher): - def distance(self, a, b): - # Compute inter-line area by using the Trapezoid formulae - # https://en.wikipedia.org/wiki/Trapezoidal_rule - # Normalize by common bbox and get the bbox fill ratio - # Call this ratio the "distance" - - # The box area is an early-exit filter for non-intersected figures - bbox = max_bbox([a, b]) - box_area = bbox[2] * bbox[3] - if not box_area: - return 1 - - def _approx(line, segments): - if len(line) // 2 != segments + 1: - line = approximate_line(line, segments=segments) - return np.reshape(line, (-1, 2)) - - segments = max(len(a.points) // 2, len(b.points) // 2, 5) - 1 - - a = _approx(a.points, segments) - b = _approx(b.points, segments) - dists = np.linalg.norm(a - b, axis=1) - dists = dists[:-1] + dists[1:] - a_steps = np.linalg.norm(a[1:] - a[:-1], axis=1) - b_steps = np.linalg.norm(b[1:] - b[:-1], axis=1) - - # For the common bbox we can't use - # - the AABB (axis-alinged bbox) of a point set - # - the exterior of a point set - # - the convex hull of a point set - # because these soultions won't be correctly normalized. - # The lines can have multiple self-intersections, which can give - # the inter-line area more than internal area of the options above, - # producing the value of the distance outside of the [0; 1] range. - # - # Instead, we can compute the upper boundary for the inter-line - # area based on the maximum point distance and line length. - max_area = np.max(dists) * max(np.sum(a_steps), np.sum(b_steps)) - - area = np.dot(dists, a_steps + b_steps) * 0.5 * 0.5 / max(max_area, 1.0) - - return abs(1 - area) - - -@attrs -class CaptionsMatcher(AnnotationMatcher): - def match_annotations(self, sources): - raise NotImplementedError() - - -@attrs -class Cuboid3dMatcher(_ShapeMatcher): - def distance(self, a, b): - raise NotImplementedError() - - -@attrs -class ImageAnnotationMatcher(AnnotationMatcher): - def match_annotations(self, sources): - raise NotImplementedError() - - -@attrs(kw_only=True) -class AnnotationMerger: - def merge_clusters(self, clusters): - raise NotImplementedError() - - -@attrs(kw_only=True) -class LabelMerger(AnnotationMerger, LabelMatcher): - quorum = attrib(converter=int, default=0) - - def merge_clusters(self, clusters): - assert len(clusters) <= 1 - if len(clusters) == 0: - return [] - - votes = {} # label -> score - for ann in clusters[0]: - label = self._context._get_src_label_name(ann, ann.label) - votes[label] = 1 + votes.get(label, 0) - - merged = [] - for label, count in votes.items(): - if count < self.quorum: - sources = set( - self.get_ann_source(id(a)) - for a in clusters[0] - if label not in [self._context._get_src_label_name(l, l.label) for l in a] - ) - sources = [self._context._dataset_map[s][1] for s in sources] - self._context.add_item_error(FailedLabelVotingError, votes, sources=sources) - continue - - merged.append( - Label( - self._context._get_label_id(label), - attributes={"score": count / len(self._context._dataset_map)}, - ) - ) - - return merged - - -@attrs(kw_only=True) -class _ShapeMerger(AnnotationMerger, _ShapeMatcher): - quorum = attrib(converter=int, default=0) - - def merge_clusters(self, clusters): - return list(map(self.merge_cluster, clusters)) - - def find_cluster_label(self, cluster): - votes = {} - for s in cluster: - label = self._context._get_src_label_name(s, s.label) - state = votes.setdefault(label, [0, 0]) - state[0] += s.attributes.get("score", 1.0) - state[1] += 1 - - label, (score, count) = max(votes.items(), key=lambda e: e[1][0]) - if count < self.quorum: - self._context.add_item_error(FailedLabelVotingError, votes) - label = None - score = score / len(self._context._dataset_map) - label = self._context._get_label_id(label) - return label, score - - @staticmethod - def _merge_cluster_shape_mean_box_nearest(cluster): - mbbox = Bbox(*mean_bbox(cluster)) - dist = (segment_iou(mbbox, s) for s in cluster) - nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1]) - return cluster[nearest_pos] - - def merge_cluster_shape(self, cluster): - shape = self._merge_cluster_shape_mean_box_nearest(cluster) - shape_score = sum(max(0, self.distance(shape, s)) for s in cluster) / len(cluster) - return shape, shape_score - - def merge_cluster(self, cluster): - label, label_score = self.find_cluster_label(cluster) - shape, shape_score = self.merge_cluster_shape(cluster) - - shape.z_order = max(cluster, key=lambda a: a.z_order).z_order - shape.label = label - shape.attributes["score"] = label_score * shape_score if label is not None else shape_score - - return shape - - -@attrs -class BboxMerger(_ShapeMerger, BboxMatcher): - pass - - -@attrs -class PolygonMerger(_ShapeMerger, PolygonMatcher): - pass - - -@attrs -class MaskMerger(_ShapeMerger, MaskMatcher): - pass - - -@attrs -class PointsMerger(_ShapeMerger, PointsMatcher): - pass - - -@attrs -class LineMerger(_ShapeMerger, LineMatcher): - pass - - -@attrs -class CaptionsMerger(AnnotationMerger, CaptionsMatcher): - pass - - -@attrs -class Cuboid3dMerger(_ShapeMerger, Cuboid3dMatcher): - @staticmethod - def _merge_cluster_shape_mean_box_nearest(cluster): - raise NotImplementedError() - # mbbox = Bbox(*mean_cuboid(cluster)) - # dist = (segment_iou(mbbox, s) for s in cluster) - # nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1]) - # return cluster[nearest_pos] - - def merge_cluster(self, cluster): - label, label_score = self.find_cluster_label(cluster) - shape, shape_score = self.merge_cluster_shape(cluster) - - shape.label = label - shape.attributes["score"] = label_score * shape_score if label is not None else shape_score - - return shape - - -@attrs -class ImageAnnotationMerger(AnnotationMerger, ImageAnnotationMatcher): - pass - - -_AT1 = TypeVar("_AT1") -_AT2 = TypeVar("_AT2") - - -def match_segments( - a_segms: Sequence[_AT1], - b_segms: Sequence[_AT2], - distance: Callable[[_AT1, _AT2], float] = segment_iou, - dist_thresh: float = 1.0, - label_matcher: Callable[[_AT1, _AT2], bool] = lambda a, b: a.label == b.label, -) -> Tuple[List[Tuple[_AT1, _AT2]], List[Tuple[_AT1, _AT2]], List[_AT1], List[_AT2]]: - """ - Finds the best matching annotations using the provided distance function. - If the annotations match by distance, but have different labels, - they are considered mismatching. - - Parameters: - - distance: func(a_ann, b_ann) -> float [0; 1] - a function that estimates annotation - similarity, with 0 meaning 'not similar' and 1 - 'exactly the same'. - - dist_thresh: a value in the range [0; 1], minimal distance between a pair of annotations - to be considered for matching - - - Returns (matching, mismatching, a_unmatched, b_unmatched), where: - - 'matching' and 'mismatching' - lists of (a_ann, b_ann) tuples - - 'a_unmatched' and 'b_unmatched' - lists of corresponding unmatched annotations - """ - - assert callable(distance), distance - assert callable(label_matcher), label_matcher - - max_anns = max(len(a_segms), len(b_segms)) - distances = np.array( - [ - [ - 1 - distance(a, b) if a is not None and b is not None else 1 - for b, _ in itertools.zip_longest(b_segms, range(max_anns), fillvalue=None) - ] - for a, _ in itertools.zip_longest(a_segms, range(max_anns), fillvalue=None) - ] - ) - distances[distances > 1 - dist_thresh] = 1 - - if a_segms and b_segms: - a_matches, b_matches = linear_sum_assignment(distances) - else: - a_matches = [] - b_matches = [] - - # matches: boxes we succeeded to match completely - # mispred: boxes we succeeded to match, having label mismatch - matches = [] - mismatches = [] - # *_umatched: boxes of (*) we failed to match - a_unmatched = [] - b_unmatched = [] - - for a_idx, b_idx in zip(a_matches, b_matches): - dist = distances[a_idx, b_idx] - if dist > 1 - dist_thresh or dist == 1: - if a_idx < len(a_segms): - a_unmatched.append(a_segms[a_idx]) - if b_idx < len(b_segms): - b_unmatched.append(b_segms[b_idx]) - else: - a_ann = a_segms[a_idx] - b_ann = b_segms[b_idx] - if label_matcher(a_ann, b_ann): - matches.append((a_ann, b_ann)) - else: - mismatches.append((a_ann, b_ann)) - - if not len(a_matches) and not len(b_matches): - a_unmatched = list(a_segms) - b_unmatched = list(b_segms) - - return matches, mismatches, a_unmatched, b_unmatched - - def mean_std(dataset: IDataset): counter = _MeanStdCounter() @@ -1695,7 +1262,7 @@ def match_labels(self, item_a, item_b): def _match_segments(self, t, item_a, item_b): a_boxes = self._get_ann_type(t, item_a) b_boxes = self._get_ann_type(t, item_b) - return match_segments(a_boxes, b_boxes, dist_thresh=self.iou_threshold) + return match_segments_pair(a_boxes, b_boxes, dist_thresh=self.iou_threshold) def match_polygons(self, item_a, item_b): return self._match_segments(AnnotationType.polygon, item_a, item_b) @@ -1719,7 +1286,7 @@ def match_points(self, item_a, item_b): instance_map[id(ann)] = [inst, inst_bbox] matcher = PointsMatcher(instance_map=instance_map) - return match_segments( + return match_segments_pair( a_points, b_points, dist_thresh=self.iou_threshold, distance=matcher.distance ) @@ -1729,7 +1296,7 @@ def match_lines(self, item_a, item_b): matcher = LineMatcher() - return match_segments( + return match_segments_pair( a_lines, b_lines, dist_thresh=self.iou_threshold, distance=matcher.distance ) diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index 9458b5ee90..b96764769a 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -16,6 +16,7 @@ Polygon, PolyLine, ) +from datumaro.components.annotations import match_segments_pair from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image, MultiframeImage, PointCloud @@ -28,7 +29,6 @@ compute_ann_statistics, compute_image_statistics, find_unique_images, - match_segments, mean_std, ) from datumaro.util.definitions import DEFAULT_SUBSET_NAME @@ -438,7 +438,7 @@ def test_can_match_shape_first_and_label_later(self): Bbox(0, 0, 4, 4, label=1, id=1), ] - matches, mismatches, a_extra, b_extra = match_segments(anns1, anns2, dist_thresh=0.5) + matches, mismatches, a_extra, b_extra = match_segments_pair(anns1, anns2, dist_thresh=0.5) assert sorted(mismatches, key=lambda e: e[0].id) == [ (anns1[0], anns2[1]), (anns1[1], anns2[0]), @@ -469,7 +469,7 @@ def test_can_match(self): Bbox(0, 6, 4, 4, label=1, id=5), ] - matches, mismatches, a_extra, b_extra = match_segments(anns1, anns2, dist_thresh=0.5) + matches, mismatches, a_extra, b_extra = match_segments_pair(anns1, anns2, dist_thresh=0.5) assert sorted(mismatches, key=lambda e: e[0].id) == [ (anns1[0], anns2[1]), (anns1[1], anns2[0]), From 481767c89dd00bf7fa4c4823102918620694d521 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 19:04:09 +0400 Subject: [PATCH 05/25] sync plugins/openvino_plugin --- .../plugins/openvino_plugin/launcher.py | 280 ++++++++++++------ .../openvino_plugin/samples/__init__.py | 0 .../samples/clip_text_ViT-B_32_interp.py | 30 ++ .../clip_text_vit_l_14_336px_int8_interp.py | 30 ++ .../samples/clip_visual_ViT-B_32_interp.py | 51 ++++ .../clip_visual_vit_l_14_336px_int8_interp.py | 52 ++++ .../samples/googlenet-v4-tf_interp.py | 59 ++++ ...ustom_object_detection_gen3_atss_interp.py | 43 +++ ...custom_object_detection_gen3_ssd_interp.py | 43 +++ ...tx_custom_object_detection_yolox_interp.py | 43 +++ .../plugins/openvino_plugin/samples/utils.py | 111 +++++++ .../plugins/openvino_plugin/shift_launcher.py | 37 +++ 12 files changed, 683 insertions(+), 96 deletions(-) create mode 100644 src/datumaro/plugins/openvino_plugin/samples/__init__.py create mode 100644 src/datumaro/plugins/openvino_plugin/samples/clip_text_ViT-B_32_interp.py create mode 100644 src/datumaro/plugins/openvino_plugin/samples/clip_text_vit_l_14_336px_int8_interp.py create mode 100644 src/datumaro/plugins/openvino_plugin/samples/clip_visual_ViT-B_32_interp.py create mode 100644 src/datumaro/plugins/openvino_plugin/samples/clip_visual_vit_l_14_336px_int8_interp.py create mode 100644 src/datumaro/plugins/openvino_plugin/samples/googlenet-v4-tf_interp.py create mode 100644 src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_gen3_atss_interp.py create mode 100644 src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_gen3_ssd_interp.py create mode 100644 src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_yolox_interp.py create mode 100644 src/datumaro/plugins/openvino_plugin/samples/utils.py create mode 100644 src/datumaro/plugins/openvino_plugin/shift_launcher.py diff --git a/src/datumaro/plugins/openvino_plugin/launcher.py b/src/datumaro/plugins/openvino_plugin/launcher.py index 583d40cf88..9802ab0ca6 100644 --- a/src/datumaro/plugins/openvino_plugin/launcher.py +++ b/src/datumaro/plugins/openvino_plugin/launcher.py @@ -1,19 +1,27 @@ -# Copyright (C) 2019-2021 Intel Corporation +# Copyright (C) 2019-2024 Intel Corporation # # SPDX-License-Identifier: MIT # pylint: disable=exec-used + import logging as log import os.path as osp import shutil +import urllib +from dataclasses import dataclass, fields +from typing import Dict, List, Optional -import cv2 import numpy as np -from openvino.inference_engine import IECore +from openvino.runtime import Core +from tqdm import tqdm +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.launcher import Launcher +from datumaro.components.launcher import LauncherWithModelInterpreter +from datumaro.errors import DatumaroError +from datumaro.util.definitions import get_datumaro_cache_dir +from datumaro.util.samples import get_samples_path class _OpenvinoImporter(CliPlugin): @@ -56,69 +64,163 @@ def copy_model(model_dir, model): model["interpreter"] = osp.basename(model["interpreter"]) -class InterpreterScript: - def __init__(self, path): - with open(path, "r", encoding="utf-8") as f: - script = f.read() +@dataclass +class OpenvinoModelInfo: + interpreter: Optional[str] + description: Optional[str] + weights: Optional[str] + model_dir: Optional[str] - context = {} - exec(script, context, context) + def validate(self): + """Validate integrity of the member variables""" - process_outputs = context.get("process_outputs") - if not callable(process_outputs): - raise Exception("Can't find 'process_outputs' function in " "the interpreter script") - self.__dict__["process_outputs"] = process_outputs + def _validate(key: str): + path = getattr(self, key) + if not osp.isfile(path): + path = osp.join(self.model_dir, path) + if not osp.isfile(path): + raise DatumaroError(f'Failed to open model {key} file "{path}"') + setattr(self, key, path) - get_categories = context.get("get_categories") - assert get_categories is None or callable(get_categories) - if get_categories: - self.__dict__["get_categories"] = get_categories + for field in fields(self): + if field.name != "model_dir": + _validate(field.name) - @staticmethod - def get_categories(): - return None - @staticmethod - def process_outputs(inputs, outputs): - raise NotImplementedError("Function should be implemented in the interpreter script") +@dataclass +class BuiltinOpenvinoModelInfo(OpenvinoModelInfo): + downloadable_models = { + "clip_text_ViT-B_32", + "clip_visual_ViT-B_32", + "clip_visual_vit_l_14_336px_int8", + "clip_text_vit_l_14_336px_int8", + "googlenet-v4-tf", + } + @classmethod + def create_from_model_name(cls, model_name: str) -> "BuiltinOpenvinoModelInfo": + openvino_plugin_samples_dir = get_samples_path() + interpreter = osp.join(openvino_plugin_samples_dir, model_name + "_interp.py") + interpreter = interpreter if osp.exists(interpreter) else interpreter + + model_dir = get_datumaro_cache_dir() + + # Please visit open-model-zoo repository for OpenVINO public models if you are interested in + # https://github.com/openvinotoolkit/open_model_zoo/blob/master/models/public/index.md + url_folder = "https://storage.openvinotoolkit.org/repositories/datumaro/models/" + + description = osp.join(model_dir, model_name + ".xml") + if not osp.exists(description): + description = ( + cls._download_file(osp.join(url_folder, model_name + ".xml"), description) + if model_name in cls.downloadable_models + else None + ) -class OpenvinoLauncher(Launcher): + weights = osp.join(model_dir, model_name + ".bin") + if not osp.exists(weights): + weights = ( + cls._download_file(osp.join(url_folder, model_name + ".bin"), weights) + if model_name in cls.downloadable_models + else None + ) + + return cls( + interpreter=interpreter, + description=description, + weights=weights, + model_dir=model_dir, + ) + + @staticmethod + def _download_file(url: str, file_root: str) -> str: + log.info('Downloading: "{}" to {}\n'.format(url, file_root)) + req = urllib.request.Request(url) + with urllib.request.urlopen(req) as source, open(file_root, "wb") as output: # nosec B310 + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + return file_root + + def override(self, other: OpenvinoModelInfo) -> None: + """Override builtin model variables to other""" + + def _apply(key: str) -> None: + other_item = getattr(other, key) + self_item = getattr(self, key) + if other_item is None and self_item: + log.info(f"Override description with the builtin model {key}: {self.description}.") + setattr(other, key, self_item) + + for field in fields(self): + _apply(field.name) + + +class OpenvinoLauncher(LauncherWithModelInterpreter): cli_plugin = _OpenvinoImporter def __init__( - self, description, weights, interpreter, device=None, model_dir=None, output_layers=None + self, + description: Optional[str] = None, + weights: Optional[str] = None, + interpreter: Optional[str] = None, + model_dir: Optional[str] = None, + model_name: Optional[str] = None, + output_layers: List[str] = [], + device: Optional[str] = None, + compile_model_config: Optional[Dict] = None, ): - if not model_dir: - model_dir = "" - if not osp.isfile(description): - description = osp.join(model_dir, description) - if not osp.isfile(description): - raise Exception('Failed to open model description file "%s"' % (description)) + model_info = OpenvinoModelInfo( + interpreter=interpreter, + description=description, + weights=weights, + model_dir=model_dir, + ) + if model_name: + builtin_model_info = BuiltinOpenvinoModelInfo.create_from_model_name(model_name) + builtin_model_info.override(model_info) - if not osp.isfile(weights): - weights = osp.join(model_dir, weights) - if not osp.isfile(weights): - raise Exception('Failed to open model weights file "%s"' % (weights)) + model_info.validate() - if not osp.isfile(interpreter): - interpreter = osp.join(model_dir, interpreter) - if not osp.isfile(interpreter): - raise Exception('Failed to open model interpreter script file "%s"' % (interpreter)) + super().__init__(model_interpreter_path=model_info.interpreter) - self._interpreter = InterpreterScript(interpreter) + self.model_info = model_info self._device = device or "CPU" - self._output_blobs = output_layers + self._compile_model_config = compile_model_config + + self._core = Core() + self._network = self._core.read_model(model_info.description, model_info.weights) + + if output_layers: + log.info(f"Add additional output layers {output_layers} to the model outputs.") + self._network.add_outputs(output_layers) - self._ie = IECore() - self._network = self._ie.read_network(description, weights) self._check_model_support(self._network, self._device) self._load_executable_net() + @property + def inputs(self): + return self._network.inputs + + @property + def outputs(self): + return self._network.outputs + def _check_model_support(self, net, device): not_supported_layers = set( - name for name, dev in self._ie.query_network(net, device).items() if not dev + name for name, dev in self._core.query_model(net, device).items() if not dev ) if len(not_supported_layers) != 0: log.error( @@ -127,69 +229,55 @@ def _check_model_support(self, net, device): ) raise NotImplementedError("Some layers are not supported on the device") - def _load_executable_net(self, batch_size=1): + def _load_executable_net(self, batch_size: int = 1): network = self._network - if self._output_blobs: - network.add_outputs(self._output_blobs) - - iter_inputs = iter(network.input_info) + iter_inputs = iter(network.inputs) self._input_blob = next(iter_inputs) - # NOTE: handling for the inclusion of `image_info` in OpenVino2019 - self._require_image_info = "image_info" in network.input_info - if self._input_blob == "image_info": - self._input_blob = next(iter_inputs) + is_dynamic_layout = False + try: + self._input_layout = self._input_blob.shape + except ValueError: + # In case of that the input has dynamic shape + self._input_layout = self._input_blob.partial_shape + is_dynamic_layout = True + + if is_dynamic_layout: + self._input_layout[0] = batch_size + network.reshape({self._input_blob: self._input_layout}) + else: + model_batch_size = self._input_layout[0] + if batch_size != model_batch_size: + log.warning( + "Input layout of the model is static, so that we cannot change " + f"the model batch size ({model_batch_size}) to batch size ({batch_size})! " + "Set the batch size to {model_batch_size}." + ) + batch_size = model_batch_size - self._input_layout = network.input_info[self._input_blob].input_data.shape - self._input_layout[0] = batch_size - network.reshape({self._input_blob: self._input_layout}) self._batch_size = batch_size - self._net = self._ie.load_network(network=network, num_requests=1, device_name=self._device) - - def infer(self, inputs): - assert len(inputs.shape) == 4, "Expected an input image in (N, H, W, C) format, got %s" % ( - inputs.shape, + self._net = self._core.compile_model( + model=network, + device_name=self._device, + config=self._compile_model_config, ) + self._request = self._net.create_infer_request() - if inputs.shape[3] == 1: # A batch of single-channel images - inputs = np.repeat(inputs, 3, axis=3) - - assert inputs.shape[3] == 3, "Expected BGR input, got %s" % (inputs.shape,) - - n, c, h, w = self._input_layout - if inputs.shape[1:3] != (h, w): - resized_inputs = np.empty((n, h, w, c), dtype=inputs.dtype) - for inp, resized_input in zip(inputs, resized_inputs): - cv2.resize(inp, (w, h), resized_input) - inputs = resized_inputs - inputs = inputs.transpose((0, 3, 1, 2)) # NHWC to NCHW - inputs = {self._input_blob: inputs} - if self._require_image_info: - info = np.zeros([1, 3]) - info[0, 0] = h - info[0, 1] = w - info[0, 2] = 1.0 # scale - inputs["image_info"] = info - - results = self._net.infer(inputs) - if len(results) == 1: - return next(iter(results.values())) - else: - return results - - def launch(self, inputs): + def infer(self, inputs: LauncherInputType) -> List[ModelPred]: batch_size = len(inputs) if self._batch_size < batch_size: self._load_executable_net(batch_size) - outputs = self.infer(inputs) - results = self.process_outputs(inputs, outputs) - return results + inputs = ( + {self._input_blob.get_any_name(): inputs} if isinstance(inputs, np.ndarray) else inputs + ) + results = self._request.infer(inputs=inputs) - def categories(self): - return self._interpreter.get_categories() + outputs_group_by_item = [ + {key.any_name: output for key, output in zip(results.keys(), outputs)} + for outputs in zip(*results.values()) + ] - def process_outputs(self, inputs, outputs): - return self._interpreter.process_outputs(inputs, outputs) + return outputs_group_by_item diff --git a/src/datumaro/plugins/openvino_plugin/samples/__init__.py b/src/datumaro/plugins/openvino_plugin/samples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/datumaro/plugins/openvino_plugin/samples/clip_text_ViT-B_32_interp.py b/src/datumaro/plugins/openvino_plugin/samples/clip_text_ViT-B_32_interp.py new file mode 100644 index 0000000000..d3d5ce80ed --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/samples/clip_text_ViT-B_32_interp.py @@ -0,0 +1,30 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import List, Tuple + +from datumaro.components.abstracts import IModelInterpreter +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo +from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.errors import DatumaroError +from datumaro.components.media import Image +from datumaro.plugins.openvino_plugin.samples.utils import gen_hash_key + + +class ClipTextViTB32ModelInterpreter(IModelInterpreter): + def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: + img = inp.media_as(Image).data + return img, None + + def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: + feature_vector = pred.get("output") + if feature_vector is None: + raise DatumaroError('"output" key should exist in the model prediction.') + + return [gen_hash_key(feature_vector)] + + def get_categories(self): + label_categories = LabelCategories() + return {AnnotationType.label: label_categories} diff --git a/src/datumaro/plugins/openvino_plugin/samples/clip_text_vit_l_14_336px_int8_interp.py b/src/datumaro/plugins/openvino_plugin/samples/clip_text_vit_l_14_336px_int8_interp.py new file mode 100644 index 0000000000..3e7b6ad5a2 --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/samples/clip_text_vit_l_14_336px_int8_interp.py @@ -0,0 +1,30 @@ +# Copyright (C) 2024 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import List, Tuple + +from datumaro.components.abstracts import IModelInterpreter +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo +from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.errors import DatumaroError +from datumaro.components.media import Image +from datumaro.plugins.openvino_plugin.samples.utils import gen_hash_key + + +class ClipTextViTL14ModelInterpreter(IModelInterpreter): + def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: + img = inp.media_as(Image).data + return img, None + + def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: + feature_vector = pred.get("output") + if feature_vector is None: + raise DatumaroError('"output" key should exist in the model prediction.') + + return [gen_hash_key(feature_vector)] + + def get_categories(self): + label_categories = LabelCategories() + return {AnnotationType.label: label_categories} diff --git a/src/datumaro/plugins/openvino_plugin/samples/clip_visual_ViT-B_32_interp.py b/src/datumaro/plugins/openvino_plugin/samples/clip_visual_ViT-B_32_interp.py new file mode 100644 index 0000000000..844e9806db --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/samples/clip_visual_ViT-B_32_interp.py @@ -0,0 +1,51 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import os.path as osp +from typing import List, Tuple + +import cv2 +import numpy as np + +from datumaro.components.abstracts import IModelInterpreter +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo +from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.errors import DatumaroError +from datumaro.components.media import Image +from datumaro.plugins.openvino_plugin.samples.utils import gen_hash_key +from datumaro.util.samples import get_samples_path + + +class ClipVisualViTB32ModelInterpreter(IModelInterpreter): + mean = (255 * np.array([0.485, 0.456, 0.406])).reshape(1, 1, 3) + std = (255 * np.array([0.229, 0.224, 0.225])).reshape(1, 1, 3) + + def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: + img = inp.media_as(Image).data + img = cv2.resize(img, (224, 224)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = (img - self.mean) / self.std + if len(img.shape) == 3 and img.shape[2] in {3, 4}: + img = np.transpose(img, (2, 0, 1)) + return img, None + + def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: + feature_vector = pred.get("output") + if feature_vector is None: + raise DatumaroError('"output" key should exist in the model prediction.') + + return [gen_hash_key(feature_vector)] + + def get_categories(self): + label_categories = LabelCategories() + + openvino_plugin_samples_dir = get_samples_path() + imagenet_class_path = osp.join(openvino_plugin_samples_dir, "imagenet.class") + with open(imagenet_class_path, "r", encoding="utf-8") as file: + for line in file.readlines(): + label = line.strip() + label_categories.add(label) + + return {AnnotationType.label: label_categories} diff --git a/src/datumaro/plugins/openvino_plugin/samples/clip_visual_vit_l_14_336px_int8_interp.py b/src/datumaro/plugins/openvino_plugin/samples/clip_visual_vit_l_14_336px_int8_interp.py new file mode 100644 index 0000000000..320059357a --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/samples/clip_visual_vit_l_14_336px_int8_interp.py @@ -0,0 +1,52 @@ +# Copyright (C) 2024 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import os.path as osp +from typing import List, Tuple + +import cv2 +import numpy as np + +from datumaro.components.abstracts import IModelInterpreter +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo +from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.errors import DatumaroError +from datumaro.components.media import Image +from datumaro.plugins.openvino_plugin.samples.utils import gen_hash_key +from datumaro.util.samples import get_samples_path + + +class ClipViTL14ModelInterpreter(IModelInterpreter): + mean = (255 * np.array([0.485, 0.456, 0.406])).reshape(1, 1, 3) + std = (255 * np.array([0.229, 0.224, 0.225])).reshape(1, 1, 3) + + def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: + img = inp.media_as(Image).data + img = cv2.resize(img, (336, 336)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = (img - self.mean) / self.std + + if img.ndim == 3 and img.shape[2] in {3, 4}: + img = np.transpose(img, (2, 0, 1)) + return img, None + + def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: + feature_vector = pred.get("output") + if feature_vector is None: + raise DatumaroError('"output" key should exist in the model prediction.') + + return [gen_hash_key(feature_vector)] + + def get_categories(self): + label_categories = LabelCategories() + openvino_plugin_samples_dir = get_samples_path() + imagenet_class_path = osp.join(openvino_plugin_samples_dir, "imagenet.class") + + with open(imagenet_class_path, "r", encoding="utf-8") as file: + labels = [line.strip() for line in file] + for label in labels: + label_categories.add(label) + + return {AnnotationType.label: label_categories} diff --git a/src/datumaro/plugins/openvino_plugin/samples/googlenet-v4-tf_interp.py b/src/datumaro/plugins/openvino_plugin/samples/googlenet-v4-tf_interp.py new file mode 100644 index 0000000000..54d1b49ce5 --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/samples/googlenet-v4-tf_interp.py @@ -0,0 +1,59 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import List, Tuple + +import cv2 + +from datumaro.components.abstracts import IModelInterpreter +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo +from datumaro.components.annotation import ( + Annotation, + AnnotationType, + FeatureVector, + Label, + LabelCategories, +) +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.errors import DatumaroError +from datumaro.components.media import Image + + +class GooglenetV4TfModelInterpreter(IModelInterpreter): + LOGIT_KEY = "InceptionV4/Logits/Predictions" + FEAT_KEY = "InceptionV4/Logits/PreLogitsFlatten/flatten_1/Reshape:0" + + def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: + img = inp.media_as(Image).data + img = cv2.resize(img, (299, 299)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img, None + + def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: + logit = pred.get(self.LOGIT_KEY) + if logit is None: + raise DatumaroError(f'"{self.LOGIT_KEY}" key should exist in the model prediction.') + + feature_vector = pred.get(self.FEAT_KEY) + if feature_vector is None: + raise DatumaroError(f'"{self.FEAT_KEY}" key should exist in the model prediction.') + + outputs = [ + Label(label=label, attributes={"score": score}) for label, score in enumerate(logit) + ] + outputs += [FeatureVector(feature_vector)] + + return outputs # [FeatureVector(logit), FeatureVector(feature_vector)] + + def get_categories(self): + # output categories - label map etc. + + label_categories = LabelCategories() + + with open("samples/imagenet.class", "r", encoding="utf-8") as file: + for line in file.readlines(): + label = line.strip() + label_categories.add(label) + + return {AnnotationType.label: label_categories} diff --git a/src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_gen3_atss_interp.py b/src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_gen3_atss_interp.py new file mode 100644 index 0000000000..0216908568 --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_gen3_atss_interp.py @@ -0,0 +1,43 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import List, Tuple + +import cv2 + +from datumaro.components.abstracts import IModelInterpreter +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo +from datumaro.components.annotation import Annotation +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.media import Image +from datumaro.plugins.openvino_plugin.samples.utils import ( + create_bboxes_with_rescaling, + rescale_img_keeping_aspect_ratio, +) + +__all__ = ["OTXATSSModelInterpreter"] + + +class OTXATSSModelInterpreter(IModelInterpreter): + h_model = 736 + w_model = 992 + + def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: + img = inp.media_as(Image).data + output = rescale_img_keeping_aspect_ratio(img, self.h_model, self.w_model) + + # From BGR to RGB + output.image = cv2.cvtColor(output.image, cv2.COLOR_BGR2RGB) + # From HWC to CHW + output.image = output.image.transpose(2, 0, 1) + + return output.image, output.scale + + def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: + scale = info + r_scale = 1 / scale + return create_bboxes_with_rescaling(pred["boxes"], pred["labels"], r_scale) + + def get_categories(self): + return None diff --git a/src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_gen3_ssd_interp.py b/src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_gen3_ssd_interp.py new file mode 100644 index 0000000000..db31c6c9fe --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_gen3_ssd_interp.py @@ -0,0 +1,43 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import List, Tuple + +import cv2 + +from datumaro.components.abstracts import IModelInterpreter +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo +from datumaro.components.annotation import Annotation +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.media import Image +from datumaro.plugins.openvino_plugin.samples.utils import ( + create_bboxes_with_rescaling, + rescale_img_keeping_aspect_ratio, +) + +__all__ = ["OTXSSDModelInterpreter"] + + +class OTXSSDModelInterpreter(IModelInterpreter): + h_model = 864 + w_model = 864 + + def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: + img = inp.media_as(Image).data + output = rescale_img_keeping_aspect_ratio(img, self.h_model, self.w_model) + + # From BGR to RGB + output.image = cv2.cvtColor(output.image, cv2.COLOR_BGR2RGB) + # From HWC to CHW + output.image = output.image.transpose(2, 0, 1) + + return output.image, output.scale + + def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: + scale = info + r_scale = 1 / scale + return create_bboxes_with_rescaling(pred["boxes"], pred["labels"], r_scale) + + def get_categories(self): + return None diff --git a/src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_yolox_interp.py b/src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_yolox_interp.py new file mode 100644 index 0000000000..64ac0141c8 --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/samples/otx_custom_object_detection_yolox_interp.py @@ -0,0 +1,43 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import List, Tuple + +import cv2 + +from datumaro.components.abstracts import IModelInterpreter +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo +from datumaro.components.annotation import Annotation +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.media import Image +from datumaro.plugins.openvino_plugin.samples.utils import ( + create_bboxes_with_rescaling, + rescale_img_keeping_aspect_ratio, +) + +__all__ = ["OTXYoloXModelInterpreter"] + + +class OTXYoloXModelInterpreter(IModelInterpreter): + h_model = 416 + w_model = 416 + + def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: + img = inp.media_as(Image).data + output = rescale_img_keeping_aspect_ratio(img, self.h_model, self.w_model) + + # From BGR to RGB + output.image = cv2.cvtColor(output.image, cv2.COLOR_BGR2RGB) + # From HWC to CHW + output.image = output.image.transpose(2, 0, 1) + + return output.image, output.scale + + def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: + scale = info + r_scale = 1 / scale + return create_bboxes_with_rescaling(pred["boxes"], pred["labels"], r_scale) + + def get_categories(self): + return None diff --git a/src/datumaro/plugins/openvino_plugin/samples/utils.py b/src/datumaro/plugins/openvino_plugin/samples/utils.py new file mode 100644 index 0000000000..94492ef0fd --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/samples/utils.py @@ -0,0 +1,111 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from dataclasses import dataclass +from typing import List + +import cv2 +import numpy as np + +from datumaro.components.annotation import Bbox, HashKey + + +def gen_hash_key(features: np.ndarray) -> HashKey: + features = np.sign(features) + hash_key = np.clip(features, 0, None) + hash_key = hash_key.astype(np.uint8) + hash_key = np.packbits(hash_key, axis=-1) + return HashKey(hash_key) + + +def create_bboxes_with_rescaling( + bboxes: np.ndarray, labels: np.ndarray, r_scale: float +) -> List[Bbox]: + idx = 0 + anns = [] + for bbox, label in zip(bboxes, labels): + points = r_scale * bbox[:4] + x1, y1, x2, y2 = points + conf = bbox[4] + anns.append( + Bbox( + x=x1, + y=y1, + w=x2 - x1, + h=y2 - y1, + id=idx, + label=label, + attributes={"score": conf}, + ) + ) + idx += 1 + return anns + + +@dataclass +class RescaledImage: + """Dataclass for a rescaled image. + + This dataclass represents a rescaled image along with the scaling information. + + Attributes: + image: The rescaled image as a NumPy array. + scale: The scale factor by which the image was resized to fit the model input size. + The scale factor is the same for both height and width. + + Note: + The `image` attribute stores the rescaled image as a NumPy array. + The `scale` attribute represents the scale factor used to resize the image. + The scale factor indicates how much the image was scaled to fit the model's input size. + """ + + image: np.ndarray + scale: float + + +def rescale_img_keeping_aspect_ratio( + img: np.ndarray, h_model: int, w_model: int, padding: bool = True +) -> RescaledImage: + """ + Rescale image while maintaining its aspect ratio. + + This function rescales the input image to fit the requirements of the model input. + It also attempts to preserve the original aspect ratio of the input image. + If the aspect ratio of the input image does not match the aspect ratio required by the model, + the function applies zero padding to the image boundaries to maintain the aspect ratio if `padding` option is true. + + Parameters: + img: The image to be rescaled. + h_model: The desired height of the image required by the model. + w_model: The desired width of the image required by the model. + padding: If true, pad the output image boundaries to make the output image size `(h_model, w_model). + Otherwise, there is no pad, so that the output image size can be different with `(h_model, w_model)`. + """ + assert len(img.shape) == 3 + + h_img, w_img = img.shape[:2] + + scale = min(h_model / h_img, w_model / w_img) + + h_resize = min(int(scale * h_img), h_model) + w_resize = min(int(scale * w_img), w_model) + + num_channel = img.shape[-1] + + if padding: + resized_inputs = np.zeros((h_model, w_model, num_channel), dtype=np.uint8) + + resized_inputs[:h_resize, :w_resize, :] = cv2.resize( + img, + (w_resize, h_resize), + interpolation=cv2.INTER_LINEAR, + ) + else: + resized_inputs = cv2.resize( + img, + (w_resize, h_resize), + interpolation=cv2.INTER_LINEAR, + ) + + return RescaledImage(image=resized_inputs, scale=scale) diff --git a/src/datumaro/plugins/openvino_plugin/shift_launcher.py b/src/datumaro/plugins/openvino_plugin/shift_launcher.py new file mode 100644 index 0000000000..71e9c29a9c --- /dev/null +++ b/src/datumaro/plugins/openvino_plugin/shift_launcher.py @@ -0,0 +1,37 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from datumaro.components.errors import MediaTypeError +from datumaro.components.media import Image +from datumaro.plugins.openvino_plugin.launcher import OpenvinoLauncher + + +class ShiftLauncher(OpenvinoLauncher): + def __init__( + self, + description=None, + weights=None, + interpreter=None, + model_dir=None, + model_name=None, + output_layers=None, + device=None, + ): + super().__init__( + description, + weights, + interpreter, + model_dir, + model_name, + output_layers, + device, + ) + + self._device = device or "cpu" + self._output_blobs = next(iter(self._net.outputs)) + + def type_check(self, item): + if not isinstance(item.media, Image): + raise MediaTypeError(f"Media type should be Image, Current type={type(item.media)}") + return True From f5db58d28fec4386368cd480fe29f9ade3805012 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 18:25:56 +0400 Subject: [PATCH 06/25] sync components/comparator.py (extracting from components/operations.py) --- src/datumaro/cli/commands/diff.py | 24 +- src/datumaro/components/comparator.py | 691 ++++++++++++++++++++++++++ src/datumaro/components/operations.py | 299 +---------- tests/integration/cli/test_compare.py | 2 +- tests/unit/test_compare.py | 43 +- 5 files changed, 726 insertions(+), 333 deletions(-) create mode 100644 src/datumaro/components/comparator.py diff --git a/src/datumaro/cli/commands/diff.py b/src/datumaro/cli/commands/diff.py index 800172df72..20a5d7f437 100644 --- a/src/datumaro/cli/commands/diff.py +++ b/src/datumaro/cli/commands/diff.py @@ -8,8 +8,8 @@ import os.path as osp from enum import Enum, auto +from datumaro.components.comparator import DistanceComparator, EqualityComparator from datumaro.components.errors import ProjectNotFoundError -from datumaro.components.operations import DistanceComparator, ExactComparator from datumaro.util import dump_json_file from datumaro.util.os_util import rmtree from datumaro.util.scope import on_error_do, scope_add, scoped @@ -221,24 +221,14 @@ def diff_command(args): if args.method is ComparisonMethod.equality: if args.ignore_field: args.ignore_field = eq_default_if - comparator = ExactComparator( + comparator = EqualityComparator( match_images=args.match_images, ignored_fields=args.ignore_field, ignored_attrs=args.ignore_attr, ignored_item_attrs=args.ignore_item_attr, + all=args.all, ) - matches, mismatches, a_extra, b_extra, errors = comparator.compare_datasets( - first_dataset, second_dataset - ) - - output = { - "mismatches": mismatches, - "a_extra_items": sorted(a_extra), - "b_extra_items": sorted(b_extra), - "errors": errors, - } - if args.all: - output["matches"] = matches + output = comparator.compare_datasets(first_dataset, second_dataset) output_file = osp.join( dst_dir, generate_next_file_name("diff", ext=".json", basedir=dst_dir) @@ -246,12 +236,6 @@ def diff_command(args): log.info("Saving diff to '%s'" % output_file) dump_json_file(output_file, output, indent=True) - print("Found:") - print("The first project has %s unmatched items" % len(a_extra)) - print("The second project has %s unmatched items" % len(b_extra)) - print("%s item conflicts" % len(errors)) - print("%s matching annotations" % len(matches)) - print("%s mismatching annotations" % len(mismatches)) elif args.method is ComparisonMethod.distance: comparator = DistanceComparator(iou_threshold=args.iou_thresh) diff --git a/src/datumaro/components/comparator.py b/src/datumaro/components/comparator.py new file mode 100644 index 0000000000..b33b205b64 --- /dev/null +++ b/src/datumaro/components/comparator.py @@ -0,0 +1,691 @@ +# Copyright (C) 2023-2024 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import logging as log +import os +import os.path as osp +from textwrap import wrap +from typing import Dict, List, Set, Tuple +from unittest import TestCase + +from attr import attrib, attrs +from tabulate import tabulate + +from datumaro.cli.util.project import generate_next_file_name +from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories, Points +from datumaro.components.annotations.matcher import LineMatcher, PointsMatcher, match_segments_pair +from datumaro.components.dataset import Dataset +from datumaro.components.operations import ( + compute_ann_statistics, + compute_image_statistics, + match_items_by_id, + match_items_by_image_hash, +) +from datumaro.components.shift_analyzer import ShiftAnalyzer +from datumaro.util import dump_json_file, filter_dict, find +from datumaro.util.annotation_util import find_instances, max_bbox +from datumaro.util.attrs_util import default_if_none + + +@attrs +class DistanceComparator: + iou_threshold = attrib(converter=float, default=0.5) + + def match_annotations(self, item_a, item_b): + return {t: self._match_ann_type(t, item_a, item_b) for t in AnnotationType} + + def _match_ann_type(self, t, *args): + # pylint: disable=no-value-for-parameter + if t == AnnotationType.label: + return self.match_labels(*args) + elif t == AnnotationType.bbox: + return self.match_boxes(*args) + elif t == AnnotationType.polygon: + return self.match_polygons(*args) + elif t == AnnotationType.mask: + return self.match_masks(*args) + elif t == AnnotationType.points: + return self.match_points(*args) + elif t == AnnotationType.polyline: + return self.match_lines(*args) + # pylint: enable=no-value-for-parameter + else: + raise NotImplementedError("Unexpected annotation type %s" % t) + + @staticmethod + def _get_ann_type(t, item): + return [a for a in item.annotations if a.type == t] + + def match_labels(self, item_a, item_b): + a_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_a)) + b_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_b)) + + matches = a_labels & b_labels + a_unmatched = a_labels - b_labels + b_unmatched = b_labels - a_labels + return matches, a_unmatched, b_unmatched + + def _match_segments(self, t, item_a, item_b): + a_boxes = self._get_ann_type(t, item_a) + b_boxes = self._get_ann_type(t, item_b) + return match_segments_pair(a_boxes, b_boxes, dist_thresh=self.iou_threshold) + + def match_polygons(self, item_a, item_b): + return self._match_segments(AnnotationType.polygon, item_a, item_b) + + def match_masks(self, item_a, item_b): + return self._match_segments(AnnotationType.mask, item_a, item_b) + + def match_boxes(self, item_a, item_b): + return self._match_segments(AnnotationType.bbox, item_a, item_b) + + def match_points(self, item_a, item_b): + a_points = self._get_ann_type(AnnotationType.points, item_a) + b_points = self._get_ann_type(AnnotationType.points, item_b) + + instance_map = {} + for s in [item_a.annotations, item_b.annotations]: + s_instances = find_instances(s) + for inst in s_instances: + inst_bbox = max_bbox(inst) + for ann in inst: + instance_map[id(ann)] = [inst, inst_bbox] + matcher = PointsMatcher(instance_map=instance_map) + + return match_segments_pair( + a_points, b_points, dist_thresh=self.iou_threshold, distance=matcher.distance + ) + + def match_lines(self, item_a, item_b): + a_lines = self._get_ann_type(AnnotationType.polyline, item_a) + b_lines = self._get_ann_type(AnnotationType.polyline, item_b) + + matcher = LineMatcher() + + return match_segments_pair( + a_lines, b_lines, dist_thresh=self.iou_threshold, distance=matcher.distance + ) + + +@attrs +class EqualityComparator: + match_images: bool = attrib(kw_only=True, default=False) + ignored_fields = attrib(kw_only=True, factory=set, validator=default_if_none(set)) + ignored_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set)) + ignored_item_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set)) + all = attrib(kw_only=True, default=False) + + _test: TestCase = attrib(init=False) + errors: list = attrib(init=False) + + def __attrs_post_init__(self): + self._test = TestCase() + self._test.maxDiff = None + + def _match_items(self, a, b): + if self.match_images: + return match_items_by_image_hash(a, b) + else: + return match_items_by_id(a, b) + + def _compare_categories(self, a, b): + test = self._test + errors = self.errors + + try: + test.assertEqual(sorted(a, key=lambda t: t.value), sorted(b, key=lambda t: t.value)) + except AssertionError as e: + errors.append({"type": "categories", "message": str(e)}) + + if AnnotationType.label in a: + try: + test.assertEqual( + a[AnnotationType.label].items, + b[AnnotationType.label].items, + ) + except AssertionError as e: + errors.append({"type": "labels", "message": str(e)}) + if AnnotationType.mask in a: + try: + test.assertEqual( + a[AnnotationType.mask].colormap, + b[AnnotationType.mask].colormap, + ) + except AssertionError as e: + errors.append({"type": "colormap", "message": str(e)}) + if AnnotationType.points in a: + try: + test.assertEqual( + a[AnnotationType.points].items, + b[AnnotationType.points].items, + ) + except AssertionError as e: + errors.append({"type": "points", "message": str(e)}) + + def _compare_annotations(self, a: Annotation, b: Annotation): + ignored_fields = self.ignored_fields + ignored_attrs = self.ignored_attrs + + a_fields = {k: None for k in a.as_dict() if k in ignored_fields} + b_fields = {k: None for k in b.as_dict() if k in ignored_fields} + if "attributes" not in ignored_fields: + a_fields["attributes"] = filter_dict(a.attributes, ignored_attrs) + b_fields["attributes"] = filter_dict(b.attributes, ignored_attrs) + + if a.type == b.type == AnnotationType.skeleton and "elements" not in ignored_fields: + a_fields["elements"] = sorted( + filter(lambda p: p.visibility[0] != Points.Visibility.absent, a.elements), + key=lambda p: p.label if p.label is not None else -1, + ) + b_fields["elements"] = sorted( + filter(lambda p: p.visibility[0] != Points.Visibility.absent, b.elements), + key=lambda p: p.label if p.label is not None else -1, + ) + + result = a.wrap(**a_fields) == b.wrap(**b_fields) + + return result + + def _compare_items(self, item_a, item_b): + test = self._test + + a_id = (item_a.id, item_a.subset) + b_id = (item_b.id, item_b.subset) + + matched = [] + unmatched = [] + errors = [] + + try: + test.assertEqual( + filter_dict(item_a.attributes, self.ignored_item_attrs), + filter_dict(item_b.attributes, self.ignored_item_attrs), + ) + except AssertionError as e: + errors.append({"type": "item_attr", "a_item": a_id, "b_item": b_id, "message": str(e)}) + + b_annotations = item_b.annotations[:] + for ann_a in item_a.annotations: + ann_b_candidates = [x for x in item_b.annotations if x.type == ann_a.type] + + ann_b = find( + enumerate(self._compare_annotations(ann_a, x) for x in ann_b_candidates), + lambda x: x[1], + ) + if ann_b is None: + unmatched.append( + { + "item": a_id, + "source": "a", + "ann": str(ann_a), + } + ) + continue + else: + ann_b = ann_b_candidates[ann_b[0]] + + b_annotations.remove(ann_b) # avoid repeats + matched.append({"a_item": a_id, "b_item": b_id, "a": str(ann_a), "b": str(ann_b)}) + + for ann_b in b_annotations: + unmatched.append({"item": b_id, "source": "b", "ann": str(ann_b)}) + + return matched, unmatched, errors + + @staticmethod + def _print_output(output: dict): + print("Found:") + print("The first project has %s unmatched items" % len(output.get("a_extra_items", []))) + print("The second project has %s unmatched items" % len(output.get("b_extra_items", []))) + print("%s item conflicts" % len(output.get("errors", []))) + print("%s matching annotations" % len(output.get("matches", []))) + print("%s mismatching annotations" % len(output.get("mismatches", []))) + + def compare_datasets(self, a, b): + self.errors = [] + errors = self.errors + + self._compare_categories(a.categories(), b.categories()) + + matched = [] + unmatched = [] + + matches, a_unmatched, b_unmatched = self._match_items(a, b) + + if a.categories().get(AnnotationType.label) != b.categories().get(AnnotationType.label): + output = { + "mismatches": unmatched, + "a_extra_items": sorted(a_unmatched), + "b_extra_items": sorted(b_unmatched), + "errors": errors, + } + if self.all: + output["matches"] = matched + + self._print_output(output) + return output + + _dist = lambda s: len(s[1]) + len(s[2]) + for a_ids, b_ids in matches: + # build distance matrix + match_status = {} # (a_id, b_id): [matched, unmatched, errors] + a_matches = {a_id: None for a_id in a_ids} + b_matches = {b_id: None for b_id in b_ids} + + for a_id in a_ids: + item_a = a.get(*a_id) + candidates = {} + + for b_id in b_ids: + item_b = b.get(*b_id) + + i_m, i_um, i_err = self._compare_items(item_a, item_b) + candidates[b_id] = [i_m, i_um, i_err] + + if len(i_um) == 0: + a_matches[a_id] = b_id + b_matches[b_id] = a_id + matched.extend(i_m) + errors.extend(i_err) + break + + match_status[a_id] = candidates + + # assign + for a_id in a_ids: + if len(b_ids) == 0: + break + + # find the closest, ignore already assigned + matched_b = a_matches[a_id] + if matched_b is not None: + continue + min_dist = -1 + for b_id in b_ids: + if b_matches[b_id] is not None: + continue + d = _dist(match_status[a_id][b_id]) + if d < min_dist and 0 <= min_dist: + continue + min_dist = d + matched_b = b_id + + if matched_b is None: + continue + a_matches[a_id] = matched_b + b_matches[matched_b] = a_id + + m = match_status[a_id][matched_b] + matched.extend(m[0]) + unmatched.extend(m[1]) + errors.extend(m[2]) + + a_unmatched |= set(a_id for a_id, m in a_matches.items() if not m) + b_unmatched |= set(b_id for b_id, m in b_matches.items() if not m) + + output = { + "mismatches": unmatched, + "a_extra_items": sorted(a_unmatched), + "b_extra_items": sorted(b_unmatched), + "errors": errors, + } + if self.all: + output["matches"] = matched + self._print_output(output) + return output + + @staticmethod + def save_compare_report( + output: Dict, + report_dir: str, + ) -> None: + """Saves the comparison report to JSON and text files. + + Args: + output: A dictionary containing the comparison data. + report_dir: A string representing the directory to save the report files. + """ + os.makedirs(report_dir, exist_ok=True) + output_file = osp.join( + report_dir, + generate_next_file_name("equality_compare", ext=".json", basedir=report_dir), + ) + + log.info(f"Saving compare json to {output_file}") + dump_json_file(output_file, output, indent=True) + + +@attrs +class TableComparator: + """ + Class for comparing datasets and generating comparison report table. + """ + + @staticmethod + def _extract_labels(dataset: Dataset) -> Set[str]: + """Extracts labels from the dataset. + + Args: + dataset: An instance of a Dataset class. + + Returns: + A set of labels present in the dataset. + """ + label_cat = dataset.categories().get(AnnotationType.label, LabelCategories()) + return set(c.name for c in label_cat) + + @staticmethod + def _compute_statistics(dataset: Dataset) -> Tuple[Dict, Dict]: + """Computes image and annotation statistics of the dataset. + + Args: + dataset: An instance of a Dataset class. + + Returns: + A tuple containing image statistics and annotation statistics. + """ + image_stats = compute_image_statistics(dataset) + ann_stats = compute_ann_statistics(dataset) + return image_stats, ann_stats + + def _analyze_dataset(self, dataset: Dataset) -> Tuple[str, Set[str], Dict, Dict]: + """Analyzes the dataset to get labels, format, and statistics. + + Args: + dataset: An instance of a Dataset class. + + Returns: + A tuple containing Dataset format, set of label names, image statistics, + and annotation statistics. + """ + dataset_format = dataset.format + dataset_labels = self._extract_labels(dataset) + image_stats, ann_stats = self._compute_statistics(dataset) + return dataset_format, dataset_labels, image_stats, ann_stats + + @staticmethod + def _create_table(headers: List[str], rows: List[List[str]]) -> str: + """Creates a table with the given headers and rows using the tabulate module. + + Args: + headers: A list containing table headers. + rows: A list containing table rows. + + Returns: + A string representation of the table. + """ + + def wrapfunc(item): + """Wrap a item consisted of text, returning a list of wrapped lines.""" + max_len = 35 + return "\n".join(wrap(item, max_len)) + + wrapped_rows = [] + for row in rows: + new_row = [wrapfunc(item) for item in row] + wrapped_rows.append(new_row) + + return tabulate(wrapped_rows, headers, tablefmt="grid") + + @staticmethod + def _create_dict(rows: List[List[str]]) -> Dict[str, List[str]]: + """Creates a dictionary from the rows of the table. + + Args: + rows: A list containing table rows. + + Returns: + A dictionary where the key is the first element of a row and the value is + the rest of the row. + """ + data_dict = {row[0]: row[1:] for row in rows[1:]} + return data_dict + + def _create_high_level_comparison_table( + self, first_info: Tuple, second_info: Tuple + ) -> Tuple[str, Dict]: + """Generates a high-level comparison table. + + Args: + first_info: A tuple containing information about the first dataset. + second_info: A tuple containing information about the second dataset. + + Returns: + A tuple containing the table as a string and a dictionary representing the data + of the table. + """ + first_format, first_labels, first_image_stats, first_ann_stats = first_info + second_format, second_labels, second_image_stats, second_ann_stats = second_info + + headers = ["Field", "First", "Second"] + + rows = [ + ["Format", first_format, second_format], + ["Number of classes", str(len(first_labels)), str(len(second_labels))], + [ + "Common classes", + ", ".join(sorted(list(first_labels.intersection(second_labels)))), + ", ".join(sorted(list(second_labels.intersection(first_labels)))), + ], + ["Classes", ", ".join(sorted(first_labels)), ", ".join(sorted(second_labels))], + [ + "Images count", + str(first_image_stats["dataset"]["images count"]), + str(second_image_stats["dataset"]["images count"]), + ], + [ + "Unique images count", + str(first_image_stats["dataset"]["unique images count"]), + str(second_image_stats["dataset"]["unique images count"]), + ], + [ + "Repeated images count", + str(first_image_stats["dataset"]["repeated images count"]), + str(second_image_stats["dataset"]["repeated images count"]), + ], + [ + "Annotations count", + str(first_ann_stats["annotations count"]), + str(second_ann_stats["annotations count"]), + ], + [ + "Unannotated images count", + str(first_ann_stats["unannotated images count"]), + str(second_ann_stats["unannotated images count"]), + ], + ] + + table = self._create_table(headers, rows) + data_dict = self._create_dict(rows) + + return table, data_dict + + def _create_mid_level_comparison_table( + self, first_info: Tuple, second_info: Tuple + ) -> Tuple[str, Dict]: + """Generates a mid-level comparison table. + + Args: + first_info: A tuple containing information about the first dataset. + second_info: A tuple containing information about the second dataset. + + Returns: + A tuple containing the table as a string and a dictionary representing the data + of the table. + """ + _, _, first_image_stats, first_ann_stats = first_info + _, _, second_image_stats, second_ann_stats = second_info + + headers = ["Field", "First", "Second"] + + rows = [] + + first_subsets = sorted(list(first_image_stats["subsets"].keys())) + second_subsets = sorted(list(second_image_stats["subsets"].keys())) + + subset_names = first_subsets.copy() + subset_names.extend(item for item in second_subsets if item not in first_subsets) + + for subset_name in subset_names: + first_subset_data = first_image_stats["subsets"].get(subset_name, {}) + second_subset_data = second_image_stats["subsets"].get(subset_name, {}) + mean_str_first = ( + ", ".join(f"{val:6.2f}" for val in first_subset_data.get("image mean (RGB)", [])) + if "image mean (RGB)" in first_subset_data + else "" + ) + std_str_first = ( + ", ".join(f"{val:6.2f}" for val in first_subset_data.get("image std (RGB)", [])) + if "image std" in first_subset_data + else "" + ) + mean_str_second = ( + ", ".join(f"{val:6.2f}" for val in second_subset_data.get("image mean (RGB)", [])) + if "image mean (RGB)" in second_subset_data + else "" + ) + std_str_second = ( + ", ".join(f"{val:6.2f}" for val in second_subset_data.get("image std", [])) + if "image std (RGB)" in second_subset_data + else "" + ) + rows.append([f"{subset_name} - Image Mean (RGB)", mean_str_first, mean_str_second]) + rows.append([f"{subset_name} - Image Std (RGB)", std_str_first, std_str_second]) + + first_labels = sorted(list(first_ann_stats["annotations"]["labels"]["distribution"].keys())) + second_labels = sorted( + list(second_ann_stats["annotations"]["labels"]["distribution"].keys()) + ) + + label_names = first_labels.copy() + label_names.extend(item for item in second_labels if item not in first_labels) + + for label_name in label_names: + count_dist_first = first_ann_stats["annotations"]["labels"]["distribution"].get( + label_name, [0, 0.0] + ) + count_dist_second = second_ann_stats["annotations"]["labels"]["distribution"].get( + label_name, [0, 0.0] + ) + count_first, dist_first = count_dist_first if count_dist_first[0] != 0 else ["", ""] + count_second, dist_second = count_dist_second if count_dist_second[0] != 0 else ["", ""] + rows.append( + [ + f"Label - {label_name}", + f"imgs: {count_first}, percent: {dist_first:.4f}" if count_first != "" else "", + f"imgs: {count_second}, percent: {dist_second:.4f}" + if count_second != "" + else "", + ] + ) + + table = self._create_table(headers, rows) + data_dict = self._create_dict(rows) + + return table, data_dict + + def _create_low_level_comparison_table( + self, first_dataset: Dataset, second_dataset: Dataset + ) -> Tuple[str, Dict]: + """Generates a low-level comparison table. + + Args: + first_dataset: The first dataset to compare. + second_dataset: The second dataset to compare. + + Returns: + A tuple containing the table as a string and a dictionary representing the data + of the table. + """ + shift_analyzer = ShiftAnalyzer() + cov_shift = shift_analyzer.compute_covariate_shift([first_dataset, second_dataset]) + label_shift = shift_analyzer.compute_label_shift([first_dataset, second_dataset]) + + headers = ["Field", "Value"] + + rows = [ + ["Covariate shift", str(cov_shift)], + ["Label shift", str(label_shift)], + ] + + table = self._create_table(headers, rows) + data_dict = self._create_dict(rows) + + return table, data_dict + + def compare_datasets( + self, first: Dataset, second: Dataset, mode: str = "all" + ) -> Tuple[str, str, str, Dict]: + """Compares two datasets and generates comparison reports. + + Args: + first: The first dataset to compare. + second: The second dataset to compare. + + Returns: + A tuple containing high-level table, mid-level table, low-level table, and a + dictionary representation of the comparison. + """ + first_info = self._analyze_dataset(first) + second_info = self._analyze_dataset(second) + + high_level_table, high_level_dict = None, {} + mid_level_table, mid_level_dict = None, {} + low_level_table, low_level_dict = None, {} + + if mode in ["high", "all"]: + high_level_table, high_level_dict = self._create_high_level_comparison_table( + first_info, second_info + ) + if mode in ["mid", "all"]: + mid_level_table, mid_level_dict = self._create_mid_level_comparison_table( + first_info, second_info + ) + if mode in ["low", "all"]: + low_level_table, low_level_dict = self._create_low_level_comparison_table(first, second) + + comparison_dict = dict( + high_level=high_level_dict, mid_level=mid_level_dict, low_level=low_level_dict + ) + + print(f"High-level comparison:\n{high_level_table}\n") + print(f"Mid-level comparison:\n{mid_level_table}\n") + print(f"Low-level comparison:\n{low_level_table}\n") + + return high_level_table, mid_level_table, low_level_table, comparison_dict + + @staticmethod + def save_compare_report( + high_level_table: str, + mid_level_table: str, + low_level_table: str, + comparison_dict: Dict, + report_dir: str, + ) -> None: + """Saves the comparison report to JSON and text files. + + Args: + high_level_table: High-level comparison table as a string. + mid_level_table: Mid-level comparison table as a string. + low_level_table: Low-level comparison table as a string. + comparison_dict: A dictionary containing the comparison data. + report_dir: A string representing the directory to save the report files. + """ + os.makedirs(report_dir, exist_ok=True) + json_output_file = osp.join( + report_dir, generate_next_file_name("table_compare", ext=".json", basedir=report_dir) + ) + txt_output_file = osp.join( + report_dir, generate_next_file_name("table_compare", ext=".txt", basedir=report_dir) + ) + + log.info(f"Saving compare json to {json_output_file}") + log.info(f"Saving compare table to {txt_output_file}") + + dump_json_file(json_output_file, comparison_dict, indent=True) + with open(txt_output_file, "w") as f: + f.write(f"High-level Comparison:\n{high_level_table}\n\n") + f.write(f"Mid-level Comparison:\n{mid_level_table}\n\n") + f.write(f"Low-level Comparison:\n{low_level_table}\n\n") diff --git a/src/datumaro/components/operations.py b/src/datumaro/components/operations.py index d48e9ca324..420c203453 100644 --- a/src/datumaro/components/operations.py +++ b/src/datumaro/components/operations.py @@ -8,7 +8,6 @@ from collections import OrderedDict from copy import deepcopy from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union -from unittest import TestCase import attr import cv2 @@ -20,10 +19,8 @@ AnnotationType, LabelCategories, MaskCategories, - Points, PointsCategories, ) -from datumaro.components.annotations import LineMatcher, PointsMatcher, match_segments_pair from datumaro.components.annotations.merger import ( BboxMerger, CaptionsMerger, @@ -55,9 +52,9 @@ WrongGroupError, ) from datumaro.components.media import Image, MediaElement, MultiframeImage, PointCloud, Video -from datumaro.util import filter_dict, find +from datumaro.util import find from datumaro.util.annotation_util import find_instances, max_bbox -from datumaro.util.attrs_util import default_if_none, ensure_cls +from datumaro.util.attrs_util import ensure_cls def match_annotations_equal(a, b): @@ -1221,86 +1218,6 @@ def get_label(ann): return stats -@attrs -class DistanceComparator: - iou_threshold = attrib(converter=float, default=0.5) - - def match_annotations(self, item_a, item_b): - return {t: self._match_ann_type(t, item_a, item_b) for t in AnnotationType} - - def _match_ann_type(self, t, *args): - # pylint: disable=no-value-for-parameter - if t == AnnotationType.label: - return self.match_labels(*args) - elif t == AnnotationType.bbox: - return self.match_boxes(*args) - elif t == AnnotationType.polygon: - return self.match_polygons(*args) - elif t == AnnotationType.mask: - return self.match_masks(*args) - elif t == AnnotationType.points: - return self.match_points(*args) - elif t == AnnotationType.polyline: - return self.match_lines(*args) - # pylint: enable=no-value-for-parameter - else: - raise NotImplementedError("Unexpected annotation type %s" % t) - - @staticmethod - def _get_ann_type(t, item): - return [a for a in item.annotations if a.type == t] - - def match_labels(self, item_a, item_b): - a_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_a)) - b_labels = set(a.label for a in self._get_ann_type(AnnotationType.label, item_b)) - - matches = a_labels & b_labels - a_unmatched = a_labels - b_labels - b_unmatched = b_labels - a_labels - return matches, a_unmatched, b_unmatched - - def _match_segments(self, t, item_a, item_b): - a_boxes = self._get_ann_type(t, item_a) - b_boxes = self._get_ann_type(t, item_b) - return match_segments_pair(a_boxes, b_boxes, dist_thresh=self.iou_threshold) - - def match_polygons(self, item_a, item_b): - return self._match_segments(AnnotationType.polygon, item_a, item_b) - - def match_masks(self, item_a, item_b): - return self._match_segments(AnnotationType.mask, item_a, item_b) - - def match_boxes(self, item_a, item_b): - return self._match_segments(AnnotationType.bbox, item_a, item_b) - - def match_points(self, item_a, item_b): - a_points = self._get_ann_type(AnnotationType.points, item_a) - b_points = self._get_ann_type(AnnotationType.points, item_b) - - instance_map = {} - for s in [item_a.annotations, item_b.annotations]: - s_instances = find_instances(s) - for inst in s_instances: - inst_bbox = max_bbox(inst) - for ann in inst: - instance_map[id(ann)] = [inst, inst_bbox] - matcher = PointsMatcher(instance_map=instance_map) - - return match_segments_pair( - a_points, b_points, dist_thresh=self.iou_threshold, distance=matcher.distance - ) - - def match_lines(self, item_a, item_b): - a_lines = self._get_ann_type(AnnotationType.polyline, item_a) - b_lines = self._get_ann_type(AnnotationType.polyline, item_b) - - matcher = LineMatcher() - - return match_segments_pair( - a_lines, b_lines, dist_thresh=self.iou_threshold, distance=matcher.distance - ) - - def match_items_by_id(a: IDataset, b: IDataset): a_items = set((item.id, item.subset) for item in a) b_items = set((item.id, item.subset) for item in b) @@ -1367,215 +1284,3 @@ def find_unique_images(dataset: IDataset, item_hash: Optional[Callable] = None): for item in dataset: matcher.process_item(item) return matcher.get_result() - - -def match_classes(a: CategoriesInfo, b: CategoriesInfo): - a_label_cat = a.get(AnnotationType.label, LabelCategories()) - b_label_cat = b.get(AnnotationType.label, LabelCategories()) - - a_labels = set(c.name for c in a_label_cat) - b_labels = set(c.name for c in b_label_cat) - - matches = a_labels & b_labels - a_unmatched = a_labels - b_labels - b_unmatched = b_labels - a_labels - return matches, a_unmatched, b_unmatched - - -@attrs -class ExactComparator: - match_images: bool = attrib(kw_only=True, default=False) - ignored_fields = attrib(kw_only=True, factory=set, validator=default_if_none(set)) - ignored_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set)) - ignored_item_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set)) - - _test: TestCase = attrib(init=False) - errors: list = attrib(init=False) - - def __attrs_post_init__(self): - self._test = TestCase() - self._test.maxDiff = None - - def _match_items(self, a, b): - if self.match_images: - return match_items_by_image_hash(a, b) - else: - return match_items_by_id(a, b) - - def _compare_categories(self, a, b): - test = self._test - errors = self.errors - - try: - test.assertEqual(sorted(a, key=lambda t: t.value), sorted(b, key=lambda t: t.value)) - except AssertionError as e: - errors.append({"type": "categories", "message": str(e)}) - - if AnnotationType.label in a: - try: - test.assertEqual( - a[AnnotationType.label].items, - b[AnnotationType.label].items, - ) - except AssertionError as e: - errors.append({"type": "labels", "message": str(e)}) - if AnnotationType.mask in a: - try: - test.assertEqual( - a[AnnotationType.mask].colormap, - b[AnnotationType.mask].colormap, - ) - except AssertionError as e: - errors.append({"type": "colormap", "message": str(e)}) - if AnnotationType.points in a: - try: - test.assertEqual( - a[AnnotationType.points].items, - b[AnnotationType.points].items, - ) - except AssertionError as e: - errors.append({"type": "points", "message": str(e)}) - - def _compare_annotations(self, a: Annotation, b: Annotation): - ignored_fields = self.ignored_fields - ignored_attrs = self.ignored_attrs - - a_fields = {k: None for k in a.as_dict() if k in ignored_fields} - b_fields = {k: None for k in b.as_dict() if k in ignored_fields} - if "attributes" not in ignored_fields: - a_fields["attributes"] = filter_dict(a.attributes, ignored_attrs) - b_fields["attributes"] = filter_dict(b.attributes, ignored_attrs) - - if a.type == b.type == AnnotationType.skeleton and "elements" not in ignored_fields: - a_fields["elements"] = sorted( - filter(lambda p: p.visibility[0] != Points.Visibility.absent, a.elements), - key=lambda p: p.label if p.label is not None else -1, - ) - b_fields["elements"] = sorted( - filter(lambda p: p.visibility[0] != Points.Visibility.absent, b.elements), - key=lambda p: p.label if p.label is not None else -1, - ) - - result = a.wrap(**a_fields) == b.wrap(**b_fields) - - return result - - def _compare_items(self, item_a, item_b): - test = self._test - - a_id = (item_a.id, item_a.subset) - b_id = (item_b.id, item_b.subset) - - matched = [] - unmatched = [] - errors = [] - - try: - test.assertEqual( - filter_dict(item_a.attributes, self.ignored_item_attrs), - filter_dict(item_b.attributes, self.ignored_item_attrs), - ) - except AssertionError as e: - errors.append({"type": "item_attr", "a_item": a_id, "b_item": b_id, "message": str(e)}) - - b_annotations = item_b.annotations[:] - for ann_a in item_a.annotations: - ann_b_candidates = [x for x in item_b.annotations if x.type == ann_a.type] - - ann_b = find( - enumerate(self._compare_annotations(ann_a, x) for x in ann_b_candidates), - lambda x: x[1], - ) - if ann_b is None: - unmatched.append( - { - "item": a_id, - "source": "a", - "ann": str(ann_a), - } - ) - continue - else: - ann_b = ann_b_candidates[ann_b[0]] - - b_annotations.remove(ann_b) # avoid repeats - matched.append({"a_item": a_id, "b_item": b_id, "a": str(ann_a), "b": str(ann_b)}) - - for ann_b in b_annotations: - unmatched.append({"item": b_id, "source": "b", "ann": str(ann_b)}) - - return matched, unmatched, errors - - def compare_datasets(self, a, b): - self.errors = [] - errors = self.errors - - self._compare_categories(a.categories(), b.categories()) - - matched = [] - unmatched = [] - - matches, a_unmatched, b_unmatched = self._match_items(a, b) - - if a.categories().get(AnnotationType.label) != b.categories().get(AnnotationType.label): - return matched, unmatched, a_unmatched, b_unmatched, errors - - _dist = lambda s: len(s[1]) + len(s[2]) - for a_ids, b_ids in matches: - # build distance matrix - match_status = {} # (a_id, b_id): [matched, unmatched, errors] - a_matches = {a_id: None for a_id in a_ids} - b_matches = {b_id: None for b_id in b_ids} - - for a_id in a_ids: - item_a = a.get(*a_id) - candidates = {} - - for b_id in b_ids: - item_b = b.get(*b_id) - - i_m, i_um, i_err = self._compare_items(item_a, item_b) - candidates[b_id] = [i_m, i_um, i_err] - - if len(i_um) == 0: - a_matches[a_id] = b_id - b_matches[b_id] = a_id - matched.extend(i_m) - errors.extend(i_err) - break - - match_status[a_id] = candidates - - # assign - for a_id in a_ids: - if len(b_ids) == 0: - break - - # find the closest, ignore already assigned - matched_b = a_matches[a_id] - if matched_b is not None: - continue - min_dist = -1 - for b_id in b_ids: - if b_matches[b_id] is not None: - continue - d = _dist(match_status[a_id][b_id]) - if d < min_dist and 0 <= min_dist: - continue - min_dist = d - matched_b = b_id - - if matched_b is None: - continue - a_matches[a_id] = matched_b - b_matches[matched_b] = a_id - - m = match_status[a_id][matched_b] - matched.extend(m[0]) - unmatched.extend(m[1]) - errors.extend(m[2]) - - a_unmatched |= set(a_id for a_id, m in a_matches.items() if not m) - b_unmatched |= set(b_id for b_id, m in b_matches.items() if not m) - - return matched, unmatched, a_unmatched, b_unmatched, errors diff --git a/tests/integration/cli/test_compare.py b/tests/integration/cli/test_compare.py index 21ecdb6ea9..292e60eb77 100644 --- a/tests/integration/cli/test_compare.py +++ b/tests/integration/cli/test_compare.py @@ -18,9 +18,9 @@ Polygon, PolyLine, ) +from datumaro.components.comparator import DistanceComparator from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image -from datumaro.components.operations import DistanceComparator from datumaro.components.project import Dataset from tests.requirements import Requirements, mark_requirement diff --git a/tests/unit/test_compare.py b/tests/unit/test_compare.py index 528f1f422a..5ada3488b1 100644 --- a/tests/unit/test_compare.py +++ b/tests/unit/test_compare.py @@ -13,9 +13,9 @@ PointsCategories, Skeleton, ) +from datumaro.components.comparator import DistanceComparator, EqualityComparator from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image -from datumaro.components.operations import DistanceComparator, ExactComparator from datumaro.components.project import Dataset from datumaro.util.definitions import DEFAULT_SUBSET_NAME @@ -166,14 +166,15 @@ def test_can_match_points(self): self.assertEqual(0, len(mismatches)) -class ExactComparatorTest(TestCase): +class EqualityComparatorTest(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_class_comparison(self): a = Dataset.from_iterable([], categories=["a", "b", "c"]) b = Dataset.from_iterable([], categories=["b", "c"]) - comp = ExactComparator() - _, _, _, _, errors = comp.compare_datasets(a, b) + comp = EqualityComparator() + output = comp.compare_datasets(a, b) + errors = output["errors"] self.assertEqual(1, len(errors), errors) @@ -195,11 +196,15 @@ def test_item_comparison(self): categories=["a", "b", "c"], ) - comp = ExactComparator() - _, _, a_extra_items, b_extra_items, errors = comp.compare_datasets(a, b) + comp = EqualityComparator() + output = comp.compare_datasets(a, b) - self.assertEqual({("1", "train")}, a_extra_items) - self.assertEqual({("3", DEFAULT_SUBSET_NAME)}, b_extra_items) + a_extra_items = output["a_extra_items"] + b_extra_items = output["b_extra_items"] + errors = output["errors"] + + self.assertEqual([("1", "train")], a_extra_items) + self.assertEqual([("3", DEFAULT_SUBSET_NAME)], b_extra_items) self.assertEqual(1, len(errors), errors) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @@ -272,9 +277,12 @@ def test_annotation_comparison(self): categories=["a", "b", "c", "d"], ) - comp = ExactComparator() - matched, unmatched, _, _, errors = comp.compare_datasets(a, b) + comp = EqualityComparator(all=True) + output = comp.compare_datasets(a, b) + matched = output["matches"] + unmatched = output["mismatches"] + errors = output["errors"] self.assertEqual(6, len(matched), matched) self.assertEqual(2, len(unmatched), unmatched) self.assertEqual(0, len(errors), errors) @@ -385,10 +393,10 @@ def test_skeleton_annotation_comparison(self): categories=categories, ) - comp = ExactComparator() - _, unmatched, _, _, _ = comp.compare_datasets(a, b) + comp = EqualityComparator() + output = comp.compare_datasets(a, b) - assert unmatched == [ + assert output["mismatches"] == [ { "item": ("3", "default"), "source": "a", @@ -509,9 +517,14 @@ def test_image_comparison(self): categories=["a", "b", "c", "d"], ) - comp = ExactComparator(match_images=True) - matched_ann, unmatched_ann, a_unmatched, b_unmatched, errors = comp.compare_datasets(a, b) + comp = EqualityComparator(match_images=True, all=True) + output = comp.compare_datasets(a, b) + matched_ann = output["matches"] + unmatched_ann = output["mismatches"] + a_unmatched = output["a_extra_items"] + b_unmatched = output["b_extra_items"] + errors = output["errors"] self.assertEqual(3, len(matched_ann), matched_ann) self.assertEqual(5, len(unmatched_ann), unmatched_ann) self.assertEqual(1, len(a_unmatched), a_unmatched) From 4e223a699ba321b84c61d55eac805d1bea1dc4ff Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 22:13:28 +0400 Subject: [PATCH 07/25] sync components/dataset_item_storage.py --- .../components/dataset_item_storage.py | 66 +++++++++++++++++-- src/datumaro/components/dataset_storage.py | 22 ++++++- src/datumaro/components/hl_ops/__init__.py | 6 +- 3 files changed, 85 insertions(+), 9 deletions(-) diff --git a/src/datumaro/components/dataset_item_storage.py b/src/datumaro/components/dataset_item_storage.py index b143e2639e..1b2a6cfc17 100644 --- a/src/datumaro/components/dataset_item_storage.py +++ b/src/datumaro/components/dataset_item_storage.py @@ -1,10 +1,14 @@ -from __future__ import annotations +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT from copy import copy from enum import Enum, auto -from typing import Any, Iterator, Optional, Tuple, Type, Union +from typing import Any, Iterator, Optional, Set, Tuple, Type, Union -from datumaro.components.dataset_base import CategoriesInfo, DatasetItem, IDataset, MediaElement +from datumaro.components.annotation import AnnotationType +from datumaro.components.dataset_base import CategoriesInfo, DatasetInfo, DatasetItem, IDataset +from datumaro.components.media import MediaElement from datumaro.util.definitions import DEFAULT_SUBSET_NAME __all__ = ["ItemStatus", "DatasetItemStorage", "DatasetItemStorageDatasetView"] @@ -20,6 +24,7 @@ class DatasetItemStorage: def __init__(self): self.data = {} # { subset_name: { id: DatasetItem } } self._traversal_order = {} # maintain the order of elements + self._order = [] # allow indexing def __iter__(self) -> Iterator[DatasetItem]: for item in self._traversal_order.values(): @@ -36,6 +41,8 @@ def put(self, item: DatasetItem) -> bool: subset = self.data.setdefault(item.subset, {}) is_new = subset.get(item.id) is None self._traversal_order[(item.id, item.subset)] = item + if is_new: + self._order.append((item.id, item.subset)) subset[item.id] = item return is_new @@ -61,7 +68,9 @@ def remove(self, id: Union[str, DatasetItem], subset: Optional[str] = None) -> b is_removed = subset_data.get(id) is not None subset_data[id] = None if is_removed: + # TODO : investigate why "del subset_data[id]" cannot replace "subset_data[id] = None". self._traversal_order.pop((id, subset)) + self._order.remove((id, subset)) return is_removed def __contains__(self, x: Union[DatasetItem, Tuple[str, str]]) -> bool: @@ -76,19 +85,41 @@ def get_subset(self, name): def subsets(self): return self.data + def get_annotated_items(self): + return sum(bool(s.annotations) for s in self._traversal_order.values()) + + def get_datasetitem_by_path(self, path): + for s in self._traversal_order.values(): + if getattr(s.media, "path", None) == path: + return s + + def get_annotations(self): + annotations_by_type = {t.name: {"count": 0} for t in AnnotationType} + for item in self._traversal_order.values(): + for ann in item.annotations: + annotations_by_type[ann.type.name]["count"] += 1 + return sum(t["count"] for t in annotations_by_type.values()) + def __copy__(self): copied = DatasetItemStorage() copied._traversal_order = copy(self._traversal_order) + copied._order = copy(self._order) copied.data = copy(self.data) return copied + def __getitem__(self, idx: int) -> DatasetItem: + _id, subset = self._order[idx] + item = self.data[subset][_id] + return item + class DatasetItemStorageDatasetView(IDataset): class Subset(IDataset): - def __init__(self, parent: DatasetItemStorageDatasetView, name: str): + def __init__(self, parent: "DatasetItemStorageDatasetView", name: str): super().__init__() self.parent = parent self.name = name + self._length = None @property def _data(self): @@ -100,9 +131,17 @@ def __iter__(self): yield item def __len__(self): - return len(self._data) + if self._length is not None: + return self._length + + self._length = 0 + for item in self._data.values(): + if item is not None: + self._length += 1 + return self._length def put(self, item): + self._length = None return self._data.put(item) def get(self, id, subset=None): @@ -111,6 +150,7 @@ def get(self, id, subset=None): def remove(self, id, subset=None): assert (subset or DEFAULT_SUBSET_NAME) == (self.name or DEFAULT_SUBSET_NAME) + self._length = None return self._data.remove(id, subset) def get_subset(self, name): @@ -120,21 +160,31 @@ def get_subset(self, name): def subsets(self): return {self.name or DEFAULT_SUBSET_NAME: self} + def infos(self): + return self.parent.infos() + def categories(self): return self.parent.categories() def media_type(self): return self.parent.media_type() + def ann_types(self): + return self.parent.ann_types() + def __init__( self, parent: DatasetItemStorage, + infos: DatasetInfo, categories: CategoriesInfo, media_type: Optional[Type[MediaElement]], + ann_types: Optional[Set[AnnotationType]], ): self._parent = parent + self._infos = infos self._categories = categories self._media_type = media_type + self._ann_types = ann_types def __iter__(self): yield from self._parent @@ -142,6 +192,9 @@ def __iter__(self): def __len__(self): return len(self._parent) + def infos(self): + return self._infos + def categories(self): return self._categories @@ -159,3 +212,6 @@ def get(self, id, subset=None): def media_type(self): return self._media_type + + def ann_types(self): + return self._ann_types diff --git a/src/datumaro/components/dataset_storage.py b/src/datumaro/components/dataset_storage.py index 6adcb47747..90562fa5bf 100644 --- a/src/datumaro/components/dataset_storage.py +++ b/src/datumaro/components/dataset_storage.py @@ -30,7 +30,13 @@ class DatasetPatchWrapper(DatasetItemStorageDatasetView): # The purpose of this class is to indicate that the input dataset is # a patch and autofill patch info in Exporter def __init__(self, patch: DatasetPatch, parent: IDataset): - super().__init__(patch.data, parent.categories(), parent.media_type()) + super().__init__( + parent=patch.data, + infos={}, + categories=parent.categories(), + media_type=parent.media_type(), + ann_types=None, + ) self.patch = patch def subsets(self): @@ -195,7 +201,11 @@ def _update_status(item_id, new_status: ItemStatus): patch = self._storage # must be empty after transforming cache = DatasetItemStorage() source = self._source or DatasetItemStorageDatasetView( - self._storage, categories=self._categories, media_type=media_type + parent=self._storage, + infos={}, + categories=self._categories, + media_type=media_type, + ann_types=None, ) transform = None @@ -316,7 +326,13 @@ def _merged(self) -> IDataset: return self._source elif self._source is not None: self.init_cache() - return DatasetItemStorageDatasetView(self._storage, self._categories, self._media_type) + return DatasetItemStorageDatasetView( + parent=self._storage, + infos={}, + categories=self._categories, + media_type=self._media_type, + ann_types=None, + ) def __len__(self) -> int: if self._length is None: diff --git a/src/datumaro/components/hl_ops/__init__.py b/src/datumaro/components/hl_ops/__init__.py index c828c37ded..a5c61dca74 100644 --- a/src/datumaro/components/hl_ops/__init__.py +++ b/src/datumaro/components/hl_ops/__init__.py @@ -109,7 +109,11 @@ def merge(*datasets: IDataset) -> IDataset: categories = ExactMerge.merge_categories(d.categories() for d in datasets) media_type = ExactMerge.merge_media_types(datasets) return DatasetItemStorageDatasetView( - ExactMerge.merge(*datasets), categories=categories, media_type=media_type + parent=ExactMerge.merge(*datasets), + infos={}, + categories=categories, + media_type=media_type, + ann_types=None, ) From 7cb7a65a7aef049aa198fa360de6fec5e4d6f42d Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 22:25:33 +0400 Subject: [PATCH 08/25] sync components/filter.py --- src/datumaro/components/filter.py | 220 ++++++++++++++++++++++++------ 1 file changed, 178 insertions(+), 42 deletions(-) diff --git a/src/datumaro/components/filter.py b/src/datumaro/components/filter.py index 1962f5af9c..9b828e3365 100644 --- a/src/datumaro/components/filter.py +++ b/src/datumaro/components/filter.py @@ -1,17 +1,21 @@ -# Copyright (C) 2019-2021 Intel Corporation +# Copyright (C) 2019-2024 Intel Corporation # # SPDX-License-Identifier: MIT +from __future__ import annotations import logging as log +from typing import TYPE_CHECKING, Callable, Optional # Disable B410: import_lxml - the library is used for writing -from lxml import etree as ET # nosec, lxml has proper XPath implementation +from lxml import etree as ET # nosec from datumaro.components.annotation import ( Annotation, AnnotationType, Bbox, Caption, + Ellipse, + HashKey, Label, Mask, Points, @@ -21,10 +25,22 @@ from datumaro.components.media import Image from datumaro.components.transformer import ItemTransform +if TYPE_CHECKING: + from datumaro.components.dataset_base import CategoriesInfo, DatasetItem, IDataset + +__all__ = [ + "XPathDatasetFilter", + "XPathAnnotationsFilter", + "UserFunctionDatasetFilter", + "UserFunctionAnnotationsFilter", +] + class DatasetItemEncoder: @classmethod - def encode(cls, item, categories=None): + def encode( + cls, item: DatasetItem, categories: Optional[CategoriesInfo] = None + ) -> ET.ElementBase: item_elem = ET.Element("item") ET.SubElement(item_elem, "id").text = str(item.id) ET.SubElement(item_elem, "subset").text = str(item.subset) @@ -39,31 +55,34 @@ def encode(cls, item, categories=None): return item_elem @classmethod - def encode_image(cls, image): + def encode_image(cls, image: Image) -> ET.ElementBase: image_elem = ET.Element("image") size = image.size if size is not None: - h, w = size + h, w = str(size[0]), str(size[1]) else: h = "unknown" w = h - ET.SubElement(image_elem, "width").text = str(w) - ET.SubElement(image_elem, "height").text = str(h) + ET.SubElement(image_elem, "height").text = h + ET.SubElement(image_elem, "width").text = w ET.SubElement(image_elem, "has_data").text = "%d" % int(image.has_data) - ET.SubElement(image_elem, "path").text = image.path + if hasattr(image, "path"): + ET.SubElement(image_elem, "path").text = image.path return image_elem @classmethod - def encode_annotation_base(cls, annotation): + def encode_annotation_base(cls, annotation: Annotation) -> ET.ElementBase: assert isinstance(annotation, Annotation) ann_elem = ET.Element("annotation") ET.SubElement(ann_elem, "id").text = str(annotation.id) ET.SubElement(ann_elem, "type").text = str(annotation.type.name) for k, v in annotation.attributes.items(): + if k.isdigit(): + k = "_" + k ET.SubElement(ann_elem, k.replace(" ", "-")).text = str(v) ET.SubElement(ann_elem, "group").text = str(annotation.group) @@ -71,7 +90,7 @@ def encode_annotation_base(cls, annotation): return ann_elem @staticmethod - def _get_label(label_id, categories): + def _get_label(label_id: Optional[int], categories: Optional[CategoriesInfo]) -> str: label = "" if label_id is None: return "" @@ -82,7 +101,9 @@ def _get_label(label_id, categories): return label @classmethod - def encode_label_object(cls, obj, categories): + def encode_label_object( + cls, obj: Label, categories: Optional[CategoriesInfo] + ) -> ET.ElementBase: ann_elem = cls.encode_annotation_base(obj) ET.SubElement(ann_elem, "label").text = str(cls._get_label(obj.label, categories)) @@ -91,7 +112,7 @@ def encode_label_object(cls, obj, categories): return ann_elem @classmethod - def encode_mask_object(cls, obj, categories): + def encode_mask_object(cls, obj: Mask, categories: Optional[CategoriesInfo]) -> ET.ElementBase: ann_elem = cls.encode_annotation_base(obj) ET.SubElement(ann_elem, "label").text = str(cls._get_label(obj.label, categories)) @@ -100,7 +121,7 @@ def encode_mask_object(cls, obj, categories): return ann_elem @classmethod - def encode_bbox_object(cls, obj, categories): + def encode_bbox_object(cls, obj: Bbox, categories: Optional[CategoriesInfo]) -> ET.ElementBase: ann_elem = cls.encode_annotation_base(obj) ET.SubElement(ann_elem, "label").text = str(cls._get_label(obj.label, categories)) @@ -114,7 +135,9 @@ def encode_bbox_object(cls, obj, categories): return ann_elem @classmethod - def encode_points_object(cls, obj, categories): + def encode_points_object( + cls, obj: Points, categories: Optional[CategoriesInfo] + ) -> ET.ElementBase: ann_elem = cls.encode_annotation_base(obj) ET.SubElement(ann_elem, "label").text = str(cls._get_label(obj.label, categories)) @@ -139,7 +162,9 @@ def encode_points_object(cls, obj, categories): return ann_elem @classmethod - def encode_polygon_object(cls, obj, categories): + def encode_polygon_object( + cls, obj: Polygon, categories: Optional[CategoriesInfo] + ) -> ET.ElementBase: ann_elem = cls.encode_annotation_base(obj) ET.SubElement(ann_elem, "label").text = str(cls._get_label(obj.label, categories)) @@ -163,7 +188,9 @@ def encode_polygon_object(cls, obj, categories): return ann_elem @classmethod - def encode_polyline_object(cls, obj, categories): + def encode_polyline_object( + cls, obj: PolyLine, categories: Optional[CategoriesInfo] + ) -> ET.ElementBase: ann_elem = cls.encode_annotation_base(obj) ET.SubElement(ann_elem, "label").text = str(cls._get_label(obj.label, categories)) @@ -187,7 +214,7 @@ def encode_polyline_object(cls, obj, categories): return ann_elem @classmethod - def encode_caption_object(cls, obj): + def encode_caption_object(cls, obj: Caption) -> ET.ElementBase: ann_elem = cls.encode_annotation_base(obj) ET.SubElement(ann_elem, "caption").text = str(obj.caption) @@ -195,7 +222,26 @@ def encode_caption_object(cls, obj): return ann_elem @classmethod - def encode_annotation(cls, o, categories=None): + def encode_ellipse_object( + cls, obj: Ellipse, categories: Optional[CategoriesInfo] + ) -> ET.ElementBase: + ann_elem = cls.encode_annotation_base(obj) + + ET.SubElement(ann_elem, "label").text = str(cls._get_label(obj.label, categories)) + ET.SubElement(ann_elem, "label_id").text = str(obj.label) + + ET.SubElement(ann_elem, "x1").text = str(obj.x1) + ET.SubElement(ann_elem, "y1").text = str(obj.y1) + ET.SubElement(ann_elem, "x2").text = str(obj.x2) + ET.SubElement(ann_elem, "y2").text = str(obj.y2) + ET.SubElement(ann_elem, "area").text = str(obj.get_area()) + + return ann_elem + + @classmethod + def encode_annotation( + cls, o: Annotation, categories: Optional[CategoriesInfo] = None + ) -> ET.ElementBase: if isinstance(o, Label): return cls.encode_label_object(o, categories) if isinstance(o, Mask): @@ -210,51 +256,54 @@ def encode_annotation(cls, o, categories=None): return cls.encode_polygon_object(o, categories) if isinstance(o, Caption): return cls.encode_caption_object(o) + if isinstance(o, Ellipse): + return cls.encode_ellipse_object(o, categories) + if isinstance(o, HashKey): + return cls.encode_annotation_base(o) + raise NotImplementedError("Unexpected annotation object passed: %s" % o) @staticmethod - def to_string(encoded_item): + def to_string(encoded_item: ET.ElementBase) -> str: return ET.tostring(encoded_item, encoding="unicode", pretty_print=True) class XPathDatasetFilter(ItemTransform): - def __init__(self, extractor, xpath=None): + def __init__(self, extractor: IDataset, xpath: str) -> None: super().__init__(extractor) - if xpath is not None: - try: - xpath = ET.XPath(xpath) - except Exception: - log.error("Failed to create XPath from expression '%s'", xpath) - raise + try: + xpath_eval = ET.XPath(xpath) + except Exception: + log.error("Failed to create XPath from expression '%s'", xpath) + raise - self._f = lambda item: bool( - xpath(DatasetItemEncoder.encode(item, extractor.categories())) - ) - else: - self._f = None + # Return true -> filter out an item + self._f = lambda item: bool( + xpath_eval(DatasetItemEncoder.encode(item, extractor.categories())) + ) - def transform_item(self, item): - if self._f and not self._f(item): + def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]: + if not self._f(item): return None return item class XPathAnnotationsFilter(ItemTransform): - def __init__(self, extractor, xpath=None, remove_empty=False): + def __init__(self, extractor: IDataset, xpath: str, remove_empty: bool = False) -> None: super().__init__(extractor) - if xpath is not None: - try: - xpath = ET.XPath(xpath) - except Exception: - log.error("Failed to create XPath from expression '%s'", xpath) - raise - self._filter = xpath + try: + xpath_eval = ET.XPath(xpath) + except Exception: + log.error("Failed to create XPath from expression '%s'", xpath) + raise + + self._filter = xpath_eval self._remove_empty = remove_empty - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]: if self._filter is None: return item @@ -268,3 +317,90 @@ def transform_item(self, item): if self._remove_empty and len(annotations) == 0: return None return self.wrap_item(item, annotations=annotations) + + +class UserFunctionDatasetFilter(ItemTransform): + """Filter dataset items using a user-provided Python function. + + Parameters: + extractor: Datumaro `Dataset` to filter. + filter_func: A Python callable that takes a `DatasetItem` as its input and + returns a boolean. If the return value is True, that `DatasetItem` will be retained. + Otherwise, it is removed. + + Example: + This is an example of filtering dataset items with images larger than 1024 pixels:: + + from datumaro.components.media import Image + + def filter_func(item: DatasetItem) -> bool: + h, w = item.media_as(Image).size + return h > 1024 or w > 1024 + + filtered = UserFunctionDatasetFilter( + extractor=dataset, filter_func=filter_func) + # No items with an image height or width greater than 1024 + filtered_items = [item for item in filtered] + """ + + def __init__(self, extractor: IDataset, filter_func: Callable[[DatasetItem], bool]): + super().__init__(extractor) + + self._filter_func = filter_func + + def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]: + return item if self._filter_func(item) else None + + +class UserFunctionAnnotationsFilter(ItemTransform): + """Filter annotations using a user-provided Python function. + + Parameters: + extractor: Datumaro `Dataset` to filter. + filter_func: A Python callable that takes `DatasetItem` and `Annotation` as its inputs + and returns a boolean. If the return value is True, the `Annotation` will be retained. + Otherwise, it is removed. + remove_empty: If True, `DatasetItem` without any annotations is removed + after filtering its annotations. Otherwise, do not filter `DatasetItem`. + + Example: + This is an example of removing bounding boxes sized greater than 50% of the image size:: + + from datumaro.components.media import Image + from datumaro.components.annotation import Annotation, Bbox + + def filter_func(item: DatasetItem, ann: Annotation) -> bool: + # If the annotation is not a Bbox, do not filter + if not isinstance(ann, Bbox): + return False + + h, w = item.media_as(Image).size + image_size = h * w + bbox_size = ann.h * ann.w + + # Accept Bboxes smaller than 50% of the image size + return bbox_size < 0.5 * image_size + + filtered = UserFunctionAnnotationsFilter( + extractor=dataset, filter_func=filter_func) + # No bounding boxes with a size greater than 50% of their image + filtered_items = [item for item in filtered] + """ + + def __init__( + self, + extractor: IDataset, + filter_func: Callable[[DatasetItem, Annotation], bool], + remove_empty: bool = False, + ): + super().__init__(extractor) + + self._filter_func = filter_func + self._remove_empty = remove_empty + + def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]: + filtered_anns = [ann for ann in item.annotations if self._filter_func(item, ann)] + + if self._remove_empty and not filtered_anns: + return None + return self.wrap_item(item, annotations=filtered_anns) From 67490bdd6066c2b961c53b8f9fd2137abe37ab9c Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 22:39:00 +0400 Subject: [PATCH 09/25] sync components/dataset_storage.py --- src/datumaro/components/dataset_storage.py | 391 ++++++++++++++++++--- 1 file changed, 334 insertions(+), 57 deletions(-) diff --git a/src/datumaro/components/dataset_storage.py b/src/datumaro/components/dataset_storage.py index 90562fa5bf..80d0b902c6 100644 --- a/src/datumaro/components/dataset_storage.py +++ b/src/datumaro/components/dataset_storage.py @@ -1,10 +1,20 @@ -from __future__ import annotations +# Copyright (C) 2020-2023 Intel Corporation +# +# SPDX-License-Identifier: MIT -from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union +import logging as log +from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union from datumaro.components.annotation import AnnotationType, LabelCategories from datumaro.components.contexts.importer import _ImportFail -from datumaro.components.dataset_base import CategoriesInfo, DatasetBase, DatasetItem, IDataset +from datumaro.components.dataset_base import ( + DEFAULT_SUBSET_NAME, + CategoriesInfo, + DatasetBase, + DatasetInfo, + DatasetItem, + IDataset, +) from datumaro.components.dataset_item_storage import ( DatasetItemStorage, DatasetItemStorageDatasetView, @@ -13,14 +23,14 @@ from datumaro.components.errors import ( CategoriesRedefinedError, ConflictingCategoriesError, + DatasetInfosRedefinedError, MediaTypeError, + NotAvailableError, RepeatedItemError, ) from datumaro.components.media import MediaElement from datumaro.components.transformer import ItemTransform, Transform -from datumaro.plugins.transforms import ProjectLabels from datumaro.util import is_method_redefined -from datumaro.util.definitions import DEFAULT_SUBSET_NAME __all__ = ["DatasetPatch", "DatasetStorage"] @@ -29,13 +39,13 @@ class DatasetPatch: class DatasetPatchWrapper(DatasetItemStorageDatasetView): # The purpose of this class is to indicate that the input dataset is # a patch and autofill patch info in Exporter - def __init__(self, patch: DatasetPatch, parent: IDataset): + def __init__(self, patch: "DatasetPatch", parent: IDataset): super().__init__( - parent=patch.data, - infos={}, + patch.data, + infos=parent.infos(), categories=parent.categories(), media_type=parent.media_type(), - ann_types=None, + ann_types=parent.ann_types(), ) self.patch = patch @@ -45,11 +55,13 @@ def subsets(self): def __init__( self, data: DatasetItemStorage, + infos: DatasetInfo, categories: CategoriesInfo, updated_items: Dict[Tuple[str, str], ItemStatus], updated_subsets: Dict[str, ItemStatus] = None, ): self.data = data + self.infos = infos self.categories = categories self.updated_items = updated_items self._updated_subsets = updated_subsets @@ -67,12 +79,55 @@ def as_dataset(self, parent: IDataset) -> IDataset: return __class__.DatasetPatchWrapper(self, parent) +class _StackedTransform(Transform): + def __init__(self, source: IDataset, transforms: List[Transform]): + super().__init__(source) + + self.is_local = True + self.transforms: List[Transform] = [] + self.malformed_transform_indices: Dict[int, Exception] = {} + for idx, transform in enumerate(transforms): + try: + source = transform[0](source, *transform[1], **transform[2]) + except Exception as e: + self.malformed_transform_indices[idx] = e + + self.transforms.append(source) + + if self.is_local and not isinstance(source, ItemTransform): + self.is_local = False + + def transform_item(self, item: DatasetItem) -> DatasetItem: + for t in self.transforms: + if item is None: + break + item = t.transform_item(item) + return item + + def __iter__(self) -> Iterator[DatasetItem]: + yield from self.transforms[-1] + + def infos(self) -> DatasetInfo: + return self.transforms[-1].infos() + + def categories(self) -> CategoriesInfo: + return self.transforms[-1].categories() + + def media_type(self) -> Type[MediaElement]: + return self.transforms[-1].media_type() + + def ann_types(self) -> Set[AnnotationType]: + return self.transforms[-1].ann_types() + + class DatasetStorage(IDataset): def __init__( self, - source: Union[IDataset, DatasetItemStorage] = None, + source: Union[IDataset, DatasetItemStorage], + infos: Optional[DatasetInfo] = None, categories: Optional[CategoriesInfo] = None, media_type: Optional[Type[MediaElement]] = None, + ann_types: Optional[Set[AnnotationType]] = None, ): if source is None and categories is None: categories = {} @@ -80,6 +135,12 @@ def __init__( raise ValueError("Can't use both source and categories") self._categories = categories + if source is None and infos is None: + infos = {} + elif isinstance(source, IDataset) and infos is not None: + raise ValueError("Can't use both source and infos") + self._infos = infos + if media_type: pass elif isinstance(source, IDataset) and source.media_type(): @@ -87,8 +148,18 @@ def __init__( else: raise ValueError("Media type must be provided for a dataset") assert issubclass(media_type, MediaElement) + self._media_type = media_type + if ann_types: + pass + elif isinstance(source, IDataset) and source.ann_types(): + ann_types = source.ann_types() + else: + ann_types = set() + + self._ann_types = ann_types + # Possible combinations: # 1. source + storage # - Storage contains a patch to the Source data. @@ -117,7 +188,7 @@ def is_cache_initialized(self) -> bool: def _is_unchanged_wrapper(self) -> bool: return self._source is not None and self._storage.is_empty() and not self._transforms - def init_cache(self): + def init_cache(self) -> None: if not self.is_cache_initialized(): for _ in self._iter_init_cache(): pass @@ -149,35 +220,6 @@ def _iter_init_cache_unchecked(self) -> Iterable[DatasetItem]: # # The patch is always applied on top of the source / transforms stack. - class _StackedTransform(Transform): - def __init__(self, source, transforms): - super().__init__(source) - - self.is_local = True - self.transforms: List[Transform] = [] - for transform in transforms: - source = transform[0](source, *transform[1], **transform[2]) - self.transforms.append(source) - - if self.is_local and not isinstance(source, ItemTransform): - self.is_local = False - - def transform_item(self, item): - for t in self.transforms: - if item is None: - break - item = t.transform_item(item) - return item - - def __iter__(self): - yield from self.transforms[-1] - - def categories(self): - return self.transforms[-1].categories() - - def media_type(self): - return self.transforms[-1].media_type() - def _update_status(item_id, new_status: ItemStatus): current_status = self._updated_items.get(item_id) @@ -197,17 +239,22 @@ def _update_status(item_id, new_status: ItemStatus): else: assert False, "Unknown status %s" % new_status + def _add_ann_types(item: DatasetItem): + for ann in item.annotations: + if ann.type == AnnotationType.hash_key: + continue + self._ann_types.add(ann.type) + media_type = self._media_type patch = self._storage # must be empty after transforming cache = DatasetItemStorage() source = self._source or DatasetItemStorageDatasetView( - parent=self._storage, - infos={}, + self._storage, + infos=self._infos, categories=self._categories, media_type=media_type, - ann_types=None, + ann_types=self._ann_types, ) - transform = None old_ids = None if self._transforms: @@ -220,7 +267,7 @@ def _update_status(item_id, new_status: ItemStatus): # A generic way to find modified items: # Collect all the dataset original ids and compare # with transform outputs. - # TODO: introduce Extractor.items() / .ids() to avoid extra + # TODO: introduce DatasetBase.items() / .ids() to avoid extra # dataset traversals? old_ids = set((item.id, item.subset) for item in source) source = transform @@ -231,6 +278,8 @@ def _update_status(item_id, new_status: ItemStatus): "Transforms are not allowed to change media " "type of dataset items" ) + self._drop_malformed_transforms(transform.malformed_transform_indices) + i = -1 for i, item in enumerate(source): if item.media and not isinstance(item.media, media_type): @@ -275,6 +324,7 @@ def _update_status(item_id, new_status: ItemStatus): cache.put(item) yield item + _add_ann_types(item) if i == -1: cache = patch @@ -282,6 +332,7 @@ def _update_status(item_id, new_status: ItemStatus): if not self._flush_changes: _update_status((item.id, item.subset), ItemStatus.added) yield item + _add_ann_types(item) else: for item in patch: if item in cache: # already processed @@ -290,6 +341,7 @@ def _update_status(item_id, new_status: ItemStatus): _update_status((item.id, item.subset), ItemStatus.added) cache.put(item) yield item + _add_ann_types(item) if not self._flush_changes and transform and not transform.is_local: # Mark removed items that were not produced by transforms @@ -308,6 +360,13 @@ def _update_status(item_id, new_status: ItemStatus): # Don't need to override categories if already defined self._categories = source_cat + if transform: + source_infos = transform.infos() + else: + source_infos = source.infos() + if source_infos is not None: + self._infos = source_infos + self._source = None self._transforms = [] @@ -327,11 +386,11 @@ def _merged(self) -> IDataset: elif self._source is not None: self.init_cache() return DatasetItemStorageDatasetView( - parent=self._storage, - infos={}, + self._storage, + infos=self._infos, categories=self._categories, media_type=self._media_type, - ann_types=None, + ann_types=self._ann_types, ) def __len__(self) -> int: @@ -339,6 +398,22 @@ def __len__(self) -> int: self.init_cache() return self._length + def infos(self) -> DatasetInfo: + if self.is_cache_initialized(): + return self._infos + elif self._infos is not None: + return self._infos + elif any(is_method_redefined("infos", Transform, t[0]) for t in self._transforms): + self.init_cache() + return self._infos + else: + return self._source.infos() + + def define_infos(self, infos: DatasetInfo): + if self._infos or self._source is not None: + raise DatasetInfosRedefinedError() + self._infos = infos + def categories(self) -> CategoriesInfo: if self.is_cache_initialized(): return self._categories @@ -358,13 +433,20 @@ def define_categories(self, categories: CategoriesInfo): def media_type(self) -> Type[MediaElement]: return self._media_type - def put(self, item: DatasetItem): + def ann_types(self) -> Set[AnnotationType]: + return self._ann_types + + def put(self, item: DatasetItem) -> None: if item.media and not isinstance(item.media, self._media_type): raise MediaTypeError( "Mismatching item media type '%s', " "the dataset contains '%s' items." % (type(item.media), self._media_type) ) + ann_types = set([ann.type for ann in item.annotations]) + # hash_key can be included any task + ann_types.discard(AnnotationType.hash_key) + is_new = self._storage.put(item) if not self.is_cache_initialized() or is_new: @@ -374,16 +456,17 @@ def put(self, item: DatasetItem): if is_new and not self.is_cache_initialized(): self._length = None + self._ann_types = set() if self._length is not None: self._length += is_new - def get(self, id, subset=None) -> Optional[DatasetItem]: + def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: id = str(id) subset = subset or DEFAULT_SUBSET_NAME item = self._storage.get(id, subset) if item is None and not self.is_cache_initialized(): - if self._source.get.__func__ == DatasetBase.get or self._transforms: + if self._source.get.__func__ == DatasetBase.get: # can be improved if IDataset is ABC self.init_cache() item = self._storage.get(id, subset) @@ -393,7 +476,7 @@ def get(self, id, subset=None) -> Optional[DatasetItem]: self._storage.put(item) return item - def remove(self, id, subset=None): + def remove(self, id: str, subset: Optional[str] = None) -> None: id = str(id) subset = subset or DEFAULT_SUBSET_NAME @@ -403,18 +486,28 @@ def remove(self, id, subset=None): self._updated_items[(id, subset)] = ItemStatus.removed if is_removed and not self.is_cache_initialized(): self._length = None + self._ann_types = set() if self._length is not None: self._length -= is_removed - def get_subset(self, name): + def get_subset(self, name: str) -> IDataset: return self._merged().get_subset(name) - def subsets(self): + def subsets(self) -> Dict[str, IDataset]: # TODO: check if this can be optimized in case of transforms # and other cases return self._merged().subsets() - def transform(self, method: Type[Transform], *args, **kwargs): + def get_annotated_items(self) -> int: + return self._storage.get_annotated_items() + + def get_annotations(self) -> int: + return self._storage.get_annotations() + + def get_datasetitem_by_path(self, path: str) -> Optional[DatasetItem]: + return self._storage.get_datasetitem_by_path(path) + + def transform(self, method: Type[Transform], *args, **kwargs) -> None: # Flush accumulated changes if not self._storage.is_empty(): source = self._merged() @@ -427,14 +520,18 @@ def transform(self, method: Type[Transform], *args, **kwargs): self._source = source self._transforms.append((method, args, kwargs)) + if is_method_redefined("infos", Transform, method): + self._infos = None + if is_method_redefined("categories", Transform, method): self._categories = None self._length = None + self._ann_types = set() def has_updated_items(self): return bool(self._transforms) or bool(self._updated_items) - def get_patch(self): + def get_patch(self) -> DatasetPatch: # Patch includes only added or modified items. # To find removed items, one needs to consult updated_items list. if self._transforms: @@ -452,7 +549,9 @@ def get_patch(self): else: patch.put(self._storage.get(item_id, subset)) - return DatasetPatch(patch, self._categories, self._updated_items) + return DatasetPatch( + patch, infos=self._infos, categories=self._categories, updated_items=self._updated_items + ) def flush_changes(self): self._updated_items = {} @@ -472,6 +571,8 @@ def update(self, source: Union[DatasetPatch, IDataset, Iterable[DatasetItem]]): else: self.put(source.data.get(*item_id)) elif isinstance(source, IDataset): + from datumaro.plugins.transforms import ProjectLabels + for item in ProjectLabels( source, self.categories().get(AnnotationType.label, LabelCategories()) ): @@ -479,3 +580,179 @@ def update(self, source: Union[DatasetPatch, IDataset, Iterable[DatasetItem]]): else: for item in source: self.put(item) + + def _drop_malformed_transforms(self, malformed_transform_indices: Dict[int, Exception]) -> None: + safe_transforms = [] + for idx, transform in enumerate(self._transforms): + if idx in malformed_transform_indices: + log.error( + f"Automatically drop {transform} from the transform stack " + "because an error is raised. Therefore, the dataset will not be " + "transformed by this transformation since it is dropped.", + exc_info=malformed_transform_indices[idx], + ) + continue + + safe_transforms += [transform] + + self._transforms = safe_transforms + + def __getitem__(self, idx: int) -> DatasetItem: + try: + return self._storage[idx] + except IndexError: # Data storage should be initialized + self.init_cache() + return self._storage[idx] + + +class StreamSubset(IDataset): + def __init__(self, source: IDataset, subset: str) -> None: + if not source.is_stream: + raise ValueError("source should be a stream.") + self._source = source + self._subset = subset + self._length = None + + def __iter__(self) -> Iterator[DatasetItem]: + for item in self._source: + if item.subset == self._subset: + yield item + + def __len__(self) -> int: + if self._length is None: + self._length = sum(1 for _ in self) + return self._length + + def subsets(self) -> Dict[str, IDataset]: + raise NotAvailableError("Cannot get subsets of the subset.") + + def get_subset(self, name) -> IDataset: + raise NotAvailableError("Cannot get a subset of the subset.") + + def infos(self) -> DatasetInfo: + return self._source.infos() + + def categories(self) -> CategoriesInfo: + return self._source.categories() + + def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: + raise NotAvailableError( + "Random access to the dataset item is not allowed in streaming. " + "You can access to the dataset item only by using its iterator." + ) + + def media_type(self) -> Type[MediaElement]: + return self._source.media_type() + + def ann_types(self) -> Set[AnnotationType]: + return self._source.ann_types() + + @property + def is_stream(self) -> bool: + return True + + +class StreamDatasetStorage(DatasetStorage): + def __init__( + self, + source: IDataset, + infos: Optional[DatasetInfo] = None, + categories: Optional[CategoriesInfo] = None, + media_type: Optional[Type[MediaElement]] = None, + ann_types: Optional[Set[AnnotationType]] = None, + ): + if not source.is_stream: + raise ValueError("source should be a stream.") + self._subset_names = list(source.subsets().keys()) + self._transform_ids_for_latest_subset_names = [] + super().__init__(source, infos, categories, media_type, ann_types) + + def is_cache_initialized(self) -> bool: + log.debug("This function has no effect on streaming.") + return True + + def init_cache(self) -> None: + log.debug("This function has no effect on streaming.") + pass + + @property + def stacked_transform(self) -> IDataset: + if self._transforms: + transform = _StackedTransform(self._source, self._transforms) + self._drop_malformed_transforms(transform.malformed_transform_indices) + else: + transform = self._source + + self._flush_changes = True + return transform + + def __iter__(self) -> Iterator[DatasetItem]: + for item in self.stacked_transform: + yield item + + for ann in item.annotations: + if ann.type == AnnotationType.hash_key: + continue + self._ann_types.add(ann.type) + + def __len__(self) -> int: + if self._length is None: + self._length = len(self._source) + return self._length + + def put(self, item: DatasetItem) -> None: + raise NotAvailableError("Drop-in replacement is not allowed in streaming.") + + def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: + raise NotAvailableError( + "Random access to the dataset item is not allowed in streaming. " + "You can access to the dataset item only by using its iterator." + ) + + def remove(self, id: str, subset: Optional[str] = None) -> None: + raise NotAvailableError("Drop-in removal is not allowed in streaming.") + + def get_subset(self, name: str) -> IDataset: + return self.subsets()[name] + + @property + def subset_names(self): + if self._transform_ids_for_latest_subset_names != [id(t) for t in self._transforms]: + self._subset_names = {item.subset for item in self} + self._transform_ids_for_latest_subset_names = [id(t) for t in self._transforms] + + return self._subset_names + + def subsets(self) -> Dict[str, IDataset]: + return {subset: StreamSubset(self, subset) for subset in self.subset_names} + + def transform(self, method: Type[Transform], *args, **kwargs) -> None: + super().transform(method, *args, **kwargs) + + def get_annotated_items(self) -> int: + return super().get_annotated_items() + + def get_annotations(self) -> int: + return super().get_annotations() + + def get_datasetitem_by_path(self, path: str) -> Optional[DatasetItem]: + raise NotAvailableError("Get dataset item by path is not allowed in streaming.") + + def get_patch(self): + raise NotAvailableError("Get patch is not allowed in streaming.") + + def flush_changes(self): + raise NotAvailableError("Flush changes is not allowed in streaming.") + + def update(self, source: Union[DatasetPatch, IDataset, Iterable[DatasetItem]]): + raise NotAvailableError("Update is not allowed in streaming.") + + def infos(self) -> DatasetInfo: + return self.stacked_transform.infos() + + def categories(self) -> CategoriesInfo: + return self.stacked_transform.categories() + + @property + def is_stream(self) -> bool: + return True From ddb46b64c582a26df9d661404cc45baf658878d6 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Wed, 15 Jan 2025 20:47:19 +0400 Subject: [PATCH 10/25] sync components/merge (extracting from components/operations.py) --- src/datumaro/cli/commands/merge.py | 6 +- src/datumaro/components/dataset.py | 16 +- src/datumaro/components/hl_ops/__init__.py | 4 +- src/datumaro/components/merge/__init__.py | 89 ++ src/datumaro/components/merge/base.py | 125 +++ src/datumaro/components/merge/exact_merge.py | 319 ++++++ .../components/merge/extractor_merger.py | 86 ++ .../components/merge/intersect_merge.py | 656 +++++++++++++ src/datumaro/components/merge/union_merge.py | 88 ++ src/datumaro/components/operations.py | 905 +----------------- src/datumaro/components/project.py | 6 + tests/unit/test_api.py | 3 +- tests/unit/test_ops.py | 26 +- 13 files changed, 1408 insertions(+), 921 deletions(-) create mode 100644 src/datumaro/components/merge/__init__.py create mode 100644 src/datumaro/components/merge/base.py create mode 100644 src/datumaro/components/merge/exact_merge.py create mode 100644 src/datumaro/components/merge/extractor_merger.py create mode 100644 src/datumaro/components/merge/intersect_merge.py create mode 100644 src/datumaro/components/merge/union_merge.py diff --git a/src/datumaro/cli/commands/merge.py b/src/datumaro/cli/commands/merge.py index d2f3c67d0b..5188fb5825 100644 --- a/src/datumaro/cli/commands/merge.py +++ b/src/datumaro/cli/commands/merge.py @@ -8,10 +8,10 @@ import os.path as osp from collections import OrderedDict -from datumaro.components.dataset import DEFAULT_FORMAT +from datumaro.components.dataset import DEFAULT_FORMAT, Dataset from datumaro.components.environment import Environment from datumaro.components.errors import DatasetMergeError, DatasetQualityError, ProjectNotFoundError -from datumaro.components.operations import IntersectMerge +from datumaro.components.merge.intersect_merge import IntersectMerge from datumaro.components.project import ProjectBuildTargets from datumaro.util import dump_json_file from datumaro.util.scope import scope_add, scoped @@ -230,7 +230,7 @@ def merge_command(args): quorum=args.quorum, ) ) - merged_dataset = merger(source_datasets) + merged_dataset = Dataset(source=merger(*source_datasets)) merged_dataset.export(save_dir=dst_dir, format=converter, **export_args) diff --git a/src/datumaro/components/dataset.py b/src/datumaro/components/dataset.py index 0fe778a80b..34b628702b 100644 --- a/src/datumaro/components/dataset.py +++ b/src/datumaro/components/dataset.py @@ -76,12 +76,18 @@ def subsets(self): return self.parent.subsets() return {self.name: self} + def infos(self): + return {} + def categories(self): return self.parent.categories() def media_type(self): return self.parent.media_type() + def ann_types(self): + return [] + def as_dataset(self) -> Dataset: return Dataset.from_extractors(self, env=self.parent.env) @@ -177,10 +183,10 @@ def from_extractors(*sources: IDataset, env: Optional[Environment] = None) -> Da source = sources[0] dataset = Dataset(source=source, env=env) else: - from datumaro.components.operations import ExactMerge + from datumaro.components.merge.exact_merge import ExactMerge media_type = ExactMerge.merge_media_types(sources) - source = ExactMerge.merge(*sources) + source = ExactMerge.merge(sources) categories = ExactMerge.merge_categories(s.categories() for s in sources) dataset = Dataset(source=source, categories=categories, media_type=media_type, env=env) return dataset @@ -225,12 +231,18 @@ def get_subset(self, name) -> DatasetSubset: def subsets(self) -> Dict[str, DatasetSubset]: return {k: self.get_subset(k) for k in self._data.subsets()} + def infos(self): + return {} + def categories(self) -> CategoriesInfo: return self._data.categories() def media_type(self) -> Type[MediaElement]: return self._data.media_type() + def ann_types(self): + return [] + def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: return self._data.get(id, subset) diff --git a/src/datumaro/components/hl_ops/__init__.py b/src/datumaro/components/hl_ops/__init__.py index a5c61dca74..19f5536e25 100644 --- a/src/datumaro/components/hl_ops/__init__.py +++ b/src/datumaro/components/hl_ops/__init__.py @@ -14,7 +14,7 @@ from datumaro.components.exporter import Exporter from datumaro.components.filter import XPathAnnotationsFilter, XPathDatasetFilter from datumaro.components.launcher import Launcher -from datumaro.components.operations import ExactMerge +from datumaro.components.merge.exact_merge import ExactMerge from datumaro.components.transformer import ModelTransform, Transform from datumaro.components.validator import TaskType, Validator from datumaro.util import parse_str_enum_value @@ -109,7 +109,7 @@ def merge(*datasets: IDataset) -> IDataset: categories = ExactMerge.merge_categories(d.categories() for d in datasets) media_type = ExactMerge.merge_media_types(datasets) return DatasetItemStorageDatasetView( - parent=ExactMerge.merge(*datasets), + parent=ExactMerge.merge(datasets), infos={}, categories=categories, media_type=media_type, diff --git a/src/datumaro/components/merge/__init__.py b/src/datumaro/components/merge/__init__.py new file mode 100644 index 0000000000..cb0d08455f --- /dev/null +++ b/src/datumaro/components/merge/__init__.py @@ -0,0 +1,89 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from .base import Merger +from .exact_merge import ExactMerge +from .intersect_merge import IntersectMerge +from .union_merge import UnionMerge + +DEFAULT_MERGE_POLICY = "exact" + + +def get_merger(merge_policy: str = DEFAULT_MERGE_POLICY, *args, **kwargs) -> Merger: + """ + Get :class:`Merger` according to `merge_policy`. You have to choose an appropriate `Merger` + for your purpose. The available merge policies are "union", "intersect", and "exact". + + 1. :class:`UnionMerge` + + Merge several datasets with "union" policy: + + - Label categories are merged according to the union of their label names. + For example, if Dataset-A has {"car", "cat", "dog"} and Dataset-B has + {"car", "bus", "truck"} labels, the merged dataset will have + {"bust", "car", "cat", "dog", "truck"} labels. + + - If there are two or more dataset items whose (id, subset) pairs match each other, + both are included in the merged dataset. At this time, since the same (id, subset) + pair cannot be duplicated in the dataset, we add a suffix to the id of each source item. + For example, if Dataset-A has DatasetItem(id="magic", subset="train") and Dataset-B has + also DatasetItem(id="magic", subset="train"), the merged dataset will have + DatasetItem(id="magic-0", subset="train") and DatasetItem(id="magic-1", subset="train"). + + 2. :class:`IntersectMerge` + + Merge several datasets with "intersect" policy: + + - If there are two or more dataset items whose (id, subset) pairs match each other, + we can consider this as having an intersection in our dataset. This method merges + the annotations of the corresponding :class:`DatasetItem` into one :class:`DatasetItem` + to handle this intersection. The rule to handle merging annotations is provided by + :class:`AnnotationMerger` according to their annotation types. For example, + DatasetItem(id="item_1", subset="train", annotations=[Bbox(0, 0, 1, 1)]) from Dataset-A and + DatasetItem(id="item_1", subset="train", annotations=[Bbox(.5, .5, 1, 1)]) from Dataset-B can be + merged into DatasetItem(id="item_1", subset="train", annotations=[Bbox(0, 0, 1, 1)]). + + - Label categories are merged according to the union of their label names + (Same as `UnionMerge`). For example, if Dataset-A has {"car", "cat", "dog"} + and Dataset-B has {"car", "bus", "truck"} labels, the merged dataset will have + {"bust", "car", "cat", "dog", "truck"} labels. + + - This merge has configuration parameters (`conf`) to control the annotation merge behaviors. + + For example, + + ```python + merge = IntersectMerge( + conf=IntersectMerge.Conf( + pairwise_dist=0.25, + groups=[], + output_conf_thresh=0.0, + quorum=0, + ) + ) + ``` + + For more details for the parameters, please refer to :class:`IntersectMerge.Conf`. + + 3. :class:`ExactMerge` + + Merges several datasets using the "simple" algorithm: + + - All datasets should have the same categories + - items are matched by (id, subset) pairs + - matching items share the media info available: + - nothing + nothing = nothing + - nothing + something = something + - something A + something B = conflict + - annotations are matched by value and shared + - in case of conflicts, throws an error + """ + if merge_policy == "union": + return UnionMerge(*args, **kwargs) + elif merge_policy == "intersect": + return IntersectMerge(*args, **kwargs) + elif merge_policy == "exact": + return ExactMerge(*args, **kwargs) + + raise ValueError(f"{merge_policy} is invalid Merger name.") diff --git a/src/datumaro/components/merge/base.py b/src/datumaro/components/merge/base.py new file mode 100644 index 0000000000..e985c11d40 --- /dev/null +++ b/src/datumaro/components/merge/base.py @@ -0,0 +1,125 @@ +# Copyright (C) 2020-2023 Intel Corporation +# +# SPDX-License-Identifier: MIT +import logging as log +import os +from collections import OrderedDict +from typing import Dict, Optional, Sequence, Set, Type + +from datumaro.components.abstracts.merger import IMergerContext +from datumaro.components.annotation import AnnotationType +from datumaro.components.cli_plugin import CliPlugin +from datumaro.components.dataset_base import CategoriesInfo, DatasetInfo, IDataset +from datumaro.components.dataset_item_storage import DatasetItemStorageDatasetView +from datumaro.components.errors import ( + ConflictingCategoriesError, + DatasetMergeError, + DatasetQualityError, + MediaTypeError, +) +from datumaro.components.media import MediaElement +from datumaro.util import dump_json_file + + +class Merger(IMergerContext, CliPlugin): + """Merge multiple datasets into one dataset""" + + def __init__(self, **options): + super().__init__(**options) + self.__dict__["_sources"] = None + self.errors = [] + + @staticmethod + def merge_infos(sources: Sequence[DatasetInfo]) -> Dict: + """Merge several :class:`IDataset` into one :class:`IDataset`""" + infos = {} + for source in sources: + for k, v in source.items(): + if k in infos: + log.warning( + "Duplicated infos field %s: overwrite from %s to %s", k, infos[k], v + ) + infos[k] = v + return infos + + @staticmethod + def merge_categories(sources: Sequence[CategoriesInfo]) -> Dict: + categories = {} + for source_idx, source in enumerate(sources): + for cat_type, source_cat in source.items(): + existing_cat = categories.setdefault(cat_type, source_cat) + if existing_cat != source_cat and len(source_cat) != 0: + if len(existing_cat) == 0: + categories[cat_type] = source_cat + else: + raise ConflictingCategoriesError( + "Merging of datasets with different categories is " + "only allowed in 'merge' command.", + sources=list(range(source_idx)), + ) + return categories + + @staticmethod + def merge_media_types(sources: Sequence[IDataset]) -> Optional[Type[MediaElement]]: + if sources: + media_type = sources[0].media_type() + for s in sources: + if not issubclass(s.media_type(), media_type) or not issubclass( + media_type, s.media_type() + ): + # Symmetric comparision is needed in the case of subclasses: + # eg. Image and RoIImage + raise MediaTypeError("Datasets have different media types") + return media_type + + return None + + @staticmethod + def merge_ann_types(sources: Sequence[IDataset]) -> Optional[Set[AnnotationType]]: + ann_types = set() + for source in sources: + ann_types.union(source.ann_types()) + return ann_types + + def __call__(self, *datasets: IDataset) -> DatasetItemStorageDatasetView: + infos = self.merge_infos(d.infos() for d in datasets) + categories = self.merge_categories(d.categories() for d in datasets) + media_type = self.merge_media_types(datasets) + ann_types = self.merge_ann_types(datasets) + return DatasetItemStorageDatasetView( + parent=self.merge(datasets), + infos=infos, + categories=categories, + media_type=media_type, + ann_types=ann_types, + ) + + def save_merge_report(self, path: str) -> None: + item_errors = OrderedDict() + source_errors = OrderedDict() + all_errors = [] + + for e in self.errors: + if isinstance(e, DatasetQualityError): + item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1 + elif isinstance(e, DatasetMergeError): + for s in e.sources: + source_errors[str(s)] = source_errors.get(s, 0) + 1 + item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1 + + all_errors.append(str(e)) + + errors = OrderedDict( + [ + ("Item errors", item_errors), + ("Source errors", source_errors), + ("All errors", all_errors), + ] + ) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + dump_json_file(path, errors, indent=True) + + def get_any_label_name(self, ann, label_id): + raise NotImplementedError diff --git a/src/datumaro/components/merge/exact_merge.py b/src/datumaro/components/merge/exact_merge.py new file mode 100644 index 0000000000..96dfa1b64d --- /dev/null +++ b/src/datumaro/components/merge/exact_merge.py @@ -0,0 +1,319 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import Any, Dict, Iterable, List, Sequence, Tuple, Union + +from datumaro.components.annotation import Annotation +from datumaro.components.dataset_base import DatasetItem, IDataset +from datumaro.components.dataset_item_storage import DatasetItemStorage +from datumaro.components.errors import ( + DatasetMergeError, + MismatchingAttributesError, + MismatchingImageInfoError, + MismatchingMediaError, + MismatchingMediaPathError, + VideoMergeError, +) +from datumaro.components.media import Image, MediaElement, MultiframeImage, PointCloud, Video +from datumaro.components.merge import Merger + +__all__ = ["ExactMerge"] + + +class ExactMerge(Merger): + """ + Merges several datasets using the "simple" algorithm: + - All datasets should have the same categories + - items are matched by (id, subset) pairs + - matching items share the media info available: + - nothing + nothing = nothing + - nothing + something = something + - something A + something B = conflict + - annotations are matched by value and shared + - in case of conflicts, throws an error + """ + + def __init__(self, **options): + super().__init__(**options) + + @classmethod + def merge(cls, sources: Sequence[IDataset]) -> DatasetItemStorage: + items = DatasetItemStorage() + for source_idx, source in enumerate(sources): + for item in source: + existing_item = items.get(item.id, item.subset) + if existing_item is not None: + try: + item = cls.merge_items(existing_item, item) + except DatasetMergeError as e: + e.sources = set(range(source_idx)) + raise e + + items.put(item) + return items + + @classmethod + def _match_annotations_equal(cls, a, b): + matches = [] + a_unmatched = a[:] + b_unmatched = b[:] + for a_ann in a: + for b_ann in b_unmatched: + if a_ann != b_ann: + continue + + matches.append((a_ann, b_ann)) + a_unmatched.remove(a_ann) + b_unmatched.remove(b_ann) + break + + return matches, a_unmatched, b_unmatched + + @classmethod + def _merge_annotations_equal(cls, a, b): + matches, a_unmatched, b_unmatched = cls._match_annotations_equal(a, b) + return [ann_a for (ann_a, _) in matches] + a_unmatched + b_unmatched + + @classmethod + def merge_items(cls, existing_item: DatasetItem, current_item: DatasetItem) -> DatasetItem: + return existing_item.wrap( + media=cls._merge_media(existing_item, current_item), + attributes=cls._merge_attrs( + existing_item.attributes, + current_item.attributes, + item_id=(existing_item.id, existing_item.subset), + ), + annotations=cls._merge_anno(existing_item.annotations, current_item.annotations), + ) + + @classmethod + def _merge_attrs(cls, a: Dict[str, Any], b: Dict[str, Any], item_id: Tuple[str, str]) -> Dict: + merged = {} + + for name in a.keys() | b.keys(): + a_val = a.get(name, None) + b_val = b.get(name, None) + + if name not in a: + m_val = b_val + elif name not in b: + m_val = a_val + elif a_val != b_val: + raise MismatchingAttributesError(item_id, name, a_val, b_val) + else: + m_val = a_val + + merged[name] = m_val + + return merged + + @classmethod + def _merge_media( + cls, item_a: DatasetItem, item_b: DatasetItem + ) -> Union[Image, PointCloud, Video]: + if (not item_a.media or isinstance(item_a.media, Image)) and ( + not item_b.media or isinstance(item_b.media, Image) + ): + media = cls._merge_images(item_a, item_b) + elif (not item_a.media or isinstance(item_a.media, PointCloud)) and ( + not item_b.media or isinstance(item_b.media, PointCloud) + ): + media = cls._merge_point_clouds(item_a, item_b) + elif (not item_a.media or isinstance(item_a.media, Video)) and ( + not item_b.media or isinstance(item_b.media, Video) + ): + media = cls._merge_videos(item_a, item_b) + elif (not item_a.media or isinstance(item_a.media, MultiframeImage)) and ( + not item_b.media or isinstance(item_b.media, MultiframeImage) + ): + media = cls._merge_multiframe_images(item_a, item_b) + elif (not item_a.media or isinstance(item_a.media, MediaElement)) and ( + not item_b.media or isinstance(item_b.media, MediaElement) + ): + if isinstance(item_a.media, MediaElement) and isinstance(item_b.media, MediaElement): + item_a_path = getattr(item_a.media, "path", None) + item_b_path = getattr(item_b.media, "path", None) + + if item_a_path and item_b_path and item_a_path != item_b_path: + raise MismatchingMediaPathError( + (item_a.id, item_a.subset), item_a_path, item_b_path + ) + elif item_a_path is None and item_b_path is None: + raise MismatchingMediaError( + (item_a.id, item_a.subset), item_a.media, item_b.media + ) + + media = item_a.media if item_a_path else item_b.media + + elif isinstance(item_a.media, MediaElement): + media = item_a.media + else: + media = item_b.media + else: + raise MismatchingMediaError((item_a.id, item_a.subset), item_a.media, item_b.media) + return media + + @classmethod + def _merge_images(cls, item_a: DatasetItem, item_b: DatasetItem) -> Image: + media = None + + if isinstance(item_a.media, Image) and isinstance(item_b.media, Image): + item_a_path = getattr(item_a.media, "path", None) + item_b_path = getattr(item_b.media, "path", None) + + if ( + item_a_path + and item_b_path + and item_a_path != item_b_path + and item_a.media.has_data is item_b.media.has_data + ): + # We use has_data as a replacement for path existence check + # - If only one image has data, we'll use it. The other + # one is just a path metainfo, which is not significant + # in this case. + # - If both images have data or both don't, we need + # to compare paths. + # + # Different paths can aclually point to the same file, + # but it's not the case we'd like to allow here to be + # a "simple" merging strategy used for extractor joining + raise MismatchingMediaPathError( + (item_a.id, item_a.subset), item_a_path, item_b_path + ) + + if ( + item_a.media.has_size + and item_b.media.has_size + and item_a.media.size != item_b.media.size + ): + raise MismatchingImageInfoError( + (item_a.id, item_a.subset), item_a.media.size, item_b.media.size + ) + + # Avoid direct comparison here for better performance + # If there are 2 "data-only" images, they won't be compared and + # we just use the first one + if item_a.media.has_data: + media = item_a.media + elif item_b.media.has_data: + media = item_b.media + elif item_a_path: + media = item_a.media + elif item_b_path: + media = item_b.media + elif item_a.media.has_size: + media = item_a.media + elif item_b.media.has_size: + media = item_b.media + else: + raise MismatchingMediaError((item_a.id, item_a.subset), item_a.media, item_b.media) + + if not media.has_data or not media.has_size: + if item_a.media._size: + media._size = item_a.media._size + elif item_b.media._size: + media._size = item_b.media._size + elif isinstance(item_a.media, Image): + media = item_a.media + else: + media = item_b.media + + return media + + @classmethod + def _merge_point_clouds(cls, item_a: DatasetItem, item_b: DatasetItem) -> PointCloud: + media = None + + if isinstance(item_a.media, PointCloud) and isinstance(item_b.media, PointCloud): + item_a_path = getattr(item_a.media, "path", None) + item_b_path = getattr(item_b.media, "path", None) + + if item_a_path and item_b_path and item_a_path != item_b_path: + raise MismatchingMediaPathError( + (item_a.id, item_a.subset), item_a_path, item_b_path + ) + + # Avoid direct comparison here for better performance + # If there are 2 "data-only" pointclouds, they won't be compared and + # we just use the first one + if item_a.media.has_data or item_a.media.extra_images: + media = item_a.media + + if item_b.media.extra_images: + for image in item_b.media.extra_images: + if image not in media.extra_images: + media.extra_images.append(image) + elif item_b.media.has_data or item_b.media.extra_images: + media = item_b.media + + if item_a.media.extra_images: + for image in item_a.media.extra_images: + if image not in media.extra_images: + media.extra_images.append(image) + else: + raise MismatchingMediaError((item_a.id, item_a.subset), item_a.media, item_b.media) + + elif isinstance(item_a.media, PointCloud): + media = item_a.media + else: + media = item_b.media + + return media + + @classmethod + def _merge_videos(cls, item_a: DatasetItem, item_b: DatasetItem) -> Video: + media = None + + if isinstance(item_a.media, Video) and isinstance(item_b.media, Video): + if ( + item_a.media.path is not item_b.media.path + or item_a.media._start_frame is not item_b.media._start_frame + or item_a.media._end_frame is not item_b.media._end_frame + or item_a.media._step is not item_b.media._step + ): + raise VideoMergeError(item_a.id) + + media = item_a.media + elif isinstance(item_a.media, Video): + media = item_a.media + else: + media = item_b.media + + return media + + @classmethod + def _merge_multiframe_images(cls, item_a: DatasetItem, item_b: DatasetItem) -> MultiframeImage: + media = None + + if isinstance(item_a.media, MultiframeImage) and isinstance(item_b.media, MultiframeImage): + if item_a.media.path and item_b.media.path and item_a.media.path != item_b.media.path: + raise MismatchingMediaPathError( + (item_a.id, item_a.subset), item_a.media.path, item_b.media.path + ) + + if item_a.media.path or item_a.media.data: + media = item_a.media + + if item_b.media.data: + for image in item_b.media.data: + if image not in media.data: + media.data.append(image) + else: + media = item_b.media + + if item_a.media.data: + for image in item_a.media.data: + if image not in media.data: + media.data.append(image) + + elif isinstance(item_a.media, MultiframeImage): + media = item_a.media + else: + media = item_b.media + + return media + + @classmethod + def _merge_anno(cls, a: Iterable[Annotation], b: Iterable[Annotation]) -> List[Annotation]: + return cls._merge_annotations_equal(a, b) diff --git a/src/datumaro/components/merge/extractor_merger.py b/src/datumaro/components/merge/extractor_merger.py new file mode 100644 index 0000000000..f5c3a873c3 --- /dev/null +++ b/src/datumaro/components/merge/extractor_merger.py @@ -0,0 +1,86 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from collections import defaultdict +from typing import Dict, Iterator, List, Optional, Sequence, TypeVar + +from datumaro.components.contexts.importer import _ImportFail +from datumaro.components.dataset_base import ( + CategoriesInfo, + DatasetBase, + DatasetInfo, + DatasetItem, + SubsetBase, +) + +T = TypeVar("T") + + +def check_identicalness(seq: Sequence[T], raise_error_on_empty: bool = True) -> Optional[T]: + if len(seq) == 0 and raise_error_on_empty: + raise _ImportFail("It should not be empty.") + elif len(seq) == 0 and not raise_error_on_empty: + return None + + if seq.count(seq[0]) != len(seq): + raise _ImportFail("All items in the sequence should be identical.") + + return seq[0] + + +class ExtractorMerger(DatasetBase): + """A simple class to merge single-subset extractors.""" + + def __init__( + self, + sources: Sequence[SubsetBase], + ): + if len(sources) == 0: + raise _ImportFail("It should not be empty.") + + self._infos = check_identicalness([s.infos() for s in sources]) + self._categories = check_identicalness([s.categories() for s in sources]) + self._media_type = check_identicalness([s.media_type() for s in sources]) + + ann_types = set() + for source in sources: + ann_types.union(source.ann_types()) + self._ann_types = ann_types + + self._is_stream = check_identicalness([s.is_stream for s in sources]) + + self._subsets: Dict[str, List[SubsetBase]] = defaultdict(list) + for source in sources: + self._subsets[source.subset] += [source] + + def infos(self) -> DatasetInfo: + return self._infos + + def categories(self) -> CategoriesInfo: + return self._categories + + def __iter__(self) -> Iterator[DatasetItem]: + for sources in self._subsets.values(): + for source in sources: + yield from source + + def __len__(self) -> int: + return sum(len(source) for sources in self._subsets.values() for source in sources) + + def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: + if subset is not None and (sources := self._subsets.get(subset, [])): + for source in sources: + if item := source.get(id, subset): + return item + + for sources in self._subsets.values(): + for source in sources: + if item := source.get(id=id, subset=source.subset): + return item + + return None + + @property + def is_stream(self) -> bool: + return self._is_stream diff --git a/src/datumaro/components/merge/intersect_merge.py b/src/datumaro/components/merge/intersect_merge.py new file mode 100644 index 0000000000..6d4f1e9605 --- /dev/null +++ b/src/datumaro/components/merge/intersect_merge.py @@ -0,0 +1,656 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import logging as log +from collections import OrderedDict +from typing import Dict, Sequence + +import attr +from attr import attrib, attrs + +from datumaro.components.annotation import ( + AnnotationType, + LabelCategories, + MaskCategories, + PointsCategories, +) +from datumaro.components.annotations.merger import ( + AnnotationMerger, + BboxMerger, + CaptionsMerger, + Cuboid2DMerger, + Cuboid3dMerger, + EllipseMerger, + FeatureVectorMerger, + HashKeyMerger, + ImageAnnotationMerger, + LabelMerger, + LineMerger, + MaskMerger, + PointsMerger, + PolygonMerger, + RotatedBboxMerger, + TabularMerger, +) +from datumaro.components.dataset_base import DatasetItem, IDataset +from datumaro.components.dataset_item_storage import ( + DatasetItemStorage, + DatasetItemStorageDatasetView, +) +from datumaro.components.errors import ( + AnnotationsTooCloseError, + ConflictingCategoriesError, + FailedAttrVotingError, + NoMatchingAnnError, + NoMatchingItemError, + WrongGroupError, +) +from datumaro.components.merge import Merger +from datumaro.util import find +from datumaro.util.annotation_util import find_instances, max_bbox +from datumaro.util.attrs_util import ensure_cls + +__all__ = ["IntersectMerge"] + + +@attrs +class IntersectMerge(Merger): + """ + Merge several datasets with "intersect" policy: + + - If there are two or more dataset items whose (id, subset) pairs match each other, + we can consider this as having an intersection in our dataset. This method merges + the annotations of the corresponding :class:`DatasetItem` into one :class:`DatasetItem` + to handle this intersection. The rule to handle merging annotations is provided by + :class:`AnnotationMerger` according to their annotation types. For example, + DatasetItem(id="item_1", subset="train", annotations=[Bbox(0, 0, 1, 1)]) from Dataset-A and + DatasetItem(id="item_1", subset="train", annotations=[Bbox(.5, .5, 1, 1)]) from Dataset-B can be + merged into DatasetItem(id="item_1", subset="train", annotations=[Bbox(0, 0, 1, 1)]). + + - Label categories are merged according to the union of their label names + (Same as `UnionMerge`). For example, if Dataset-A has {"car", "cat", "dog"} + and Dataset-B has {"car", "bus", "truck"} labels, the merged dataset will have + {"bust", "car", "cat", "dog", "truck"} labels. + + - This merge has configuration parameters (`conf`) to control the annotation merge behaviors. + + For example, + + ```python + merge = IntersectMerge( + conf=IntersectMerge.Conf( + pairwise_dist=0.25, + groups=[], + output_conf_thresh=0.0, + quorum=0, + ) + ) + ``` + + For more details for the parameters, please refer to :class:`IntersectMerge.Conf`. + """ + + def __init__(self, **options): + super().__init__(**options) + + @attrs(repr_ns="IntersectMerge", kw_only=True) + class Conf: + """ + Parameters + ---------- + pairwise_dist + IoU match threshold for segments + sigma + Parameter for Object Keypoint Similarity metric + (https://cocodataset.org/#keypoints-eval) + output_conf_thresh + Confidence threshold for output annotations + quorum + Minimum count for a label and attribute voting results to be counted + ignored_attributes + Attributes to be ignored in the merged :class:`DatasetItem` + groups + A comma-separated list of labels in annotation groups to check. + '?' postfix can be added to a label to make it optional in the group (repeatable) + close_distance + Distance threshold between annotations to decide their closeness. If they are decided + to be close, it will be enrolled to the error tracker. + """ + + pairwise_dist = attrib(converter=float, default=0.5) + sigma = attrib(converter=list, factory=list) + + output_conf_thresh = attrib(converter=float, default=0) + quorum = attrib(converter=int, default=0) + ignored_attributes = attrib(converter=set, factory=set) + + def _groups_converter(value): + result = [] + for group in value: + rg = set() + for label in group: + optional = label.endswith("?") + name = label if not optional else label[:-1] + rg.add((name, optional)) + result.append(rg) + return result + + groups = attrib(converter=_groups_converter, factory=list) + close_distance = attrib(converter=float, default=0.75) + + conf = attrib(converter=ensure_cls(Conf), factory=Conf) + + # Error trackers: + errors = attrib(factory=list, init=False) + + def add_item_error(self, error, *args, **kwargs): + self.errors.append(error(self._item_id, *args, **kwargs)) + + # Indexes: + _dataset_map = attrib(init=False) # id(dataset) -> (dataset, index) + _item_map = attrib(init=False) # id(item) -> (item, id(dataset)) + _ann_map = attrib(init=False) # id(ann) -> (ann, id(item)) + _item_id = attrib(init=False) + _item = attrib(init=False) + + # Misc. + _infos = attrib(init=False) # merged infos + _categories = attrib(init=False) # merged categories + + def merge(self, sources: Sequence[IDataset]) -> DatasetItemStorage: + self._infos = self.merge_infos([d.infos() for d in sources]) + self._categories = self.merge_categories([d.categories() for d in sources]) + merged = DatasetItemStorage() + self._check_groups_definition() + + item_matches, item_map = self.match_items(sources) + self._item_map = item_map + self._dataset_map = {id(d): (d, i) for i, d in enumerate(sources)} + + for item_id, items in item_matches.items(): + self._item_id = item_id + + if len(items) < len(sources): + missing_sources = set(id(s) for s in sources) - set(items) + missing_sources = [self._dataset_map[s][1] for s in missing_sources] + self.add_item_error(NoMatchingItemError, sources=missing_sources) + merged.put(self.merge_items(items)) + + return merged + + def get_ann_source(self, ann_id): + return self._item_map[self._ann_map[ann_id][1]][1] + + def __call__(self, *datasets: IDataset) -> DatasetItemStorageDatasetView: + # TODO: self.merge() should be the first since this order matters for + # IntersectMerge. + merged = self.merge(datasets) + infos = self.merge_infos(d.infos() for d in datasets) + categories = self.merge_categories(d.categories() for d in datasets) + media_type = self.merge_media_types(datasets) + ann_types = self.merge_ann_types(datasets) + return DatasetItemStorageDatasetView( + parent=merged, + infos=infos, + categories=categories, + media_type=media_type, + ann_types=ann_types, + ) + + def merge_categories(self, sources: Sequence[IDataset]) -> Dict: + # TODO: This is a temporary workaround to minimize code changes. + # We have to revisit it to make this class stateless. + if hasattr(self, "_categories"): + return self._categories + + dst_categories = {} + + label_cat = self._merge_label_categories(sources) + if label_cat is None: + label_cat = LabelCategories() + dst_categories[AnnotationType.label] = label_cat + + points_cat = self._merge_point_categories(sources, label_cat) + if points_cat is not None: + dst_categories[AnnotationType.points] = points_cat + + mask_cat = self._merge_mask_categories(sources, label_cat) + if mask_cat is not None: + dst_categories[AnnotationType.mask] = mask_cat + + return dst_categories + + def merge_items(self, items: Dict[int, DatasetItem]) -> DatasetItem: + self._item = next(iter(items.values())) + + self._ann_map = {} + sources = [] + for item in items.values(): + self._ann_map.update({id(a): (a, id(item)) for a in item.annotations}) + sources.append(item.annotations) + log.debug( + "Merging item %s: source annotations %s" % (self._item_id, list(map(len, sources))) + ) + + annotations = self.merge_annotations(sources) + + annotations = [ + a for a in annotations if self.conf.output_conf_thresh <= a.attributes.get("score", 1) + ] + + return self._item.wrap(annotations=annotations) + + def merge_annotations(self, sources): + self._make_mergers(sources) + + clusters = self._match_annotations(sources) + + joined_clusters = sum(clusters.values(), []) + group_map = self._find_cluster_groups(joined_clusters) + + annotations = [] + for t, clusters in clusters.items(): + for cluster in clusters: + self._check_cluster_sources(cluster) + + merged_clusters = self._merge_clusters(t, clusters) + + for merged_ann, cluster in zip(merged_clusters, clusters): + attributes = self._find_cluster_attrs(cluster, merged_ann) + attributes = { + k: v for k, v in attributes.items() if k not in self.conf.ignored_attributes + } + attributes.update(merged_ann.attributes) + merged_ann.attributes = attributes + + new_group_id = find(enumerate(group_map), lambda e: id(cluster) in e[1][0]) + if new_group_id is None: + new_group_id = 0 + else: + new_group_id = new_group_id[0] + 1 + merged_ann.group = new_group_id + + if self.conf.close_distance: + self._check_annotation_distance(t, merged_clusters) + + annotations += merged_clusters + + if self.conf.groups: + self._check_groups(annotations) + + return annotations + + def match_items(self, datasets): + item_ids = set((item.id, item.subset) for d in datasets for item in d) + + item_map = {} # id(item) -> (item, id(dataset)) + + matches = OrderedDict() + for item_id, item_subset in sorted(item_ids, key=lambda e: e[0]): + items = {} + for d in datasets: + item = d.get(item_id, subset=item_subset) + if item: + items[id(d)] = item + item_map[id(item)] = (item, id(d)) + matches[(item_id, item_subset)] = items + + return matches, item_map + + def _merge_label_categories(self, sources): + same = True + common = None + for src_categories in sources: + src_cat = src_categories.get(AnnotationType.label) + if common is None: + common = src_cat + elif common != src_cat: + same = False + break + + if same: + return common + + dst_cat = LabelCategories() + for src_id, src_categories in enumerate(sources): + src_cat = src_categories.get(AnnotationType.label) + if src_cat is None: + continue + + for src_label in src_cat.items: + dst_label = dst_cat.find(src_label.name, src_label.parent)[1] + if dst_label is not None: + if dst_label != src_label: + if ( + src_label.parent + and dst_label.parent + and src_label.parent != dst_label.parent + ): + raise ConflictingCategoriesError( + "Can't merge label category %s (from #%s): " + "parent label conflict: %s vs. %s" + % (src_label.name, src_id, src_label.parent, dst_label.parent), + sources=list(range(src_id)), + ) + dst_label.parent = dst_label.parent or src_label.parent + dst_label.attributes |= src_label.attributes + else: + pass + else: + dst_cat.add(src_label.name, src_label.parent, src_label.attributes) + + return dst_cat + + def _merge_point_categories(self, sources, label_cat): + dst_point_cat = PointsCategories() + + for src_id, src_categories in enumerate(sources): + src_label_cat = src_categories.get(AnnotationType.label) + src_point_cat = src_categories.get(AnnotationType.points) + if src_label_cat is None or src_point_cat is None: + continue + + for src_label_id, src_cat in src_point_cat.items.items(): + src_label = src_label_cat.items[src_label_id].name + src_parent_label = src_label_cat.items[src_label_id].parent + dst_label_id = label_cat.find(src_label, src_parent_label)[0] + dst_cat = dst_point_cat.items.get(dst_label_id) + if dst_cat is not None: + if dst_cat != src_cat: + raise ConflictingCategoriesError( + "Can't merge point category for label " + "%s (from #%s): %s vs. %s" % (src_label, src_id, src_cat, dst_cat), + sources=list(range(src_id)), + ) + else: + pass + else: + dst_point_cat.add(dst_label_id, src_cat.labels, src_cat.joints) + + if len(dst_point_cat.items) == 0: + return None + + return dst_point_cat + + def _merge_mask_categories(self, sources, label_cat): + dst_mask_cat = MaskCategories() + + for src_id, src_categories in enumerate(sources): + src_label_cat = src_categories.get(AnnotationType.label) + src_mask_cat = src_categories.get(AnnotationType.mask) + if src_label_cat is None or src_mask_cat is None: + continue + + for src_label_id, src_cat in src_mask_cat.colormap.items(): + src_label = src_label_cat.items[src_label_id].name + src_parent_label = src_label_cat.items[src_label_id].parent + dst_label_id = label_cat.find(src_label, src_parent_label)[0] + dst_cat = dst_mask_cat.colormap.get(dst_label_id) + if dst_cat is not None: + if dst_cat != src_cat: + raise ConflictingCategoriesError( + "Can't merge mask category for label " + "%s (from #%s): %s vs. %s" % (src_label, src_id, src_cat, dst_cat), + sources=list(range(src_id)), + ) + else: + pass + else: + dst_mask_cat.colormap[dst_label_id] = src_cat + + if len(dst_mask_cat.colormap) == 0: + return None + + return dst_mask_cat + + def _match_annotations(self, sources): + all_by_type = {} + for s in sources: + src_by_type = {} + for a in s: + src_by_type.setdefault(a.type, []).append(a) + for k, v in src_by_type.items(): + all_by_type.setdefault(k, []).append(v) + + clusters = {} + for k, v in all_by_type.items(): + clusters.setdefault(k, []).extend(self._match_ann_type(k, v)) + + return clusters + + def _make_mergers(self, sources): + def _make(c, **kwargs): + kwargs.update(attr.asdict(self.conf)) + fields = attr.fields_dict(c) + return c(**{k: v for k, v in kwargs.items() if k in fields}, context=self) + + def _for_type(t, **kwargs): + if t is AnnotationType.unknown: + return _make(AnnotationMerger, **kwargs) + elif t is AnnotationType.label: + return _make(LabelMerger, **kwargs) + elif t is AnnotationType.bbox: + return _make(BboxMerger, **kwargs) + elif t is AnnotationType.mask: + return _make(MaskMerger, **kwargs) + elif t is AnnotationType.polygon: + return _make(PolygonMerger, **kwargs) + elif t is AnnotationType.polyline: + return _make(LineMerger, **kwargs) + elif t is AnnotationType.points: + return _make(PointsMerger, **kwargs) + elif t is AnnotationType.caption: + return _make(CaptionsMerger, **kwargs) + elif t is AnnotationType.cuboid_3d: + return _make(Cuboid3dMerger, **kwargs) + elif t is AnnotationType.super_resolution_annotation: + return _make(ImageAnnotationMerger, **kwargs) + elif t is AnnotationType.depth_annotation: + return _make(ImageAnnotationMerger, **kwargs) + elif t is AnnotationType.ellipse: + return _make(EllipseMerger, **kwargs) + elif t is AnnotationType.hash_key: + return _make(HashKeyMerger, **kwargs) + elif t is AnnotationType.feature_vector: + return _make(FeatureVectorMerger, **kwargs) + elif t is AnnotationType.tabular: + return _make(TabularMerger, **kwargs) + elif t is AnnotationType.rotated_bbox: + return _make(RotatedBboxMerger, **kwargs) + elif t is AnnotationType.cuboid_2d: + return _make(Cuboid2DMerger, **kwargs) + elif t is AnnotationType.skeleton: + # to do: add skeletons merge + return _make(ImageAnnotationMerger, **kwargs) + else: + raise NotImplementedError("Type %s is not supported" % t) + + instance_map = {} + for s in sources: + s_instances = find_instances(s) + for inst in s_instances: + inst_bbox = max_bbox( + [ + a + for a in inst + if a.type + in {AnnotationType.polygon, AnnotationType.mask, AnnotationType.bbox} + ] + ) + for ann in inst: + instance_map[id(ann)] = [inst, inst_bbox] + + self._mergers = {t: _for_type(t, instance_map=instance_map) for t in AnnotationType} + + def _match_ann_type(self, t, sources): + return self._mergers[t].match_annotations(sources) + + def _merge_clusters(self, t, clusters): + return self._mergers[t].merge_clusters(clusters) + + def _find_cluster_groups(self, clusters): + cluster_groups = [] + visited = set() + for a_idx, cluster_a in enumerate(clusters): + if a_idx in visited: + continue + visited.add(a_idx) + + cluster_group = {id(cluster_a)} + + # find segment groups in the cluster group + a_groups = set(ann.group for ann in cluster_a) + for cluster_b in clusters[a_idx + 1 :]: + b_groups = set(ann.group for ann in cluster_b) + if a_groups & b_groups: + a_groups |= b_groups + + # now we know all the segment groups in this cluster group + # so we can find adjacent clusters + for b_idx, cluster_b in enumerate(clusters[a_idx + 1 :]): + b_idx = a_idx + 1 + b_idx + b_groups = set(ann.group for ann in cluster_b) + if a_groups & b_groups: + cluster_group.add(id(cluster_b)) + visited.add(b_idx) + + if a_groups == {0}: + continue # skip annotations without a group + cluster_groups.append((cluster_group, a_groups)) + return cluster_groups + + def _find_cluster_attrs(self, cluster, ann): + quorum = self.conf.quorum or 0 + + # TODO: when attribute types are implemented, add linear + # interpolation for contiguous values + + attr_votes = {} # name -> { value: score , ... } + for s in cluster: + for name, value in s.attributes.items(): + votes = attr_votes.get(name, {}) + votes[value] = 1 + votes.get(value, 0) + attr_votes[name] = votes + + attributes = {} + for name, votes in attr_votes.items(): + winner, count = max(votes.items(), key=lambda e: e[1]) + if count < quorum: + if sum(votes.values()) < quorum: + # blame provokers + missing_sources = set( + self.get_ann_source(id(a)) + for a in cluster + if s.attributes.get(name) == winner + ) + else: + # blame outliers + missing_sources = set( + self.get_ann_source(id(a)) + for a in cluster + if s.attributes.get(name) != winner + ) + missing_sources = [self._dataset_map[s][1] for s in missing_sources] + self.add_item_error( + FailedAttrVotingError, name, votes, ann, sources=missing_sources + ) + continue + attributes[name] = winner + + return attributes + + def _check_cluster_sources(self, cluster): + if len(cluster) == len(self._dataset_map): + return + + def _has_item(s): + item = self._dataset_map[s][0].get(*self._item_id) + if not item: + return False + if len(item.annotations) == 0: + return False + return True + + missing_sources = set(self._dataset_map) - set(self.get_ann_source(id(a)) for a in cluster) + missing_sources = [self._dataset_map[s][1] for s in missing_sources if _has_item(s)] + if missing_sources: + self.add_item_error(NoMatchingAnnError, cluster[0], sources=missing_sources) + + def _check_annotation_distance(self, t, annotations): + for a_idx, a_ann in enumerate(annotations): + for b_ann in annotations[a_idx + 1 :]: + d = self._mergers[t].distance(a_ann, b_ann) + if self.conf.close_distance < d: + self.add_item_error(AnnotationsTooCloseError, a_ann, b_ann, d) + + def _check_groups(self, annotations): + check_groups = [] + for check_group_raw in self.conf.groups: + check_group = set(l[0] for l in check_group_raw) + optional = set(l[0] for l in check_group_raw if l[1]) + check_groups.append((check_group, optional)) + + def _check_group(group_labels, group): + for check_group, optional in check_groups: + common = check_group & group_labels + real_miss = check_group - common - optional + extra = group_labels - check_group + if common and (extra or real_miss): + self.add_item_error(WrongGroupError, group_labels, check_group, group) + break + + groups = find_instances(annotations) + for group in groups: + group_labels = set() + for ann in group: + if not hasattr(ann, "label"): + continue + label = self._get_label_name(ann.label) + + if ann.group: + group_labels.add(label) + else: + _check_group({label}, [ann]) + + if not group_labels: + continue + _check_group(group_labels, group) + + def _get_label_name(self, label_id): + if label_id is None: + return None + return self._categories[AnnotationType.label].items[label_id].name + + def _get_label_id(self, label, parent=""): + if label is not None: + return self._categories[AnnotationType.label].find(label, parent)[0] + return None + + def _get_src_label_name(self, ann, label_id): + if label_id is None: + return None + item_id = self._ann_map[id(ann)][1] + dataset_id = self._item_map[item_id][1] + return ( + self._dataset_map[dataset_id][0].categories()[AnnotationType.label].items[label_id].name + ) + + def get_any_label_name(self, ann, label_id): + if label_id is None: + return None + try: + return self._get_src_label_name(ann, label_id) + except KeyError: + return self._get_label_name(label_id) + + def _check_groups_definition(self): + for group in self.conf.groups: + for label, _ in group: + _, entry = self._categories[AnnotationType.label].find(label) + if entry is None: + raise ValueError( + "Datasets do not contain " + "label '%s', available labels %s" + % (label, [i.name for i in self._categories[AnnotationType.label].items]) + ) diff --git a/src/datumaro/components/merge/union_merge.py b/src/datumaro/components/merge/union_merge.py new file mode 100644 index 0000000000..e4c4d0cf02 --- /dev/null +++ b/src/datumaro/components/merge/union_merge.py @@ -0,0 +1,88 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from collections import defaultdict +from typing import Dict, List, Sequence, Tuple + +from datumaro.components.annotation import AnnotationType, LabelCategories +from datumaro.components.dataset_base import DatasetItem, IDataset +from datumaro.components.dataset_item_storage import DatasetItemStorage +from datumaro.components.merge import Merger + +__all__ = ["UnionMerge"] + + +class UnionMerge(Merger): + """ + Merge several datasets with "union" policy: + + - Label categories are merged according to the union of their label names. + For example, if Dataset-A has {"car", "cat", "dog"} and Dataset-B has + {"car", "bus", "truck"} labels, the merged dataset will have + {"bust", "car", "cat", "dog", "truck"} labels. + + - If there are two or more dataset items whose (id, subset) pairs match each other, + both are included in the merged dataset. At this time, since the same (id, subset) + pair cannot be duplicated in the dataset, we add a suffix to the id of each source item. + For example, if Dataset-A has DatasetItem(id="magic", subset="train") and Dataset-B has + also DatasetItem(id="magic", subset="train"), the merged dataset will have + DatasetItem(id="magic-0", subset="train") and DatasetItem(id="magic-1", subset="train"). + """ + + def __init__(self, **options): + super().__init__(**options) + self._matching_table = {} + + def merge(self, sources: Sequence[IDataset]) -> DatasetItemStorage: + dict_items: Dict[Tuple[str, str], List[DatasetItem]] = defaultdict(list) + + for source_idx, source in enumerate(sources): + for item in source: + if self._matching_table.get(source_idx, None): + for ann in item.annotations: + ann.label = self._matching_table[source_idx][ann.label] + dict_items[item.id, item.subset].append(item) + + item_storage = DatasetItemStorage() + + for items in dict_items.values(): + if len(items) == 1: + item_storage.put(items[0]) + else: + for idx, item in enumerate(items): + # Add prefix + item_storage.put(item.wrap(id=f"{item.id}-{idx}")) + + return item_storage + + def merge_categories(self, sources: Sequence[IDataset]) -> Dict: + dst_categories = {} + + label_cat = self._merge_label_categories(sources) + if label_cat is None: + label_cat = LabelCategories() + dst_categories[AnnotationType.label] = label_cat + + return dst_categories + + def _merge_label_categories(self, sources: Sequence[IDataset]) -> LabelCategories: + dst_cat = LabelCategories() + for src_id, src_categories in enumerate(sources): + src_cat = src_categories.get(AnnotationType.label) + if src_cat is None: + continue + + for src_label in src_cat.items: + src_idx = src_cat.find(src_label.name)[0] + dst_idx = dst_cat.find(src_label.name)[0] + if dst_idx is None: + dst_cat.add(src_label.name, src_label.parent, src_label.attributes) + dst_idx = dst_cat.find(src_label.name)[0] + + if self._matching_table.get(src_id, None): + self._matching_table[src_id].update({src_idx: dst_idx}) + else: + self._matching_table[src_id] = {src_idx: dst_idx} + + return dst_cat diff --git a/src/datumaro/components/operations.py b/src/datumaro/components/operations.py index 420c203453..b1150a839b 100644 --- a/src/datumaro/components/operations.py +++ b/src/datumaro/components/operations.py @@ -5,911 +5,16 @@ import hashlib import logging as log -from collections import OrderedDict from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Callable, Dict, Optional, Set, Tuple -import attr import cv2 import numpy as np -from attr import attrib, attrs - -from datumaro.components.annotation import ( - Annotation, - AnnotationType, - LabelCategories, - MaskCategories, - PointsCategories, -) -from datumaro.components.annotations.merger import ( - BboxMerger, - CaptionsMerger, - Cuboid3dMerger, - ImageAnnotationMerger, - LabelMerger, - LineMerger, - MaskMerger, - PointsMerger, - PolygonMerger, -) -from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.dataset import Dataset, IDataset -from datumaro.components.dataset_base import CategoriesInfo, DatasetItem -from datumaro.components.dataset_item_storage import DatasetItemStorage -from datumaro.components.errors import ( - AnnotationsTooCloseError, - ConflictingCategoriesError, - DatasetMergeError, - FailedAttrVotingError, - MediaTypeError, - MismatchingAttributesError, - MismatchingImageInfoError, - MismatchingMediaError, - MismatchingMediaPathError, - NoMatchingAnnError, - NoMatchingItemError, - VideoMergeError, - WrongGroupError, -) -from datumaro.components.media import Image, MediaElement, MultiframeImage, PointCloud, Video -from datumaro.util import find -from datumaro.util.annotation_util import find_instances, max_bbox -from datumaro.util.attrs_util import ensure_cls - - -def match_annotations_equal(a, b): - matches = [] - a_unmatched = a[:] - b_unmatched = b[:] - for a_ann in a: - for b_ann in b_unmatched: - if a_ann != b_ann: - continue - - matches.append((a_ann, b_ann)) - a_unmatched.remove(a_ann) - b_unmatched.remove(b_ann) - break - - return matches, a_unmatched, b_unmatched - - -def merge_annotations_equal(a, b): - matches, a_unmatched, b_unmatched = match_annotations_equal(a, b) - return [ann_a for (ann_a, _) in matches] + a_unmatched + b_unmatched - - -def merge_categories(sources): - categories = {} - for source_idx, source in enumerate(sources): - for cat_type, source_cat in source.items(): - existing_cat = categories.setdefault(cat_type, source_cat) - if existing_cat != source_cat and len(source_cat) != 0: - if len(existing_cat) == 0: - categories[cat_type] = source_cat - else: - raise ConflictingCategoriesError( - "Merging of datasets with different categories is " - "only allowed in 'merge' command.", - sources=list(range(source_idx)), - ) - return categories - - -class MergingStrategy(CliPlugin): - @classmethod - def merge(cls, sources, **options): - instance = cls(**options) - return instance(sources) - - def __init__(self, **options): - super().__init__(**options) - self.__dict__["_sources"] = None - - def __call__(self, sources): - raise NotImplementedError() - - -class ExactMerge: - """ - Merges several datasets using the "simple" algorithm: - - items are matched by (id, subset) pairs - - matching items share the media info available: - - - nothing + nothing = nothing - - nothing + something = something - - something A + something B = conflict - - annotations are matched by value and shared - - in case of conflicts, throws an error - """ - - @classmethod - def merge(cls, *sources: IDataset) -> DatasetItemStorage: - items = DatasetItemStorage() - for source_idx, source in enumerate(sources): - for item in source: - existing_item = items.get(item.id, item.subset) - if existing_item is not None: - try: - item = cls._merge_items(existing_item, item) - except DatasetMergeError as e: - e.sources = set(range(source_idx)) - raise e - - items.put(item) - return items - - @classmethod - def _merge_items(cls, existing_item: DatasetItem, current_item: DatasetItem) -> DatasetItem: - return existing_item.wrap( - media=cls._merge_media(existing_item, current_item), - attributes=cls._merge_attrs( - existing_item.attributes, - current_item.attributes, - item_id=(existing_item.id, existing_item.subset), - ), - annotations=cls._merge_anno(existing_item.annotations, current_item.annotations), - ) - - @staticmethod - def _merge_attrs(a: Dict[str, Any], b: Dict[str, Any], item_id: Tuple[str, str]) -> Dict: - merged = {} - - for name in a.keys() | b.keys(): - a_val = a.get(name, None) - b_val = b.get(name, None) - - if name not in a: - m_val = b_val - elif name not in b: - m_val = a_val - elif a_val != b_val: - raise MismatchingAttributesError(item_id, name, a_val, b_val) - else: - m_val = a_val - - merged[name] = m_val - - return merged - - @classmethod - def _merge_media( - cls, item_a: DatasetItem, item_b: DatasetItem - ) -> Union[Image, PointCloud, Video]: - if (not item_a.media or isinstance(item_a.media, Image)) and ( - not item_b.media or isinstance(item_b.media, Image) - ): - media = cls._merge_images(item_a, item_b) - elif (not item_a.media or isinstance(item_a.media, PointCloud)) and ( - not item_b.media or isinstance(item_b.media, PointCloud) - ): - media = cls._merge_point_clouds(item_a, item_b) - elif (not item_a.media or isinstance(item_a.media, Video)) and ( - not item_b.media or isinstance(item_b.media, Video) - ): - media = cls._merge_videos(item_a, item_b) - elif (not item_a.media or isinstance(item_a.media, MultiframeImage)) and ( - not item_b.media or isinstance(item_b.media, MultiframeImage) - ): - media = cls._merge_multiframe_images(item_a, item_b) - elif (not item_a.media or isinstance(item_a.media, MediaElement)) and ( - not item_b.media or isinstance(item_b.media, MediaElement) - ): - if isinstance(item_a.media, MediaElement) and isinstance(item_b.media, MediaElement): - if ( - item_a.media.path - and item_b.media.path - and item_a.media.path != item_b.media.path - ): - raise MismatchingMediaPathError( - (item_a.id, item_a.subset), item_a.media.path, item_b.media.path - ) - - if item_a.media.path: - media = item_a.media - else: - media = item_b.media - - elif isinstance(item_a.media, MediaElement): - media = item_a.media - else: - media = item_b.media - else: - raise MismatchingMediaError((item_a.id, item_a.subset), item_a.media, item_b.media) - return media - - @staticmethod - def _merge_images(item_a: DatasetItem, item_b: DatasetItem) -> Image: - media = None - - if isinstance(item_a.media, Image) and isinstance(item_b.media, Image): - if ( - item_a.media.path - and item_b.media.path - and item_a.media.path != item_b.media.path - and item_a.media.has_data is item_b.media.has_data - ): - # We use has_data as a replacement for path existence check - # - If only one image has data, we'll use it. The other - # one is just a path metainfo, which is not significant - # in this case. - # - If both images have data or both don't, we need - # to compare paths. - # - # Different paths can aclually point to the same file, - # but it's not the case we'd like to allow here to be - # a "simple" merging strategy used for extractor joining - raise MismatchingMediaPathError( - (item_a.id, item_a.subset), item_a.media.path, item_b.media.path - ) - - if ( - item_a.media.has_size - and item_b.media.has_size - and item_a.media.size != item_b.media.size - ): - raise MismatchingImageInfoError( - (item_a.id, item_a.subset), item_a.media.size, item_b.media.size - ) - - # Avoid direct comparison here for better performance - # If there are 2 "data-only" images, they won't be compared and - # we just use the first one - if item_a.media.has_data: - media = item_a.media - elif item_b.media.has_data: - media = item_b.media - elif item_a.media.path: - media = item_a.media - elif item_b.media.path: - media = item_b.media - elif item_a.media.has_size: - media = item_a.media - elif item_b.media.has_size: - media = item_b.media - else: - assert False, "Unknown image field combination" - - if not media.has_data or not media.has_size: - if item_a.media._size: - media._size = item_a.media._size - elif item_b.media._size: - media._size = item_b.media._size - elif isinstance(item_a.media, Image): - media = item_a.media - else: - media = item_b.media - - return media - - @staticmethod - def _merge_point_clouds(item_a: DatasetItem, item_b: DatasetItem) -> PointCloud: - media = None - - if isinstance(item_a.media, PointCloud) and isinstance(item_b.media, PointCloud): - if item_a.media.path and item_b.media.path and item_a.media.path != item_b.media.path: - raise MismatchingMediaPathError( - (item_a.id, item_a.subset), item_a.media.path, item_b.media.path - ) - - if item_a.media.path or item_a.media.extra_images: - media = item_a.media - - if item_b.media.extra_images: - for image in item_b.media.extra_images: - if image not in media.extra_images: - media.extra_images.append(image) - else: - media = item_b.media - - if item_a.media.extra_images: - for image in item_a.media.extra_images: - if image not in media.extra_images: - media.extra_images.append(image) - - elif isinstance(item_a.media, PointCloud): - media = item_a.media - else: - media = item_b.media - - return media - - @staticmethod - def _merge_videos(item_a: DatasetItem, item_b: DatasetItem) -> Video: - media = None - - if isinstance(item_a.media, Video) and isinstance(item_b.media, Video): - if ( - item_a.media.path is not item_b.media.path - or item_a.media._start_frame is not item_b.media._start_frame - or item_a.media._end_frame is not item_b.media._end_frame - or item_a.media._step is not item_b.media._step - ): - raise VideoMergeError(item_a.id) - - media = item_a.media - elif isinstance(item_a.media, Video): - media = item_a.media - else: - media = item_b.media - - return media - - @staticmethod - def _merge_multiframe_images(item_a: DatasetItem, item_b: DatasetItem) -> MultiframeImage: - media = None - - if isinstance(item_a.media, MultiframeImage) and isinstance(item_b.media, MultiframeImage): - if item_a.media.path and item_b.media.path and item_a.media.path != item_b.media.path: - raise MismatchingMediaPathError( - (item_a.id, item_a.subset), item_a.media.path, item_b.media.path - ) - - if item_a.media.path or item_a.media.data: - media = item_a.media - - if item_b.media.data: - for image in item_b.media.data: - if image not in media.data: - media.data.append(image) - else: - media = item_b.media - - if item_a.media.data: - for image in item_a.media.data: - if image not in media.data: - media.data.append(image) - - elif isinstance(item_a.media, MultiframeImage): - media = item_a.media - else: - media = item_b.media - - return media - - @staticmethod - def _merge_anno(a: Iterable[Annotation], b: Iterable[Annotation]) -> List[Annotation]: - return merge_annotations_equal(a, b) - - @staticmethod - def merge_categories(sources: Iterable[IDataset]) -> CategoriesInfo: - return merge_categories(sources) - - @staticmethod - def merge_media_types(sources: Iterable[IDataset]) -> Type[MediaElement]: - if sources: - media_type = sources[0].media_type() - for s in sources: - if not issubclass(s.media_type(), media_type) or not issubclass( - media_type, s.media_type() - ): - # Symmetric comparision is needed in the case of subclasses: - # eg. Image and ByteImage - raise MediaTypeError("Datasets have different media types") - return media_type - - return None - - -@attrs -class IntersectMerge(MergingStrategy): - @attrs(repr_ns="IntersectMerge", kw_only=True) - class Conf: - pairwise_dist = attrib(converter=float, default=0.5) - sigma = attrib(converter=list, factory=list) - - output_conf_thresh = attrib(converter=float, default=0) - quorum = attrib(converter=int, default=0) - ignored_attributes = attrib(converter=set, factory=set) - - def _groups_converter(value): - result = [] - for group in value: - rg = set() - for label in group: - optional = label.endswith("?") - name = label if not optional else label[:-1] - rg.add((name, optional)) - result.append(rg) - return result - - groups = attrib(converter=_groups_converter, factory=list) - close_distance = attrib(converter=float, default=0.75) - - conf = attrib(converter=ensure_cls(Conf), factory=Conf) - - # Error trackers: - errors = attrib(factory=list, init=False) - - def add_item_error(self, error, *args, **kwargs): - self.errors.append(error(self._item_id, *args, **kwargs)) - - # Indexes: - _dataset_map = attrib(init=False) # id(dataset) -> (dataset, index) - _item_map = attrib(init=False) # id(item) -> (item, id(dataset)) - _ann_map = attrib(init=False) # id(ann) -> (ann, id(item)) - _item_id = attrib(init=False) - _item = attrib(init=False) - - # Misc. - _categories = attrib(init=False) # merged categories - - def __call__(self, datasets): - self._categories = self._merge_categories([d.categories() for d in datasets]) - merged = Dataset( - categories=self._categories, media_type=ExactMerge.merge_media_types(datasets) - ) - - self._check_groups_definition() - - item_matches, item_map = self.match_items(datasets) - self._item_map = item_map - self._dataset_map = {id(d): (d, i) for i, d in enumerate(datasets)} - - for item_id, items in item_matches.items(): - self._item_id = item_id - - if len(items) < len(datasets): - missing_sources = set(id(s) for s in datasets) - set(items) - missing_sources = [self._dataset_map[s][1] for s in missing_sources] - self.add_item_error(NoMatchingItemError, sources=missing_sources) - merged.put(self.merge_items(items)) - - return merged - - def get_ann_source(self, ann_id): - return self._item_map[self._ann_map[ann_id][1]][1] - - def merge_items(self, items): - self._item = next(iter(items.values())) - self._ann_map = {} - sources = [] - for item in items.values(): - self._ann_map.update({id(a): (a, id(item)) for a in item.annotations}) - sources.append(item.annotations) - log.debug( - "Merging item %s: source annotations %s" % (self._item_id, list(map(len, sources))) - ) - - annotations = self.merge_annotations(sources) - - annotations = [ - a for a in annotations if self.conf.output_conf_thresh <= a.attributes.get("score", 1) - ] - - return self._item.wrap(annotations=annotations) - - def merge_annotations(self, sources): - self._make_mergers(sources) - - clusters = self._match_annotations(sources) - - joined_clusters = sum(clusters.values(), []) - group_map = self._find_cluster_groups(joined_clusters) - - annotations = [] - for t, clusters in clusters.items(): - for cluster in clusters: - self._check_cluster_sources(cluster) - - merged_clusters = self._merge_clusters(t, clusters) - - for merged_ann, cluster in zip(merged_clusters, clusters): - attributes = self._find_cluster_attrs(cluster, merged_ann) - attributes = { - k: v for k, v in attributes.items() if k not in self.conf.ignored_attributes - } - attributes.update(merged_ann.attributes) - merged_ann.attributes = attributes - - new_group_id = find(enumerate(group_map), lambda e: id(cluster) in e[1][0]) - if new_group_id is None: - new_group_id = 0 - else: - new_group_id = new_group_id[0] + 1 - merged_ann.group = new_group_id - - if self.conf.close_distance: - self._check_annotation_distance(t, merged_clusters) - - annotations += merged_clusters - - if self.conf.groups: - self._check_groups(annotations) - - return annotations - - @staticmethod - def match_items(datasets): - item_ids = set((item.id, item.subset) for d in datasets for item in d) - - item_map = {} # id(item) -> (item, id(dataset)) - - matches = OrderedDict() - for item_id, item_subset in sorted(item_ids, key=lambda e: e[0]): - items = {} - for d in datasets: - item = d.get(item_id, subset=item_subset) - if item: - items[id(d)] = item - item_map[id(item)] = (item, id(d)) - matches[(item_id, item_subset)] = items - - return matches, item_map - - def _merge_label_categories(self, sources): - same = True - common = None - for src_categories in sources: - src_cat = src_categories.get(AnnotationType.label) - if common is None: - common = src_cat - elif common != src_cat: - same = False - break - - if same: - return common - - dst_cat = LabelCategories() - for src_id, src_categories in enumerate(sources): - src_cat = src_categories.get(AnnotationType.label) - if src_cat is None: - continue - - for src_label in src_cat.items: - dst_label = dst_cat.find(src_label.name, src_label.parent)[1] - if dst_label is not None: - if dst_label != src_label: - if ( - src_label.parent - and dst_label.parent - and src_label.parent != dst_label.parent - ): - raise ConflictingCategoriesError( - "Can't merge label category %s (from #%s): " - "parent label conflict: %s vs. %s" - % (src_label.name, src_id, src_label.parent, dst_label.parent), - sources=list(range(src_id)), - ) - dst_label.parent = dst_label.parent or src_label.parent - dst_label.attributes |= src_label.attributes - else: - pass - else: - dst_cat.add(src_label.name, src_label.parent, src_label.attributes) - - return dst_cat - - def _merge_point_categories(self, sources, label_cat): - dst_point_cat = PointsCategories() - - for src_id, src_categories in enumerate(sources): - src_label_cat = src_categories.get(AnnotationType.label) - src_point_cat = src_categories.get(AnnotationType.points) - if src_label_cat is None or src_point_cat is None: - continue - - for src_label_id, src_cat in src_point_cat.items.items(): - src_label = src_label_cat.items[src_label_id].name - src_parent_label = src_label_cat.items[src_label_id].parent - dst_label_id = label_cat.find(src_label, src_parent_label)[0] - dst_cat = dst_point_cat.items.get(dst_label_id) - if dst_cat is not None: - if dst_cat != src_cat: - raise ConflictingCategoriesError( - "Can't merge point category for label " - "%s (from #%s): %s vs. %s" % (src_label, src_id, src_cat, dst_cat), - sources=list(range(src_id)), - ) - else: - pass - else: - dst_point_cat.add(dst_label_id, src_cat.labels, src_cat.joints) - - if len(dst_point_cat.items) == 0: - return None - - return dst_point_cat - - def _merge_mask_categories(self, sources, label_cat): - dst_mask_cat = MaskCategories() - - for src_id, src_categories in enumerate(sources): - src_label_cat = src_categories.get(AnnotationType.label) - src_mask_cat = src_categories.get(AnnotationType.mask) - if src_label_cat is None or src_mask_cat is None: - continue - - for src_label_id, src_cat in src_mask_cat.colormap.items(): - src_label = src_label_cat.items[src_label_id].name - src_parent_label = src_label_cat.items[src_label_id].parent - dst_label_id = label_cat.find(src_label, src_parent_label)[0] - dst_cat = dst_mask_cat.colormap.get(dst_label_id) - if dst_cat is not None: - if dst_cat != src_cat: - raise ConflictingCategoriesError( - "Can't merge mask category for label " - "%s (from #%s): %s vs. %s" % (src_label, src_id, src_cat, dst_cat), - sources=list(range(src_id)), - ) - else: - pass - else: - dst_mask_cat.colormap[dst_label_id] = src_cat - - if len(dst_mask_cat.colormap) == 0: - return None - - return dst_mask_cat - - def _merge_categories(self, sources): - dst_categories = {} - - label_cat = self._merge_label_categories(sources) - if label_cat is None: - label_cat = LabelCategories() - dst_categories[AnnotationType.label] = label_cat - - points_cat = self._merge_point_categories(sources, label_cat) - if points_cat is not None: - dst_categories[AnnotationType.points] = points_cat - - mask_cat = self._merge_mask_categories(sources, label_cat) - if mask_cat is not None: - dst_categories[AnnotationType.mask] = mask_cat - - return dst_categories - - def _match_annotations(self, sources): - all_by_type = {} - for s in sources: - src_by_type = {} - for a in s: - src_by_type.setdefault(a.type, []).append(a) - for k, v in src_by_type.items(): - all_by_type.setdefault(k, []).append(v) - - clusters = {} - for k, v in all_by_type.items(): - clusters.setdefault(k, []).extend(self._match_ann_type(k, v)) - - return clusters - - def _make_mergers(self, sources): - def _make(c, **kwargs): - kwargs.update(attr.asdict(self.conf)) - fields = attr.fields_dict(c) - return c(**{k: v for k, v in kwargs.items() if k in fields}, context=self) - - def _for_type(t, **kwargs): - if t is AnnotationType.label: - return _make(LabelMerger, **kwargs) - elif t is AnnotationType.bbox: - return _make(BboxMerger, **kwargs) - elif t is AnnotationType.mask: - return _make(MaskMerger, **kwargs) - elif t is AnnotationType.polygon: - return _make(PolygonMerger, **kwargs) - elif t is AnnotationType.polyline: - return _make(LineMerger, **kwargs) - elif t is AnnotationType.points: - return _make(PointsMerger, **kwargs) - elif t is AnnotationType.caption: - return _make(CaptionsMerger, **kwargs) - elif t is AnnotationType.cuboid_3d: - return _make(Cuboid3dMerger, **kwargs) - elif t is AnnotationType.super_resolution_annotation: - return _make(ImageAnnotationMerger, **kwargs) - elif t is AnnotationType.depth_annotation: - return _make(ImageAnnotationMerger, **kwargs) - elif t is AnnotationType.skeleton: - # to do: add skeletons merge - return _make(ImageAnnotationMerger, **kwargs) - # TODO: remove later - elif ( - t is AnnotationType.unknown - or t is AnnotationType.ellipse - or t is AnnotationType.hash_key - or t is AnnotationType.feature_vector - or t is AnnotationType.tabular - or t is AnnotationType.rotated_bbox - or t is AnnotationType.cuboid_2d - ): - return None - else: - raise NotImplementedError("Type %s is not supported" % t) - - instance_map = {} - for s in sources: - s_instances = find_instances(s) - for inst in s_instances: - inst_bbox = max_bbox( - [ - a - for a in inst - if a.type - in {AnnotationType.polygon, AnnotationType.mask, AnnotationType.bbox} - ] - ) - for ann in inst: - instance_map[id(ann)] = [inst, inst_bbox] - - self._mergers = {t: _for_type(t, instance_map=instance_map) for t in AnnotationType} - - def _match_ann_type(self, t, sources): - return self._mergers[t].match_annotations(sources) - - def _merge_clusters(self, t, clusters): - return self._mergers[t].merge_clusters(clusters) - - @staticmethod - def _find_cluster_groups(clusters): - cluster_groups = [] - visited = set() - for a_idx, cluster_a in enumerate(clusters): - if a_idx in visited: - continue - visited.add(a_idx) - - cluster_group = {id(cluster_a)} - - # find segment groups in the cluster group - a_groups = set(ann.group for ann in cluster_a) - for cluster_b in clusters[a_idx + 1 :]: - b_groups = set(ann.group for ann in cluster_b) - if a_groups & b_groups: - a_groups |= b_groups - - # now we know all the segment groups in this cluster group - # so we can find adjacent clusters - for b_idx, cluster_b in enumerate(clusters[a_idx + 1 :]): - b_idx = a_idx + 1 + b_idx - b_groups = set(ann.group for ann in cluster_b) - if a_groups & b_groups: - cluster_group.add(id(cluster_b)) - visited.add(b_idx) - - if a_groups == {0}: - continue # skip annotations without a group - cluster_groups.append((cluster_group, a_groups)) - return cluster_groups - - def _find_cluster_attrs(self, cluster, ann): - quorum = self.conf.quorum or 0 - - # TODO: when attribute types are implemented, add linear - # interpolation for contiguous values - - attr_votes = {} # name -> { value: score , ... } - for s in cluster: - for name, value in s.attributes.items(): - votes = attr_votes.get(name, {}) - votes[value] = 1 + votes.get(value, 0) - attr_votes[name] = votes - - attributes = {} - for name, votes in attr_votes.items(): - winner, count = max(votes.items(), key=lambda e: e[1]) - if count < quorum: - if sum(votes.values()) < quorum: - # blame provokers - missing_sources = set( - self.get_ann_source(id(a)) - for a in cluster - if s.attributes.get(name) == winner - ) - else: - # blame outliers - missing_sources = set( - self.get_ann_source(id(a)) - for a in cluster - if s.attributes.get(name) != winner - ) - missing_sources = [self._dataset_map[s][1] for s in missing_sources] - self.add_item_error( - FailedAttrVotingError, name, votes, ann, sources=missing_sources - ) - continue - attributes[name] = winner - - return attributes - - def _check_cluster_sources(self, cluster): - if len(cluster) == len(self._dataset_map): - return - - def _has_item(s): - item = self._dataset_map[s][0].get(*self._item_id) - if not item: - return False - if len(item.annotations) == 0: - return False - return True - - missing_sources = set(self._dataset_map) - set(self.get_ann_source(id(a)) for a in cluster) - missing_sources = [self._dataset_map[s][1] for s in missing_sources if _has_item(s)] - if missing_sources: - self.add_item_error(NoMatchingAnnError, cluster[0], sources=missing_sources) - - def _check_annotation_distance(self, t, annotations): - for a_idx, a_ann in enumerate(annotations): - for b_ann in annotations[a_idx + 1 :]: - d = self._mergers[t].distance(a_ann, b_ann) - if self.conf.close_distance < d: - self.add_item_error(AnnotationsTooCloseError, a_ann, b_ann, d) - - def _check_groups(self, annotations): - check_groups = [] - for check_group_raw in self.conf.groups: - check_group = set(l[0] for l in check_group_raw) - optional = set(l[0] for l in check_group_raw if l[1]) - check_groups.append((check_group, optional)) - - def _check_group(group_labels, group): - for check_group, optional in check_groups: - common = check_group & group_labels - real_miss = check_group - common - optional - extra = group_labels - check_group - if common and (extra or real_miss): - self.add_item_error(WrongGroupError, group_labels, check_group, group) - break - - groups = find_instances(annotations) - for group in groups: - group_labels = set() - for ann in group: - if not hasattr(ann, "label"): - continue - label = self._get_label_name(ann.label) - - if ann.group: - group_labels.add(label) - else: - _check_group({label}, [ann]) - - if not group_labels: - continue - _check_group(group_labels, group) - - def _get_label_name(self, label_id): - if label_id is None: - return None - return self._categories[AnnotationType.label].items[label_id].name - - def _get_label_id(self, label, parent=""): - if label is not None: - return self._categories[AnnotationType.label].find(label, parent)[0] - return None - - def _get_src_label_name(self, ann, label_id): - if label_id is None: - return None - item_id = self._ann_map[id(ann)][1] - dataset_id = self._item_map[item_id][1] - return ( - self._dataset_map[dataset_id][0].categories()[AnnotationType.label].items[label_id].name - ) - - def get_any_label_name(self, ann, label_id): - if label_id is None: - return None - try: - return self._get_src_label_name(ann, label_id) - except KeyError: - return self._get_label_name(label_id) - - def _check_groups_definition(self): - for group in self.conf.groups: - for label, _ in group: - _, entry = self._categories[AnnotationType.label].find(label) - if entry is None: - raise ValueError( - "Datasets do not contain " - "label '%s', available labels %s" - % (label, [i.name for i in self._categories[AnnotationType.label].items]) - ) +from datumaro.components.annotation import AnnotationType, LabelCategories +from datumaro.components.dataset import IDataset +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.media import Image def mean_std(dataset: IDataset): diff --git a/src/datumaro/components/project.py b/src/datumaro/components/project.py index 9dcaa20fab..e092f8d70d 100644 --- a/src/datumaro/components/project.py +++ b/src/datumaro/components/project.py @@ -138,6 +138,9 @@ def __len__(self): def subsets(self): return self._dataset.subsets() + def infos(self): + return {} + def get_subset(self, name): return self._dataset.get_subset(name) @@ -150,6 +153,9 @@ def get(self, id, subset=None): def media_type(self): return self._dataset.media_type() + def ann_types(self): + return [] + class IgnoreMode(Enum): rewrite = auto() diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 46ef2351ab..bd3b960503 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -16,13 +16,12 @@ def test_can_import_core(self): def test_can_reach_module_alias_symbols_from_base(self): import datumaro as dm - assert hasattr(dm.ops, "ExactMerge") assert hasattr(dm.project, "Project") assert hasattr(dm.errors, "DatumaroError") @mark_requirement(Requirements.DATUM_API) def test_can_import_from_module_aliases(self): # pylint: disable=unused-import + from datumaro.components.merge.exact_merge import ExactMerge from datumaro.errors import DatumaroError - from datumaro.ops import ExactMerge from datumaro.project import Project diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index b96764769a..aff6aa1e50 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -19,13 +19,15 @@ from datumaro.components.annotations import match_segments_pair from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem -from datumaro.components.media import Image, MultiframeImage, PointCloud -from datumaro.components.operations import ( +from datumaro.components.errors import ( FailedAttrVotingError, - IntersectMerge, NoMatchingAnnError, NoMatchingItemError, WrongGroupError, +) +from datumaro.components.media import Image, MultiframeImage, PointCloud +from datumaro.components.merge.intersect_merge import IntersectMerge +from datumaro.components.operations import ( compute_ann_statistics, compute_image_statistics, find_unique_images, @@ -557,7 +559,7 @@ def test_can_match_items(self): ) merger = IntersectMerge() - merged = merger([source0, source1, source2]) + merged = merger(source0, source1, source2) compare_datasets(self, expected, merged) self.assertEqual( @@ -710,7 +712,7 @@ def test_can_match_shapes(self): ) merger = IntersectMerge(conf={"quorum": 1, "pairwise_dist": 0.1}) - merged = merger([source0, source1, source2]) + merged = merger(source0, source1, source2) compare_datasets(self, expected, merged, ignored_attrs={"score"}) self.assertEqual( @@ -769,7 +771,7 @@ def test_can_match_lines_when_line_not_approximated(self): ) merger = IntersectMerge(conf={"quorum": 1, "pairwise_dist": 0.1}) - merged = merger([source0, source1]) + merged = merger(source0, source1) compare_datasets(self, expected, merged, ignored_attrs={"score"}) self.assertEqual(0, len(merger.errors)) @@ -846,7 +848,7 @@ def test_attributes(self): ) merger = IntersectMerge(conf={"quorum": 3, "ignored_attributes": {"ignored"}}) - merged = merger([source0, source1, source2]) + merged = merger(source0, source1, source2) compare_datasets(self, expected, merged, ignored_attrs={"score"}) self.assertEqual(2, len([e for e in merger.errors if isinstance(e, FailedAttrVotingError)])) @@ -873,7 +875,7 @@ def test_group_checks(self): ) merger = IntersectMerge(conf={"groups": [["a", "a_g1", "a_g2_opt?"], ["c", "c_g1_opt?"]]}) - merger([dataset, dataset]) + merger(dataset, dataset) self.assertEqual( 3, len([e for e in merger.errors if isinstance(e, WrongGroupError)]), merger.errors @@ -927,7 +929,7 @@ def test_can_merge_classes(self): ) merger = IntersectMerge() - merged = merger([source0, source1]) + merged = merger(source0, source1) compare_datasets(self, expected, merged, ignored_attrs={"score"}) @@ -1015,7 +1017,7 @@ def test_can_merge_categories(self): ) merger = IntersectMerge() - merged = merger([source0, source1]) + merged = merger(source0, source1) compare_datasets(self, expected, merged, ignored_attrs={"score"}) @@ -1069,7 +1071,7 @@ def test_can_merge_point_clouds(self): ) merger = IntersectMerge() - merged = merger([source0, source1]) + merged = merger(source0, source1) compare_datasets(self, expected, merged) @@ -1112,6 +1114,6 @@ def test_can_merge_multiframe_images(self): ) merger = IntersectMerge() - merged = merger([source0, source1]) + merged = merger(source0, source1) compare_datasets(self, expected, merged) From 68767ec8879d45ae13b3db37c6981f163e7b1137 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Thu, 16 Jan 2025 18:22:24 +0400 Subject: [PATCH 11/25] some typing fixes --- src/datumaro/components/abstracts/merger.py | 8 ++++--- .../components/annotations/matcher.py | 24 +++++++++---------- src/datumaro/components/annotations/merger.py | 8 +++---- src/datumaro/components/merge/base.py | 4 ++-- src/datumaro/util/annotation_util.py | 6 ++--- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/datumaro/components/abstracts/merger.py b/src/datumaro/components/abstracts/merger.py index 714e209641..cca8849efc 100644 --- a/src/datumaro/components/abstracts/merger.py +++ b/src/datumaro/components/abstracts/merger.py @@ -6,7 +6,7 @@ from typing import Dict, Optional, Sequence, Type from datumaro.components.annotation import Annotation -from datumaro.components.dataset_base import IDataset +from datumaro.components.dataset_base import DatasetInfo, IDataset from datumaro.components.dataset_item_storage import ( DatasetItemStorage, DatasetItemStorageDatasetView, @@ -23,16 +23,18 @@ def get_any_label_name(self, ann: Annotation, label_id: int) -> str: class IMergerContext(IMatcherContext): + @staticmethod @abstractmethod - def merge_infos(self, sources: Sequence[IDataset]) -> Dict: + def merge_infos(sources: Sequence[DatasetInfo]) -> Dict: raise NotImplementedError @abstractmethod def merge_categories(self, sources: Sequence[IDataset]) -> Dict: raise NotImplementedError + @staticmethod @abstractmethod - def merge_media_types(self, sources: Sequence[IDataset]) -> Optional[Type[MediaElement]]: + def merge_media_types(sources: Sequence[IDataset]) -> Optional[Type[MediaElement]]: raise NotImplementedError @abstractmethod diff --git a/src/datumaro/components/annotations/matcher.py b/src/datumaro/components/annotations/matcher.py index eb7c874cc4..5535db5935 100644 --- a/src/datumaro/components/annotations/matcher.py +++ b/src/datumaro/components/annotations/matcher.py @@ -9,7 +9,7 @@ from datumaro.components.abstracts import IMergerContext from datumaro.components.abstracts.merger import IMatcherContext -from datumaro.components.annotation import Annotation, Points +from datumaro.components.annotation import Annotation, Label, Points, RotatedBbox, Shape from datumaro.util.annotation_util import ( OKS, approximate_line, @@ -158,18 +158,18 @@ def match_segments_more_than_pair( class AnnotationMatcher: _context: Optional[Union[IMatcherContext, IMergerContext]] = attrib(default=None) - def match_annotations(self, sources): + def match_annotations(self, sources: List[List[Annotation]]) -> List[List[Annotation]]: raise NotImplementedError() @attrs class LabelMatcher(AnnotationMatcher): - def distance(self, a, b): + def distance(self, a: Label, b: Label) -> bool: a_label = self._context.get_any_label_name(a, a.label) b_label = self._context.get_any_label_name(b, b.label) return a_label == b_label - def match_annotations(self, sources): + def match_annotations(self, sources: List[List[Annotation]]) -> List[List[Annotation]]: return [sum(sources, [])] @@ -250,7 +250,7 @@ def _has_same_source(cluster, extra_id): return clusters - def distance(self, a, b): + def distance(self, a: Shape, b: Shape) -> float: return segment_iou(a, b) def label_matcher(self, a, b): @@ -279,7 +279,7 @@ class PointsMatcher(ShapeMatcher): sigma: Optional[list] = attrib(default=None) instance_map = attrib(converter=dict) - def distance(self, a, b): + def distance(self, a: Points, b: Points) -> int: a_bbox = self.instance_map[id(a)][1] b_bbox = self.instance_map[id(b)][1] if bbox_iou(a_bbox, b_bbox) <= 0: @@ -336,7 +336,7 @@ def _approx(line, segments): @attrs class CaptionsMatcher(AnnotationMatcher): - def match_annotations(self, sources): + def match_annotations(self, sources: List[List[Annotation]]) -> List[List[Annotation]]: raise NotImplementedError() @@ -348,25 +348,25 @@ def distance(self, a, b): @attrs class ImageAnnotationMatcher(AnnotationMatcher): - def match_annotations(self, sources): + def match_annotations(self, sources: List[List[Annotation]]) -> List[List[Annotation]]: raise NotImplementedError() @attrs class HashKeyMatcher(AnnotationMatcher): - def match_annotations(self, sources): + def match_annotations(self, sources: List[List[Annotation]]) -> List[List[Annotation]]: raise NotImplementedError() @attrs class FeatureVectorMatcher(AnnotationMatcher): - def match_annotations(self, sources): + def match_annotations(self, sources: List[List[Annotation]]) -> List[List[Annotation]]: raise NotImplementedError() @attrs class TabularMatcher(AnnotationMatcher): - def match_annotations(self, sources): + def match_annotations(self, sources: List[List[Annotation]]) -> List[List[Annotation]]: raise NotImplementedError() @@ -374,7 +374,7 @@ def match_annotations(self, sources): class RotatedBboxMatcher(ShapeMatcher): sigma: Optional[list] = attrib(default=None) - def distance(self, a, b): + def distance(self, a: RotatedBbox, b: RotatedBbox) -> int: a = Points([p for pt in a.as_polygon() for p in pt]) b = Points([p for pt in b.as_polygon() for p in pt]) diff --git a/src/datumaro/components/annotations/merger.py b/src/datumaro/components/annotations/merger.py index 8ff7593a61..e17f98bada 100644 --- a/src/datumaro/components/annotations/merger.py +++ b/src/datumaro/components/annotations/merger.py @@ -4,7 +4,7 @@ from attr import attrib, attrs -from datumaro.components.annotation import Bbox, Label +from datumaro.components.annotation import Annotation, Bbox, Label, Shape from datumaro.components.errors import FailedLabelVotingError from datumaro.util.annotation_util import mean_bbox, segment_iou @@ -47,7 +47,7 @@ @attrs(kw_only=True) class AnnotationMerger(AnnotationMatcher): - def merge_clusters(self, clusters): + def merge_clusters(self, clusters: list[list[Annotation]]) -> list[Annotation]: raise NotImplementedError() @@ -55,7 +55,7 @@ def merge_clusters(self, clusters): class LabelMerger(AnnotationMerger, LabelMatcher): quorum = attrib(converter=int, default=0) - def merge_clusters(self, clusters): + def merge_clusters(self, clusters: list[list[Label]]) -> list[Label]: assert len(clusters) <= 1 if len(clusters) == 0: return [] @@ -91,7 +91,7 @@ def merge_clusters(self, clusters): class _ShapeMerger(AnnotationMerger, ShapeMatcher): quorum = attrib(converter=int, default=0) - def merge_clusters(self, clusters): + def merge_clusters(self, clusters: list[list[Shape]]) -> list[Shape]: return list(map(self.merge_cluster, clusters)) def find_cluster_label(self, cluster): diff --git a/src/datumaro/components/merge/base.py b/src/datumaro/components/merge/base.py index e985c11d40..6cb05f37f2 100644 --- a/src/datumaro/components/merge/base.py +++ b/src/datumaro/components/merge/base.py @@ -24,8 +24,8 @@ class Merger(IMergerContext, CliPlugin): """Merge multiple datasets into one dataset""" - def __init__(self, **options): - super().__init__(**options) + def __init__(self, **kwargs): + super().__init__(**kwargs) self.__dict__["_sources"] = None self.errors = [] diff --git a/src/datumaro/util/annotation_util.py b/src/datumaro/util/annotation_util.py index fa32129ec2..f9bbb2b71a 100644 --- a/src/datumaro/util/annotation_util.py +++ b/src/datumaro/util/annotation_util.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT from itertools import groupby -from typing import Callable, Dict, Iterable, NewType, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union import numpy as np from typing_extensions import Literal @@ -22,9 +22,7 @@ BboxCoords = Tuple[float, float, float, float] "A tuple of bounding box coordinates, (x, y, w, h)" -_Shape = NewType("_Shape", Shape) - -SpatialAnnotation = Union[_Shape, Mask] +SpatialAnnotation = Union[Shape, Mask] def find_instances(instance_anns: Sequence[Annotation]) -> Sequence[Sequence[Annotation]]: From 2eb7b05a5edcca7db33a5df144e730c8519e01f6 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Thu, 16 Jan 2025 18:37:03 +0400 Subject: [PATCH 12/25] removing debug artifacts --- tests/unit/test_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index aff6aa1e50..1ddc7cdb61 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -311,7 +311,6 @@ def test_stats(self): } actual = compute_ann_statistics(dataset) - self.maxDiff = None self.assertEqual(expected, actual) @@ -389,7 +388,6 @@ def test_stats_with_empty_dataset(self): } actual = compute_ann_statistics(dataset) - self.maxDiff = None self.assertEqual(expected, actual) From 08ba03f0a4add8ce9ade57b5a4238099f1cc84e3 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Thu, 16 Jan 2025 20:50:34 +0400 Subject: [PATCH 13/25] fixing tests (partially syncing components/extractor_tfds.py) --- src/datumaro/components/extractor_tfds.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/datumaro/components/extractor_tfds.py b/src/datumaro/components/extractor_tfds.py index 037c643f66..fce17e1a3c 100644 --- a/src/datumaro/components/extractor_tfds.py +++ b/src/datumaro/components/extractor_tfds.py @@ -8,13 +8,13 @@ import logging as log import os.path as osp from types import SimpleNamespace as namespace -from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Set, Union import attrs from attrs import field, frozen from datumaro.components.annotation import AnnotationType, Bbox, Label, LabelCategories -from datumaro.components.dataset_base import CategoriesInfo, DatasetItem, IDataset +from datumaro.components.dataset_base import CategoriesInfo, DatasetInfo, DatasetItem, IDataset from datumaro.components.media import ByteImage, Image, MediaElement from datumaro.util.tf_util import import_tf @@ -407,6 +407,9 @@ def __iter__(self) -> Iterator[DatasetItem]: yield dm_item + def infos(self) -> DatasetInfo: + return self._parent.infos() + def categories(self) -> CategoriesInfo: return self._parent.categories() @@ -430,15 +433,20 @@ def get(self, id, subset=None) -> Optional[DatasetItem]: def media_type(self) -> Type[MediaElement]: return self._parent._media_type + def ann_types(self) -> Set[AnnotationType]: + return self._parent.ann_types() + class _TfdsExtractor(IDataset): _categories: CategoriesInfo + _infos: DatasetInfo def __init__(self, tfds_ds_name: str) -> None: self._adapter = _TFDS_ADAPTERS[tfds_ds_name] tfds_builder = tfds.builder(tfds_ds_name) tfds_ds_info = tfds_builder.info + self._infos = {} self._categories = {} self._state = namespace() self._adapter.transform_categories(tfds_builder, self._categories, self._state) @@ -466,6 +474,9 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[DatasetItem]: return itertools.chain.from_iterable(self._split_extractors.values()) + def infos(self) -> DatasetInfo: + return self._infos + def categories(self) -> CategoriesInfo: return self._categories @@ -490,6 +501,14 @@ def get(self, id, subset=None) -> Optional[DatasetItem]: def media_type(self) -> Type[MediaElement]: return self._media_type + def ann_types(self) -> Set[AnnotationType]: + ann_types = set() + for items in self._split_extractors.values(): + for item in items: + for ann in item.annotations: + ann_types.add(ann.type) + return ann_types + # Some dataset metadata elements are either inconvenient to hardcode, or may change # depending on the version of TFDS. We fetch them from the attributes of the `tfds.Builder` From c246d3cb118ff33e7abf28e798ebce33d7f5ebc1 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Thu, 16 Jan 2025 21:07:41 +0400 Subject: [PATCH 14/25] linters fix --- src/datumaro/components/extractor_tfds.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/datumaro/components/extractor_tfds.py b/src/datumaro/components/extractor_tfds.py index fce17e1a3c..6752b43c3b 100644 --- a/src/datumaro/components/extractor_tfds.py +++ b/src/datumaro/components/extractor_tfds.py @@ -8,7 +8,19 @@ import logging as log import os.path as osp from types import SimpleNamespace as namespace -from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Set, Union +from typing import ( + Any, + Callable, + Dict, + Iterator, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) import attrs from attrs import field, frozen From d16d8a3b7fa95d2ee643543e2efcc46682628206 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Fri, 17 Jan 2025 15:19:08 +0400 Subject: [PATCH 15/25] correctly process missing openvino --- src/datumaro/plugins/openvino_plugin/launcher.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/datumaro/plugins/openvino_plugin/launcher.py b/src/datumaro/plugins/openvino_plugin/launcher.py index 9802ab0ca6..560bdbb6ff 100644 --- a/src/datumaro/plugins/openvino_plugin/launcher.py +++ b/src/datumaro/plugins/openvino_plugin/launcher.py @@ -13,7 +13,6 @@ from typing import Dict, List, Optional import numpy as np -from openvino.runtime import Core from tqdm import tqdm from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred @@ -23,6 +22,14 @@ from datumaro.util.definitions import get_datumaro_cache_dir from datumaro.util.samples import get_samples_path +try: + from openvino.runtime import Core +except ImportError: + log.debug("Unable to import openvino.") + OPENVINO_AVAILABLE = False +else: + OPENVINO_AVAILABLE = True + class _OpenvinoImporter(CliPlugin): @staticmethod From 53d8c2c2208104105c9a94f872a44c89637599d1 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Fri, 17 Jan 2025 16:44:58 +0400 Subject: [PATCH 16/25] typing for some things --- src/datumaro/components/dataset.py | 15 +-- src/datumaro/components/dataset_base.py | 10 +- .../components/dataset_item_storage.py | 8 +- src/datumaro/components/dataset_storage.py | 12 ++- src/datumaro/components/launcher.py | 6 +- src/datumaro/components/project.py | 9 +- src/datumaro/components/transformer.py | 31 +++--- src/datumaro/plugins/transforms.py | 94 +++++++++++-------- 8 files changed, 109 insertions(+), 76 deletions(-) diff --git a/src/datumaro/components/dataset.py b/src/datumaro/components/dataset.py index 34b628702b..6f3f298098 100644 --- a/src/datumaro/components/dataset.py +++ b/src/datumaro/components/dataset.py @@ -11,7 +11,7 @@ import warnings from contextlib import contextmanager from copy import copy -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union from datumaro.components.annotation import AnnotationType, LabelCategories from datumaro.components.config_model import Source @@ -19,6 +19,7 @@ from datumaro.components.dataset_base import ( CategoriesInfo, DatasetBase, + DatasetInfo, DatasetItem, IDataset, ImportContext, @@ -76,7 +77,7 @@ def subsets(self): return self.parent.subsets() return {self.name: self} - def infos(self): + def infos(self) -> DatasetInfo: return {} def categories(self): @@ -85,8 +86,8 @@ def categories(self): def media_type(self): return self.parent.media_type() - def ann_types(self): - return [] + def ann_types(self) -> Set[AnnotationType]: + return set() def as_dataset(self) -> Dataset: return Dataset.from_extractors(self, env=self.parent.env) @@ -231,7 +232,7 @@ def get_subset(self, name) -> DatasetSubset: def subsets(self) -> Dict[str, DatasetSubset]: return {k: self.get_subset(k) for k in self._data.subsets()} - def infos(self): + def infos(self) -> DatasetInfo: return {} def categories(self) -> CategoriesInfo: @@ -240,8 +241,8 @@ def categories(self) -> CategoriesInfo: def media_type(self) -> Type[MediaElement]: return self._data.media_type() - def ann_types(self): - return [] + def ann_types(self) -> Set[AnnotationType]: + return set() def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: return self._data.get(id, subset) diff --git a/src/datumaro/components/dataset_base.py b/src/datumaro/components/dataset_base.py index 8ae2caa332..3d5ed679d0 100644 --- a/src/datumaro/components/dataset_base.py +++ b/src/datumaro/components/dataset_base.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Any, Dict, Iterator, List, Optional, Sequence, Type, TypeVar, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union, cast import attr from attr import attrs, field @@ -109,7 +109,7 @@ def media_type(self) -> Type[MediaElement]: """ raise NotImplementedError() - def ann_types(self) -> List[AnnotationType]: + def ann_types(self) -> Set[AnnotationType]: """ Returns available task type from dataset annotation types. """ @@ -225,10 +225,10 @@ def __init__( self._media_type = media_type self._ann_types = ann_types if ann_types else set() - def media_type(self): + def media_type(self) -> Type[MediaElement]: return self._media_type - def ann_types(self): + def ann_types(self) -> Set[AnnotationType]: return self._ann_types @@ -260,7 +260,7 @@ def __init__( self._categories = {} self._items = [] - def infos(self): + def infos(self) -> DatasetInfo: return self._infos def categories(self): diff --git a/src/datumaro/components/dataset_item_storage.py b/src/datumaro/components/dataset_item_storage.py index 1b2a6cfc17..5cfc33e505 100644 --- a/src/datumaro/components/dataset_item_storage.py +++ b/src/datumaro/components/dataset_item_storage.py @@ -160,7 +160,7 @@ def get_subset(self, name): def subsets(self): return {self.name or DEFAULT_SUBSET_NAME: self} - def infos(self): + def infos(self) -> DatasetInfo: return self.parent.infos() def categories(self): @@ -169,7 +169,7 @@ def categories(self): def media_type(self): return self.parent.media_type() - def ann_types(self): + def ann_types(self) -> Set[AnnotationType]: return self.parent.ann_types() def __init__( @@ -192,7 +192,7 @@ def __iter__(self): def __len__(self): return len(self._parent) - def infos(self): + def infos(self) -> DatasetInfo: return self._infos def categories(self): @@ -213,5 +213,5 @@ def get(self, id, subset=None): def media_type(self): return self._media_type - def ann_types(self): + def ann_types(self) -> Set[AnnotationType]: return self._ann_types diff --git a/src/datumaro/components/dataset_storage.py b/src/datumaro/components/dataset_storage.py index 80d0b902c6..395d8a2c77 100644 --- a/src/datumaro/components/dataset_storage.py +++ b/src/datumaro/components/dataset_storage.py @@ -79,12 +79,15 @@ def as_dataset(self, parent: IDataset) -> IDataset: return __class__.DatasetPatchWrapper(self, parent) +PostponedTransform = tuple[Type[Transform], tuple, dict] + + class _StackedTransform(Transform): - def __init__(self, source: IDataset, transforms: List[Transform]): + def __init__(self, source: IDataset, transforms: List[PostponedTransform]): super().__init__(source) self.is_local = True - self.transforms: List[Transform] = [] + self.transforms: List[IDataset] = [] self.malformed_transform_indices: Dict[int, Exception] = {} for idx, transform in enumerate(transforms): try: @@ -101,6 +104,7 @@ def transform_item(self, item: DatasetItem) -> DatasetItem: for t in self.transforms: if item is None: break + t: ItemTransform item = t.transform_item(item) return item @@ -172,7 +176,7 @@ def __init__( else: self._source = source self._storage = DatasetItemStorage() # patch or cache - self._transforms = [] # A stack of postponed transforms + self._transforms: list[PostponedTransform] = [] # A stack of postponed transforms # Describes changes in the dataset since initialization self._updated_items = {} # (id, subset) -> ItemStatus @@ -255,7 +259,7 @@ def _add_ann_types(item: DatasetItem): media_type=media_type, ann_types=self._ann_types, ) - transform = None + transform: Optional[_StackedTransform] = None old_ids = None if self._transforms: transform = _StackedTransform(source, self._transforms) diff --git a/src/datumaro/components/launcher.py b/src/datumaro/components/launcher.py index 5e40c6bdd2..3fd01d8162 100644 --- a/src/datumaro/components/launcher.py +++ b/src/datumaro/components/launcher.py @@ -19,7 +19,7 @@ ) from datumaro.components.annotation import Annotation from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.dataset_base import DatasetItem +from datumaro.components.dataset_base import DatasetInfo, DatasetItem from datumaro.errors import DatumaroError @@ -116,8 +116,8 @@ def launch(self, batch: Sequence[DatasetItem], stack: bool = True) -> List[List[ return [self.postprocess(pred, info) for pred, info in zip(preds, inputs_info)] - def infos(self): - return None + def infos(self) -> DatasetInfo: + return {} def categories(self): return None diff --git a/src/datumaro/components/project.py b/src/datumaro/components/project.py index e092f8d70d..89878991ad 100644 --- a/src/datumaro/components/project.py +++ b/src/datumaro/components/project.py @@ -22,6 +22,7 @@ List, NewType, Optional, + Set, Tuple, TypeVar, Union, @@ -30,6 +31,7 @@ import networkx as nx import ruamel.yaml as yaml +from datumaro import AnnotationType from datumaro.components.config import Config from datumaro.components.config_model import ( BuildStage, @@ -43,6 +45,7 @@ TreeLayout, ) from datumaro.components.dataset import DEFAULT_FORMAT, Dataset, IDataset +from datumaro.components.dataset_base import DatasetInfo from datumaro.components.environment import Environment from datumaro.components.errors import ( DatasetMergeError, @@ -138,7 +141,7 @@ def __len__(self): def subsets(self): return self._dataset.subsets() - def infos(self): + def infos(self) -> DatasetInfo: return {} def get_subset(self, name): @@ -153,8 +156,8 @@ def get(self, id, subset=None): def media_type(self): return self._dataset.media_type() - def ann_types(self): - return [] + def ann_types(self) -> Set[AnnotationType]: + return set() class IgnoreMode(Enum): diff --git a/src/datumaro/components/transformer.py b/src/datumaro/components/transformer.py index 3d9b91c660..9d836a4147 100644 --- a/src/datumaro/components/transformer.py +++ b/src/datumaro/components/transformer.py @@ -2,13 +2,20 @@ # # SPDX-License-Identifier: MIT from multiprocessing.pool import ThreadPool -from typing import Iterator, List, Optional +from typing import Dict, Iterator, List, Optional, Type import numpy as np +from datumaro import MediaElement from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.dataset_base import DatasetBase, DatasetItem, IDataset +from datumaro.components.dataset_base import ( + CategoriesInfo, + DatasetBase, + DatasetInfo, + DatasetItem, + IDataset, +) from datumaro.components.launcher import Launcher from datumaro.util import is_method_redefined, take_by from datumaro.util.multi_procs_util import consumer_generator @@ -21,7 +28,7 @@ class Transform(DatasetBase, CliPlugin): """ @staticmethod - def wrap_item(item, **kwargs): + def wrap_item(item: DatasetItem, **kwargs) -> DatasetItem: return item.wrap(**kwargs) def __init__(self, extractor: IDataset): @@ -29,15 +36,15 @@ def __init__(self, extractor: IDataset): self._extractor = extractor - def categories(self): + def categories(self) -> CategoriesInfo: return self._extractor.categories() - def subsets(self): + def subsets(self) -> Dict[str, IDataset]: if self._subsets is None: self._subsets = set(self._extractor.subsets()) return super().subsets() - def __len__(self): + def __len__(self) -> int: assert self._length in {None, "parent"} or isinstance(self._length, int) if ( self._length is None @@ -47,10 +54,10 @@ def __len__(self): self._length = len(self._extractor) return super().__len__() - def media_type(self): + def media_type(self) -> Type[MediaElement]: return self._extractor.media_type() - def infos(self): + def infos(self) -> DatasetInfo: return self._extractor.infos() @@ -65,7 +72,7 @@ def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]: raise NotImplementedError() - def __iter__(self): + def __iter__(self) -> Iterator[DatasetItem]: for item in self._extractor: item = self.transform_item(item) if item is not None: @@ -233,19 +240,19 @@ def get_subset(self, name): subset = self._extractor.get_subset(name) return __class__(subset, self._launcher, self._batch_size) - def infos(self): + def infos(self) -> DatasetInfo: launcher_override = self._launcher.infos() if launcher_override is not None: return launcher_override return self._extractor.infos() - def categories(self): + def categories(self) -> CategoriesInfo: launcher_override = self._launcher.categories() if launcher_override is not None: return launcher_override return self._extractor.categories() - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: inputs = np.expand_dims(item.media, axis=0) annotations = self._launcher.launch(inputs)[0] return self.wrap_item(item, annotations=annotations) diff --git a/src/datumaro/plugins/transforms.py b/src/datumaro/plugins/transforms.py index acb013467c..e26345f2fc 100644 --- a/src/datumaro/plugins/transforms.py +++ b/src/datumaro/plugins/transforms.py @@ -14,7 +14,7 @@ from copy import deepcopy from enum import Enum, auto from itertools import chain -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union import cv2 import numpy as np @@ -22,6 +22,7 @@ import datumaro.util.mask_tools as mask_tools from datumaro.components.annotation import ( + Annotation, AnnotationType, Bbox, Caption, @@ -34,9 +35,10 @@ Polygon, PolyLine, RleMask, + Shape, ) from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.dataset_base import DatasetItem, IDataset +from datumaro.components.dataset_base import CategoriesInfo, DatasetItem, IDataset from datumaro.components.errors import DatumaroError from datumaro.components.media import Image from datumaro.components.transformer import ItemTransform, Transform @@ -65,12 +67,12 @@ def build_cmdline_parser(cls, **kwargs): ) return parser - def __init__(self, extractor, allow_removal=False): + def __init__(self, extractor: IDataset, allow_removal: bool = False): super().__init__(extractor) self._allow_removal = allow_removal - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: annotations = [] segments = [] for ann in item.annotations: @@ -91,7 +93,13 @@ def transform_item(self, item): @classmethod def crop_segments( - cls, segment_anns, img_width, img_height, *, item: DatasetItem, allow_removal: bool = False + cls, + segment_anns: list[Annotation], + img_width: int, + img_height: int, + *, + item: DatasetItem, + allow_removal: bool = False, ): segment_anns = sorted(segment_anns, key=lambda x: x.z_order) @@ -141,7 +149,7 @@ def crop_segments( return new_anns @staticmethod - def _make_group_id(anns, ann_id): + def _make_group_id(anns: list[Annotation], ann_id: int): if ann_id: return ann_id max_gid = max(anns, default=0, key=lambda x: x.group) @@ -162,12 +170,12 @@ def build_cmdline_parser(cls, **kwargs): parser.add_argument("--include-polygons", action="store_true", help="Include polygons") return parser - def __init__(self, extractor, include_polygons=False): + def __init__(self, extractor: IDataset, include_polygons: bool = False): super().__init__(extractor) self._include_polygons = include_polygons - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: annotations = [] segments = [] for ann in item.annotations: @@ -189,7 +197,13 @@ def transform_item(self, item): return self.wrap_item(item, annotations=annotations) @classmethod - def merge_segments(cls, instance, img_width, img_height, include_polygons=False): + def merge_segments( + cls, + instance: Sequence[Annotation], + img_width: int, + img_height: int, + include_polygons: bool = False, + ): polygons = [a for a in instance if a.type == AnnotationType.polygon] masks = [a for a in instance if a.type == AnnotationType.mask] if not polygons and not masks: @@ -233,14 +247,14 @@ def merge_segments(cls, instance, img_width, img_height, include_polygons=False) return instance @staticmethod - def find_instances(annotations): + def find_instances(annotations: Sequence[Annotation]) -> Sequence[Sequence[Annotation]]: return find_instances( a for a in annotations if a.type in {AnnotationType.polygon, AnnotationType.mask} ) class PolygonsToMasks(ItemTransform, CliPlugin): - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: annotations = [] for ann in item.annotations: if ann.type == AnnotationType.polygon: @@ -254,7 +268,7 @@ def transform_item(self, item): return self.wrap_item(item, annotations=annotations) @staticmethod - def convert_polygon(polygon, img_h, img_w): + def convert_polygon(polygon: Polygon, img_h: int, img_w: int): rle = mask_utils.frPyObjects([polygon.points], img_h, img_w)[0] return RleMask( @@ -268,7 +282,7 @@ def convert_polygon(polygon, img_h, img_w): class BoxesToMasks(ItemTransform, CliPlugin): - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: annotations = [] for ann in item.annotations: if ann.type == AnnotationType.bbox: @@ -282,7 +296,7 @@ def transform_item(self, item): return self.wrap_item(item, annotations=annotations) @staticmethod - def convert_bbox(bbox, img_h, img_w): + def convert_bbox(bbox: Bbox, img_h: int, img_w: int): rle = mask_utils.frPyObjects([bbox.as_polygon()], img_h, img_w)[0] return RleMask( @@ -296,7 +310,7 @@ def convert_bbox(bbox, img_h, img_w): class MasksToPolygons(ItemTransform, CliPlugin): - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: annotations = [] for ann in item.annotations: if ann.type == AnnotationType.mask: @@ -314,7 +328,7 @@ def transform_item(self, item): return self.wrap_item(item, annotations=annotations) @staticmethod - def convert_mask(mask): + def convert_mask(mask: Mask) -> list[Polygon]: polygons = mask_tools.mask_to_polygons(mask.image) return [ @@ -331,7 +345,7 @@ def convert_mask(mask): class ShapesToBoxes(ItemTransform, CliPlugin): - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: annotations = [] for ann in item.annotations: if ann.type in { @@ -347,7 +361,7 @@ def transform_item(self, item): return self.wrap_item(item, annotations=annotations) @staticmethod - def convert_shape(shape): + def convert_shape(shape: Shape) -> Bbox: bbox = shape.get_bbox() return Bbox( *bbox, @@ -370,12 +384,12 @@ def build_cmdline_parser(cls, **kwargs): parser.add_argument("-s", "--start", type=int, default=1, help="Start value for item ids") return parser - def __init__(self, extractor, start=1): + def __init__(self, extractor: IDataset, start: int = 1): super().__init__(extractor) self._length = "parent" self._start = start - def __iter__(self): + def __iter__(self) -> Iterator[DatasetItem]: for i, item in enumerate(self._extractor): yield self.wrap_item(item, id=i + self._start) @@ -405,7 +419,11 @@ def build_cmdline_parser(cls, **kwargs): ) return parser - def __init__(self, extractor, mapping=None): + def __init__( + self, + extractor: IDataset, + mapping: Dict[str, str] | List[Tuple[str, str]] | None = None, + ): super().__init__(extractor) if mapping is None: @@ -420,7 +438,7 @@ def __init__(self, extractor, mapping=None): self._length = "parent" self._subsets = set(counts) - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: return self.wrap_item(item, subset=self._mapping.get(item.subset, item.subset)) @@ -461,7 +479,7 @@ def build_cmdline_parser(cls, **kwargs): parser.add_argument("--seed", type=int, help="Random seed") return parser - def __init__(self, extractor, splits, seed=None): + def __init__(self, extractor: IDataset, splits: list[tuple[str, float]], seed=None): super().__init__(extractor) if splits is None: @@ -504,7 +522,7 @@ def _find_split(self, index): return subset return subset # all the possible remainder goes to the last split - def __iter__(self): + def __iter__(self) -> Iterator[DatasetItem]: for i, item in enumerate(self._extractor): yield self.wrap_item(item, subset=self._find_split(i)) @@ -514,7 +532,7 @@ class IdFromImageName(ItemTransform, CliPlugin): Renames items in the dataset using image file name (without extension). """ - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: if isinstance(item.media, Image) and item.media.path: name = osp.splitext(osp.basename(item.media.path))[0] return self.wrap_item(item, id=name) @@ -561,7 +579,7 @@ def build_cmdline_parser(cls, **kwargs): ) return parser - def __init__(self, extractor, regex): + def __init__(self, extractor: IDataset, regex: str): super().__init__(extractor) assert regex and isinstance(regex, str) @@ -570,7 +588,7 @@ def __init__(self, extractor, regex): self._re = re.compile(regex) self._sub = sub - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: return self.wrap_item(item, id=self._re.sub(self._sub, item.id).format(item=item)) @@ -785,7 +803,7 @@ def __init__( src_categories = self._extractor.categories() - src_label_cat = src_categories.get(AnnotationType.label) + src_label_cat: LabelCategories = src_categories.get(AnnotationType.label) if isinstance(dst_labels, LabelCategories): dst_label_cat = deepcopy(dst_labels) @@ -856,17 +874,17 @@ def __init__( self._categories[AnnotationType.points] = dst_point_cat - def _make_label_id_map(self, src_label_cat, dst_label_cat): + def _make_label_id_map(self, src_label_cat: LabelCategories, dst_label_cat: LabelCategories): id_mapping = { src_id: dst_label_cat.find(src_label_cat[src_id].name, src_label_cat[src_id].parent)[0] for src_id in range(len(src_label_cat or ())) } self._map_id = lambda src_id: id_mapping.get(src_id, None) - def categories(self): + def categories(self) -> CategoriesInfo: return self._categories - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: annotations = [] for ann in item.annotations: if getattr(ann, "label", None) is not None: @@ -884,7 +902,7 @@ class AnnsToLabels(ItemTransform, CliPlugin): transforms them into a set of annotations of type Label """ - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: labels = set(p.label for p in item.annotations if getattr(p, "label") is not None) annotations = [] for label in labels: @@ -898,7 +916,7 @@ class BboxValuesDecrement(ItemTransform, CliPlugin): Subtracts one from the coordinates of bounding boxes """ - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: annotations = [p for p in item.annotations if p.type != AnnotationType.bbox] bboxes = [p for p in item.annotations if p.type == AnnotationType.bbox] for bbox in bboxes: @@ -969,7 +987,7 @@ def __init__( self._scale_y = scale_y @staticmethod - def _lazy_resize_image(image, new_size): + def _lazy_resize_image(image: Image, new_size: tuple[int, int]) -> Image: def _resize_image(_): h, w = image.size yscale = new_size[0] / float(h) @@ -1007,7 +1025,7 @@ def _lazy_encode(): return _lazy_encode - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> DatasetItem: if not isinstance(item.media, Image): raise DatumaroError("Item %s: image info is required for this transform" % (item.id,)) @@ -1104,7 +1122,7 @@ def __init__(self, extractor: IDataset, ids: Iterable[Tuple[str, str]]): super().__init__(extractor) self._ids = set(tuple(v) for v in (ids or [])) - def transform_item(self, item): + def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]: if (item.id, item.subset) in self._ids: return None return item @@ -1151,7 +1169,7 @@ def __init__(self, extractor: IDataset, *, ids: Optional[Iterable[Tuple[str, str super().__init__(extractor) self._ids = set(tuple(v) for v in (ids or [])) - def transform_item(self, item: DatasetItem): + def transform_item(self, item: DatasetItem) -> DatasetItem: if not self._ids or (item.id, item.subset) in self._ids: return item.wrap(annotations=[]) return item @@ -1225,7 +1243,7 @@ def _filter_attrs(self, attrs): else: return filter_dict(attrs, exclude_keys=self._attributes) - def transform_item(self, item: DatasetItem): + def transform_item(self, item: DatasetItem) -> DatasetItem: if not self._ids or (item.id, item.subset) in self._ids: filtered_annotations = [] for ann in item.annotations: From 662be838fb72e33cc589794c580351eacc10782c Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Tue, 21 Jan 2025 22:24:19 +0300 Subject: [PATCH 17/25] typing and import fixes --- src/datumaro/components/dataset.py | 8 ++--- src/datumaro/components/dataset_base.py | 4 +-- .../components/dataset_item_storage.py | 4 +-- src/datumaro/components/dataset_storage.py | 13 ++++---- src/datumaro/components/extractor_tfds.py | 22 +++---------- src/datumaro/components/importer.py | 2 +- src/datumaro/components/project.py | 4 +-- src/datumaro/components/transformer.py | 2 +- src/datumaro/plugins/transforms.py | 32 +++++++++---------- 9 files changed, 38 insertions(+), 53 deletions(-) diff --git a/src/datumaro/components/dataset.py b/src/datumaro/components/dataset.py index 6f3f298098..d35cb009ba 100644 --- a/src/datumaro/components/dataset.py +++ b/src/datumaro/components/dataset.py @@ -11,7 +11,7 @@ import warnings from contextlib import contextmanager from copy import copy -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union from datumaro.components.annotation import AnnotationType, LabelCategories from datumaro.components.config_model import Source @@ -86,7 +86,7 @@ def categories(self): def media_type(self): return self.parent.media_type() - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return set() def as_dataset(self) -> Dataset: @@ -238,10 +238,10 @@ def infos(self) -> DatasetInfo: def categories(self) -> CategoriesInfo: return self._data.categories() - def media_type(self) -> Type[MediaElement]: + def media_type(self): return self._data.media_type() - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return set() def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: diff --git a/src/datumaro/components/dataset_base.py b/src/datumaro/components/dataset_base.py index 3d5ed679d0..227e7b95ad 100644 --- a/src/datumaro/components/dataset_base.py +++ b/src/datumaro/components/dataset_base.py @@ -225,10 +225,10 @@ def __init__( self._media_type = media_type self._ann_types = ann_types if ann_types else set() - def media_type(self) -> Type[MediaElement]: + def media_type(self): return self._media_type - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return self._ann_types diff --git a/src/datumaro/components/dataset_item_storage.py b/src/datumaro/components/dataset_item_storage.py index 5cfc33e505..feb13215a2 100644 --- a/src/datumaro/components/dataset_item_storage.py +++ b/src/datumaro/components/dataset_item_storage.py @@ -169,7 +169,7 @@ def categories(self): def media_type(self): return self.parent.media_type() - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return self.parent.ann_types() def __init__( @@ -213,5 +213,5 @@ def get(self, id, subset=None): def media_type(self): return self._media_type - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return self._ann_types diff --git a/src/datumaro/components/dataset_storage.py b/src/datumaro/components/dataset_storage.py index 395d8a2c77..011c908c9c 100644 --- a/src/datumaro/components/dataset_storage.py +++ b/src/datumaro/components/dataset_storage.py @@ -104,7 +104,6 @@ def transform_item(self, item: DatasetItem) -> DatasetItem: for t in self.transforms: if item is None: break - t: ItemTransform item = t.transform_item(item) return item @@ -117,10 +116,10 @@ def infos(self) -> DatasetInfo: def categories(self) -> CategoriesInfo: return self.transforms[-1].categories() - def media_type(self) -> Type[MediaElement]: + def media_type(self): return self.transforms[-1].media_type() - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return self.transforms[-1].ann_types() @@ -434,10 +433,10 @@ def define_categories(self, categories: CategoriesInfo): raise CategoriesRedefinedError() self._categories = categories - def media_type(self) -> Type[MediaElement]: + def media_type(self): return self._media_type - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return self._ann_types def put(self, item: DatasetItem) -> None: @@ -645,10 +644,10 @@ def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: "You can access to the dataset item only by using its iterator." ) - def media_type(self) -> Type[MediaElement]: + def media_type(self): return self._source.media_type() - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return self._source.ann_types() @property diff --git a/src/datumaro/components/extractor_tfds.py b/src/datumaro/components/extractor_tfds.py index 6752b43c3b..83066a89b2 100644 --- a/src/datumaro/components/extractor_tfds.py +++ b/src/datumaro/components/extractor_tfds.py @@ -8,19 +8,7 @@ import logging as log import os.path as osp from types import SimpleNamespace as namespace -from typing import ( - Any, - Callable, - Dict, - Iterator, - Mapping, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, -) +from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union import attrs from attrs import field, frozen @@ -442,10 +430,10 @@ def get(self, id, subset=None) -> Optional[DatasetItem]: return None - def media_type(self) -> Type[MediaElement]: + def media_type(self): return self._parent._media_type - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return self._parent.ann_types() @@ -510,10 +498,10 @@ def get(self, id, subset=None) -> Optional[DatasetItem]: return None return self._split_extractors[subset].get(id) - def media_type(self) -> Type[MediaElement]: + def media_type(self): return self._media_type - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): ann_types = set() for items in self._split_extractors.values(): for item in items: diff --git a/src/datumaro/components/importer.py b/src/datumaro/components/importer.py index ce50e488f2..688f3e4cd1 100644 --- a/src/datumaro/components/importer.py +++ b/src/datumaro/components/importer.py @@ -5,7 +5,7 @@ from os import path as osp from typing import Callable, Dict, List, Optional -from datumaro import CliPlugin +from datumaro.components.cli_plugin import CliPlugin from datumaro.components.errors import DatasetNotFoundError from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext diff --git a/src/datumaro/components/project.py b/src/datumaro/components/project.py index 89878991ad..c1408937ec 100644 --- a/src/datumaro/components/project.py +++ b/src/datumaro/components/project.py @@ -22,7 +22,6 @@ List, NewType, Optional, - Set, Tuple, TypeVar, Union, @@ -31,7 +30,6 @@ import networkx as nx import ruamel.yaml as yaml -from datumaro import AnnotationType from datumaro.components.config import Config from datumaro.components.config_model import ( BuildStage, @@ -156,7 +154,7 @@ def get(self, id, subset=None): def media_type(self): return self._dataset.media_type() - def ann_types(self) -> Set[AnnotationType]: + def ann_types(self): return set() diff --git a/src/datumaro/components/transformer.py b/src/datumaro/components/transformer.py index 9d836a4147..80f35fa049 100644 --- a/src/datumaro/components/transformer.py +++ b/src/datumaro/components/transformer.py @@ -6,7 +6,6 @@ import numpy as np -from datumaro import MediaElement from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset_base import ( @@ -17,6 +16,7 @@ IDataset, ) from datumaro.components.launcher import Launcher +from datumaro.components.media import MediaElement from datumaro.util import is_method_redefined, take_by from datumaro.util.multi_procs_util import consumer_generator diff --git a/src/datumaro/plugins/transforms.py b/src/datumaro/plugins/transforms.py index e26345f2fc..569526b054 100644 --- a/src/datumaro/plugins/transforms.py +++ b/src/datumaro/plugins/transforms.py @@ -72,7 +72,7 @@ def __init__(self, extractor: IDataset, allow_removal: bool = False): self._allow_removal = allow_removal - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): annotations = [] segments = [] for ann in item.annotations: @@ -175,7 +175,7 @@ def __init__(self, extractor: IDataset, include_polygons: bool = False): self._include_polygons = include_polygons - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): annotations = [] segments = [] for ann in item.annotations: @@ -254,7 +254,7 @@ def find_instances(annotations: Sequence[Annotation]) -> Sequence[Sequence[Annot class PolygonsToMasks(ItemTransform, CliPlugin): - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): annotations = [] for ann in item.annotations: if ann.type == AnnotationType.polygon: @@ -282,7 +282,7 @@ def convert_polygon(polygon: Polygon, img_h: int, img_w: int): class BoxesToMasks(ItemTransform, CliPlugin): - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): annotations = [] for ann in item.annotations: if ann.type == AnnotationType.bbox: @@ -310,7 +310,7 @@ def convert_bbox(bbox: Bbox, img_h: int, img_w: int): class MasksToPolygons(ItemTransform, CliPlugin): - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): annotations = [] for ann in item.annotations: if ann.type == AnnotationType.mask: @@ -345,7 +345,7 @@ def convert_mask(mask: Mask) -> list[Polygon]: class ShapesToBoxes(ItemTransform, CliPlugin): - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): annotations = [] for ann in item.annotations: if ann.type in { @@ -438,7 +438,7 @@ def __init__( self._length = "parent" self._subsets = set(counts) - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): return self.wrap_item(item, subset=self._mapping.get(item.subset, item.subset)) @@ -532,7 +532,7 @@ class IdFromImageName(ItemTransform, CliPlugin): Renames items in the dataset using image file name (without extension). """ - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): if isinstance(item.media, Image) and item.media.path: name = osp.splitext(osp.basename(item.media.path))[0] return self.wrap_item(item, id=name) @@ -588,7 +588,7 @@ def __init__(self, extractor: IDataset, regex: str): self._re = re.compile(regex) self._sub = sub - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): return self.wrap_item(item, id=self._re.sub(self._sub, item.id).format(item=item)) @@ -884,7 +884,7 @@ def _make_label_id_map(self, src_label_cat: LabelCategories, dst_label_cat: Labe def categories(self) -> CategoriesInfo: return self._categories - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): annotations = [] for ann in item.annotations: if getattr(ann, "label", None) is not None: @@ -902,7 +902,7 @@ class AnnsToLabels(ItemTransform, CliPlugin): transforms them into a set of annotations of type Label """ - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): labels = set(p.label for p in item.annotations if getattr(p, "label") is not None) annotations = [] for label in labels: @@ -916,7 +916,7 @@ class BboxValuesDecrement(ItemTransform, CliPlugin): Subtracts one from the coordinates of bounding boxes """ - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): annotations = [p for p in item.annotations if p.type != AnnotationType.bbox] bboxes = [p for p in item.annotations if p.type == AnnotationType.bbox] for bbox in bboxes: @@ -1025,7 +1025,7 @@ def _lazy_encode(): return _lazy_encode - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): if not isinstance(item.media, Image): raise DatumaroError("Item %s: image info is required for this transform" % (item.id,)) @@ -1122,7 +1122,7 @@ def __init__(self, extractor: IDataset, ids: Iterable[Tuple[str, str]]): super().__init__(extractor) self._ids = set(tuple(v) for v in (ids or [])) - def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]: + def transform_item(self, item): if (item.id, item.subset) in self._ids: return None return item @@ -1169,7 +1169,7 @@ def __init__(self, extractor: IDataset, *, ids: Optional[Iterable[Tuple[str, str super().__init__(extractor) self._ids = set(tuple(v) for v in (ids or [])) - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): if not self._ids or (item.id, item.subset) in self._ids: return item.wrap(annotations=[]) return item @@ -1243,7 +1243,7 @@ def _filter_attrs(self, attrs): else: return filter_dict(attrs, exclude_keys=self._attributes) - def transform_item(self, item: DatasetItem) -> DatasetItem: + def transform_item(self, item): if not self._ids or (item.id, item.subset) in self._ids: filtered_annotations = [] for ann in item.annotations: From 0edb281cb9d271dc0178f41587321ab43316ce42 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Mon, 27 Jan 2025 12:28:42 +0300 Subject: [PATCH 18/25] Update src/datumaro/components/comparator.py Co-authored-by: Maxim Zhiltsov --- src/datumaro/components/comparator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datumaro/components/comparator.py b/src/datumaro/components/comparator.py index b33b205b64..0d72b5b681 100644 --- a/src/datumaro/components/comparator.py +++ b/src/datumaro/components/comparator.py @@ -116,6 +116,7 @@ class EqualityComparator: ignored_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set)) ignored_item_attrs = attrib(kw_only=True, factory=set, validator=default_if_none(set)) all = attrib(kw_only=True, default=False) + "Include matches in the output" _test: TestCase = attrib(init=False) errors: list = attrib(init=False) From 64d00207873b5f88ae8e6957e4c30c3836809332 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Mon, 27 Jan 2025 13:58:53 +0300 Subject: [PATCH 19/25] handling of malformed transforms --- src/datumaro/components/dataset_storage.py | 30 +++++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/datumaro/components/dataset_storage.py b/src/datumaro/components/dataset_storage.py index 011c908c9c..3228096b4c 100644 --- a/src/datumaro/components/dataset_storage.py +++ b/src/datumaro/components/dataset_storage.py @@ -83,7 +83,12 @@ def as_dataset(self, parent: IDataset) -> IDataset: class _StackedTransform(Transform): - def __init__(self, source: IDataset, transforms: List[PostponedTransform]): + def __init__( + self, + source: IDataset, + transforms: List[PostponedTransform], + raise_on_malformed_transform: bool = True, + ): super().__init__(source) self.is_local = True @@ -93,6 +98,8 @@ def __init__(self, source: IDataset, transforms: List[PostponedTransform]): try: source = transform[0](source, *transform[1], **transform[2]) except Exception as e: + if raise_on_malformed_transform: + raise self.malformed_transform_indices[idx] = e self.transforms.append(source) @@ -131,7 +138,10 @@ def __init__( categories: Optional[CategoriesInfo] = None, media_type: Optional[Type[MediaElement]] = None, ann_types: Optional[Set[AnnotationType]] = None, + raise_on_malformed_transform: bool = True, ): + self._raise_on_malformed_transform = raise_on_malformed_transform + if source is None and categories is None: categories = {} elif isinstance(source, IDataset) and categories is not None: @@ -261,7 +271,9 @@ def _add_ann_types(item: DatasetItem): transform: Optional[_StackedTransform] = None old_ids = None if self._transforms: - transform = _StackedTransform(source, self._transforms) + transform = _StackedTransform( + source, self._transforms, self._raise_on_malformed_transform + ) if transform.is_local: # An optimized way to find modified items: # Transform items inplace and analyze transform outputs @@ -663,12 +675,20 @@ def __init__( categories: Optional[CategoriesInfo] = None, media_type: Optional[Type[MediaElement]] = None, ann_types: Optional[Set[AnnotationType]] = None, + raise_on_malformed_transform: bool = True, ): if not source.is_stream: raise ValueError("source should be a stream.") self._subset_names = list(source.subsets().keys()) self._transform_ids_for_latest_subset_names = [] - super().__init__(source, infos, categories, media_type, ann_types) + super().__init__( + source=source, + infos=infos, + categories=categories, + media_type=media_type, + ann_types=ann_types, + raise_on_malformed_transform=raise_on_malformed_transform, + ) def is_cache_initialized(self) -> bool: log.debug("This function has no effect on streaming.") @@ -681,7 +701,9 @@ def init_cache(self) -> None: @property def stacked_transform(self) -> IDataset: if self._transforms: - transform = _StackedTransform(self._source, self._transforms) + transform = _StackedTransform( + self._source, self._transforms, self._raise_on_malformed_transform + ) self._drop_malformed_transforms(transform.malformed_transform_indices) else: transform = self._source From dbe2918c18189ac103652710e4f4ceb328b50324 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 28 Jan 2025 13:17:31 +0300 Subject: [PATCH 20/25] Make transform policy kw-only --- src/datumaro/components/dataset_storage.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/datumaro/components/dataset_storage.py b/src/datumaro/components/dataset_storage.py index 3228096b4c..ea6f89cfd9 100644 --- a/src/datumaro/components/dataset_storage.py +++ b/src/datumaro/components/dataset_storage.py @@ -87,6 +87,7 @@ def __init__( self, source: IDataset, transforms: List[PostponedTransform], + *, raise_on_malformed_transform: bool = True, ): super().__init__(source) @@ -138,6 +139,7 @@ def __init__( categories: Optional[CategoriesInfo] = None, media_type: Optional[Type[MediaElement]] = None, ann_types: Optional[Set[AnnotationType]] = None, + *, raise_on_malformed_transform: bool = True, ): self._raise_on_malformed_transform = raise_on_malformed_transform @@ -272,7 +274,9 @@ def _add_ann_types(item: DatasetItem): old_ids = None if self._transforms: transform = _StackedTransform( - source, self._transforms, self._raise_on_malformed_transform + source, + self._transforms, + raise_on_malformed_transform=self._raise_on_malformed_transform ) if transform.is_local: # An optimized way to find modified items: @@ -675,6 +679,7 @@ def __init__( categories: Optional[CategoriesInfo] = None, media_type: Optional[Type[MediaElement]] = None, ann_types: Optional[Set[AnnotationType]] = None, + *, raise_on_malformed_transform: bool = True, ): if not source.is_stream: From 829245a1a0b14ae4019b37c5c818b31746fbe70f Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 28 Jan 2025 13:18:07 +0300 Subject: [PATCH 21/25] Update src/datumaro/components/dataset_storage.py --- src/datumaro/components/dataset_storage.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/datumaro/components/dataset_storage.py b/src/datumaro/components/dataset_storage.py index ea6f89cfd9..b014ef33a9 100644 --- a/src/datumaro/components/dataset_storage.py +++ b/src/datumaro/components/dataset_storage.py @@ -707,7 +707,9 @@ def init_cache(self) -> None: def stacked_transform(self) -> IDataset: if self._transforms: transform = _StackedTransform( - self._source, self._transforms, self._raise_on_malformed_transform + self._source, + self._transforms, + raise_on_malformed_transform=self._raise_on_malformed_transform ) self._drop_malformed_transforms(transform.malformed_transform_indices) else: From 662345b00d37590f9244cfa93b6e5e16978debfb Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 28 Jan 2025 12:21:45 +0200 Subject: [PATCH 22/25] Fix formatting --- src/datumaro/components/dataset_storage.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/datumaro/components/dataset_storage.py b/src/datumaro/components/dataset_storage.py index b014ef33a9..5b130106db 100644 --- a/src/datumaro/components/dataset_storage.py +++ b/src/datumaro/components/dataset_storage.py @@ -274,9 +274,9 @@ def _add_ann_types(item: DatasetItem): old_ids = None if self._transforms: transform = _StackedTransform( - source, - self._transforms, - raise_on_malformed_transform=self._raise_on_malformed_transform + source, + self._transforms, + raise_on_malformed_transform=self._raise_on_malformed_transform, ) if transform.is_local: # An optimized way to find modified items: @@ -707,9 +707,9 @@ def init_cache(self) -> None: def stacked_transform(self) -> IDataset: if self._transforms: transform = _StackedTransform( - self._source, - self._transforms, - raise_on_malformed_transform=self._raise_on_malformed_transform + self._source, + self._transforms, + raise_on_malformed_transform=self._raise_on_malformed_transform, ) self._drop_malformed_transforms(transform.malformed_transform_indices) else: From d5e4e7313650f811e146c869e52dfbc360c03ab5 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Tue, 28 Jan 2025 15:32:05 +0300 Subject: [PATCH 23/25] synced util/mask_tools.py::make_index_mask --- src/datumaro/util/mask_tools.py | 62 +++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/src/datumaro/util/mask_tools.py b/src/datumaro/util/mask_tools.py index 68d161e5f8..21b11e49c2 100644 --- a/src/datumaro/util/mask_tools.py +++ b/src/datumaro/util/mask_tools.py @@ -147,8 +147,66 @@ def remap_mask(mask: ColorMask, map_fn) -> ColorMask: return np.array([max(0, map_fn(c)) for c in range(256)], dtype=np.uint8)[mask] -def make_index_mask(binary_mask: BinaryMask, index: int, dtype=None) -> IndexMask: - return binary_mask * np.array([index], dtype=dtype or np.min_scalar_type(index)) +def make_index_mask( + binary_mask: BinaryMask, + index: int, + ignore_index: int = 0, + dtype: Optional[np.dtype] = None, +) -> IndexMask: + """Create an index mask from a binary mask by filling a given index value. + + Args: + binary_mask: Binary mask to create an index mask. + index: Scalar value to fill the ones in the binary mask. + ignore_index: Scalar value to fill in the zeros in the binary mask. + Defaults to 0. + dtype: Data type for the resulting mask. If not specified, + it will be inferred from the provided `index` to hold its value. + For example, if `index=255`, the inferred dtype will be `np.uint8`. + Defaults to None. + + Returns: + np.ndarray: Index mask created from the binary mask. + + Raises: + ValueError: If dtype is not specified and incompatible scalar types are used for index + and ignore_index. + + Examples: + >>> binary_mask = np.eye(2, dtype=np.bool_) + >>> index_mask = make_index_mask(binary_mask, index=10, ignore_index=255, dtype=np.uint8) + >>> print(index_mask) + array([[ 10, 255], + [255, 10]], dtype=uint8) + """ + if dtype is None: + dtype = np.min_scalar_type(index) + if dtype != np.min_scalar_type(ignore_index): + msg = ( + "Given dtype is None, " + "but inferred dtypes from the given index and ignore_index are different each other. " + "Please mannually set dtype" + ) + raise ValueError(msg, index, ignore_index) + + flipped_zero_np_scalar = ~np.full(tuple(), fill_value=0, dtype=dtype) + + # NOTE: This dispatching rule is required for a performance boost + if ignore_index == flipped_zero_np_scalar: + flipped_index = ~np.full(tuple(), fill_value=index, dtype=dtype) + return ~(binary_mask * flipped_index) + elif index < ignore_index: + diff = ignore_index - index + mask = ~binary_mask * np.full(tuple(), fill_value=diff, dtype=dtype) + mask += index + return mask + elif index > ignore_index: + diff = index - ignore_index + mask = binary_mask * np.full(tuple(), fill_value=diff, dtype=dtype) + mask += ignore_index + return mask + + return np.full_like(binary_mask, fill_value=index, dtype=dtype) def make_binary_mask(mask: Union[BinaryMask, IndexMask]) -> BinaryMask: From 4c7a12d41b4e04aeaf0b5189d2e539d281a1fb9f Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Tue, 28 Jan 2025 15:52:59 +0300 Subject: [PATCH 24/25] fixed some sonarcloud issues --- src/datumaro/components/annotation.py | 7 +------ src/datumaro/components/annotations/merger.py | 4 ---- src/datumaro/components/merge/base.py | 2 +- src/datumaro/components/merge/extractor_merger.py | 2 +- src/datumaro/plugins/openvino_plugin/launcher.py | 1 - 5 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/datumaro/components/annotation.py b/src/datumaro/components/annotation.py index 66dac1e0d6..d303076b44 100644 --- a/src/datumaro/components/annotation.py +++ b/src/datumaro/components/annotation.py @@ -973,11 +973,6 @@ def get_area(self): Returns: float: The area of the polygon. """ - # import pycocotools.mask as mask_utils - - # x, y, w, h = self.get_bbox() - # rle = mask_utils.frPyObjects([self.points], y + h, x + w) - # area = mask_utils.area(rle)[0] area = self._get_shoelace_area() return area @@ -1501,7 +1496,7 @@ def _get_3d_points(dim, location, rotation_y, denorm): theta = -1 * math.acos(np.dot(denorm_norm, ori_denorm)) n_vector = np.cross(denorm, ori_denorm) n_vector_norm = n_vector / np.sqrt(n_vector[0] ** 2 + n_vector[1] ** 2 + n_vector[2] ** 2) - rotation_matrix, j = cv2.Rodrigues(theta * n_vector_norm) + rotation_matrix, _ = cv2.Rodrigues(theta * n_vector_norm) corners_3d = np.dot(rotation_matrix, corners_3d) corners_3d = corners_3d + np.array(location, dtype=np.float32).reshape(3, 1) return corners_3d.transpose(1, 0) diff --git a/src/datumaro/components/annotations/merger.py b/src/datumaro/components/annotations/merger.py index e17f98bada..28177a9c1e 100644 --- a/src/datumaro/components/annotations/merger.py +++ b/src/datumaro/components/annotations/merger.py @@ -168,10 +168,6 @@ class Cuboid3dMerger(_ShapeMerger, Cuboid3dMatcher): @staticmethod def _merge_cluster_shape_mean_box_nearest(cluster): raise NotImplementedError() - # mbbox = Bbox(*mean_cuboid(cluster)) - # dist = (segment_iou(mbbox, s) for s in cluster) - # nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1]) - # return cluster[nearest_pos] def merge_cluster(self, cluster): label, label_score = self.find_cluster_label(cluster) diff --git a/src/datumaro/components/merge/base.py b/src/datumaro/components/merge/base.py index 6cb05f37f2..1238d1d609 100644 --- a/src/datumaro/components/merge/base.py +++ b/src/datumaro/components/merge/base.py @@ -78,7 +78,7 @@ def merge_media_types(sources: Sequence[IDataset]) -> Optional[Type[MediaElement def merge_ann_types(sources: Sequence[IDataset]) -> Optional[Set[AnnotationType]]: ann_types = set() for source in sources: - ann_types.union(source.ann_types()) + ann_types = ann_types.union(source.ann_types()) return ann_types def __call__(self, *datasets: IDataset) -> DatasetItemStorageDatasetView: diff --git a/src/datumaro/components/merge/extractor_merger.py b/src/datumaro/components/merge/extractor_merger.py index f5c3a873c3..91396e4ada 100644 --- a/src/datumaro/components/merge/extractor_merger.py +++ b/src/datumaro/components/merge/extractor_merger.py @@ -45,7 +45,7 @@ def __init__( ann_types = set() for source in sources: - ann_types.union(source.ann_types()) + ann_types = ann_types.union(source.ann_types()) self._ann_types = ann_types self._is_stream = check_identicalness([s.is_stream for s in sources]) diff --git a/src/datumaro/plugins/openvino_plugin/launcher.py b/src/datumaro/plugins/openvino_plugin/launcher.py index 560bdbb6ff..ff66c96a52 100644 --- a/src/datumaro/plugins/openvino_plugin/launcher.py +++ b/src/datumaro/plugins/openvino_plugin/launcher.py @@ -108,7 +108,6 @@ class BuiltinOpenvinoModelInfo(OpenvinoModelInfo): def create_from_model_name(cls, model_name: str) -> "BuiltinOpenvinoModelInfo": openvino_plugin_samples_dir = get_samples_path() interpreter = osp.join(openvino_plugin_samples_dir, model_name + "_interp.py") - interpreter = interpreter if osp.exists(interpreter) else interpreter model_dir = get_datumaro_cache_dir() From d0cede5df376cbb8ab0d89c8ed1759ec051f7c13 Mon Sep 17 00:00:00 2001 From: Dmitrii Lavrukhin Date: Tue, 28 Jan 2025 15:57:02 +0300 Subject: [PATCH 25/25] linters fix --- src/datumaro/util/mask_tools.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/datumaro/util/mask_tools.py b/src/datumaro/util/mask_tools.py index 21b11e49c2..2c0056dd41 100644 --- a/src/datumaro/util/mask_tools.py +++ b/src/datumaro/util/mask_tools.py @@ -183,9 +183,8 @@ def make_index_mask( dtype = np.min_scalar_type(index) if dtype != np.min_scalar_type(ignore_index): msg = ( - "Given dtype is None, " - "but inferred dtypes from the given index and ignore_index are different each other. " - "Please mannually set dtype" + "Given dtype is None, but inferred dtypes from the given index and " + "ignore_index are different from each other. Please manually set dtype" ) raise ValueError(msg, index, ignore_index)