From e844370bb42f2e7af757880b9019e91a24d7d16e Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Sat, 8 Feb 2025 14:05:40 -0500 Subject: [PATCH 1/3] feat: Add SQLite Offline store Signed-off-by: Francisco Javier Arceo --- .../feast/infra/online_stores/sqlite.py | 175 ++++++++++++++++-- .../registration/test_universal_registry.py | 2 +- .../online_store/test_online_retrieval.py | 62 +++++++ 3 files changed, 219 insertions(+), 20 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index 23b4f6db3a..896c20e36d 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -26,6 +26,7 @@ from feast.feature_view import FeatureView from feast.infra.infra_object import SQLITE_INFRA_OBJECT_CLASS_TYPE, InfraObject from feast.infra.key_encoding_utils import ( + deserialize_entity_key, serialize_entity_key, serialize_f32, ) @@ -91,6 +92,9 @@ class SqliteOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): path: StrictStr = "data/online.db" """ (optional) Path to sqlite db """ + vector_enabled: bool = False + vector_len: Optional[int] = None + class SqliteOnlineStore(OnlineStore): """ @@ -104,22 +108,21 @@ class SqliteOnlineStore(OnlineStore): @staticmethod def _get_db_path(config: RepoConfig) -> str: - assert ( - config.online_store.type == "sqlite" - or config.online_store.type.endswith("SqliteOnlineStore") - ) - - if config.repo_path and not Path(config.online_store.path).is_absolute(): - db_path = str(config.repo_path / config.online_store.path) - else: - db_path = config.online_store.path - return db_path + online_store = config.online_store + if not isinstance(online_store, SqliteOnlineStoreConfig): + raise ValueError("online_store must be SqliteOnlineStoreConfig") + if config.repo_path and not Path(online_store.path).is_absolute(): + return str(config.repo_path / online_store.path) + return str(online_store.path) def _get_conn(self, config: RepoConfig): if not self._conn: db_path = self._get_db_path(config) self._conn = _initialize_conn(db_path) - if sys.version_info[0:2] == (3, 10) and config.online_store.vector_enabled: + online_store = config.online_store + if not isinstance(online_store, SqliteOnlineStoreConfig): + raise ValueError("online_store must be SqliteOnlineStoreConfig") + if sys.version_info[0:2] == (3, 10) and online_store.vector_enabled: import sqlite_vec # noqa: F401 self._conn.enable_load_extension(True) # type: ignore @@ -141,6 +144,9 @@ def online_write_batch( ], progress: Optional[Callable[[int], Any]], ) -> None: + online_store = config.online_store + if not isinstance(online_store, SqliteOnlineStoreConfig): + raise ValueError("online_store must be SqliteOnlineStoreConfig") conn = self._get_conn(config) project = config.project @@ -157,9 +163,13 @@ def online_write_batch( table_name = _table_id(project, table) for feature_name, val in values.items(): - if config.online_store.vector_enabled: + online_store = config.online_store + if not isinstance(online_store, SqliteOnlineStoreConfig): + raise ValueError("online_store must be SqliteOnlineStoreConfig") + if online_store.vector_enabled and online_store.vector_len: vector_bin = serialize_f32( - val.float_list_val.val, config.online_store.vector_len + val.float_list_val.val, + online_store.vector_len, ) # type: ignore conn.execute( f""" @@ -356,22 +366,28 @@ def retrieve_online_documents( Returns: List of tuples containing the event timestamp, the document feature, the vector value, and the distance """ - project = config.project - - if not config.online_store.vector_enabled: + online_store = config.online_store + if not isinstance(online_store, SqliteOnlineStoreConfig): + raise ValueError("online_store must be SqliteOnlineStoreConfig") + if not online_store.vector_enabled: raise ValueError("sqlite-vss is not enabled in the online store config") conn = self._get_conn(config) cur = conn.cursor() # Convert the embedding to a binary format instead of using SerializeToString() - query_embedding_bin = serialize_f32(embedding, config.online_store.vector_len) - table_name = _table_id(project, table) + online_store = config.online_store + if not isinstance(online_store, SqliteOnlineStoreConfig): + raise ValueError("online_store must be SqliteOnlineStoreConfig") + if not online_store.vector_len: + raise ValueError("vector_len is not configured in the online store config") + query_embedding_bin = serialize_f32(embedding, online_store.vector_len) # type: ignore + table_name = _table_id(config.project, table) cur.execute( f""" CREATE VIRTUAL TABLE vec_example using vec0( - vector_value float[{config.online_store.vector_len}] + vector_value float[{online_store.vector_len}] ); """ ) @@ -444,6 +460,127 @@ def retrieve_online_documents( return result + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + query: List[float], + top_k: int, + distance_metric: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Retrieve documents using vector similarity search. + Args: + config: Feast configuration object + table: FeatureView object as the table to search + requested_features: List of requested features to retrieve + query: Query embedding to search for + top_k: Number of items to return + distance_metric: Distance metric to use (optional) + Returns: + List of tuples containing the event timestamp, entity key, and feature values + """ + online_store = config.online_store + if not isinstance(online_store, SqliteOnlineStoreConfig): + raise ValueError("online_store must be SqliteOnlineStoreConfig") + if not online_store.vector_enabled: + raise ValueError("Vector search is not enabled in the online store config") + + conn = self._get_conn(config) + cur = conn.cursor() + + online_store = config.online_store + if not isinstance(online_store, SqliteOnlineStoreConfig): + raise ValueError("online_store must be SqliteOnlineStoreConfig") + if not online_store.vector_len: + raise ValueError("vector_len is not configured in the online store config") + query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore + table_name = _table_id(config.project, table) + + cur.execute( + f""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_example using vec0( + vector_value float[{online_store.vector_len}] + ); + """ + ) + + cur.execute( + f""" + INSERT INTO vec_example(rowid, vector_value) + select rowid, vector_value from {table_name} + """ + ) + + cur.execute( + """ + INSERT INTO vec_example(rowid, vector_value) + VALUES (?, ?) + """, + (0, query_embedding_bin), + ) + + cur.execute( + f""" + select + fv.entity_key, + fv.feature_name, + fv.value, + f.distance, + fv.event_ts, + fv.created_ts + from ( + select + rowid, + vector_value, + distance + from vec_example + where vector_value match ? + order by distance + limit ? + ) f + left join {table_name} fv + on f.rowid = fv.rowid + where fv.feature_name in ({",".join(["?" for _ in requested_features])}) + """, + ( + query_embedding_bin, + top_k, + *[f.split(":")[-1] for f in requested_features], + ), + ) + + rows = cur.fetchall() + result: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + + for entity_key, feature_name, value_bin, distance, event_ts, created_ts in rows: + val = ValueProto() + val.ParseFromString(value_bin) + entity_key_proto = None + if entity_key: + entity_key_proto = deserialize_entity_key( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ) + res = {feature_name: val} + res["distance"] = ValueProto(float_val=distance) + result.append((event_ts, entity_key_proto, res)) + + return result + def _initialize_conn(db_path: str): try: diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index 3819d168d7..639dcd9149 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -1164,7 +1164,7 @@ def test_registry_cache_thread_async(test_registry): test_registry.teardown() -@pytest.mark.integration +# @pytest.mark.integration @pytest.mark.parametrize( "test_registry", all_fixtures, diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 6b0adb6263..fc64fa7c1c 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -738,6 +738,68 @@ def test_sqlite_vec_import() -> None: assert result == [(2, 2.39), (1, 2.39)] +@pytest.mark.skipif( + sys.version_info[0:2] != (3, 10), + reason="Only works on Python 3.10", +) +def test_sqlite_get_online_documents_v2() -> None: + """Test retrieving documents using v2 method with vector similarity search.""" + n = 10 + vector_length = 8 + runner = CliRunner() + with runner.local_repo( + get_example_repo("example_feature_repo_1.py"), "file" + ) as store: + store.config.online_store.vector_enabled = True + store.config.online_store.vector_len = vector_length + document_embeddings_fv = store.get_feature_view(name="document_embeddings") + + provider = store._get_provider() + + # Create test data + item_keys = [ + EntityKeyProto( + join_keys=["item_id"], entity_values=[ValueProto(int64_val=i)] + ) + for i in range(n) + ] + data = [] + for item_key in item_keys: + data.append( + ( + item_key, + { + "Embeddings": ValueProto( + float_list_val=FloatListProto( + val=[float(x) for x in np.random.random(vector_length)] + ) + ) + }, + _utc_now(), + _utc_now(), + ) + ) + + provider.online_write_batch( + config=store.config, + table=document_embeddings_fv, + data=data, + progress=None, + ) + + # Test vector similarity search + query_embedding = [float(x) for x in np.random.random(vector_length)] + result = store.retrieve_online_documents_v2( + features=["document_embeddings:Embeddings"], + query=query_embedding, + top_k=3, + ).to_dict() + + assert "Embeddings" in result + assert "distance" in result + assert len(result["distance"]) == 3 + + @pytest.mark.skip(reason="Skipping this test as CI struggles with it") def test_local_milvus() -> None: import random From 798d2c442ea4e417afda198bb41bbf28bb4245a1 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Sat, 8 Feb 2025 14:14:39 -0500 Subject: [PATCH 2/3] removing commented out thing Signed-off-by: Francisco Javier Arceo --- .../tests/integration/registration/test_universal_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index 639dcd9149..3819d168d7 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -1164,7 +1164,7 @@ def test_registry_cache_thread_async(test_registry): test_registry.teardown() -# @pytest.mark.integration +@pytest.mark.integration @pytest.mark.parametrize( "test_registry", all_fixtures, From da6c00e6d7a9a26af66a330f616ffa87a0684fa4 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Sat, 8 Feb 2025 14:49:41 -0500 Subject: [PATCH 3/3] updated entity key serialization Signed-off-by: Francisco Javier Arceo --- sdk/python/tests/unit/online_store/test_online_retrieval.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index fc64fa7c1c..74006e559e 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -752,6 +752,7 @@ def test_sqlite_get_online_documents_v2() -> None: ) as store: store.config.online_store.vector_enabled = True store.config.online_store.vector_len = vector_length + store.config.entity_key_serialization_version = 3 document_embeddings_fv = store.get_feature_view(name="document_embeddings") provider = store._get_provider()