Skip to content

Commit

Permalink
fix(taps): Respect forced replication method when retrieving state (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon authored Dec 17, 2023
1 parent ae5b125 commit 71a7785
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 17 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ requests = ">=2.25.1"
simpleeval = ">=0.9.13"
simplejson = ">=3.17.6"
sqlalchemy = ">=1.4,<3.0"
typing-extensions = ">=4.2.0"
typing-extensions = ">=4.5.0"
# urllib3 2.0 is not compatible with botocore
urllib3 = ">=1.26,<2"

Expand Down
6 changes: 5 additions & 1 deletion singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,11 @@ def get_starting_replication_key_value(
"""
state = self.get_context_state(context)

return get_starting_replication_value(state)
return (
get_starting_replication_value(state)
if self.replication_method != REPLICATION_FULL_TABLE
else None
)

def get_starting_timestamp(self, context: dict | None) -> datetime.datetime | None:
"""Get starting replication timestamp.
Expand Down
15 changes: 14 additions & 1 deletion tests/core/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from __future__ import annotations

import typing as t
from contextlib import contextmanager

import pendulum
import pytest
from typing_extensions import override

from singer_sdk import Stream, Tap
from singer_sdk.typing import (
Expand All @@ -32,15 +34,24 @@ def __init__(self, tap: Tap):
"""Create a new stream."""
super().__init__(tap, schema=self.schema, name=self.name)

@override
def get_records(
self,
context: dict | None, # noqa: ARG002
context: dict | None,
) -> t.Iterable[dict[str, t.Any]]:
"""Generate records."""
yield {"id": 1, "value": "Egypt"}
yield {"id": 2, "value": "Germany"}
yield {"id": 3, "value": "India"}

@contextmanager
def with_replication_method(self, method: str | None) -> t.Iterator[None]:
"""Context manager to temporarily override the replication method."""
original_method = self.forced_replication_method
self.forced_replication_method = method
yield
self.forced_replication_method = original_method


class UnixTimestampIncrementalStream(SimpleTestStream):
name = "unix_ts"
Expand All @@ -55,6 +66,7 @@ class UnixTimestampIncrementalStream(SimpleTestStream):
class UnixTimestampIncrementalStream2(UnixTimestampIncrementalStream):
name = "unix_ts_override"

@override
def compare_start_date(self, value: str, start_date_value: str) -> str:
"""Compare a value to a start date value."""

Expand All @@ -73,6 +85,7 @@ class SimpleTestTap(Tap):
additional_properties=False,
).to_dict()

@override
def discover_streams(self) -> list[Stream]:
"""List all streams."""
return [
Expand Down
44 changes: 31 additions & 13 deletions tests/core/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,18 @@ def get_next_page_token(
response: requests.Response,
previous_token: str | None, # noqa: ARG002
) -> str | None:
if self.next_page_token_jsonpath:
all_matches = extract_jsonpath(
self.next_page_token_jsonpath,
response.json(),
)
try:
return first(all_matches)
except StopIteration:
return None

else:
if not self.next_page_token_jsonpath:
return response.headers.get("X-Next-Page", None)

all_matches = extract_jsonpath(
self.next_page_token_jsonpath,
response.json(),
)
try:
return first(all_matches)
except StopIteration:
return None


class GraphqlTestStream(GraphQLStream):
"""Test Graphql stream class."""
Expand Down Expand Up @@ -111,58 +110,74 @@ def test_stream_apply_catalog(stream: Stream):


@pytest.mark.parametrize(
"stream_name,bookmark_value,expected_starting_value",
"stream_name,forced_replication_method,bookmark_value,expected_starting_value",
[
pytest.param(
"test",
None,
None,
pendulum.parse(CONFIG_START_DATE),
id="datetime-repl-key-no-state",
),
pytest.param(
"test",
None,
"2021-02-01",
pendulum.datetime(2021, 2, 1),
id="datetime-repl-key-recent-bookmark",
),
pytest.param(
"test",
REPLICATION_FULL_TABLE,
"2021-02-01",
None,
id="datetime-forced-full-table",
),
pytest.param(
"test",
None,
"2020-01-01",
pendulum.parse(CONFIG_START_DATE),
id="datetime-repl-key-old-bookmark",
),
pytest.param(
"unix_ts",
None,
None,
CONFIG_START_DATE,
id="naive-unix-ts-repl-key-no-state",
),
pytest.param(
"unix_ts",
None,
"1612137600",
"1612137600",
id="naive-unix-ts-repl-key-recent-bookmark",
),
pytest.param(
"unix_ts",
None,
"1577858400",
"1577858400",
id="naive-unix-ts-repl-key-old-bookmark",
),
pytest.param(
"unix_ts_override",
None,
None,
CONFIG_START_DATE,
id="unix-ts-repl-key-no-state",
),
pytest.param(
"unix_ts_override",
None,
"1612137600",
"1612137600",
id="unix-ts-repl-key-recent-bookmark",
),
pytest.param(
"unix_ts_override",
None,
"1577858400",
pendulum.parse(CONFIG_START_DATE).format("X"),
id="unix-ts-repl-key-old-bookmark",
Expand All @@ -172,6 +187,7 @@ def test_stream_apply_catalog(stream: Stream):
def test_stream_starting_timestamp(
tap: Tap,
stream_name: str,
forced_replication_method: str | None,
bookmark_value: str,
expected_starting_value: t.Any,
):
Expand All @@ -194,7 +210,9 @@ def test_stream_starting_timestamp(
},
)
stream._write_starting_replication_value(None)
assert get_starting_value(None) == expected_starting_value

with stream.with_replication_method(forced_replication_method):
assert get_starting_value(None) == expected_starting_value


def test_stream_invalid_replication_key(tap: SimpleTestTap):
Expand Down

0 comments on commit 71a7785

Please sign in to comment.