Skip to content

Commit

Permalink
keep original rotation on import-export in yolo oriented boxes (#72)
Browse files Browse the repository at this point in the history
* keep original rotation on import-export in yolo oriented boxes

Co-authored-by: Maxim Zhiltsov <[email protected]>
  • Loading branch information
Eldies and zhiltsov-max authored Jan 8, 2025
1 parent 7812654 commit 08e77b2
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 23 deletions.
2 changes: 1 addition & 1 deletion site/source/api/developer_manual.rst
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ or by regular module importing:
.. code-block:: python
import datumaro as dm
from datumaro.plugins.data_formats.yolo.converter import YoloConverter
from datumaro.plugins.data_formats.yolo.exporter import YoloConverter
# Import a dataset
dataset = dm.Dataset.import_from(src_dir, 'voc')
Expand Down
41 changes: 41 additions & 0 deletions src/datumaro/plugins/data_formats/yolo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from datumaro.util.meta_file_util import get_meta_file, has_meta_file, parse_meta_file
from datumaro.util.os_util import split_path

from .exporter import bbox_annotation_as_polygon
from .format import (
YoloPath,
YoloUltralyticsClassificationFormat,
Expand Down Expand Up @@ -521,6 +522,45 @@ def _load_one_annotation(


class YoloUltralyticsOrientedBoxesExtractor(YoloUltralyticsDetectionExtractor):
@staticmethod
def _restore_original_rotation(
imported_points: list[tuple[float, float]],
bbox: Bbox,
) -> Bbox:
exported_points = np.array(list(take_by(bbox_annotation_as_polygon(bbox), count=2)))
best_shift = min(
range(4),
key=lambda shift: sum(
np.linalg.norm(np.roll(imported_points, -shift, axis=0) - exported_points, axis=1)
),
)
if best_shift == 0:
return bbox

x, y = bbox.x, bbox.y
width, height = bbox.w, bbox.h
rotation = bbox.attributes.get("rotation", 0)
center_x = x + width / 2
center_y = y + height / 2
if best_shift == 1:
rotation -= 90
width, height = height, width
elif best_shift == 2:
rotation -= 180
elif best_shift == 3:
rotation -= 270
width, height = height, width
rotation = rotation % 360

return Bbox(
x=center_x - width / 2,
y=center_y - height / 2,
w=width,
h=height,
label=bbox.label,
attributes=(dict(rotation=rotation) if abs(rotation) > 0.00001 else {}),
)

def _load_one_annotation(
self, parts: List[str], image_height: int, image_width: int
) -> Annotation:
Expand Down Expand Up @@ -553,6 +593,7 @@ def _load_one_annotation(
label=label_id,
attributes=(dict(rotation=rotation) if abs(rotation) > 0.00001 else {}),
)
bbox = self._restore_original_rotation(points, bbox)
if len(parts) == 10:
bbox.attributes["track_id"] = self._parse_field(parts[-1], int, "bbox track id")
return bbox
Expand Down
4 changes: 2 additions & 2 deletions src/datumaro/plugins/data_formats/yolo/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _make_yolo_bbox(img_size, box):
return x, y, w, h


def _bbox_annotation_as_polygon(bbox: Bbox) -> List[float]:
def bbox_annotation_as_polygon(bbox: Bbox) -> List[float]:
points = bbox.as_polygon()

def rotate_point(x: float, y: float):
Expand Down Expand Up @@ -388,7 +388,7 @@ class YoloUltralyticsOrientedBoxesConverter(YoloUltralyticsDetectionConverter):
def _make_annotation_line(self, width: int, height: int, anno: Annotation) -> Optional[str]:
if anno.label is None or not isinstance(anno, Bbox):
return
points = _bbox_annotation_as_polygon(anno)
points = bbox_annotation_as_polygon(anno)
values = [value / size for value, size in zip(points, cycle((width, height)))]
string_values = " ".join("%.6f" % p for p in values)
return f"{self._map_labels_for_save[anno.label]} {string_values}{self._make_track_id_suffix(anno)}\n"
Expand Down
24 changes: 4 additions & 20 deletions tests/unit/data_formats/test_yolo_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,10 @@ def compare_rotated_annotations(expected: Bbox, actual: Bbox, ignored_attrs=None
rotation_diff = expected.attributes.get("rotation", 0) - actual.attributes.get(
"rotation", 0
)
rotation_diff %= 180
rotation_diff = min(rotation_diff, 180 - rotation_diff)
assert rotation_diff < 0.01 or abs(rotation_diff - 90) < 0.01
if rotation_diff < 0.01:
return compare_annotations(expected, actual, ignored_attrs=ignored_attrs)
if abs(rotation_diff - 90) < 0.01:
x, y, w, h = actual.get_bbox()
center_x = x + w / 2
center_y = y + h / 2
new_width = h
new_height = w
actual = Bbox(
x=center_x - new_width / 2,
y=center_y - new_height / 2,
w=new_width,
h=new_height,
label=actual.label,
attributes=actual.attributes,
)
return compare_annotations(expected, actual, ignored_attrs=ignored_attrs)
rotation_diff %= 360
rotation_diff = min(rotation_diff, 360 - rotation_diff)
assert rotation_diff < 0.01
return compare_annotations(expected, actual, ignored_attrs=ignored_attrs)

compare_datasets(
self.helper_tc,
Expand Down

0 comments on commit 08e77b2

Please sign in to comment.