Skip to content

Commit

Permalink
feature/mx-1435 computed provenance fields (#205)
Browse files Browse the repository at this point in the history
# PR Context
- this removes the validator that automatically set identifier and
stableTargetId based on the identifierInPrimarySource and
hadPrimarySource attributes of an extracted item
- instead, this logic is moved to computed fields, which allows correct
type signatures for the constructors of extracted items (read:
identifier and stableTargetId are not required anymore when constructing
an extracted item)
- one model validator is still needed though, to allow parsing models
from jsons/dicts where an identifier is already present, but this does
not interfere with the typing

# Migration Guide
- mypy ignores for call-arg and arg-type can be removed for extracted
model instantiation
- calls to `model_construct()` should be avoided, because computed
fields still compute even when validation is skipped: this may result in
issues when provenance fields are missing from the call to construct
- field assignment on extracted models is not supported anymore, try to
collect all values before instanciation or create a copy with the
updated value

# Added
- add validator to verify computed fields can be set but not altered to
base model
- new class hierarchy for identifiers: ExtractedIdentifier and
MergedIdentifier

# Changes
- use json instead of pickle to calculate checksum of models
- swapped `set_identifiers` validator for computed fields on each
extracted model

# Removed
- removed custom stringify method on base entities that included the
`identifier` field

---------

Signed-off-by: Nicolas Drebenstedt <[email protected]>
Co-authored-by: rababerladuseladim <[email protected]>
  • Loading branch information
cutoffthetop and rababerladuseladim authored Jul 12, 2024
1 parent 6fae42c commit fb5947b
Show file tree
Hide file tree
Showing 31 changed files with 376 additions and 307 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- add validator to base model that verifies computed fields can be set but not altered
- new class hierarchy for identifiers: ExtractedIdentifier and MergedIdentifier

### Changes

- improve typing for methods using `Self`
- make local type variables private
- use json instead of pickle to calculate checksum of models
- replace `set_identifiers` validator with computed fields on each extracted model

### Deprecated

### Removed

- removed custom stringify method on base entities that included the `identifier` field

### Fixed

- fix typing for `__eq__` arguments
Expand Down
8 changes: 5 additions & 3 deletions mex/common/backend_api/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from mex.common.backend_api.models import BulkInsertResponse
from mex.common.connector import HTTPConnector
from mex.common.models import ExtractedData
from mex.common.models import AnyExtractedModel
from mex.common.settings import BaseSettings
from mex.common.types import Identifier
from mex.common.types import AnyExtractedIdentifier


class BackendApiConnector(HTTPConnector):
Expand All @@ -27,7 +27,9 @@ def _set_url(self) -> None:
settings = BaseSettings.get()
self.url = urljoin(str(settings.backend_api_url), self.API_VERSION)

def post_models(self, models: list[ExtractedData]) -> list[Identifier]:
def post_models(
self, models: list[AnyExtractedModel]
) -> list[AnyExtractedIdentifier]:
"""Post models to Backend API in a bulk insertion request.
Args:
Expand Down
4 changes: 2 additions & 2 deletions mex/common/backend_api/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from mex.common.models import BaseModel
from mex.common.types import Identifier
from mex.common.types import AnyExtractedIdentifier


class BulkInsertResponse(BaseModel):
"""Response body for the bulk ingestion endpoint."""

identifiers: list[Identifier]
identifiers: list[AnyExtractedIdentifier]
4 changes: 2 additions & 2 deletions mex/common/extract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from collections.abc import Generator
from pathlib import Path
from os import PathLike
from typing import TYPE_CHECKING, Any, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -38,7 +38,7 @@ def get_dtypes_for_model(model: type["BaseModel"]) -> dict[str, "Dtype"]:


def parse_csv(
path_or_buffer: Union[str, Path, "ReadCsvBuffer[Any]"],
path_or_buffer: Union[str, PathLike[str], "ReadCsvBuffer[Any]"],
into: type[_BaseModelT],
chunksize: int = 10,
**kwargs: Any,
Expand Down
4 changes: 2 additions & 2 deletions mex/common/ldap/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def transform_ldap_person_to_mex_person(
f"'{ldap_person.department}' or departmentNumber "
f"'{ldap_person.departmentNumber}'"
)
return ExtractedPerson( # type: ignore[call-arg]
return ExtractedPerson(
identifierInPrimarySource=str(ldap_person.objectGUID),
hadPrimarySource=primary_source.stableTargetId,
affiliation=[], # TODO resolve organization for person.company/RKI
Expand All @@ -132,7 +132,7 @@ def transform_ldap_actor_to_mex_contact_point(
Returns:
Extracted contact point
"""
return ExtractedContactPoint( # type: ignore[call-arg]
return ExtractedContactPoint(
identifierInPrimarySource=str(ldap_actor.objectGUID),
hadPrimarySource=primary_source.stableTargetId,
email=ldap_actor.mail,
Expand Down
16 changes: 13 additions & 3 deletions mex/common/models/access_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Annotated, ClassVar, Literal

from pydantic import Field
from pydantic import Field, computed_field

from mex.common.models.base import BaseModel
from mex.common.models.extracted_data import ExtractedData
Expand Down Expand Up @@ -92,8 +92,18 @@ class ExtractedAccessPlatform(BaseAccessPlatform, ExtractedData):
entityType: Annotated[
Literal["ExtractedAccessPlatform"], Field(alias="$type", frozen=True)
] = "ExtractedAccessPlatform"
identifier: Annotated[ExtractedAccessPlatformIdentifier, Field(frozen=True)]
stableTargetId: MergedAccessPlatformIdentifier

@computed_field # type: ignore[misc]
@property
def identifier(self) -> ExtractedAccessPlatformIdentifier:
"""Return the computed identifier for this extracted data item."""
return self._get_identifier(ExtractedAccessPlatformIdentifier)

@computed_field # type: ignore[misc]
@property
def stableTargetId(self) -> MergedAccessPlatformIdentifier: # noqa: N802
"""Return the computed stableTargetId for this extracted data item."""
return self._get_stable_target_id(MergedAccessPlatformIdentifier)


class MergedAccessPlatform(BaseAccessPlatform, MergedItem):
Expand Down
16 changes: 13 additions & 3 deletions mex/common/models/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Annotated, ClassVar, Literal

from pydantic import Field
from pydantic import Field, computed_field

from mex.common.models.base import BaseModel
from mex.common.models.extracted_data import ExtractedData
Expand Down Expand Up @@ -97,8 +97,18 @@ class ExtractedActivity(BaseActivity, ExtractedData):
entityType: Annotated[
Literal["ExtractedActivity"], Field(alias="$type", frozen=True)
] = "ExtractedActivity"
identifier: Annotated[ExtractedActivityIdentifier, Field(frozen=True)]
stableTargetId: MergedActivityIdentifier

@computed_field # type: ignore[misc]
@property
def identifier(self) -> ExtractedActivityIdentifier:
"""Return the computed identifier for this extracted data item."""
return self._get_identifier(ExtractedActivityIdentifier)

@computed_field # type: ignore[misc]
@property
def stableTargetId(self) -> MergedActivityIdentifier: # noqa: N802
"""Return the computed stableTargetId for this extracted data item."""
return self._get_stable_target_id(MergedActivityIdentifier)


class MergedActivity(BaseActivity, MergedItem):
Expand Down
59 changes: 47 additions & 12 deletions mex/common/models/base.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
import hashlib
import pickle # nosec
import json
from collections.abc import MutableMapping
from functools import cache
from types import UnionType
from typing import (
Any,
TypeVar,
Union,
)
from typing import Any, Union

from pydantic import BaseModel as PydanticBaseModel
from pydantic import (
BaseModel as PydanticBaseModel,
)
from pydantic import (
ConfigDict,
TypeAdapter,
ValidationError,
ValidatorFunctionWrapHandler,
model_validator,
)
from pydantic.json_schema import DEFAULT_REF_TEMPLATE, JsonSchemaMode
from pydantic.json_schema import GenerateJsonSchema as PydanticJsonSchemaGenerator

from mex.common.models.schema import JsonSchemaGenerator
from mex.common.transform import MExEncoder
from mex.common.utils import get_inner_types

_RawModelDataT = TypeVar("_RawModelDataT")


class BaseModel(PydanticBaseModel):
"""Common base class for all MEx model classes."""
Expand Down Expand Up @@ -142,7 +140,7 @@ def _fix_value_listyness_for_field(cls, field_name: str, value: Any) -> Any:

@model_validator(mode="before")
@classmethod
def fix_listyness(cls, data: _RawModelDataT) -> _RawModelDataT:
def fix_listyness(cls, data: Any) -> Any:
"""Adjust the listyness of to-be-parsed data to match the desired shape.
If that data is a Mapping and the model defines a list[T] field but the raw data
Expand All @@ -156,7 +154,7 @@ def fix_listyness(cls, data: _RawModelDataT) -> _RawModelDataT:
entry however, an error is raised, because we would not know which to choose.
Args:
data: Raw data to be parsed
data: Raw data or instance to be parsed
Returns:
data with fixed list shapes
Expand All @@ -168,9 +166,46 @@ def fix_listyness(cls, data: _RawModelDataT) -> _RawModelDataT:
data[name] = cls._fix_value_listyness_for_field(field_name, value)
return data

@model_validator(mode="wrap")
def verify_computed_field_consistency(
cls, data: Any, handler: ValidatorFunctionWrapHandler
) -> Any:
"""Validate that parsed values for computed fields are consistent.
Parsing a dictionary with a value for a computed field that is consistent with
what that field would have computed anyway is allowed. Omitting values for
computed fields is perfectly valid as well. However, if the parsed value is
different from the computed value, a validation error is raised.
Args:
data: Raw data or instance to be parsed
handler: Validator function wrap handler
Returns:
data with consistent computed fields.
"""
if not cls.model_computed_fields:
return handler(data)
if not isinstance(data, MutableMapping):
raise AssertionError(
"Input should be a valid dictionary, validating other types is not "
"supported for models with computed fields."
)
custom_values = {
field: value
for field in cls.model_computed_fields
if (value := data.pop(field, None))
}
result = handler(data)
computed_values = result.model_dump(include=set(custom_values))
if computed_values != custom_values:
raise ValueError("Cannot set computed fields to custom values!")
return result

def checksum(self) -> str:
"""Calculate md5 checksum for this model."""
return hashlib.md5(pickle.dumps(self)).hexdigest() # noqa: S324
json_str = json.dumps(self, sort_keys=True, cls=MExEncoder)
return hashlib.md5(json_str.encode()).hexdigest() # noqa: S324

def __str__(self) -> str:
"""Format this model as a string for logging."""
Expand Down
16 changes: 13 additions & 3 deletions mex/common/models/contact_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Annotated, ClassVar, Literal

from pydantic import Field
from pydantic import Field, computed_field

from mex.common.models.base import BaseModel
from mex.common.models.extracted_data import ExtractedData
Expand Down Expand Up @@ -44,8 +44,18 @@ class ExtractedContactPoint(BaseContactPoint, ExtractedData):
entityType: Annotated[
Literal["ExtractedContactPoint"], Field(alias="$type", frozen=True)
] = "ExtractedContactPoint"
identifier: Annotated[ExtractedContactPointIdentifier, Field(frozen=True)]
stableTargetId: MergedContactPointIdentifier

@computed_field # type: ignore[misc]
@property
def identifier(self) -> ExtractedContactPointIdentifier:
"""Return the computed identifier for this extracted data item."""
return self._get_identifier(ExtractedContactPointIdentifier)

@computed_field # type: ignore[misc]
@property
def stableTargetId(self) -> MergedContactPointIdentifier: # noqa: N802
"""Return the computed stableTargetId for this extracted data item."""
return self._get_stable_target_id(MergedContactPointIdentifier)


class MergedContactPoint(BaseContactPoint, MergedItem):
Expand Down
16 changes: 13 additions & 3 deletions mex/common/models/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Annotated, ClassVar, Literal

from pydantic import Field
from pydantic import Field, computed_field

from mex.common.models.base import BaseModel
from mex.common.models.extracted_data import ExtractedData
Expand Down Expand Up @@ -136,8 +136,18 @@ class ExtractedDistribution(BaseDistribution, ExtractedData):
entityType: Annotated[
Literal["ExtractedDistribution"], Field(alias="$type", frozen=True)
] = "ExtractedDistribution"
identifier: Annotated[ExtractedDistributionIdentifier, Field(frozen=True)]
stableTargetId: MergedDistributionIdentifier

@computed_field # type: ignore[misc]
@property
def identifier(self) -> ExtractedDistributionIdentifier:
"""Return the computed identifier for this extracted data item."""
return self._get_identifier(ExtractedDistributionIdentifier)

@computed_field # type: ignore[misc]
@property
def stableTargetId(self) -> MergedDistributionIdentifier: # noqa: N802
"""Return the computed stableTargetId for this extracted data item."""
return self._get_stable_target_id(MergedDistributionIdentifier)


class MergedDistribution(BaseDistribution, MergedItem):
Expand Down
19 changes: 7 additions & 12 deletions mex/common/models/entity.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import TYPE_CHECKING, ClassVar

from pydantic import ConfigDict

from mex.common.models.base import BaseModel
from mex.common.types import Identifier


class BaseEntity(BaseModel, extra="forbid"):
class BaseEntity(BaseModel):
"""Abstract base model for extracted data, merged item and rule set classes.
This class gives type hints for an `identifier` field, the frozen `entityType` field
Expand All @@ -13,6 +14,10 @@ class BaseEntity(BaseModel, extra="forbid"):
type as well as the correct literal values for the entity and stem types.
"""

model_config = ConfigDict(
extra="forbid",
)

if TYPE_CHECKING: # pragma: no cover
# The frozen `entityType` field is added to all `BaseEntity` subclasses to
# help with assigning the correct class when reading raw JSON entities.
Expand All @@ -26,13 +31,3 @@ class BaseEntity(BaseModel, extra="forbid"):
# type of items. E.g. `ExtractedPerson`, `MergedPerson` and `PreventivePerson`
# all share the same `stemType` of `Person`.
stemType: ClassVar

# A globally unique identifier is added to all `BaseEntity` subclasses and
# should be typed to the correct identifier type. Regardless of the entity-type
# or whether this item was extracted, merged, etc., identifiers will be assigned
# just once and should be declared as `frozen` on subclasses.
identifier: Identifier

def __str__(self) -> str:
"""Format this instance as a string for logging."""
return f"{self.entityType}: {self.identifier}"
Loading

0 comments on commit fb5947b

Please sign in to comment.