Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WeaviateDocumentIngestOperator #36402

Merged
merged 1 commit into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions airflow/providers/weaviate/operators/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,88 @@ def execute(self, context: Context) -> list:
tenant=self.tenant,
)
return insertion_errors


class WeaviateDocumentIngestOperator(BaseOperator):
"""
Create or replace objects belonging to documents.

In real-world scenarios, information sources like Airflow docs, Stack Overflow, or other issues
are considered 'documents' here. It's crucial to keep the database objects in sync with these sources.
If any changes occur in these documents, this function aims to reflect those changes in the database.

.. note::

This function assumes responsibility for identifying changes in documents, dropping relevant
database objects, and recreating them based on updated information. It's crucial to handle this
process with care, ensuring backups and validation are in place to prevent data loss or
inconsistencies.

Provides users with multiple ways of dealing with existing values.
replace: replace the existing objects with new objects. This option requires to identify the
objects belonging to a document. which by default is done by using document_column field.
skip: skip the existing objects and only add the missing objects of a document.
error: raise an error if an object belonging to a existing document is tried to be created.

:param data: A single pandas DataFrame or a list of dicts to be ingested.
:param class_name: Name of the class in Weaviate schema where data is to be ingested.
:param existing: Strategy for handling existing data: 'skip', or 'replace'. Default is 'skip'.
:param document_column: Column in DataFrame that identifying source document.
:param uuid_column: Column with pre-generated UUIDs. If not provided, UUIDs will be generated.
:param vector_column: Column with embedding vectors for pre-embedded data.
:param batch_config_params: Additional parameters for Weaviate batch configuration.
:param tenant: The tenant to which the object will be added.
:param verbose: Flag to enable verbose output during the ingestion process.
:return: list of UUID which failed to create
"""

template_fields: Sequence[str] = ("input_data",)

def __init__(
self,
conn_id: str,
input_data: pd.DataFrame | list[dict[str, Any]] | list[pd.DataFrame],
class_name: str,
document_column: str,
existing: str = "skip",
uuid_column: str = "id",
vector_col: str = "Vector",
batch_config_params: dict | None = None,
tenant: str | None = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
self.hook_params = kwargs.pop("hook_params", {})

super().__init__(**kwargs)

self.conn_id = conn_id
self.input_data = input_data
self.class_name = class_name
self.document_column = document_column
self.existing = existing
self.uuid_column = uuid_column
self.vector_col = vector_col
self.batch_config_params = batch_config_params
self.tenant = tenant
self.verbose = verbose

@cached_property
def hook(self) -> WeaviateHook:
"""Return an instance of the WeaviateHook."""
return WeaviateHook(conn_id=self.conn_id, **self.hook_params)

def execute(self, context: Context) -> list:
self.log.debug("Total input objects : %s", len(self.input_data))
insertion_errors = self.hook.create_or_replace_document_objects(
data=self.input_data,
class_name=self.class_name,
document_column=self.document_column,
existing=self.existing,
uuid_column=self.uuid_column,
vector_column=self.vector_col,
batch_config_params=self.batch_config_params,
tenant=self.tenant,
verbose=self.verbose,
)
return insertion_errors
51 changes: 50 additions & 1 deletion tests/providers/weaviate/operators/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

import pytest

from airflow.providers.weaviate.operators.weaviate import WeaviateIngestOperator
from airflow.providers.weaviate.operators.weaviate import (
WeaviateDocumentIngestOperator,
WeaviateIngestOperator,
)


class TestWeaviateIngestOperator:
Expand Down Expand Up @@ -73,3 +76,49 @@ def test_templates(self, create_task_instance_of_operator):

assert dag_id == ti.task.input_json
assert dag_id == ti.task.input_data


class TestWeaviateDocumentIngestOperator:
@pytest.fixture
def operator(self):
return WeaviateDocumentIngestOperator(
task_id="weaviate_task",
conn_id="weaviate_conn",
input_data=[{"data": "sample_data"}],
class_name="my_class",
document_column="docLink",
existing="skip",
uuid_column="id",
vector_col="vector",
batch_config_params={"size": 1000},
)

def test_constructor(self, operator):
assert operator.conn_id == "weaviate_conn"
assert operator.input_data == [{"data": "sample_data"}]
assert operator.class_name == "my_class"
assert operator.document_column == "docLink"
assert operator.existing == "skip"
assert operator.uuid_column == "id"
assert operator.vector_col == "vector"
assert operator.batch_config_params == {"size": 1000}
assert operator.hook_params == {}

@patch("airflow.providers.weaviate.operators.weaviate.WeaviateDocumentIngestOperator.log")
def test_execute_with_input_json(self, mock_log, operator):
operator.hook.create_or_replace_document_objects = MagicMock()

operator.execute(context=None)

operator.hook.create_or_replace_document_objects.assert_called_once_with(
data=[{"data": "sample_data"}],
class_name="my_class",
document_column="docLink",
existing="skip",
uuid_column="id",
vector_column="vector",
batch_config_params={"size": 1000},
tenant=None,
verbose=False,
)
mock_log.debug.assert_called_once_with("Total input objects : %s", len([{"data": "sample_data"}]))