Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1415: Add LabelConfig->ResponseModel generator #327

Merged
merged 18 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
358 changes: 357 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Repository = 'https://github.com/HumanSignal/label-studio-sdk'
python = "^3.8"
Pillow = ">=10.0.1"
appdirs = ">=1.4.3"
datamodel-code-generator = "^0.26.0"
httpx = ">=0.21.2"
ijson = ">=3.2.3"
jsonschema = ">=4.23.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import types
import sys
import functools
from typing import Type, Dict, Any, Tuple, Generator
from pathlib import Path
from tempfile import TemporaryDirectory
from datamodel_code_generator import DataModelType, PythonVersion, LiteralType
from datamodel_code_generator.model import get_data_model_types
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser
from pydantic import BaseModel
from contextlib import contextmanager


@functools.lru_cache(maxsize=128)
def _generate_model_code(json_schema_str: str, class_name: str = 'MyModel') -> str:

data_model_types = get_data_model_types(
DataModelType.PydanticV2BaseModel,
target_python_version=PythonVersion.PY_311
)

parser = JsonSchemaParser(
json_schema_str,
data_model_type=data_model_types.data_model,
data_model_root_type=data_model_types.root_model,
data_model_field_type=data_model_types.field_model,
data_type_manager_type=data_model_types.data_type_manager,
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
enum_field_as_literal=LiteralType.All,
class_name=class_name
)

model_code = parser.parse()
return model_code

@contextmanager
def json_schema_to_pydantic(json_schema: dict, class_name: str = 'MyModel') -> Generator[Type[BaseModel], None, None]:
"""
Convert a JSON schema to a Pydantic model and provide it as a context manager.

Args:
json_schema (dict): The JSON schema to convert.
class_name (str, optional): The name of the generated Pydantic class. Defaults to 'MyModel'.

Example:
```python
example_schema = {
"type": "object",
"properties": {
"sentiment": {
"type": "string",
"description": "Sentiment of the text",
"enum": ["Positive", "Negative", "Neutral"],
}
},
"required": ["sentiment"]
}
with json_schema_to_pydantic(example_schema) as ResponseModel:
instance = ResponseModel(sentiment='Positive')
print(instance.model_dump())
```
"""
# Convert the JSON schema dictionary to a JSON string
json_schema_str = json.dumps(json_schema)

# Generate Pydantic model code from the JSON schema string
model_code: str = _generate_model_code(json_schema_str, class_name)

# Create a unique module name using the id of the JSON schema string
module_name = f'dynamic_module_{id(json_schema_str)}'

# Create a new module object with the unique name and execute the generated model code in the context of the new module
mod = types.ModuleType(module_name)
exec(model_code, mod.__dict__)
model_class = getattr(mod, class_name)

try:
# Add the new module to sys.modules to make it importable
# This is necessary to avoid Pydantic errors related to undefined models
sys.modules[module_name] = mod
yield model_class
finally:
if module_name in sys.modules:
del sys.modules[module_name]

161 changes: 161 additions & 0 deletions src/label_studio_sdk/label_interface/control_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def validate_node(cls, tag: xml.etree.ElementTree.Element) -> bool:
and tag.tag not in _NOT_CONTROL_TAGS
)

def to_json_schema(self):
"""
Converts the current ControlTag instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema.
"""
return {"type": "string"}

@classmethod
def parse_node(cls, tag: xml.etree.ElementTree.Element, tags_mapping=None) -> "ControlTag":
"""
Expand Down Expand Up @@ -485,6 +494,19 @@ class ChoicesTag(ControlTag):
_label_attr_name: str = "choices"
_value_class: Type[ChoicesValue] = ChoicesValue

def to_json_schema(self):
"""
Converts the current ChoicesTag instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema compatible with OpenAPI 3.
"""
return {
"type": "string",
"enum": self.labels,
"description": f"Choices for {self.to_name[0]}"
}


class LabelsValue(SpanSelection):
labels: List[str]
Expand All @@ -496,6 +518,41 @@ class LabelsTag(ControlTag):
_label_attr_name: str = "labels"
_value_class: Type[LabelsValue] = LabelsValue

def to_json_schema(self):
"""
Converts the current LabelsTag instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema compatible with OpenAPI 3.
"""
return {
"type": "array",
"items": {
"type": "object",
"required": ["start", "end", "labels"],
"properties": {
"start": {
"type": "integer",
"minimum": 0
},
"end": {
"type": "integer",
"minimum": 0
},
"labels": {
"type": "array",
"items": {
"type": "string",
"enum": self.labels
}
},
"text": {
"type": "string"
}
}
},
"description": f"Labels and span indices for {self.to_name[0]}"
}

## Image tags

Expand Down Expand Up @@ -684,6 +741,26 @@ class NumberTag(ControlTag):
""" """
tag: str = "Number"
_value_class: Type[NumberValue] = NumberValue
_label_attr_name: str = "number"

def to_json_schema(self):
"""
Converts the current NumberTag instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema compatible with OpenAPI 3.
"""
schema = {
"type": "number",
"description": f"Number for {self.to_name[0]}"
}

if 'min' in self.attr:
schema["minimum"] = float(self.attr['min'])
if 'max' in self.attr:
schema["maximum"] = float(self.attr['max'])

return schema


class DateTimeValue(BaseModel):
Expand All @@ -694,6 +771,25 @@ class DateTimeTag(ControlTag):
""" """
tag: str = "DateTime"
_value_class: Type[DateTimeValue] = DateTimeValue
_label_attr_name: str = "datetime"

def _label_simple(self, to_name: Optional[str] = None, *args, **kwargs) -> Region:
# TODO: temporary fix to force datetime to be a string
kwargs['datetime'] = kwargs['datetime'][0]
return super()._label_simple(to_name, *args, **kwargs)

def to_json_schema(self):
"""
Converts the current DateTimeTag instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema compatible with OpenAPI 3.
"""
return {
"type": "string",
"format": "date-time",
"description": f"Date and time for {self.to_name[0]}"
}


class HyperTextLabelsValue(SpanSelectionOffsets):
Expand All @@ -715,12 +811,26 @@ class PairwiseTag(ControlTag):
""" """
tag: str = "Pairwise"
_value_class: Type[PairwiseValue] = PairwiseValue
_label_attr_name: str = "selected"

def label(self, side):
""" """
value = PairwiseValue(selected=side)
return Region(from_tag=self, to_tag=self, value=value)

def to_json_schema(self):
"""
Converts the current PairwiseTag instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema compatible with OpenAPI 3.
"""
return {
"type": "string",
"enum": ["left", "right"],
"description": f"Pairwise selection between {self.to_name[0]} (left) and {self.to_name[1]} (right)"
}


class ParagraphLabelsValue(SpanSelectionOffsets):
paragraphlabels: List[str]
Expand Down Expand Up @@ -759,6 +869,22 @@ class RatingTag(ControlTag):
""" """
tag: str = "Rating"
_value_class: Type[RatingValue] = RatingValue
_label_attr_name: str = "rating"

def to_json_schema(self):
"""
Converts the current RatingTag instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema compatible with OpenAPI 3.
"""
max_rating = int(self.attr.get('maxRating', 5)) # Default to 5 if not specified
return {
"type": "integer",
"minimum": 0,
"maximum": max_rating,
"description": f"Rating for {self.to_name[0]} (0 to {max_rating})"
}


class RelationsTag(ControlTag):
Expand All @@ -784,6 +910,27 @@ class TaxonomyTag(ControlTag):
""" """
tag: str = "Taxonomy"
_value_class: Type[TaxonomyValue] = TaxonomyValue
_label_attr_name: str = "taxonomy"

def to_json_schema(self):
"""
Converts the current TaxonomyTag instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema compatible with OpenAPI 3.
"""
return {
"type": "array",
"items": {
"type": "array",
"items": {
"type": "string",
# TODO: enforce the order of the enums according to the taxonomy tree
"enum": self.labels
}
},
"description": f"Taxonomy for {self.to_name[0]}. Each item is a path from root to selected node."
}


class TextAreaValue(BaseModel):
Expand All @@ -794,6 +941,20 @@ class TextAreaTag(ControlTag):
""" """
tag: str = "TextArea"
_value_class: Type[TextAreaValue] = TextAreaValue
_label_attr_name: str = "text"

def to_json_schema(self):
"""
Converts the current TextAreaTag instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema compatible with OpenAPI 3.
"""
return {
"type": "string",
"description": f"Text for {self.to_name[0]}"
}



class TimeSeriesValue(SpanSelection):
Expand Down
16 changes: 15 additions & 1 deletion src/label_studio_sdk/label_interface/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,12 +531,26 @@ def load_task(self, task):
tree.task_loaded = True

for obj in tree.objects:
print(obj.value_is_variable, obj.value_name)
if obj.value_is_variable and obj.value_name in task:
obj.value = task.get(obj.value_name)

return tree

def to_json_schema(self):
"""
Converts the current LabelInterface instance into a JSON Schema.

Returns:
dict: A dictionary representing the JSON Schema.
"""
return {
"type": "object",
"properties": {
name: control.to_json_schema() for name, control in self._controls.items()
},
"required": list(self._controls.keys())
}

def parse(self, config_string: str) -> Tuple[Dict, Dict, Dict, etree._Element]:
"""Parses the received configuration string into dictionaries
of ControlTags, ObjectTags, and Labels, along with an XML tree
Expand Down
2 changes: 1 addition & 1 deletion tests/custom/test_interface/test_control_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_validate():
def test_textarea_label():
conf = LabelInterface(c.TEXTAREA_CONF)

region = conf.get_control(c.FROM_NAME).label(text=["Hello", "World"])
region = conf.get_control(c.FROM_NAME).label(label=["Hello", "World"])


def test_label_with_choices():
Expand Down
14 changes: 7 additions & 7 deletions tests/custom/test_interface/test_control_tags_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@

(OT.ImageTag, CT.RectangleTag, { "x": 10, "y": 10, "width": 10, "height": 10 }, { "x": 10.0, "y": 10.0, "width": 10.0, "height": 10.0, "rotation": 0 }),
(OT.ImageTag, CT.RectangleLabelsTag, { "x": 10, "y": 10, "width": 10, "height": 10, "label": c.LABEL1 }, { "x": 10.0, "y": 10.0, "width": 10.0, "height": 10.0, "rotation": 0, "rectanglelabels": [ c.LABEL1 ] }),
(OT.ImageTag, CT.TaxonomyTag, { "taxonomy": [ [ c.LABEL1 ] ] }, { "taxonomy": [ [ c.LABEL1 ] ] }),
(OT.ImageTag, CT.TextAreaTag, { "text": [ EX_TEXT, EX_TEXT ] }, { "text": [ EX_TEXT, EX_TEXT ] }),
(OT.ImageTag, CT.TaxonomyTag, { "label": [ [ c.LABEL1 ] ] }, { "taxonomy": [ [ c.LABEL1 ] ] }),
(OT.ImageTag, CT.TextAreaTag, { "label": [ EX_TEXT, EX_TEXT ] }, { "text": [ EX_TEXT, EX_TEXT ] }),

(OT.ImageTag, CT.RatingTag, { "rating": 3 }, { "rating": 3 }),
(OT.ImageTag, CT.RatingTag, { "label": 3 }, { "rating": 3 }),

(OT.ImageTag, CT.BrushTag, { "rle": [2,3,3,2] }, { "rle": [2,3,3,2], "format": "rle" }),
(OT.ImageTag, CT.BrushLabelsTag, { "rle": [2,3,3,2], "label": c.LABEL1 }, { "rle": [2,3,3,2], "format": "rle", "brushlabels": [ c.LABEL1 ] }),

## Text labeling
(OT.TextTag, CT.NumberTag, { "number": 5 }, { "number": 5 }),
(OT.TextTag, CT.DateTimeTag, { "datetime": "2024-05-07" }, { "datetime": "2024-05-07" }),
(OT.TextTag, CT.NumberTag, { "label": 5 }, { "number": 5 }),
(OT.TextTag, CT.DateTimeTag, { "label": "2024-05-07" }, { "datetime": "2024-05-07" }),

(OT.TextTag, CT.LabelsTag, { "start": 1, "end": 10, "label": c.LABEL1 }, { "start": 1, "end": 10, "labels": [ c.LABEL1 ] }),
(OT.TextTag, CT.LabelsTag, { "start": 1, "end": 10, "label": [ c.LABEL1, c.LABEL2 ] }, { "start": 1, "end": 10, "labels": [ c.LABEL1, c.LABEL2 ] }),
(OT.TextTag, CT.LabelsTag, { "label": c.LABEL1, "start": 1, "end": 10 }, { "labels": [ c.LABEL1 ], "start": 1, "end": 10 }),
(OT.TextTag, CT.LabelsTag, { "label": [ c.LABEL1, c.LABEL2 ], "start": 1, "end": 10 }, { "labels": [ c.LABEL1, c.LABEL2 ], "start": 1, "end": 10 }),

## Hypertext labeling
(OT.HyperTextTag, CT.HyperTextLabelsTag, { "start": 1, "end": 10, "startOffset": 10, "endOffset": 10, "label": c.LABEL1 }, { "start": 1, "end": 10, "startOffset": 10, "endOffset": 10, "htmllabels": [ c.LABEL1 ] }),
Expand Down
Loading