Skip to content

Commit

Permalink
Feature/mx 1380 update pydantic to v2 (#24)
Browse files Browse the repository at this point in the history
This PR adresses the mex-common part of [Ticket
1380](https://jira.rki.local/browse/MX-1380).

Incomplete summary of relevant pydantic changes:
- `pydantic_model.dict()` becomes `pydantic_model.model_dump()`
- `pydantic_model.parse_obj()` becomes `pydantic_model.model_validate()`
- `pydantic_model.construct()` becomes
`pydantic_model.model_construct()`
- `pydantic_model.Config` becomes `pydantic_model.model_config`, which
is of type pydantic.ConfigDict
- optional fields must now have a default value, e.g. `attribute: str |
None = None`
- settings field attribute `env` becomes `validation_alias` or
`serialization_alias`, depending on the direction (serialization or
validation)
- for a more complete overview, see the [pydantic migration
guide](https://docs.pydantic.dev/dev-v2/migration/)

This PR:
- bumps pydantic to v2.4.2
- adds the mypy plugin for pydantic

Open Questions:
- pytest captures warnings in
`tests/backend_api/test_connector.py::test_post_models_mocked`and
`tests/public_api/test_connector.py::test_post_models_mocked`:

```
C:\Users\SchiebenhoeferH\projects\mex-common\.venv\Lib\site-packages\pydantic\main.py:308: UserWarning: Pydantic serializer warnings:
    Expected `list[str]` but got `Email` - serialized value may not be as expected
    Expected `list[str]` but got `str` - serialized value may not be as expected
    Expected `list[str]` but got `str` - serialized value may not be as expected
    return self.__pydantic_serializer__.to_python(
```

I guess these are due to our listyness fix. The only fix I can think of
is to convert attributes to their expected type during serialization
which seems a bit excessive. What do you think?

---------

Signed-off-by: rababerladuseladim <[email protected]>
Co-authored-by: Nicolas Drebenstedt <[email protected]>
  • Loading branch information
rababerladuseladim and cutoffthetop authored Nov 17, 2023
1 parent cd31a22 commit 97a1ef1
Show file tree
Hide file tree
Showing 66 changed files with 1,090 additions and 659 deletions.
1 change: 1 addition & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ python_version = 3.11
follow_imports = silent
show_error_codes = True
strict = True
plugins = pydantic.mypy

[pydantic-mypy]
warn_untyped_fields = True
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ repos:
additional_dependencies:
- "backoff>=2.2.1,<3"
- "click>=8.1.7,<9"
- "pydantic[dotenv,email]>=1.10.12,<2"
- "pandas-stubs>=2.0.3.230814"
- "pydantic>=2.1.1,<3"
- "pydantic-settings>=2.0.2,<3"
- "pytest>=7.4.3,<8"
- "types-pytz>=2023.3.1.1,<2024"
- "types-requests>=2.31.0.8,<3"
Expand Down
4 changes: 2 additions & 2 deletions mex/common/backend_api/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _set_authentication(self) -> None:
def _set_url(self) -> None:
"""Set the backend api url with the version path."""
settings = BaseSettings.get()
self.url = urljoin(settings.backend_api_url, self.API_VERSION)
self.url = urljoin(str(settings.backend_api_url), self.API_VERSION)

def post_models(self, models: list[MExModel]) -> list[Identifier]:
"""Post models to Backend API in a bulk insertion request.
Expand All @@ -50,5 +50,5 @@ def post_models(self, models: list[MExModel]) -> list[Identifier]:
)
},
)
insert_response = BulkInsertResponse.parse_obj(response)
insert_response = BulkInsertResponse.model_validate(response)
return insert_response.identifiers
54 changes: 30 additions & 24 deletions mex/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
import pdb # noqa: T100
import sys
from bdb import BdbQuit
from enum import Enum
from functools import partial
from textwrap import dedent
from traceback import format_exc
from typing import Callable, Union, get_origin
from typing import Any, Callable

import click
from click import Command, Option
from click.core import ParameterSource
from click.exceptions import Abort, Exit
from pydantic import SecretStr
from pydantic.fields import ModelField
from pydantic.fields import FieldInfo

from mex.common.connector import reset_connector_context
from mex.common.logging import echo, logger
Expand All @@ -31,29 +29,32 @@
"""


def _field_to_parameters(field: ModelField) -> list[str]:
def _field_to_parameters(name: str, field: FieldInfo) -> list[str]:
"""Convert a field of a pydantic settings class into parameter declarations.
The field's name and alias are considered. Underscores are replaced with dashes
and single character parameters have two leading dashes while single character
parameters have just one.
Args:
name: name of the Field
field: Field of a Settings definition class
Returns:
List of parameter declaring strings
"""
names = [n.replace("_", "-") for n in sorted({field.name, field.alias}) if n]
names = [name] + ([field.alias] if field.alias else [])
names = [n.replace("_", "-") for n in names]
dashes = ["--" if len(n) > 1 else "-" for n in names]
return [f"{d}{n}" for d, n in zip(dashes, names)]


def _field_to_option(field: ModelField) -> Option:
def _field_to_option(name: str, settings_cls: type[SettingsType]) -> Option:
"""Convert a field of a pydantic settings class into a click option.
Args:
field: Field of a Settings definition class
name: name of the Field
settings_cls: Base settings class or a subclass of it
Returns:
Option: click Option with appropriate attributes
Expand All @@ -63,25 +64,30 @@ def _field_to_option(field: ModelField) -> Option:
# complex fields or type unions are always interpreted as strings
# and add support for SecretStr fields with correct default values
# https://pydantic-docs.helpmanual.io/usage/types/#secret-types
if (
field.is_complex()
or get_origin(field.type_) is Union
or issubclass(field.type_, (str, SecretStr, Enum))
):
field_type = str
default = json.dumps(field.default, cls=MExEncoder).strip('"')
field = settings_cls.model_fields[name]

if field.annotation in (int, bool, float):
field_type: Any = field.annotation
else:
field_type = field.type_
field_type = str

if field.is_required():
default = None
elif field.annotation in (int, bool, float):
default = field.default
else:
default = json.dumps(field.default, cls=MExEncoder).strip('"')

return Option(
_field_to_parameters(field),
_field_to_parameters(name, field),
default=default,
envvar=next(iter(field.field_info.extra["env_names"])).upper(),
help=field.field_info.description,
is_flag=field.type_ is bool and field.default is False,
envvar=settings_cls.get_env_name(name),
help=field.description,
is_flag=field.annotation is bool and field.default is False,
show_default=True,
show_envvar=True,
type=field_type,
required=field.is_required(),
)


Expand Down Expand Up @@ -111,7 +117,7 @@ def _callback(
context.call_on_close(reset_connector_context)

# load settings from parameters and store in ContextVar.
settings = settings_cls.parse_obj(
settings = settings_cls.model_validate(
{
key: value
for key, value in cli_settings.items()
Expand Down Expand Up @@ -181,13 +187,13 @@ def decorator(func: Callable[[], None]) -> Command:
return Command(
func.__name__,
help=HELP_TEMPLATE.format(
doc=func.__doc__, env_file=settings_cls.__config__.env_file
doc=func.__doc__, env_file=settings_cls.model_config.get("env_file")
),
callback=partial(_callback, func, settings_cls),
params=[
*[
_field_to_option(field)
for field in settings_cls.__fields__.values()
_field_to_option(name, settings_cls)
for name in settings_cls.model_fields
],
*meta_parameters,
],
Expand Down
9 changes: 6 additions & 3 deletions mex/common/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ def get_dtypes_for_model(model: type["BaseModel"]) -> dict[str, "Dtype"]:
Returns:
Mapping from field alias to dtype strings
"""
return {f.alias: PANDAS_DTYPE_MAP[f.type_] for f in model.__fields__.values()}
return {
f.alias or name: PANDAS_DTYPE_MAP[f.annotation or type(None)]
for name, f in model.model_fields.items()
}


def parse_csv(
path_or_buffer: Union[str, Path, "ReadCsvBuffer"],
path_or_buffer: Union[str, Path, "ReadCsvBuffer[Any]"],
into: type[ModelT],
chunksize: int = 10,
**kwargs: Any,
Expand All @@ -58,7 +61,7 @@ def parse_csv(
row.replace(to_replace=np.nan, value=None, inplace=True)
row.replace(regex=r"^\s*$", value=None, inplace=True)
try:
model = into.parse_obj(row)
model = into.model_validate(row.to_dict())
echo(f"[parse csv] {into.__name__} {index} OK")
yield model
except ValidationError as error:
Expand Down
2 changes: 1 addition & 1 deletion mex/common/ldap/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _fetch(
)
for entry in entries:
if attributes := entry.get("attributes"):
yield model_cls.parse_obj(attributes)
yield model_cls.model_validate(attributes)

@cache
def _paged_ldap_search(
Expand Down
2 changes: 1 addition & 1 deletion mex/common/ldap/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _get_merged_ids_by_attribute(
Returns:
Mapping from `LDAPPerson[attribute]` to corresponding `Identity.stableTargetId`
"""
if attribute not in LDAPPerson.__fields__:
if attribute not in LDAPPerson.model_fields:
raise RuntimeError(f"Not a valid LDAPPerson field: {attribute}")
merged_ids_by_attribute = defaultdict(list)
provider = get_provider()
Expand Down
2 changes: 1 addition & 1 deletion mex/common/ldap/models/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ class LDAPActor(BaseModel):
@staticmethod
def get_ldap_fields() -> tuple[str, ...]:
"""Return the fields that should be fetched from LDAP."""
return tuple(sorted(LDAPActor.__fields__))
return tuple(sorted(LDAPActor.model_fields))
4 changes: 2 additions & 2 deletions mex/common/ldap/models/person.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class LDAPPerson(LDAPActor):
departmentNumber: str | None = Field(None)
displayName: str | None = Field(None)
employeeID: str = Field(...)
givenName: list[str] = Field(..., min_items=1)
givenName: list[str] = Field(..., min_length=1)
ou: list[str] = Field([])
sn: str = Field(...)

@classmethod
def get_ldap_fields(cls) -> tuple[str, ...]:
"""Return the fields that should be fetched from LDAP."""
return tuple(sorted(cls.__fields__))
return tuple(sorted(cls.model_fields))


class LDAPPersonWithQuery(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion mex/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
"ExtractedContactPoint",
"ExtractedData",
"ExtractedDistribution",
"EXTRACTED_MODEL_CLASSES",
"EXTRACTED_MODEL_CLASSES_BY_NAME",
"ExtractedOrganization",
"ExtractedOrganizationalUnit",
"ExtractedPerson",
Expand All @@ -79,7 +81,7 @@
"MergedContactPoint",
"MergedDistribution",
"MergedItem",
"MergedModel",
"MERGED_MODEL_CLASSES_BY_NAME",
"MergedOrganization",
"MergedOrganizationalUnit",
"MergedPerson",
Expand Down
47 changes: 30 additions & 17 deletions mex/common/models/activity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Annotated

from pydantic import Field

from mex.common.models.base import BaseModel
Expand Down Expand Up @@ -32,20 +34,30 @@ class BaseActivity(BaseModel):

stableTargetId: ActivityID
abstract: list[Text] = []
activityType: list[ActivityType] = Field(
[], examples=["https://mex.rki.de/item/activity-type-1"]
)
activityType: list[
Annotated[
ActivityType, Field(examples=["https://mex.rki.de/item/activity-type-1"])
]
] = []
alternativeTitle: list[Text] = []
contact: list[OrganizationalUnitID | PersonID | ContactPointID] = Field(
contact: list[
Annotated[
OrganizationalUnitID | PersonID | ContactPointID,
Field(examples=[Identifier.generate(seed=42)]),
]
] = Field(
...,
examples=[Identifier.generate(seed=42)],
min_items=1,
min_length=1,
)
documentation: list[Link] = []
end: list[Timestamp] = Field(
[],
examples=["2024-01-17", "2024", "2024-01"],
)
end: list[
Annotated[
Timestamp,
Field(
examples=["2024-01-17", "2024", "2024-01"],
),
]
] = []
externalAssociate: list[OrganizationalUnitID | PersonID] = []
funderOrCommissioner: list[OrganizationID] = []
fundingProgram: list[str] = []
Expand All @@ -55,16 +67,17 @@ class BaseActivity(BaseModel):
publication: list[Link] = []
responsibleUnit: list[OrganizationalUnitID] = Field(
...,
min_items=1,
min_length=1,
)
shortName: list[Text] = []
start: list[Timestamp] = Field(
[],
examples=["2023-01-16", "2023", "2023-02"],
)
start: list[
Annotated[Timestamp, Field(examples=["2023-01-16", "2023", "2023-02"])]
] = []
succeeds: list[ActivityID] = []
theme: list[Theme] = Field([], examples=["https://mex.rki.de/item/theme-1"])
title: list[Text] = Field(..., min_items=1)
theme: list[
Annotated[Theme, Field(examples=["https://mex.rki.de/item/theme-1"])]
] = []
title: list[Text] = Field(..., min_length=1)
website: list[Link] = []


Expand Down
Loading

0 comments on commit 97a1ef1

Please sign in to comment.