Skip to content

Commit

Permalink
add util function for checking if annotation is of a certain type
Browse files Browse the repository at this point in the history
  • Loading branch information
rababerladuseladim committed Nov 16, 2023
1 parent 12ac412 commit b7d2235
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
14 changes: 9 additions & 5 deletions mex/common/public_api/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from mex.common.types import Link, LinkLanguage, Text, TextLanguage


def _is_type(type_: type, annotation: type | None) -> bool:
"""Check if annotation is or contains the provided type."""
return type_ in (annotation, *get_args(annotation))


def transform_mex_model_to_public_api_item(model: MExModel) -> PublicApiItem:
"""Convert an ExtractedData instance into a Public API item.
Expand All @@ -27,9 +32,8 @@ def transform_mex_model_to_public_api_item(model: MExModel) -> PublicApiItem:
model_dict = model.model_dump(exclude_none=True)
for field_name in sorted(model_dict):
field = model.model_fields[field_name]
is_text_or_link = any(
type_ in (Text, Link)
for type_ in [field.annotation, *get_args(field.annotation)]
is_text_or_link = _is_type(Text, field.annotation) or _is_type(
Link, field.annotation
)
if is_text_or_link:
model_values = getattr(model, field_name)
Expand Down Expand Up @@ -76,8 +80,8 @@ def transform_public_api_item_to_mex_model(
for value in api_item.values:
field_name = value.fieldName
annotation = cls.model_fields[field_name].annotation
is_link = any(type_ is Link for type_ in [annotation, *get_args(annotation)])
is_text = any(type_ is Text for type_ in [annotation, *get_args(annotation)])
is_link = _is_type(Link, annotation)
is_text = _is_type(Text, annotation)
if isinstance(value.fieldValue, list):
values = value.fieldValue
else:
Expand Down
16 changes: 16 additions & 0 deletions tests/public_api/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mex.common.models import EXTRACTED_MODEL_CLASSES_BY_NAME, MExModel
from mex.common.public_api.models import PublicApiItem
from mex.common.public_api.transform import (
_is_type,
transform_mex_model_to_public_api_item,
transform_public_api_item_to_mex_model,
)
Expand Down Expand Up @@ -126,6 +127,21 @@ def raw_api_item() -> dict[str, Any]:
}


@pytest.mark.parametrize(
("type_", "annotation", "expected"),
[
(str, str, True),
(int, int, True),
(str, int, False),
(str, list[int], False),
(str, list[str], True),
(str, str | None, True),
],
)
def test__is_type(type_: type, annotation: type | None, expected: bool) -> None:
assert _is_type(type_, annotation) is expected


def test_transform_mex_model_to_public_api_item(
raw_mex_model: dict[str, Any], raw_api_item: dict[str, Any]
) -> None:
Expand Down

0 comments on commit b7d2235

Please sign in to comment.