Skip to content

Commit

Permalink
[ISSUE #6548] make all fields nullable except from pk and cursor field
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 committed Mar 15, 2024
1 parent 609607c commit 5aa5ae8
Show file tree
Hide file tree
Showing 6 changed files with 364 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from airbyte_cdk.sources.utils.types import JsonType
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException
from airbyte_protocol.models.airbyte_protocol import (
AirbyteControlMessage,
AirbyteLogMessage,
Expand All @@ -45,6 +45,18 @@ def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit:
self._max_slices = max_slices
self._max_record_limit = max_record_limit

def _to_nested_and_composite_field(self, field: Optional[Union[str, List[str], List[List[str]]]]) -> List[List[str]]:
if not field:
return [[]]

if isinstance(field, str):
return [[field]]

if isinstance(field[0], str):
return [field]

return field

def get_message_groups(
self,
source: DeclarativeSource,
Expand All @@ -54,7 +66,11 @@ def get_message_groups(
) -> StreamRead:
if record_limit is not None and not (1 <= record_limit <= self._max_record_limit):
raise ValueError(f"Record limit must be between 1 and {self._max_record_limit}. Got {record_limit}")
schema_inferrer = SchemaInferrer()
stream = source.streams(config)[0] # The connector builder currently only supports reading from a single stream at a time
schema_inferrer = SchemaInferrer(
self._to_nested_and_composite_field(stream.primary_key),
self._to_nested_and_composite_field(stream.cursor_field),
)
datetime_format_inferrer = DatetimeFormatInferrer()

if record_limit is None:
Expand Down Expand Up @@ -88,14 +104,20 @@ def get_message_groups(
else:
raise ValueError(f"Unknown message group type: {type(message_group)}")

try:
configured_stream = configured_catalog.streams[0] # The connector builder currently only supports reading from a single stream at a time
schema = schema_inferrer.get_stream_schema(configured_stream.stream.name)
except SchemaValidationException as exception:
for validation_error in exception.validation_errors:
log_messages.append(LogMessage(validation_error, "ERROR"))
schema = exception.schema

return StreamRead(
logs=log_messages,
slices=slices,
test_read_limit_reached=self._has_reached_limit(slices),
auxiliary_requests=auxiliary_requests,
inferred_schema=schema_inferrer.get_stream_schema(
configured_catalog.streams[0].stream.name
), # The connector builder currently only supports reading from a single stream at a time
inferred_schema=schema,
latest_config_update=self._clean_config(latest_config_update.connectorConfig.config) if latest_config_update else None,
inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(),
)
Expand Down
67 changes: 50 additions & 17 deletions airbyte-cdk/python/airbyte_cdk/test/catalog_builder.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,62 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

from typing import Any, Dict, List
from typing import List, Union, overload

from airbyte_protocol.models import ConfiguredAirbyteCatalog, SyncMode
from airbyte_protocol.models import ConfiguredAirbyteCatalog, SyncMode, ConfiguredAirbyteStream


class ConfiguredAirbyteStreamBuilder:
def __init__(self) -> None:
self._stream = {
"stream": {
"name": "any name",
"json_schema": {},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_primary_key": [["id"]],
},
"primary_key": [["id"]],
"sync_mode": "full_refresh",
"destination_sync_mode": "overwrite",
}

def with_name(self, name: str) -> "ConfiguredAirbyteStreamBuilder":
self._stream["stream"]["name"] = name # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any]
return self

def with_sync_mode(self, sync_mode: SyncMode) -> "ConfiguredAirbyteStreamBuilder":
self._stream["sync_mode"] = sync_mode.name
return self

def with_primary_key(self, pk: List[List[str]]) -> "ConfiguredAirbyteStreamBuilder":
self._stream["primary_key"] = pk
self._stream["stream"]["source_defined_primary_key"] = pk # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any]
return self

def build(self) -> ConfiguredAirbyteStream:
return ConfiguredAirbyteStream.parse_obj(self._stream)


class CatalogBuilder:
def __init__(self) -> None:
self._streams: List[Dict[str, Any]] = []
self._streams: List[ConfiguredAirbyteStreamBuilder] = []

@overload
def with_stream(self, name: ConfiguredAirbyteStreamBuilder) -> "CatalogBuilder":
...

@overload
def with_stream(self, name: str, sync_mode: SyncMode) -> "CatalogBuilder":
self._streams.append(
{
"stream": {
"name": name,
"json_schema": {},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_primary_key": [["id"]],
},
"primary_key": [["id"]],
"sync_mode": sync_mode.name,
"destination_sync_mode": "overwrite",
}
)
...

def with_stream(self, name: Union[str, ConfiguredAirbyteStreamBuilder], sync_mode: Union[SyncMode, None] = None) -> "CatalogBuilder":
# As we are introducing a fully fledge ConfiguredAirbyteStreamBuilder, we would like to deprecate the previous interface
# with_stream(str, SyncMode)

# to avoid a breaking change, `name` needs to stay in the API but this can be either a name or a builder
name_or_builder = name
builder = name_or_builder if isinstance(name_or_builder, ConfiguredAirbyteStreamBuilder) else ConfiguredAirbyteStreamBuilder().with_name(name_or_builder).with_sync_mode(sync_mode)
self._streams.append(builder)
return self

def build(self) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog.parse_obj({"streams": self._streams})
return ConfiguredAirbyteCatalog(streams=list(map(lambda builder: builder.build(), self._streams)))
116 changes: 102 additions & 14 deletions airbyte-cdk/python/airbyte_cdk/utils/schema_inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

from collections import defaultdict
from typing import Any, Dict, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional

from airbyte_cdk.models import AirbyteRecordMessage
from genson import SchemaBuilder, SchemaNode
Expand Down Expand Up @@ -41,6 +41,26 @@ class NoRequiredSchemaBuilder(SchemaBuilder):
InferredSchema = Dict[str, Any]


class SchemaValidationException(Exception):

@classmethod
def merge_exceptions(cls, exceptions: List["SchemaValidationException"]) -> "SchemaValidationException":
# We assume the schema is the same for all SchemaValidationException
return SchemaValidationException(exceptions[0].schema, [x for exception in exceptions for x in exception.validation_errors])

def __init__(self, schema: InferredSchema, validation_errors: List[Exception]):
self._schema = schema
self._validation_errors = validation_errors

@property
def schema(self) -> InferredSchema:
return self._schema

@property
def validation_errors(self) -> List[str]:
return list(map(lambda error: str(error), self._validation_errors))


class SchemaInferrer:
"""
This class is used to infer a JSON schema which fits all the records passed into it
Expand All @@ -53,23 +73,15 @@ class SchemaInferrer:

stream_to_builder: Dict[str, SchemaBuilder]

def __init__(self) -> None:
def __init__(self, pk: List[List[str]] = None, cursor_field: List[List[str]] = None) -> None:
self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder)
self._pk = pk
self._cursor_field = cursor_field

def accumulate(self, record: AirbyteRecordMessage) -> None:
"""Uses the input record to add to the inferred schemas maintained by this object"""
self.stream_to_builder[record.stream].add_object(record.data)

def get_inferred_schemas(self) -> Dict[str, InferredSchema]:
"""
Returns the JSON schemas for all encountered streams inferred by inspecting all records
passed via the accumulate method
"""
schemas = {}
for stream_name, builder in self.stream_to_builder.items():
schemas[stream_name] = self._clean(builder.to_schema())
return schemas

def _clean(self, node: InferredSchema) -> InferredSchema:
"""
Recursively cleans up a produced schema:
Expand All @@ -81,7 +93,7 @@ def _clean(self, node: InferredSchema) -> InferredSchema:
if len(node["anyOf"]) == 2 and {"type": "null"} in node["anyOf"]:
real_type = node["anyOf"][1] if node["anyOf"][0]["type"] == "null" else node["anyOf"][0]
node.update(real_type)
node["type"] = [node["type"], "null"]
node["type"] = ["null", node["type"]]
node.pop("anyOf")
if "properties" in node and isinstance(node["properties"], dict):
for key, value in list(node["properties"].items()):
Expand All @@ -91,10 +103,86 @@ def _clean(self, node: InferredSchema) -> InferredSchema:
self._clean(value)
if "items" in node:
self._clean(node["items"])

# this check needs to follow the "anyOf" cleaning as it might populate `type`
if isinstance(node["type"], list):
if "null" not in node["type"]:
node["type"] = ["null"] + node["type"]
else:
node["type"] = ["null", node["type"]]
return node

def _add_required_properties(self, node: InferredSchema) -> InferredSchema:
# Removing nullable for the root as when we call `_clean`, we make everything nullable
node["type"] = "object"

exceptions = []
for field in [x for x in [self._pk, self._cursor_field] if x]:
try:
self._add_fields_as_required(node, field)
except SchemaValidationException as exception:
exceptions.append(exception)

if exceptions:
raise SchemaValidationException.merge_exceptions(exceptions)

return node

def _add_fields_as_required(self, node: InferredSchema, composite_keys: List[List[str]]) -> None:
errors: List[Exception] = []

for path in composite_keys:
try:
self._add_field_as_required(node, path)
except ValueError as exception:
errors.append(exception)

if errors:
raise SchemaValidationException(node, errors)

def _remove_null_from_type(self, node: InferredSchema) -> None:
if isinstance(node["type"], list):
if "null" in node["type"]:
node["type"].remove("null")
if len(node["type"]) == 1:
node["type"] = node["type"][0]

def _add_field_as_required(self, node: InferredSchema, path: List[str], traveled_path: List[str] = None) -> None:
if self._is_leaf(path):
self._remove_null_from_type(node)
return

if not traveled_path:
traveled_path = []

if "properties" not in node:
# This validation is only relevant when `traveled_path` is empty oskdfoskfo
raise ValueError(f"Path {traveled_path} does not refer to an object but is `{node}` and hence {path} can't be marked as required.")

next_node = path[0]
if next_node not in node["properties"]:
raise ValueError(f"Path {traveled_path} does not have field `{next_node}` in the schema and hence can't be marked as required.")

if "type" not in node:
# We do not expect this case to happen but we added a specific error message just in case
raise ValueError(f"Unknown schema error: {traveled_path} is expected to have a type but did not. Schema inferrence is probably broken")

if node["type"] not in ["object", ["null", "object"]]:
raise ValueError(f"Path {traveled_path} is expected to be an object but was of type `{node['properties'][next_node]['type']}`")

if "required" not in node or not node["required"]:
node["required"] = [next_node]
elif next_node not in node["required"]:
node["required"].append(next_node)

traveled_path.append(next_node)
self._add_field_as_required(node["properties"][next_node], path[1:], traveled_path)

def _is_leaf(self, path: List[str]) -> bool:
return len(path) == 0

def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]:
"""
Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name.
"""
return self._clean(self.stream_to_builder[stream_name].to_schema()) if stream_name in self.stream_to_builder else None
return self._add_required_properties(self._clean(self.stream_to_builder[stream_name].to_schema())) if stream_name in self.stream_to_builder else None
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,19 @@ def test_config_update():

@patch("traceback.TracebackException.from_exception")
def test_read_returns_error_response(mock_from_exception):
class MockDeclarativeStream:
@property
def primary_key(self):
return [[]]

@property
def cursor_field(self):
return []

class MockManifestDeclarativeSource:
def streams(self, config):
return [MockDeclarativeStream()]

def read(self, logger, config, catalog, state):
raise ValueError("error_message")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from airbyte_cdk.models import Type as MessageType
from unit_tests.connector_builder.utils import create_configured_catalog

_NO_PK = [[]]
_NO_CURSOR_FIELD = [[]]

MAX_PAGES_PER_SLICE = 4
MAX_SLICES = 3

Expand Down Expand Up @@ -96,7 +99,7 @@ def test_get_grouped_messages(mock_entrypoint_read: Mock) -> None:
response = {"status_code": 200, "headers": {"field": "value"}, "body": {"content": '{"name": "field"}'}}
expected_schema = {
"$schema": "http://json-schema.org/schema#",
"properties": {"name": {"type": "string"}, "date": {"type": "string"}},
"properties": {"name": {"type": ["null", "string"]}, "date": {"type": ["null", "string"]}},
"type": "object",
}
expected_datetime_fields = {"date": "%Y-%m-%d"}
Expand Down Expand Up @@ -636,6 +639,44 @@ def test_given_no_slices_then_return_empty_slices(mock_entrypoint_read: Mock) ->
assert len(stream_read.slices) == 0


@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read")
def test_given_pk_then_ensure_pk_is_pass_to_schema_inferrence(mock_entrypoint_read: Mock) -> None:
mock_source = make_mock_source(mock_entrypoint_read, iter([
request_response_log_message({"request": 1}, {"response": 2}, "http://any_url.com"),
record_message("hashiras", {"id": "Shinobu Kocho", "date": "2023-03-03"}),
record_message("hashiras", {"id": "Muichiro Tokito", "date": "2023-03-04"}),
]))
mock_source.streams.return_value = [Mock()]
mock_source.streams.return_value[0].primary_key = [["id"]]
mock_source.streams.return_value[0].cursor_field = _NO_CURSOR_FIELD
connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES)

stream_read: StreamRead = connector_builder_handler.get_message_groups(
source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras")
)

assert stream_read.inferred_schema["required"] == ["id"]


@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read")
def test_given_cursor_field_then_ensure_cursor_field_is_pass_to_schema_inferrence(mock_entrypoint_read: Mock) -> None:
mock_source = make_mock_source(mock_entrypoint_read, iter([
request_response_log_message({"request": 1}, {"response": 2}, "http://any_url.com"),
record_message("hashiras", {"id": "Shinobu Kocho", "date": "2023-03-03"}),
record_message("hashiras", {"id": "Muichiro Tokito", "date": "2023-03-04"}),
]))
mock_source.streams.return_value = [Mock()]
mock_source.streams.return_value[0].primary_key = _NO_PK
mock_source.streams.return_value[0].cursor_field = [["date"]]
connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES)

stream_read: StreamRead = connector_builder_handler.get_message_groups(
source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras")
)

assert stream_read.inferred_schema["required"] == ["date"]


def make_mock_source(mock_entrypoint_read: Mock, return_value: Iterator[AirbyteMessage]) -> MagicMock:
mock_source = MagicMock()
mock_entrypoint_read.return_value = return_value
Expand Down
Loading

0 comments on commit 5aa5ae8

Please sign in to comment.