Skip to content

Commit

Permalink
feature/mx-1702 add utils for field and type analysis (#283)
Browse files Browse the repository at this point in the history
# PR Context
- these are needed in the editor as well as the backend
- so we move them from the backend to common

# Added

- `contains_only_types` to check if fields are annotated as desired
- `group_fields_by_class_name` utility to simplify filtered model/field
lookups
- new parameters to `get_inner_types` to customize what to unpack
  • Loading branch information
cutoffthetop authored Sep 24, 2024
1 parent bdfcb5f commit bc9f9f8
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- `contains_only_types` to check if fields are annotated as desired
- `group_fields_by_class_name` utility to simplify filtered model/field lookups
- new parameters to `get_inner_types` to customize what to unpack

### Changes

### Deprecated
Expand Down
7 changes: 2 additions & 5 deletions mex/common/models/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import json
from collections.abc import MutableMapping
from functools import cache
from types import UnionType
from typing import Any, Union
from typing import Any

from pydantic import BaseModel as PydanticBaseModel
from pydantic import (
Expand Down Expand Up @@ -100,9 +99,7 @@ def _get_list_field_names(cls) -> list[str]:
"""Build a cached list of fields that look like lists."""
field_names = []
for field_name, field_info in cls.get_all_fields().items():
field_types = get_inner_types(
field_info.annotation, unpack=(Union, UnionType)
)
field_types = get_inner_types(field_info.annotation, unpack_list=False)
if any(
isinstance(field_type, type) and issubclass(field_type, list)
for field_type in field_types
Expand Down
112 changes: 97 additions & 15 deletions mex/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import re
from collections.abc import Container, Generator, Iterable, Iterator
from collections.abc import Callable, Container, Generator, Iterable, Iterator, Mapping
from functools import cache
from itertools import zip_longest
from random import random
from time import sleep
from types import UnionType
from types import NoneType, UnionType
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Literal,
TypeVar,
Union,
get_args,
get_origin,
)

if TYPE_CHECKING: # pragma: no cover
from mex.common.models import GenericFieldInfo
from mex.common.models.base.model import BaseModel

T = TypeVar("T")


Expand All @@ -36,23 +42,99 @@ def any_contains_any(bases: Iterable[Container[T] | None], tokens: Iterable[T])
return False


def contains_only_types(field: "GenericFieldInfo", *types: type) -> bool:
"""Return whether a `field` is annotated as one of the given `types`.
Unions, lists and type annotations are checked for their inner types and only the
non-`NoneType` types are considered for the type-check.
Args:
field: A `GenericFieldInfo` instance
types: Types to look for in the field's annotation
Returns:
Whether the field contains any of the given types
"""
if inner_types := list(get_inner_types(field.annotation, include_none=False)):
return all(inner_type in types for inner_type in inner_types)
return False


def get_inner_types(
annotation: Any, unpack: Iterable[Any] = (Union, UnionType, list)
annotation: Any,
include_none: bool = True,
unpack_list: bool = True,
unpack_literal: bool = True,
) -> Generator[type, None, None]:
"""Yield all inner types from annotations and the types in `unpack`."""
origin = get_origin(annotation)
if origin == Annotated:
yield from get_inner_types(get_args(annotation)[0], unpack)
elif origin in unpack:
for arg in get_args(annotation):
yield from get_inner_types(arg, unpack)
elif origin is not None:
yield origin
elif annotation is None:
yield type(None)
else:
"""Recursively yield all inner types from a given type annotation.
Args:
annotation: The type annotation to process
include_none: Whether to include NoneTypes in output
unpack_list: Whether to unpack list types
unpack_literal: Whether to unpack Literal types
Returns:
All inner types found within the annotation
"""
# Check whether to unpack lists in addition to annotations and unions
types_to_unpack = [Annotated, Union, UnionType] + ([list] if unpack_list else [])

# Get the unsubscripted version of the given type annotation
origin_type = get_origin(annotation)

# If the origin should be unpacked
if origin_type in types_to_unpack:
for inner_type in get_args(annotation):
# Recursively process each inner type
yield from get_inner_types(
inner_type, include_none, unpack_list, unpack_literal
)

# Handle Literal types based on the unpack_literal flag
elif origin_type is Literal:
if unpack_literal:
yield origin_type # Return Literal if unpacking is allowed
else:
yield annotation # Return the full annotation if not

# Yield the origin type if present
elif origin_type is not None:
yield origin_type

# Yield the annotation if it is valid type, that isn't NoneType
elif isinstance(annotation, type) and annotation is not NoneType:
yield annotation

# Optionally yield none types
elif include_none and annotation in (None, NoneType):
yield NoneType


def group_fields_by_class_name(
model_classes_by_name: Mapping[str, type["BaseModel"]],
predicate: Callable[["GenericFieldInfo"], bool],
) -> dict[str, list[str]]:
"""Group the field names by model class and filter them by the given predicate.
Args:
model_classes_by_name: Map from class names to model classes
predicate: Function to filter the fields of the classes by
Returns:
Dictionary mapping class names to a list of field names filtered by `predicate`
"""
return {
name: sorted(
{
field_name
for field_name, field_info in cls.get_all_fields().items()
if predicate(field_info)
}
)
for name, cls in model_classes_by_name.items()
}


@cache
def normalize(string: str) -> str:
Expand Down
110 changes: 100 additions & 10 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
import json
import time
from collections.abc import Iterable
from typing import Annotated, Any
from types import NoneType
from typing import Annotated, Any, Literal

import pytest
from pydantic.fields import FieldInfo

from mex.common.models import BaseModel
from mex.common.types import (
MERGED_IDENTIFIER_CLASSES,
Identifier,
MergedPersonIdentifier,
)
from mex.common.utils import (
any_contains_any,
contains_any,
contains_only_types,
get_inner_types,
group_fields_by_class_name,
grouper,
jitter_sleep,
normalize,
)


Expand Down Expand Up @@ -43,18 +54,97 @@ def test_any_contains_any(base: Any, tokens: Iterable[Any], expected: bool) -> N


@pytest.mark.parametrize(
("annotation", "expected_types"),
("annotation", "types", "expected"),
(
(None, [str], False),
(str, [str], True),
(str, [Identifier], False),
(Identifier, [str], False),
(list[str | int | list[str]], [str, float], False),
(list[str | int | list[str]], [int, str], True),
(MergedPersonIdentifier | None, MERGED_IDENTIFIER_CLASSES, True),
),
ids=[
"static None",
"simple str",
"str vs identifier",
"identifier vs str",
"complex miss",
"complex hit",
"optional identifier",
],
)
def test_contains_only_types(
annotation: Any, types: list[type], expected: bool
) -> None:
class DummyModel(BaseModel):
attribute: annotation

assert contains_only_types(DummyModel.model_fields["attribute"], *types) == expected


@pytest.mark.parametrize(
("annotation", "flags", "expected_types"),
(
(str, [str]),
(None, [type(None)]),
(str | None, [str, type(None)]),
(list[str] | None, [str, type(None)]),
(list[str | int | list[str]], [str, int, str]),
(Annotated[str | int, "This is a string or integer"], [str, int]),
(str, {}, [str]),
(None, {}, [NoneType]),
(None, {"include_none": False}, []),
(str | None, {}, [str, NoneType]),
(str | None, {"include_none": False}, [str]),
(list[str] | None, {}, [str, NoneType]),
(list[str | None], {}, [str, NoneType]),
(list[int], {"unpack_list": False}, [list]),
(list[str | int | list[str]], {}, [str, int, str]),
(Annotated[str | int, FieldInfo(description="str or int")], {}, [str, int]),
(Annotated[str | int, "This is a string or integer"], {}, [str, int]),
(Literal["okay"] | None, {}, [Literal, NoneType]),
(
Literal["okay"] | None,
{"unpack_literal": False},
[Literal["okay"], NoneType],
),
),
ids=[
"string",
"None allowing None",
"None skipping None",
"optional string allowing None",
"optional string skipping None",
"optional list of strings",
"list of optional strings",
"not unpacking list",
"list nested in list",
"annotated string or int with FieldInfo",
"annotated string or int with plain text",
"unpacking literal",
"not unpacking literal",
],
)
def test_get_inner_types(
annotation: Any, flags: dict[str, bool], expected_types: list[type]
) -> None:
assert list(get_inner_types(annotation, **flags)) == expected_types


def test_group_fields_by_class_name() -> None:
class DummyModel(BaseModel):
number: int
text: str

class PseudoModel(BaseModel):
title: str

lookup = {"Dummy": DummyModel, "Pseudo": PseudoModel}
expected = {"Dummy": ["text"], "Pseudo": ["title"]}
assert group_fields_by_class_name(lookup, lambda f: f.annotation is str) == expected


@pytest.mark.parametrize(
("string", "expected"),
(("", ""), ("__XYZ__", "xyz"), ("/foo/BAR$42", "foo bar 42")),
)
def test_get_inner_types(annotation: Any, expected_types: list[type]) -> None:
assert list(get_inner_types(annotation)) == expected_types
def test_normalize(string: str, expected: str) -> None:
assert normalize(string) == expected


def test_grouper() -> None:
Expand Down

0 comments on commit bc9f9f8

Please sign in to comment.