Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add s3 remote storage export for duckdb #4195

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions sdk/python/feast/infra/offline_stores/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def _read_data_source(data_source: DataSource) -> Table:
if isinstance(data_source.file_format, ParquetFormat):
return ibis.read_parquet(data_source.path)
elif isinstance(data_source.file_format, DeltaFormat):
return ibis.read_delta(data_source.path)
storage_options = {
"AWS_ENDPOINT_URL": data_source.s3_endpoint_override,
}

return ibis.read_delta(data_source.path, storage_options=storage_options)


def _write_data_source(
Expand Down Expand Up @@ -72,10 +76,18 @@ def _write_data_source(
new_table = pyarrow.concat_tables([table, prev_table])
ibis.memtable(new_table).to_parquet(file_options.uri)
elif isinstance(data_source.file_format, DeltaFormat):
storage_options = {
"AWS_ENDPOINT_URL": str(data_source.s3_endpoint_override),
}

if mode == "append":
from deltalake import DeltaTable

prev_schema = DeltaTable(file_options.uri).schema().to_pyarrow()
prev_schema = (
DeltaTable(file_options.uri, storage_options=storage_options)
.schema()
.to_pyarrow()
)
table = table.cast(ibis.Schema.from_pyarrow(prev_schema))
write_mode = "append"
elif mode == "overwrite":
Expand All @@ -85,13 +97,19 @@ def _write_data_source(
else "error"
)

table.to_delta(file_options.uri, mode=write_mode)
table.to_delta(
file_options.uri, mode=write_mode, storage_options=storage_options
)


class DuckDBOfflineStoreConfig(FeastConfigBaseModel):
type: StrictStr = "duckdb"
# """ Offline store type selector"""

staging_location: Optional[str] = None

staging_location_endpoint_override: Optional[str] = None


class DuckDBOfflineStore(OfflineStore):
@staticmethod
Expand All @@ -116,6 +134,8 @@ def pull_latest_from_table_or_query(
end_date=end_date,
data_source_reader=_read_data_source,
data_source_writer=_write_data_source,
staging_location=config.offline_store.staging_location,
staging_location_endpoint_override=config.offline_store.staging_location_endpoint_override,
)

@staticmethod
Expand All @@ -138,6 +158,8 @@ def get_historical_features(
full_feature_names=full_feature_names,
data_source_reader=_read_data_source,
data_source_writer=_write_data_source,
staging_location=config.offline_store.staging_location,
staging_location_endpoint_override=config.offline_store.staging_location_endpoint_override,
)

@staticmethod
Expand All @@ -160,6 +182,8 @@ def pull_all_from_table_or_query(
end_date=end_date,
data_source_reader=_read_data_source,
data_source_writer=_write_data_source,
staging_location=config.offline_store.staging_location,
staging_location_endpoint_override=config.offline_store.staging_location_endpoint_override,
)

@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,15 @@ def get_table_column_names_and_types(
elif isinstance(self.file_format, DeltaFormat):
from deltalake import DeltaTable

schema = DeltaTable(self.path).schema().to_pyarrow()
storage_options = {
"AWS_ENDPOINT_URL": str(self.s3_endpoint_override),
}

schema = (
DeltaTable(self.path, storage_options=storage_options)
.schema()
.to_pyarrow()
)
else:
raise Exception(f"Unknown FileFormat -> {self.file_format}")

Expand Down
45 changes: 45 additions & 0 deletions sdk/python/feast/infra/offline_stores/ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def pull_latest_from_table_or_query_ibis(
end_date: datetime,
data_source_reader: Callable[[DataSource], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
staging_location: Optional[str] = None,
staging_location_endpoint_override: Optional[str] = None,
) -> RetrievalJob:
fields = join_key_columns + feature_name_columns + [timestamp_field]
if created_timestamp_column:
Expand Down Expand Up @@ -82,6 +84,8 @@ def pull_latest_from_table_or_query_ibis(
full_feature_names=False,
metadata=None,
data_source_writer=data_source_writer,
staging_location=staging_location,
staging_location_endpoint_override=staging_location_endpoint_override,
)


Expand Down Expand Up @@ -140,6 +144,8 @@ def get_historical_features_ibis(
data_source_reader: Callable[[DataSource], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
full_feature_names: bool = False,
staging_location: Optional[str] = None,
staging_location_endpoint_override: Optional[str] = None,
) -> RetrievalJob:
entity_schema = _get_entity_schema(
entity_df=entity_df,
Expand Down Expand Up @@ -231,6 +237,8 @@ def read_fv(
max_event_timestamp=timestamp_range[1],
),
data_source_writer=data_source_writer,
staging_location=staging_location,
staging_location_endpoint_override=staging_location_endpoint_override,
)


Expand All @@ -244,6 +252,8 @@ def pull_all_from_table_or_query_ibis(
end_date: datetime,
data_source_reader: Callable[[DataSource], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
staging_location: Optional[str] = None,
staging_location_endpoint_override: Optional[str] = None,
) -> RetrievalJob:
fields = join_key_columns + feature_name_columns + [timestamp_field]
start_date = start_date.astimezone(tz=utc)
Expand All @@ -270,6 +280,8 @@ def pull_all_from_table_or_query_ibis(
full_feature_names=False,
metadata=None,
data_source_writer=data_source_writer,
staging_location=staging_location,
staging_location_endpoint_override=staging_location_endpoint_override,
)


Expand Down Expand Up @@ -411,6 +423,23 @@ def point_in_time_join(
return acc_table


def list_s3_files(path: str, endpoint_url: str) -> List[str]:
import boto3

s3 = boto3.client("s3", endpoint_url=endpoint_url)
if path.startswith("s3://"):
path = path[len("s3://") :]
bucket, prefix = path.split("/", 1)
objects = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
contents = objects["Contents"]
files = [
f"s3://{bucket}/{content['Key']}"
for content in contents
if content["Key"].endswith("parquet")
]
return files


class IbisRetrievalJob(RetrievalJob):
def __init__(
self,
Expand All @@ -419,6 +448,8 @@ def __init__(
full_feature_names,
metadata,
data_source_writer,
staging_location,
staging_location_endpoint_override,
) -> None:
super().__init__()
self.table = table
Expand All @@ -428,6 +459,8 @@ def __init__(
self._full_feature_names = full_feature_names
self._metadata = metadata
self.data_source_writer = data_source_writer
self.staging_location = staging_location
self.staging_location_endpoint_override = staging_location_endpoint_override

def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
return self.table.execute()
Expand Down Expand Up @@ -456,3 +489,15 @@ def persist(
@property
def metadata(self) -> Optional[RetrievalMetadata]:
return self._metadata

def supports_remote_storage_export(self) -> bool:
return self.staging_location is not None

def to_remote_storage(self) -> List[str]:
path = self.staging_location + f"/{str(uuid.uuid4())}"

storage_options = {"AWS_ENDPOINT_URL": self.staging_location_endpoint_override}

self.table.to_delta(path, storage_options=storage_options)

return list_s3_files(path, self.staging_location_endpoint_override)
8 changes: 7 additions & 1 deletion sdk/python/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,10 @@ markers =

env =
FEAST_USAGE=False
IS_TEST=True
IS_TEST=True

filterwarnings =
ignore::DeprecationWarning:pyspark.sql.pandas.*:
ignore::DeprecationWarning:pyspark.sql.connect.*:
ignore::DeprecationWarning:httpx.*:
ignore::FutureWarning:ibis_substrait.compiler.*:
8 changes: 7 additions & 1 deletion sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
import logging
import multiprocessing
import os
import random
from datetime import datetime, timedelta
from multiprocessing import Process
from sys import platform
from typing import Any, Dict, List, Tuple, no_type_check
from unittest import mock

import pandas as pd
import pytest
Expand Down Expand Up @@ -180,7 +182,11 @@ def environment(request, worker_id):
request.param, worker_id=worker_id, fixture_request=request
)

yield e
if hasattr(e.data_source_creator, "mock_environ"):
with mock.patch.dict(os.environ, e.data_source_creator.mock_environ):
yield e
else:
yield e

e.feature_store.teardown()
e.data_source_creator.teardown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tests.integration.feature_repos.universal.data_sources.file import (
DuckDBDataSourceCreator,
DuckDBDeltaDataSourceCreator,
DuckDBDeltaS3DataSourceCreator,
FileDataSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.redshift import (
Expand Down Expand Up @@ -122,6 +123,14 @@
("local", DuckDBDeltaDataSourceCreator),
]

if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True":
AVAILABLE_OFFLINE_STORES.extend(
[
("local", DuckDBDeltaS3DataSourceCreator),
]
)


AVAILABLE_ONLINE_STORES: Dict[
str, Tuple[Union[str, Dict[Any, Any]], Optional[Type[OnlineStoreCreator]]]
] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from minio import Minio
from testcontainers.core.generic import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.minio import MinioContainer

from feast import FileSource
from feast.data_format import DeltaFormat, ParquetFormat
Expand Down Expand Up @@ -134,6 +135,74 @@ def create_logged_features_destination(self) -> LoggingDestination:
return FileLoggingDestination(path=d)


class DeltaS3FileSourceCreator(FileDataSourceCreator):
def __init__(self, project_name: str, *args, **kwargs):
super().__init__(project_name)
self.minio = MinioContainer()
self.minio.start()
client = self.minio.get_client()
client.make_bucket("test")
host_ip = self.minio.get_container_host_ip()
exposed_port = self.minio.get_exposed_port(self.minio.port)
self.endpoint_url = f"http://{host_ip}:{exposed_port}"

self.mock_environ = {
"AWS_ACCESS_KEY_ID": self.minio.access_key,
"AWS_SECRET_ACCESS_KEY": self.minio.secret_key,
"AWS_EC2_METADATA_DISABLED": "true",
"AWS_REGION": "us-east-1",
"AWS_ALLOW_HTTP": "true",
"AWS_S3_ALLOW_UNSAFE_RENAME": "true",
}

def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
timestamp_field: Optional[str] = "ts",
) -> DataSource:
from deltalake.writer import write_deltalake

destination_name = self.get_prefixed_table_name(destination_name)

storage_options = {
"AWS_ACCESS_KEY_ID": self.minio.access_key,
"AWS_SECRET_ACCESS_KEY": self.minio.secret_key,
"AWS_ENDPOINT_URL": self.endpoint_url,
}

path = f"s3://test/{str(uuid.uuid4())}/{destination_name}"

write_deltalake(path, df, storage_options=storage_options)

return FileSource(
file_format=DeltaFormat(),
path=path,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping or {"ts_1": "ts"},
s3_endpoint_override=self.endpoint_url,
)

def create_saved_dataset_destination(self) -> SavedDatasetFileStorage:
return SavedDatasetFileStorage(
path=f"s3://test/{str(uuid.uuid4())}",
file_format=DeltaFormat(),
s3_endpoint_override=self.endpoint_url,
)

# LoggingDestination is parquet-only
def create_logged_features_destination(self) -> LoggingDestination:
d = tempfile.mkdtemp(prefix=self.project_name)
self.keep.append(d)
return FileLoggingDestination(path=d)

def teardown(self):
self.minio.stop()


class FileParquetDatasetSourceCreator(FileDataSourceCreator):
def create_data_source(
self,
Expand Down Expand Up @@ -273,3 +342,12 @@ class DuckDBDeltaDataSourceCreator(DeltaFileSourceCreator):
def create_offline_store_config(self):
self.duckdb_offline_store_config = DuckDBOfflineStoreConfig()
return self.duckdb_offline_store_config


class DuckDBDeltaS3DataSourceCreator(DeltaS3FileSourceCreator):
def create_offline_store_config(self):
self.duckdb_offline_store_config = DuckDBOfflineStoreConfig(
staging_location="s3://test/staging",
staging_location_endpoint_override=self.endpoint_url,
)
return self.duckdb_offline_store_config
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_spark_materialization_consistency():
batch_engine={"type": "spark.engine", "partitions": 10},
)
spark_environment = construct_test_environment(
spark_config, None, entity_key_serialization_version=1
spark_config, None, entity_key_serialization_version=2
)

df = create_basic_driver_dataset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def test_historical_features(

if job_from_df.supports_remote_storage_export():
files = job_from_df.to_remote_storage()
print(files)
assert len(files) > 0 # This test should be way more detailed
assert len(files) # 0 # This test should be way more detailed

start_time = datetime.utcnow()
actual_df_from_df_entities = job_from_df.to_df()
Expand Down
Loading
Loading