Skip to content

Commit

Permalink
Update enum name
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max committed Feb 22, 2024
1 parent 7c8d976 commit 0b8648d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
14 changes: 7 additions & 7 deletions packages/examples/cvat/exchange-oracle/src/core/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DataInfo(BaseModel):
"which provides information about all objects on images"


class LabelType(str, Enum, metaclass=BetterEnumMeta):
class LabelTypes(str, Enum, metaclass=BetterEnumMeta):
plain = "plain"
skeleton = "skeleton"

Expand All @@ -32,15 +32,15 @@ class LabelInfoBase(BaseModel):
name: str = Field(min_length=1)
# https://opencv.github.io/cvat/docs/api_sdk/sdk/reference/models/label/

type: LabelType = LabelType.plain
type: LabelTypes = LabelTypes.plain


class PlainLabelInfo(LabelInfoBase):
type: Literal[LabelType.plain]
type: Literal[LabelTypes.plain]


class SkeletonLabelInfo(LabelInfoBase):
type: Literal[LabelType.skeleton]
type: Literal[LabelTypes.skeleton]

nodes: List[str] = Field(min_items=1)
"""
Expand All @@ -57,8 +57,8 @@ class SkeletonLabelInfo(LabelInfoBase):
@root_validator
@classmethod
def validate_type(cls, values: dict) -> dict:
if values["type"] != LabelType.skeleton:
raise ValueError(f"Label type must be {LabelType.skeleton}")
if values["type"] != LabelTypes.skeleton:
raise ValueError(f"Label type must be {LabelTypes.skeleton}")

skeleton_name = values["name"]

Expand Down Expand Up @@ -139,7 +139,7 @@ def parse_manifest(manifest: Any) -> TaskManifest:
try:
labels = manifest["annotation"]["labels"]
for label_info in labels:
label_info["type"] = label_info.get("type", LabelType.plain)
label_info["type"] = label_info.get("type", LabelTypes.plain)
except KeyError:
pass

Expand Down
14 changes: 7 additions & 7 deletions packages/examples/cvat/recording-oracle/src/core/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DataInfo(BaseModel):
"which provides information about all objects on images"


class LabelType(str, Enum, metaclass=BetterEnumMeta):
class LabelTypes(str, Enum, metaclass=BetterEnumMeta):
plain = "plain"
skeleton = "skeleton"

Expand All @@ -32,15 +32,15 @@ class LabelInfoBase(BaseModel):
name: str = Field(min_length=1)
# https://opencv.github.io/cvat/docs/api_sdk/sdk/reference/models/label/

type: LabelType = LabelType.plain
type: LabelTypes = LabelTypes.plain


class PlainLabelInfo(LabelInfoBase):
type: Literal[LabelType.plain]
type: Literal[LabelTypes.plain]


class SkeletonLabelInfo(LabelInfoBase):
type: Literal[LabelType.skeleton]
type: Literal[LabelTypes.skeleton]

nodes: List[str] = Field(min_items=1)
"""
Expand All @@ -57,8 +57,8 @@ class SkeletonLabelInfo(LabelInfoBase):
@root_validator
@classmethod
def validate_type(cls, values: dict) -> dict:
if values["type"] != LabelType.skeleton:
raise ValueError(f"Label type must be {LabelType.skeleton}")
if values["type"] != LabelTypes.skeleton:
raise ValueError(f"Label type must be {LabelTypes.skeleton}")

skeleton_name = values["name"]

Expand Down Expand Up @@ -139,7 +139,7 @@ def parse_manifest(manifest: Any) -> TaskManifest:
try:
labels = manifest["annotation"]["labels"]
for label_info in labels:
label_info["type"] = label_info.get("type", LabelType.plain)
label_info["type"] = label_info.get("type", LabelTypes.plain)
except KeyError:
pass

Expand Down

0 comments on commit 0b8648d

Please sign in to comment.