From 71a778563ef76eaf4be31b339a59824dff615042 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Sun, 17 Dec 2023 12:33:02 -0600 Subject: [PATCH] fix(taps): Respect forced replication method when retrieving state (#2107) --- poetry.lock | 2 +- pyproject.toml | 2 +- singer_sdk/streams/core.py | 6 +++++- tests/core/conftest.py | 15 ++++++++++++- tests/core/test_streams.py | 44 +++++++++++++++++++++++++++----------- 5 files changed, 52 insertions(+), 17 deletions(-) diff --git a/poetry.lock b/poetry.lock index 7ec03b9d5..55c5a32c8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3045,4 +3045,4 @@ testing = ["pytest", "pytest-durations"] [metadata] lock-version = "2.0" python-versions = ">=3.7.1" -content-hash = "0362cfe0889096a98e21a1233e31a2410cd5e68aeb4943f274135d0c3f7b7f53" +content-hash = "6b78196e66b711e201275c89c441d72c0d04820abec15cc5ef29437f617cc71f" diff --git a/pyproject.toml b/pyproject.toml index 7bc8e07ad..c1462169d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ requests = ">=2.25.1" simpleeval = ">=0.9.13" simplejson = ">=3.17.6" sqlalchemy = ">=1.4,<3.0" -typing-extensions = ">=4.2.0" +typing-extensions = ">=4.5.0" # urllib3 2.0 is not compatible with botocore urllib3 = ">=1.26,<2" diff --git a/singer_sdk/streams/core.py b/singer_sdk/streams/core.py index d235987a7..e76d19d80 100644 --- a/singer_sdk/streams/core.py +++ b/singer_sdk/streams/core.py @@ -246,7 +246,11 @@ def get_starting_replication_key_value( """ state = self.get_context_state(context) - return get_starting_replication_value(state) + return ( + get_starting_replication_value(state) + if self.replication_method != REPLICATION_FULL_TABLE + else None + ) def get_starting_timestamp(self, context: dict | None) -> datetime.datetime | None: """Get starting replication timestamp. diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 06355ccfe..97eb76e7f 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -3,9 +3,11 @@ from __future__ import annotations import typing as t +from contextlib import contextmanager import pendulum import pytest +from typing_extensions import override from singer_sdk import Stream, Tap from singer_sdk.typing import ( @@ -32,15 +34,24 @@ def __init__(self, tap: Tap): """Create a new stream.""" super().__init__(tap, schema=self.schema, name=self.name) + @override def get_records( self, - context: dict | None, # noqa: ARG002 + context: dict | None, ) -> t.Iterable[dict[str, t.Any]]: """Generate records.""" yield {"id": 1, "value": "Egypt"} yield {"id": 2, "value": "Germany"} yield {"id": 3, "value": "India"} + @contextmanager + def with_replication_method(self, method: str | None) -> t.Iterator[None]: + """Context manager to temporarily override the replication method.""" + original_method = self.forced_replication_method + self.forced_replication_method = method + yield + self.forced_replication_method = original_method + class UnixTimestampIncrementalStream(SimpleTestStream): name = "unix_ts" @@ -55,6 +66,7 @@ class UnixTimestampIncrementalStream(SimpleTestStream): class UnixTimestampIncrementalStream2(UnixTimestampIncrementalStream): name = "unix_ts_override" + @override def compare_start_date(self, value: str, start_date_value: str) -> str: """Compare a value to a start date value.""" @@ -73,6 +85,7 @@ class SimpleTestTap(Tap): additional_properties=False, ).to_dict() + @override def discover_streams(self) -> list[Stream]: """List all streams.""" return [ diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index 8a415e55d..f3d9aba84 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -46,19 +46,18 @@ def get_next_page_token( response: requests.Response, previous_token: str | None, # noqa: ARG002 ) -> str | None: - if self.next_page_token_jsonpath: - all_matches = extract_jsonpath( - self.next_page_token_jsonpath, - response.json(), - ) - try: - return first(all_matches) - except StopIteration: - return None - - else: + if not self.next_page_token_jsonpath: return response.headers.get("X-Next-Page", None) + all_matches = extract_jsonpath( + self.next_page_token_jsonpath, + response.json(), + ) + try: + return first(all_matches) + except StopIteration: + return None + class GraphqlTestStream(GraphQLStream): """Test Graphql stream class.""" @@ -111,22 +110,32 @@ def test_stream_apply_catalog(stream: Stream): @pytest.mark.parametrize( - "stream_name,bookmark_value,expected_starting_value", + "stream_name,forced_replication_method,bookmark_value,expected_starting_value", [ pytest.param( "test", None, + None, pendulum.parse(CONFIG_START_DATE), id="datetime-repl-key-no-state", ), pytest.param( "test", + None, "2021-02-01", pendulum.datetime(2021, 2, 1), id="datetime-repl-key-recent-bookmark", ), pytest.param( "test", + REPLICATION_FULL_TABLE, + "2021-02-01", + None, + id="datetime-forced-full-table", + ), + pytest.param( + "test", + None, "2020-01-01", pendulum.parse(CONFIG_START_DATE), id="datetime-repl-key-old-bookmark", @@ -134,17 +143,20 @@ def test_stream_apply_catalog(stream: Stream): pytest.param( "unix_ts", None, + None, CONFIG_START_DATE, id="naive-unix-ts-repl-key-no-state", ), pytest.param( "unix_ts", + None, "1612137600", "1612137600", id="naive-unix-ts-repl-key-recent-bookmark", ), pytest.param( "unix_ts", + None, "1577858400", "1577858400", id="naive-unix-ts-repl-key-old-bookmark", @@ -152,17 +164,20 @@ def test_stream_apply_catalog(stream: Stream): pytest.param( "unix_ts_override", None, + None, CONFIG_START_DATE, id="unix-ts-repl-key-no-state", ), pytest.param( "unix_ts_override", + None, "1612137600", "1612137600", id="unix-ts-repl-key-recent-bookmark", ), pytest.param( "unix_ts_override", + None, "1577858400", pendulum.parse(CONFIG_START_DATE).format("X"), id="unix-ts-repl-key-old-bookmark", @@ -172,6 +187,7 @@ def test_stream_apply_catalog(stream: Stream): def test_stream_starting_timestamp( tap: Tap, stream_name: str, + forced_replication_method: str | None, bookmark_value: str, expected_starting_value: t.Any, ): @@ -194,7 +210,9 @@ def test_stream_starting_timestamp( }, ) stream._write_starting_replication_value(None) - assert get_starting_value(None) == expected_starting_value + + with stream.with_replication_method(forced_replication_method): + assert get_starting_value(None) == expected_starting_value def test_stream_invalid_replication_key(tap: SimpleTestTap):