Skip to content

Commit

Permalink
typing and import fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Eldies committed Jan 21, 2025
1 parent 53d8c2c commit 662be83
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 53 deletions.
8 changes: 4 additions & 4 deletions src/datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import warnings
from contextlib import contextmanager
from copy import copy
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union

from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.config_model import Source
Expand Down Expand Up @@ -86,7 +86,7 @@ def categories(self):
def media_type(self):
return self.parent.media_type()

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return set()

def as_dataset(self) -> Dataset:
Expand Down Expand Up @@ -238,10 +238,10 @@ def infos(self) -> DatasetInfo:
def categories(self) -> CategoriesInfo:
return self._data.categories()

def media_type(self) -> Type[MediaElement]:
def media_type(self):
return self._data.media_type()

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return set()

def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]:
Expand Down
4 changes: 2 additions & 2 deletions src/datumaro/components/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ def __init__(
self._media_type = media_type
self._ann_types = ann_types if ann_types else set()

def media_type(self) -> Type[MediaElement]:
def media_type(self):
return self._media_type

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return self._ann_types


Expand Down
4 changes: 2 additions & 2 deletions src/datumaro/components/dataset_item_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def categories(self):
def media_type(self):
return self.parent.media_type()

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return self.parent.ann_types()

def __init__(
Expand Down Expand Up @@ -213,5 +213,5 @@ def get(self, id, subset=None):
def media_type(self):
return self._media_type

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return self._ann_types
13 changes: 6 additions & 7 deletions src/datumaro/components/dataset_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def transform_item(self, item: DatasetItem) -> DatasetItem:
for t in self.transforms:
if item is None:
break
t: ItemTransform
item = t.transform_item(item)
return item

Expand All @@ -117,10 +116,10 @@ def infos(self) -> DatasetInfo:
def categories(self) -> CategoriesInfo:
return self.transforms[-1].categories()

def media_type(self) -> Type[MediaElement]:
def media_type(self):
return self.transforms[-1].media_type()

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return self.transforms[-1].ann_types()


Expand Down Expand Up @@ -434,10 +433,10 @@ def define_categories(self, categories: CategoriesInfo):
raise CategoriesRedefinedError()
self._categories = categories

def media_type(self) -> Type[MediaElement]:
def media_type(self):
return self._media_type

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return self._ann_types

def put(self, item: DatasetItem) -> None:
Expand Down Expand Up @@ -645,10 +644,10 @@ def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]:
"You can access to the dataset item only by using its iterator."
)

def media_type(self) -> Type[MediaElement]:
def media_type(self):
return self._source.media_type()

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return self._source.ann_types()

@property
Expand Down
22 changes: 5 additions & 17 deletions src/datumaro/components/extractor_tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,7 @@
import logging as log
import os.path as osp
from types import SimpleNamespace as namespace
from typing import (
Any,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union

import attrs
from attrs import field, frozen
Expand Down Expand Up @@ -442,10 +430,10 @@ def get(self, id, subset=None) -> Optional[DatasetItem]:

return None

def media_type(self) -> Type[MediaElement]:
def media_type(self):
return self._parent._media_type

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return self._parent.ann_types()


Expand Down Expand Up @@ -510,10 +498,10 @@ def get(self, id, subset=None) -> Optional[DatasetItem]:
return None
return self._split_extractors[subset].get(id)

def media_type(self) -> Type[MediaElement]:
def media_type(self):
return self._media_type

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
ann_types = set()
for items in self._split_extractors.values():
for item in items:
Expand Down
2 changes: 1 addition & 1 deletion src/datumaro/components/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from os import path as osp
from typing import Callable, Dict, List, Optional

from datumaro import CliPlugin
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.errors import DatasetNotFoundError
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext

Expand Down
4 changes: 1 addition & 3 deletions src/datumaro/components/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
List,
NewType,
Optional,
Set,
Tuple,
TypeVar,
Union,
Expand All @@ -31,7 +30,6 @@
import networkx as nx
import ruamel.yaml as yaml

from datumaro import AnnotationType
from datumaro.components.config import Config
from datumaro.components.config_model import (
BuildStage,
Expand Down Expand Up @@ -156,7 +154,7 @@ def get(self, id, subset=None):
def media_type(self):
return self._dataset.media_type()

def ann_types(self) -> Set[AnnotationType]:
def ann_types(self):
return set()


Expand Down
2 changes: 1 addition & 1 deletion src/datumaro/components/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np

from datumaro import MediaElement
from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_base import (
Expand All @@ -17,6 +16,7 @@
IDataset,
)
from datumaro.components.launcher import Launcher
from datumaro.components.media import MediaElement
from datumaro.util import is_method_redefined, take_by
from datumaro.util.multi_procs_util import consumer_generator

Expand Down
32 changes: 16 additions & 16 deletions src/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, extractor: IDataset, allow_removal: bool = False):

self._allow_removal = allow_removal

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
annotations = []
segments = []
for ann in item.annotations:
Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(self, extractor: IDataset, include_polygons: bool = False):

self._include_polygons = include_polygons

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
annotations = []
segments = []
for ann in item.annotations:
Expand Down Expand Up @@ -254,7 +254,7 @@ def find_instances(annotations: Sequence[Annotation]) -> Sequence[Sequence[Annot


class PolygonsToMasks(ItemTransform, CliPlugin):
def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
annotations = []
for ann in item.annotations:
if ann.type == AnnotationType.polygon:
Expand Down Expand Up @@ -282,7 +282,7 @@ def convert_polygon(polygon: Polygon, img_h: int, img_w: int):


class BoxesToMasks(ItemTransform, CliPlugin):
def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
annotations = []
for ann in item.annotations:
if ann.type == AnnotationType.bbox:
Expand Down Expand Up @@ -310,7 +310,7 @@ def convert_bbox(bbox: Bbox, img_h: int, img_w: int):


class MasksToPolygons(ItemTransform, CliPlugin):
def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
annotations = []
for ann in item.annotations:
if ann.type == AnnotationType.mask:
Expand Down Expand Up @@ -345,7 +345,7 @@ def convert_mask(mask: Mask) -> list[Polygon]:


class ShapesToBoxes(ItemTransform, CliPlugin):
def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
annotations = []
for ann in item.annotations:
if ann.type in {
Expand Down Expand Up @@ -438,7 +438,7 @@ def __init__(
self._length = "parent"
self._subsets = set(counts)

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
return self.wrap_item(item, subset=self._mapping.get(item.subset, item.subset))


Expand Down Expand Up @@ -532,7 +532,7 @@ class IdFromImageName(ItemTransform, CliPlugin):
Renames items in the dataset using image file name (without extension).
"""

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
if isinstance(item.media, Image) and item.media.path:
name = osp.splitext(osp.basename(item.media.path))[0]
return self.wrap_item(item, id=name)
Expand Down Expand Up @@ -588,7 +588,7 @@ def __init__(self, extractor: IDataset, regex: str):
self._re = re.compile(regex)
self._sub = sub

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
return self.wrap_item(item, id=self._re.sub(self._sub, item.id).format(item=item))


Expand Down Expand Up @@ -884,7 +884,7 @@ def _make_label_id_map(self, src_label_cat: LabelCategories, dst_label_cat: Labe
def categories(self) -> CategoriesInfo:
return self._categories

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
annotations = []
for ann in item.annotations:
if getattr(ann, "label", None) is not None:
Expand All @@ -902,7 +902,7 @@ class AnnsToLabels(ItemTransform, CliPlugin):
transforms them into a set of annotations of type Label
"""

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
labels = set(p.label for p in item.annotations if getattr(p, "label") is not None)
annotations = []
for label in labels:
Expand All @@ -916,7 +916,7 @@ class BboxValuesDecrement(ItemTransform, CliPlugin):
Subtracts one from the coordinates of bounding boxes
"""

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
annotations = [p for p in item.annotations if p.type != AnnotationType.bbox]
bboxes = [p for p in item.annotations if p.type == AnnotationType.bbox]
for bbox in bboxes:
Expand Down Expand Up @@ -1025,7 +1025,7 @@ def _lazy_encode():

return _lazy_encode

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
if not isinstance(item.media, Image):
raise DatumaroError("Item %s: image info is required for this transform" % (item.id,))

Expand Down Expand Up @@ -1122,7 +1122,7 @@ def __init__(self, extractor: IDataset, ids: Iterable[Tuple[str, str]]):
super().__init__(extractor)
self._ids = set(tuple(v) for v in (ids or []))

def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]:
def transform_item(self, item):
if (item.id, item.subset) in self._ids:
return None
return item
Expand Down Expand Up @@ -1169,7 +1169,7 @@ def __init__(self, extractor: IDataset, *, ids: Optional[Iterable[Tuple[str, str
super().__init__(extractor)
self._ids = set(tuple(v) for v in (ids or []))

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
if not self._ids or (item.id, item.subset) in self._ids:
return item.wrap(annotations=[])
return item
Expand Down Expand Up @@ -1243,7 +1243,7 @@ def _filter_attrs(self, attrs):
else:
return filter_dict(attrs, exclude_keys=self._attributes)

def transform_item(self, item: DatasetItem) -> DatasetItem:
def transform_item(self, item):
if not self._ids or (item.id, item.subset) in self._ids:
filtered_annotations = []
for ann in item.annotations:
Expand Down

0 comments on commit 662be83

Please sign in to comment.