Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add SQLite retrieve_online_documents_v2 #5032

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 156 additions & 19 deletions sdk/python/feast/infra/online_stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from feast.feature_view import FeatureView
from feast.infra.infra_object import SQLITE_INFRA_OBJECT_CLASS_TYPE, InfraObject
from feast.infra.key_encoding_utils import (
deserialize_entity_key,
serialize_entity_key,
serialize_f32,
)
Expand Down Expand Up @@ -91,6 +92,9 @@ class SqliteOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
path: StrictStr = "data/online.db"
""" (optional) Path to sqlite db """

vector_enabled: bool = False
vector_len: Optional[int] = None


class SqliteOnlineStore(OnlineStore):
"""
Expand All @@ -104,22 +108,21 @@ class SqliteOnlineStore(OnlineStore):

@staticmethod
def _get_db_path(config: RepoConfig) -> str:
assert (
config.online_store.type == "sqlite"
or config.online_store.type.endswith("SqliteOnlineStore")
)

if config.repo_path and not Path(config.online_store.path).is_absolute():
db_path = str(config.repo_path / config.online_store.path)
else:
db_path = config.online_store.path
return db_path
online_store = config.online_store
if not isinstance(online_store, SqliteOnlineStoreConfig):
raise ValueError("online_store must be SqliteOnlineStoreConfig")
if config.repo_path and not Path(online_store.path).is_absolute():
return str(config.repo_path / online_store.path)
return str(online_store.path)

def _get_conn(self, config: RepoConfig):
if not self._conn:
db_path = self._get_db_path(config)
self._conn = _initialize_conn(db_path)
if sys.version_info[0:2] == (3, 10) and config.online_store.vector_enabled:
online_store = config.online_store
if not isinstance(online_store, SqliteOnlineStoreConfig):
raise ValueError("online_store must be SqliteOnlineStoreConfig")
if sys.version_info[0:2] == (3, 10) and online_store.vector_enabled:
import sqlite_vec # noqa: F401

self._conn.enable_load_extension(True) # type: ignore
Expand All @@ -141,6 +144,9 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
online_store = config.online_store
if not isinstance(online_store, SqliteOnlineStoreConfig):
raise ValueError("online_store must be SqliteOnlineStoreConfig")
conn = self._get_conn(config)

project = config.project
Expand All @@ -157,9 +163,13 @@ def online_write_batch(

table_name = _table_id(project, table)
for feature_name, val in values.items():
if config.online_store.vector_enabled:
online_store = config.online_store
if not isinstance(online_store, SqliteOnlineStoreConfig):
raise ValueError("online_store must be SqliteOnlineStoreConfig")
if online_store.vector_enabled and online_store.vector_len:
vector_bin = serialize_f32(
val.float_list_val.val, config.online_store.vector_len
val.float_list_val.val,
online_store.vector_len,
) # type: ignore
conn.execute(
f"""
Expand Down Expand Up @@ -356,22 +366,28 @@ def retrieve_online_documents(
Returns:
List of tuples containing the event timestamp, the document feature, the vector value, and the distance
"""
project = config.project

if not config.online_store.vector_enabled:
online_store = config.online_store
if not isinstance(online_store, SqliteOnlineStoreConfig):
raise ValueError("online_store must be SqliteOnlineStoreConfig")
if not online_store.vector_enabled:
raise ValueError("sqlite-vss is not enabled in the online store config")

conn = self._get_conn(config)
cur = conn.cursor()

# Convert the embedding to a binary format instead of using SerializeToString()
query_embedding_bin = serialize_f32(embedding, config.online_store.vector_len)
table_name = _table_id(project, table)
online_store = config.online_store
if not isinstance(online_store, SqliteOnlineStoreConfig):
raise ValueError("online_store must be SqliteOnlineStoreConfig")
if not online_store.vector_len:
raise ValueError("vector_len is not configured in the online store config")
query_embedding_bin = serialize_f32(embedding, online_store.vector_len) # type: ignore
table_name = _table_id(config.project, table)

cur.execute(
f"""
CREATE VIRTUAL TABLE vec_example using vec0(
vector_value float[{config.online_store.vector_len}]
vector_value float[{online_store.vector_len}]
);
"""
)
Expand Down Expand Up @@ -444,6 +460,127 @@ def retrieve_online_documents(

return result

def retrieve_online_documents_v2(
self,
config: RepoConfig,
table: FeatureView,
requested_features: List[str],
query: List[float],
top_k: int,
distance_metric: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[Dict[str, ValueProto]],
]
]:
"""
Retrieve documents using vector similarity search.
Args:
config: Feast configuration object
table: FeatureView object as the table to search
requested_features: List of requested features to retrieve
query: Query embedding to search for
top_k: Number of items to return
distance_metric: Distance metric to use (optional)
Returns:
List of tuples containing the event timestamp, entity key, and feature values
"""
online_store = config.online_store
if not isinstance(online_store, SqliteOnlineStoreConfig):
raise ValueError("online_store must be SqliteOnlineStoreConfig")
if not online_store.vector_enabled:
raise ValueError("Vector search is not enabled in the online store config")

conn = self._get_conn(config)
cur = conn.cursor()

online_store = config.online_store
if not isinstance(online_store, SqliteOnlineStoreConfig):
raise ValueError("online_store must be SqliteOnlineStoreConfig")
if not online_store.vector_len:
raise ValueError("vector_len is not configured in the online store config")
query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore
table_name = _table_id(config.project, table)

cur.execute(
f"""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_example using vec0(
vector_value float[{online_store.vector_len}]
);
"""
)

cur.execute(
f"""
INSERT INTO vec_example(rowid, vector_value)
select rowid, vector_value from {table_name}
"""
)

cur.execute(
"""
INSERT INTO vec_example(rowid, vector_value)
VALUES (?, ?)
""",
(0, query_embedding_bin),
)

cur.execute(
f"""
select
fv.entity_key,
fv.feature_name,
fv.value,
f.distance,
fv.event_ts,
fv.created_ts
from (
select
rowid,
vector_value,
distance
from vec_example
where vector_value match ?
order by distance
limit ?
) f
left join {table_name} fv
on f.rowid = fv.rowid
where fv.feature_name in ({",".join(["?" for _ in requested_features])})
""",
(
query_embedding_bin,
top_k,
*[f.split(":")[-1] for f in requested_features],
),
)

rows = cur.fetchall()
result: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[Dict[str, ValueProto]],
]
] = []

for entity_key, feature_name, value_bin, distance, event_ts, created_ts in rows:
val = ValueProto()
val.ParseFromString(value_bin)
entity_key_proto = None
if entity_key:
entity_key_proto = deserialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
res = {feature_name: val}
res["distance"] = ValueProto(float_val=distance)
result.append((event_ts, entity_key_proto, res))

return result


def _initialize_conn(db_path: str):
try:
Expand Down
63 changes: 63 additions & 0 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,69 @@ def test_sqlite_vec_import() -> None:
assert result == [(2, 2.39), (1, 2.39)]


@pytest.mark.skipif(
sys.version_info[0:2] != (3, 10),
reason="Only works on Python 3.10",
)
def test_sqlite_get_online_documents_v2() -> None:
"""Test retrieving documents using v2 method with vector similarity search."""
n = 10
vector_length = 8
runner = CliRunner()
with runner.local_repo(
get_example_repo("example_feature_repo_1.py"), "file"
) as store:
store.config.online_store.vector_enabled = True
store.config.online_store.vector_len = vector_length
store.config.entity_key_serialization_version = 3
document_embeddings_fv = store.get_feature_view(name="document_embeddings")

provider = store._get_provider()

# Create test data
item_keys = [
EntityKeyProto(
join_keys=["item_id"], entity_values=[ValueProto(int64_val=i)]
)
for i in range(n)
]
data = []
for item_key in item_keys:
data.append(
(
item_key,
{
"Embeddings": ValueProto(
float_list_val=FloatListProto(
val=[float(x) for x in np.random.random(vector_length)]
)
)
},
_utc_now(),
_utc_now(),
)
)

provider.online_write_batch(
config=store.config,
table=document_embeddings_fv,
data=data,
progress=None,
)

# Test vector similarity search
query_embedding = [float(x) for x in np.random.random(vector_length)]
result = store.retrieve_online_documents_v2(
features=["document_embeddings:Embeddings"],
query=query_embedding,
top_k=3,
).to_dict()

assert "Embeddings" in result
assert "distance" in result
assert len(result["distance"]) == 3


@pytest.mark.skip(reason="Skipping this test as CI struggles with it")
def test_local_milvus() -> None:
import random
Expand Down
Loading