Skip to content

Commit

Permalink
feature/mx-1708 backend id provider (#366)
Browse files Browse the repository at this point in the history
### PR Context

- this PR reduces code-duplication by moving the backend id provider to
mex-common

### Added

- port backend identity provider implementation from editor/extractors
to common
-
https://github.com/robert-koch-institut/mex-extractors/blob/0.25.0/mex/extractors/identity.py#L11
-
https://github.com/robert-koch-institut/mex-editor/blob/0.8.0/mex/editor/identity.py#L13

### Changes

- allow backend and graph as identity provider setting to simplify
setting subclasses,
  even though graph is not implemented in mex-common
- BREAKING: make backend api connector response models generic, to keep
DRY
  • Loading branch information
cutoffthetop authored Jan 27, 2025
1 parent 4b17737 commit 4f66746
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 52 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- port backend identity provider implementation from editor/extractors to common

### Changes

- allow backend and graph as identity provider setting to simplify setting subclasses,
even though graph is not implemented in mex-common
- BREAKING: make backend api connector response models generic, to keep DRY

### Deprecated

### Removed
Expand Down
26 changes: 13 additions & 13 deletions mex/common/backend_api/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
from requests.exceptions import HTTPError

from mex.common.backend_api.models import (
ExtractedItemsRequest,
ExtractedItemsResponse,
IdentifiersResponse,
MergedItemsResponse,
ItemsContainer,
MergedModelTypeAdapter,
PreviewItemsResponse,
PaginatedItemsContainer,
RuleSetResponseTypeAdapter,
)
from mex.common.connector import HTTPConnector
from mex.common.models import (
AnyExtractedModel,
AnyMergedModel,
AnyPreviewModel,
AnyRuleSetRequest,
AnyRuleSetResponse,
)
Expand Down Expand Up @@ -58,7 +57,7 @@ def post_extracted_items(
response = self.request(
method="POST",
endpoint="ingest",
payload=ExtractedItemsRequest(items=extracted_items),
payload=ItemsContainer[AnyExtractedModel](items=extracted_items),
)
return IdentifiersResponse.model_validate(response)

Expand All @@ -69,7 +68,7 @@ def fetch_extracted_items(
entity_type: list[str] | None,
skip: int,
limit: int,
) -> ExtractedItemsResponse:
) -> PaginatedItemsContainer[AnyExtractedModel]:
"""Fetch extracted items that match the given set of filters.
Args:
Expand All @@ -96,15 +95,15 @@ def fetch_extracted_items(
"limit": str(limit),
},
)
return ExtractedItemsResponse.model_validate(response)
return PaginatedItemsContainer[AnyExtractedModel].model_validate(response)

def fetch_merged_items(
self,
query_string: str | None,
entity_type: list[str] | None,
skip: int,
limit: int,
) -> MergedItemsResponse:
) -> PaginatedItemsContainer[AnyMergedModel]:
"""Fetch merged items that match the given set of filters.
Args:
Expand All @@ -129,7 +128,7 @@ def fetch_merged_items(
"limit": str(limit),
},
)
return MergedItemsResponse.model_validate(response)
return PaginatedItemsContainer[AnyMergedModel].model_validate(response)

def get_merged_item(
self,
Expand All @@ -155,7 +154,9 @@ def get_merged_item(
"limit": "1",
},
)
response_model = MergedItemsResponse.model_validate(response)
response_model = PaginatedItemsContainer[AnyMergedModel].model_validate(
response
)
try:
return response_model.items[0]
except IndexError:
Expand Down Expand Up @@ -192,7 +193,7 @@ def fetch_preview_items(
entity_type: list[str] | None,
skip: int,
limit: int,
) -> PreviewItemsResponse:
) -> PaginatedItemsContainer[AnyPreviewModel]:
"""Fetch merged item previews that match the given set of filters.
Args:
Expand All @@ -207,7 +208,6 @@ def fetch_preview_items(
Returns:
One page of preview items and the total count that was matched
"""
# Note: this is forward-compat for MX-1649, backend might not support this yet!
response = self.request(
method="GET",
endpoint="preview-item",
Expand All @@ -218,7 +218,7 @@ def fetch_preview_items(
"limit": str(limit),
},
)
return PreviewItemsResponse.model_validate(response)
return PaginatedItemsContainer[AnyPreviewModel].model_validate(response)

def get_rule_set(
self,
Expand Down
36 changes: 9 additions & 27 deletions mex/common/backend_api/models.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,23 @@
from typing import Annotated
from typing import Annotated, Generic, TypeVar

from pydantic import Field, TypeAdapter

from mex.common.models import (
AnyExtractedModel,
AnyMergedModel,
AnyPreviewModel,
AnyRuleSetResponse,
BaseModel,
)
from mex.common.models import AnyMergedModel, AnyRuleSetResponse, BaseModel
from mex.common.types import Identifier

T = TypeVar("T")

class ExtractedItemsRequest(BaseModel):
"""Request model for a list of extracted items."""

items: list[AnyExtractedModel]


class ExtractedItemsResponse(BaseModel):
"""Response model for a list of extracted items including a total count."""

items: list[AnyExtractedModel]
total: int

class ItemsContainer(BaseModel, Generic[T]):
"""Generic container that contains items."""

class MergedItemsResponse(BaseModel):
"""Response model for a list of merged items including a total count."""

items: list[AnyMergedModel]
total: int
items: list[T]


class PreviewItemsResponse(BaseModel):
"""Response model for a list of preview items including a total count."""
class PaginatedItemsContainer(BaseModel, Generic[T]):
"""Generic container that contains items and has a total item count."""

items: list[AnyPreviewModel]
items: list[T]
total: int


Expand Down
54 changes: 54 additions & 0 deletions mex/common/identity/backend_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from functools import cache

from mex.common.backend_api.connector import BackendApiConnector
from mex.common.backend_api.models import ItemsContainer
from mex.common.identity.base import BaseProvider
from mex.common.identity.models import Identity
from mex.common.types import Identifier, MergedPrimarySourceIdentifier


class BackendApiIdentityProvider(BaseProvider, BackendApiConnector):
"""Identity provider that communicates with the backend HTTP API."""

@cache # noqa: B019
def assign(
self,
had_primary_source: MergedPrimarySourceIdentifier,
identifier_in_primary_source: str,
) -> Identity:
"""Find an Identity in a database or assign a new one."""
response = self.request(
"POST",
"identity",
{
"hadPrimarySource": had_primary_source,
"identifierInPrimarySource": identifier_in_primary_source,
},
)
return Identity.model_validate(response)

def fetch(
self,
*,
had_primary_source: Identifier | None = None,
identifier_in_primary_source: str | None = None,
stable_target_id: Identifier | None = None,
) -> list[Identity]:
"""Find Identity instances matching the given filters.
Either provide `stableTargetId` or `hadPrimarySource`
and `identifierInPrimarySource` together to get a unique result.
"""
response = self.request(
"GET",
"identity",
params={
key: str(value)
for key, value in [
("hadPrimarySource", had_primary_source),
("identifierInPrimarySource", identifier_in_primary_source),
("stableTargetId", stable_target_id),
]
},
)
return ItemsContainer[Identity].model_validate(response).items
21 changes: 12 additions & 9 deletions mex/common/identity/registry.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from collections.abc import Hashable
from typing import Final
from typing import TYPE_CHECKING, Final

from mex.common.identity.base import BaseProvider
from mex.common.identity.backend_api import BackendApiIdentityProvider
from mex.common.identity.memory import MemoryIdentityProvider
from mex.common.types import IdentityProvider

_PROVIDER_REGISTRY: Final[dict[Hashable, type[BaseProvider]]] = {}
if TYPE_CHECKING:
from mex.common.identity.base import BaseProvider

_PROVIDER_REGISTRY: Final[dict[IdentityProvider, type["BaseProvider"]]] = {}

def register_provider(key: Hashable, provider_cls: type[BaseProvider]) -> None:

def register_provider(
key: IdentityProvider, provider_cls: type["BaseProvider"]
) -> None:
"""Register an implementation of an identity provider to a settings key.
Args:
key: Possible value of `Settings.identity_provider`, this will be of type
`mex.common.identity.types.IdentityProvider` on the `BaseSettings`
but maybe overwritten in other packages that have their own settings
key: Possible value of `BaseSettings.identity_provider`
provider_cls: Implementation of an identity provider
Raises:
Expand All @@ -26,7 +28,7 @@ def register_provider(key: Hashable, provider_cls: type[BaseProvider]) -> None:
_PROVIDER_REGISTRY[key] = provider_cls


def get_provider() -> BaseProvider:
def get_provider() -> "BaseProvider":
"""Get an instance of the identity provider as configured by `identity_provider`.
Raises:
Expand All @@ -48,3 +50,4 @@ def get_provider() -> BaseProvider:

# register the default providers shipped with mex-common
register_provider(IdentityProvider.MEMORY, MemoryIdentityProvider)
register_provider(IdentityProvider.BACKEND, BackendApiIdentityProvider)
2 changes: 2 additions & 0 deletions mex/common/types/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
class IdentityProvider(Enum):
"""Choice of available identity providers."""

BACKEND = "backend"
GRAPH = "graph"
MEMORY = "memory"
10 changes: 7 additions & 3 deletions tests/backend_api/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from requests.exceptions import HTTPError

from mex.common.backend_api.connector import BackendApiConnector
from mex.common.backend_api.models import ExtractedItemsRequest, PreviewItemsResponse
from mex.common.backend_api.models import ItemsContainer, PaginatedItemsContainer
from mex.common.models import (
AnyExtractedModel,
AnyPreviewModel,
ExtractedPerson,
MergedPerson,
PersonRuleSetRequest,
Expand Down Expand Up @@ -45,7 +47,7 @@ def test_post_extracted_items_mocked(
)
assert (
json.loads(mocked_backend.call_args.kwargs["data"])
== ExtractedItemsRequest(items=[extracted_person]).model_dump()
== ItemsContainer[AnyExtractedModel](items=[extracted_person]).model_dump()
)


Expand Down Expand Up @@ -202,7 +204,9 @@ def test_fetch_preview_items_mocked(
mocked_backend: MagicMock,
preview_person: PreviewPerson,
) -> None:
preview_response = PreviewItemsResponse(items=[preview_person], total=92)
preview_response = PaginatedItemsContainer[AnyPreviewModel](
items=[preview_person], total=92
)
mocked_return = preview_response.model_dump()
mocked_backend.return_value.json.return_value = mocked_return

Expand Down
Loading

0 comments on commit 4f66746

Please sign in to comment.