From 4f66746a5f47d3eda7261cafa8c54718e7018ca7 Mon Sep 17 00:00:00 2001 From: Nicolas Drebenstedt <897972+cutoffthetop@users.noreply.github.com> Date: Mon, 27 Jan 2025 17:13:29 +0100 Subject: [PATCH] feature/mx-1708 backend id provider (#366) ### PR Context - this PR reduces code-duplication by moving the backend id provider to mex-common ### Added - port backend identity provider implementation from editor/extractors to common - https://github.com/robert-koch-institut/mex-extractors/blob/0.25.0/mex/extractors/identity.py#L11 - https://github.com/robert-koch-institut/mex-editor/blob/0.8.0/mex/editor/identity.py#L13 ### Changes - allow backend and graph as identity provider setting to simplify setting subclasses, even though graph is not implemented in mex-common - BREAKING: make backend api connector response models generic, to keep DRY --- CHANGELOG.md | 6 ++ mex/common/backend_api/connector.py | 26 +++---- mex/common/backend_api/models.py | 36 +++------ mex/common/identity/backend_api.py | 54 ++++++++++++++ mex/common/identity/registry.py | 21 +++--- mex/common/types/identity.py | 2 + tests/backend_api/test_connector.py | 10 ++- tests/identity/test_backend_api.py | 112 ++++++++++++++++++++++++++++ 8 files changed, 215 insertions(+), 52 deletions(-) create mode 100644 mex/common/identity/backend_api.py create mode 100644 tests/identity/test_backend_api.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4afd5bdc..0569bb65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- port backend identity provider implementation from editor/extractors to common + ### Changes +- allow backend and graph as identity provider setting to simplify setting subclasses, + even though graph is not implemented in mex-common +- BREAKING: make backend api connector response models generic, to keep DRY + ### Deprecated ### Removed diff --git a/mex/common/backend_api/connector.py b/mex/common/backend_api/connector.py index 2e8cb750..351eeeaf 100644 --- a/mex/common/backend_api/connector.py +++ b/mex/common/backend_api/connector.py @@ -3,18 +3,17 @@ from requests.exceptions import HTTPError from mex.common.backend_api.models import ( - ExtractedItemsRequest, - ExtractedItemsResponse, IdentifiersResponse, - MergedItemsResponse, + ItemsContainer, MergedModelTypeAdapter, - PreviewItemsResponse, + PaginatedItemsContainer, RuleSetResponseTypeAdapter, ) from mex.common.connector import HTTPConnector from mex.common.models import ( AnyExtractedModel, AnyMergedModel, + AnyPreviewModel, AnyRuleSetRequest, AnyRuleSetResponse, ) @@ -58,7 +57,7 @@ def post_extracted_items( response = self.request( method="POST", endpoint="ingest", - payload=ExtractedItemsRequest(items=extracted_items), + payload=ItemsContainer[AnyExtractedModel](items=extracted_items), ) return IdentifiersResponse.model_validate(response) @@ -69,7 +68,7 @@ def fetch_extracted_items( entity_type: list[str] | None, skip: int, limit: int, - ) -> ExtractedItemsResponse: + ) -> PaginatedItemsContainer[AnyExtractedModel]: """Fetch extracted items that match the given set of filters. Args: @@ -96,7 +95,7 @@ def fetch_extracted_items( "limit": str(limit), }, ) - return ExtractedItemsResponse.model_validate(response) + return PaginatedItemsContainer[AnyExtractedModel].model_validate(response) def fetch_merged_items( self, @@ -104,7 +103,7 @@ def fetch_merged_items( entity_type: list[str] | None, skip: int, limit: int, - ) -> MergedItemsResponse: + ) -> PaginatedItemsContainer[AnyMergedModel]: """Fetch merged items that match the given set of filters. Args: @@ -129,7 +128,7 @@ def fetch_merged_items( "limit": str(limit), }, ) - return MergedItemsResponse.model_validate(response) + return PaginatedItemsContainer[AnyMergedModel].model_validate(response) def get_merged_item( self, @@ -155,7 +154,9 @@ def get_merged_item( "limit": "1", }, ) - response_model = MergedItemsResponse.model_validate(response) + response_model = PaginatedItemsContainer[AnyMergedModel].model_validate( + response + ) try: return response_model.items[0] except IndexError: @@ -192,7 +193,7 @@ def fetch_preview_items( entity_type: list[str] | None, skip: int, limit: int, - ) -> PreviewItemsResponse: + ) -> PaginatedItemsContainer[AnyPreviewModel]: """Fetch merged item previews that match the given set of filters. Args: @@ -207,7 +208,6 @@ def fetch_preview_items( Returns: One page of preview items and the total count that was matched """ - # Note: this is forward-compat for MX-1649, backend might not support this yet! response = self.request( method="GET", endpoint="preview-item", @@ -218,7 +218,7 @@ def fetch_preview_items( "limit": str(limit), }, ) - return PreviewItemsResponse.model_validate(response) + return PaginatedItemsContainer[AnyPreviewModel].model_validate(response) def get_rule_set( self, diff --git a/mex/common/backend_api/models.py b/mex/common/backend_api/models.py index 2b37ddbf..63593bb3 100644 --- a/mex/common/backend_api/models.py +++ b/mex/common/backend_api/models.py @@ -1,41 +1,23 @@ -from typing import Annotated +from typing import Annotated, Generic, TypeVar from pydantic import Field, TypeAdapter -from mex.common.models import ( - AnyExtractedModel, - AnyMergedModel, - AnyPreviewModel, - AnyRuleSetResponse, - BaseModel, -) +from mex.common.models import AnyMergedModel, AnyRuleSetResponse, BaseModel from mex.common.types import Identifier +T = TypeVar("T") -class ExtractedItemsRequest(BaseModel): - """Request model for a list of extracted items.""" - - items: list[AnyExtractedModel] - - -class ExtractedItemsResponse(BaseModel): - """Response model for a list of extracted items including a total count.""" - - items: list[AnyExtractedModel] - total: int +class ItemsContainer(BaseModel, Generic[T]): + """Generic container that contains items.""" -class MergedItemsResponse(BaseModel): - """Response model for a list of merged items including a total count.""" - - items: list[AnyMergedModel] - total: int + items: list[T] -class PreviewItemsResponse(BaseModel): - """Response model for a list of preview items including a total count.""" +class PaginatedItemsContainer(BaseModel, Generic[T]): + """Generic container that contains items and has a total item count.""" - items: list[AnyPreviewModel] + items: list[T] total: int diff --git a/mex/common/identity/backend_api.py b/mex/common/identity/backend_api.py new file mode 100644 index 00000000..d1dcb757 --- /dev/null +++ b/mex/common/identity/backend_api.py @@ -0,0 +1,54 @@ +from functools import cache + +from mex.common.backend_api.connector import BackendApiConnector +from mex.common.backend_api.models import ItemsContainer +from mex.common.identity.base import BaseProvider +from mex.common.identity.models import Identity +from mex.common.types import Identifier, MergedPrimarySourceIdentifier + + +class BackendApiIdentityProvider(BaseProvider, BackendApiConnector): + """Identity provider that communicates with the backend HTTP API.""" + + @cache # noqa: B019 + def assign( + self, + had_primary_source: MergedPrimarySourceIdentifier, + identifier_in_primary_source: str, + ) -> Identity: + """Find an Identity in a database or assign a new one.""" + response = self.request( + "POST", + "identity", + { + "hadPrimarySource": had_primary_source, + "identifierInPrimarySource": identifier_in_primary_source, + }, + ) + return Identity.model_validate(response) + + def fetch( + self, + *, + had_primary_source: Identifier | None = None, + identifier_in_primary_source: str | None = None, + stable_target_id: Identifier | None = None, + ) -> list[Identity]: + """Find Identity instances matching the given filters. + + Either provide `stableTargetId` or `hadPrimarySource` + and `identifierInPrimarySource` together to get a unique result. + """ + response = self.request( + "GET", + "identity", + params={ + key: str(value) + for key, value in [ + ("hadPrimarySource", had_primary_source), + ("identifierInPrimarySource", identifier_in_primary_source), + ("stableTargetId", stable_target_id), + ] + }, + ) + return ItemsContainer[Identity].model_validate(response).items diff --git a/mex/common/identity/registry.py b/mex/common/identity/registry.py index 0525aa9f..6438e647 100644 --- a/mex/common/identity/registry.py +++ b/mex/common/identity/registry.py @@ -1,20 +1,22 @@ -from collections.abc import Hashable -from typing import Final +from typing import TYPE_CHECKING, Final -from mex.common.identity.base import BaseProvider +from mex.common.identity.backend_api import BackendApiIdentityProvider from mex.common.identity.memory import MemoryIdentityProvider from mex.common.types import IdentityProvider -_PROVIDER_REGISTRY: Final[dict[Hashable, type[BaseProvider]]] = {} +if TYPE_CHECKING: + from mex.common.identity.base import BaseProvider +_PROVIDER_REGISTRY: Final[dict[IdentityProvider, type["BaseProvider"]]] = {} -def register_provider(key: Hashable, provider_cls: type[BaseProvider]) -> None: + +def register_provider( + key: IdentityProvider, provider_cls: type["BaseProvider"] +) -> None: """Register an implementation of an identity provider to a settings key. Args: - key: Possible value of `Settings.identity_provider`, this will be of type - `mex.common.identity.types.IdentityProvider` on the `BaseSettings` - but maybe overwritten in other packages that have their own settings + key: Possible value of `BaseSettings.identity_provider` provider_cls: Implementation of an identity provider Raises: @@ -26,7 +28,7 @@ def register_provider(key: Hashable, provider_cls: type[BaseProvider]) -> None: _PROVIDER_REGISTRY[key] = provider_cls -def get_provider() -> BaseProvider: +def get_provider() -> "BaseProvider": """Get an instance of the identity provider as configured by `identity_provider`. Raises: @@ -48,3 +50,4 @@ def get_provider() -> BaseProvider: # register the default providers shipped with mex-common register_provider(IdentityProvider.MEMORY, MemoryIdentityProvider) +register_provider(IdentityProvider.BACKEND, BackendApiIdentityProvider) diff --git a/mex/common/types/identity.py b/mex/common/types/identity.py index 50a960a6..6179efc1 100644 --- a/mex/common/types/identity.py +++ b/mex/common/types/identity.py @@ -4,4 +4,6 @@ class IdentityProvider(Enum): """Choice of available identity providers.""" + BACKEND = "backend" + GRAPH = "graph" MEMORY = "memory" diff --git a/tests/backend_api/test_connector.py b/tests/backend_api/test_connector.py index 1b7fdf62..296f4284 100644 --- a/tests/backend_api/test_connector.py +++ b/tests/backend_api/test_connector.py @@ -5,8 +5,10 @@ from requests.exceptions import HTTPError from mex.common.backend_api.connector import BackendApiConnector -from mex.common.backend_api.models import ExtractedItemsRequest, PreviewItemsResponse +from mex.common.backend_api.models import ItemsContainer, PaginatedItemsContainer from mex.common.models import ( + AnyExtractedModel, + AnyPreviewModel, ExtractedPerson, MergedPerson, PersonRuleSetRequest, @@ -45,7 +47,7 @@ def test_post_extracted_items_mocked( ) assert ( json.loads(mocked_backend.call_args.kwargs["data"]) - == ExtractedItemsRequest(items=[extracted_person]).model_dump() + == ItemsContainer[AnyExtractedModel](items=[extracted_person]).model_dump() ) @@ -202,7 +204,9 @@ def test_fetch_preview_items_mocked( mocked_backend: MagicMock, preview_person: PreviewPerson, ) -> None: - preview_response = PreviewItemsResponse(items=[preview_person], total=92) + preview_response = PaginatedItemsContainer[AnyPreviewModel]( + items=[preview_person], total=92 + ) mocked_return = preview_response.model_dump() mocked_backend.return_value.json.return_value = mocked_return diff --git a/tests/identity/test_backend_api.py b/tests/identity/test_backend_api.py new file mode 100644 index 00000000..d66de358 --- /dev/null +++ b/tests/identity/test_backend_api.py @@ -0,0 +1,112 @@ +from unittest.mock import MagicMock, Mock + +import pytest +import requests +from pytest import MonkeyPatch + +from mex.common.identity import Identity +from mex.common.identity.backend_api import BackendApiIdentityProvider +from mex.common.models import ExtractedContactPoint +from mex.common.types import MergedPrimarySourceIdentifier + + +@pytest.fixture +def mocked_backend_identity_provider(monkeypatch: MonkeyPatch) -> MagicMock: + mocked_session = MagicMock(spec=requests.Session) + mocked_session.request = MagicMock( + return_value=Mock(spec=requests.Response, status_code=200) + ) + mocked_session.headers = {} + + def set_mocked_session(self: BackendApiIdentityProvider) -> None: + self.session = mocked_session + + monkeypatch.setattr(BackendApiIdentityProvider, "_set_session", set_mocked_session) + return mocked_session + + +def test_assign_mocked( + mocked_backend_identity_provider: requests.Session, +) -> None: + mocked_data = { + "identifier": MergedPrimarySourceIdentifier.generate(seed=962), + "hadPrimarySource": MergedPrimarySourceIdentifier.generate(seed=961), + "identifierInPrimarySource": "test", + "stableTargetId": MergedPrimarySourceIdentifier.generate(seed=963), + } + mocked_response = Mock(spec=requests.Response) + mocked_response.status_code = 200 + mocked_response.json = MagicMock(return_value=mocked_data) + mocked_backend_identity_provider.request = MagicMock(return_value=mocked_response) + + provider = BackendApiIdentityProvider.get() + identity_first = provider.assign( + had_primary_source=MergedPrimarySourceIdentifier.generate(seed=961), + identifier_in_primary_source="test", + ) + + identity = Identity.model_validate(identity_first) + + identity_first_assignment = identity.model_dump() + + assert identity_first_assignment == mocked_data + + identity_second = provider.assign( + had_primary_source=MergedPrimarySourceIdentifier.generate(seed=961), + identifier_in_primary_source="test", + ) + identity_second_assignment = identity_second.model_dump() + + assert identity_second_assignment == identity_first_assignment + + +def test_fetch_mocked( + mocked_backend_identity_provider: requests.Session, +) -> None: + mocked_data = { + "items": [ + { + "identifier": MergedPrimarySourceIdentifier.generate(seed=962), + "hadPrimarySource": MergedPrimarySourceIdentifier.generate(seed=961), + "identifierInPrimarySource": "test", + "stableTargetId": MergedPrimarySourceIdentifier.generate(seed=963), + } + ], + "total": 1, + } + + mocked_response = Mock(spec=requests.Response) + mocked_response.status_code = 200 + mocked_response.json = MagicMock(return_value=mocked_data) + mocked_backend_identity_provider.request = MagicMock(return_value=mocked_response) + + provider = BackendApiIdentityProvider.get() + + contact_point = ExtractedContactPoint( + hadPrimarySource=MergedPrimarySourceIdentifier.generate(seed=961), + identifierInPrimarySource="test", + email=["test@test.de"], + ) + + identities = provider.fetch(stable_target_id=contact_point.stableTargetId) + assert identities == [ + Identity( + stableTargetId=mocked_data["items"][0]["stableTargetId"], + identifier=mocked_data["items"][0]["identifier"], + hadPrimarySource=contact_point.hadPrimarySource, + identifierInPrimarySource=contact_point.identifierInPrimarySource, + ) + ] + + identities = provider.fetch( + had_primary_source=contact_point.hadPrimarySource, + identifier_in_primary_source=contact_point.identifierInPrimarySource, + ) + assert identities == [ + Identity( + stableTargetId=mocked_data["items"][0]["stableTargetId"], + identifier=mocked_data["items"][0]["identifier"], + hadPrimarySource=contact_point.hadPrimarySource, + identifierInPrimarySource=contact_point.identifierInPrimarySource, + ) + ]