Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Moving Milvus client to PyMilvus #4907

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,7 @@ def retrieve_online_documents(
query: Union[str, List[float]],
top_k: int,
features: Optional[List[str]] = None,
distance_metric: Optional[str] = None,
distance_metric: Optional[str] = "L2",
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.
Expand Down
147 changes: 84 additions & 63 deletions sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
CollectionSchema,
DataType,
FieldSchema,
connections,
MilvusClient,
)
from pymilvus.orm.connections import Connections

from feast import Entity
from feast.feature_view import FeatureView
Expand Down Expand Up @@ -85,14 +84,15 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
"""

type: Literal["milvus"] = "milvus"

host: Optional[StrictStr] = "localhost"
port: Optional[int] = 19530
index_type: Optional[str] = "IVF_FLAT"
metric_type: Optional[str] = "L2"
embedding_dim: Optional[int] = 128
vector_enabled: Optional[bool] = True
nlist: Optional[int] = 128
username: Optional[StrictStr] = ""
password: Optional[StrictStr] = ""


class MilvusOnlineStore(OnlineStore):
Expand All @@ -103,24 +103,23 @@ class MilvusOnlineStore(OnlineStore):
_collections: Dictionary to cache Milvus collections.
"""

_conn: Optional[Connections] = None
_collections: Dict[str, Collection] = {}
client: Optional[MilvusClient] = None
_collections: Dict[str, Any] = {}

def _connect(self, config: RepoConfig) -> connections:
if not self._conn:
if not connections.has_connection("feast"):
self._conn = connections.connect(
alias="feast",
host=config.online_store.host,
port=str(config.online_store.port),
)
return self._conn
def _connect(self, config: RepoConfig) -> MilvusClient:
if not self.client:
self.client = MilvusClient(
url=f"{config.online_store.host}:{config.online_store.port}",
token=f"{config.online_store.username}:{config.online_store.password}"
if config.online_store.username and config.online_store.password
else "",
)
return self.client

def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
def _get_collection(self, config: RepoConfig, table: FeatureView) -> Dict[str, Any]:
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
if collection_name not in self._collections:
self._connect(config)

# Create a composite key by combining entity fields
composite_key_name = (
"_".join([field.name for field in table.entity_columns]) + "_pk"
Expand Down Expand Up @@ -166,23 +165,38 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
schema = CollectionSchema(
fields=fields, description="Feast feature view data"
)
collection = Collection(name=collection_name, schema=schema, using="feast")
if not collection.has_index():
index_params = {
"index_type": config.online_store.index_type,
"metric_type": config.online_store.metric_type,
"params": {"nlist": config.online_store.nlist},
}
for vector_field in schema.fields:
if vector_field.dtype in [
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
]:
collection.create_index(
field_name=vector_field.name, index_params=index_params
)
collection.load()
self._collections[collection_name] = collection
collection_exists = self.client.has_collection(
collection_name=collection_name
)
if not collection_exists:
self.client.create_collection(
collection_name=collection_name,
dimension=config.online_store.embedding_dim,
schema=schema,
)
index_params = self.client.prepare_index_params()
for vector_field in schema.fields:
if vector_field.dtype in [
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
]:
index_params.add_index(
collection_name=collection_name,
field_name=vector_field.name,
metric_type=config.online_store.metric_type,
index_type=config.online_store.index_type,
index_name=f"vector_index_{vector_field.name}",
params={"nlist": config.online_store.nlist},
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params,
)
else:
self.client.load_collection(collection_name)
self._collections[collection_name] = self.client.describe_collection(
collection_name
)
return self._collections[collection_name]

def online_write_batch(
Expand All @@ -199,6 +213,7 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
self.client = self._connect(config)
collection = self._get_collection(config, table)
entity_batch_to_insert = []
for entity_key, values_dict, timestamp, created_ts in data:
Expand Down Expand Up @@ -231,8 +246,9 @@ def online_write_batch(
if progress:
progress(1)

collection.insert(entity_batch_to_insert)
collection.flush()
self.client.insert(
collection_name=collection["collection_name"], data=entity_batch_to_insert
)

def online_read(
self,
Expand All @@ -252,14 +268,14 @@ def update(
entities_to_keep: Sequence[Entity],
partial: bool,
):
self._connect(config)
self.client = self._connect(config)
for table in tables_to_keep:
self._get_collection(config, table)
self._collections = self._get_collection(config, table)

for table in tables_to_delete:
collection_name = _table_id(config.project, table)
collection = Collection(name=collection_name)
if collection.exists():
collection.drop()
if self._collections.get(collection_name, None):
self.client.drop_collection(collection_name)
self._collections.pop(collection_name, None)

def plan(
Expand All @@ -273,12 +289,12 @@ def teardown(
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
self._connect(config)
self.client = self._connect(config)
for table in tables:
collection = self._get_collection(config, table)
if collection:
collection.drop()
self._collections.pop(collection.name, None)
collection_name = _table_id(config.project, table)
if self._collections.get(collection_name, None):
self.client.drop_collection(collection_name)
self._collections.pop(collection_name, None)

def retrieve_online_documents(
self,
Expand All @@ -298,6 +314,8 @@ def retrieve_online_documents(
Optional[ValueProto],
]
]:
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
collection = self._get_collection(config, table)
if not config.online_store.vector_enabled:
raise ValueError("Vector search is not enabled in the online store config")
Expand All @@ -321,42 +339,45 @@ def retrieve_online_documents(
+ ["created_ts", "event_ts"]
)
assert all(
field
field in [f["name"] for f in collection["fields"]]
for field in output_fields
if field in [f.name for f in collection.schema.fields]
), f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema"

), f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema"
# Note we choose the first vector field as the field to search on. Not ideal but it's something.
ann_search_field = None
for field in collection.schema.fields:
for field in collection["fields"]:
if (
field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
and field.name in output_fields
field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
and field["name"] in output_fields
):
ann_search_field = field.name
ann_search_field = field["name"]
break

results = collection.search(
self.client.load_collection(collection_name)
results = self.client.search(
collection_name=collection_name,
data=[embedding],
anns_field=ann_search_field,
param=search_params,
search_params=search_params,
limit=top_k,
output_fields=output_fields,
consistency_level="Strong",
)

result_list = []
for hits in results:
for hit in hits:
single_record = {}
for field in output_fields:
single_record[field] = hit.entity.get(field)
single_record[field] = hit.get("entity", {}).get(field, None)

entity_key_bytes = bytes.fromhex(hit.entity.get(composite_key_name))
embedding = hit.entity.get(ann_search_field)
entity_key_bytes = bytes.fromhex(
hit.get("entity", {}).get(composite_key_name, None)
)
embedding = hit.get("entity", {}).get(ann_search_field)
serialized_embedding = _serialize_vector_to_float_list(embedding)
distance = hit.distance
event_ts = datetime.fromtimestamp(hit.entity.get("event_ts") / 1e6)
distance = hit.get("distance", None)
event_ts = datetime.fromtimestamp(
hit.get("entity", {}).get("event_ts") / 1e6
)
prepared_result = _build_retrieve_online_document_record(
entity_key_bytes,
# This may have a bug
Expand Down Expand Up @@ -412,7 +433,7 @@ def __init__(self, host: str, port: int, name: str):
self._connect()

def _connect(self):
return connections.connect(alias="default", host=self.host, port=str(self.port))
raise NotImplementedError

def to_infra_object_proto(self) -> InfraObjectProto:
# Implement serialization if needed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Dict

from testcontainers.milvus import MilvusContainer
import docker
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs

from tests.integration.feature_repos.universal.online_store_creator import (
OnlineStoreCreator,
Expand All @@ -11,13 +13,19 @@ class MilvusOnlineStoreCreator(OnlineStoreCreator):
def __init__(self, project_name: str, **kwargs):
super().__init__(project_name)
self.fixed_port = 19530
self.container = MilvusContainer("milvusdb/milvus:v2.4.4").with_exposed_ports(
self.container = DockerContainer("milvusdb/milvus:v2.4.4").with_exposed_ports(
self.fixed_port
)
self.client = docker.from_env()

def create_online_store(self) -> Dict[str, Any]:
self.container.start()
# Wait for Milvus server to be ready
# log_string_to_wait_for = "Ready to accept connections"
log_string_to_wait_for = ""
wait_for_logs(
container=self.container, predicate=log_string_to_wait_for, timeout=30
)
host = "localhost"
port = self.container.get_exposed_port(self.fixed_port)
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -897,26 +897,26 @@ def test_retrieve_online_documents(environment, fake_document_data):
).to_dict()


# @pytest.mark.integration
# @pytest.mark.universal_online_stores(only=["milvus"])
# def test_retrieve_online_milvus_documents(environment, fake_document_data):
# fs = environment.feature_store
# df, data_source = fake_document_data
# item_embeddings_feature_view = create_item_embeddings_feature_view(data_source)
# fs.apply([item_embeddings_feature_view, item()])
# fs.write_to_online_store("item_embeddings", df)
# documents = fs.retrieve_online_documents(
# feature=None,
# features=[
# "item_embeddings:embedding_float",
# "item_embeddings:item_id",
# "item_embeddings:string_feature",
# ],
# query=[1.0, 2.0],
# top_k=2,
# distance_metric="L2",
# ).to_dict()
# assert len(documents["embedding_float"]) == 2
#
# assert len(documents["item_id"]) == 2
# assert documents["item_id"] == [2, 3]
@pytest.mark.integration
@pytest.mark.universal_online_stores(only=["milvus"])
def test_retrieve_online_milvus_documents(environment, fake_document_data):
fs = environment.feature_store
df, data_source = fake_document_data
item_embeddings_feature_view = create_item_embeddings_feature_view(data_source)
fs.apply([item_embeddings_feature_view, item()])
fs.write_to_online_store("item_embeddings", df)
documents = fs.retrieve_online_documents(
feature=None,
features=[
"item_embeddings:embedding_float",
"item_embeddings:item_id",
"item_embeddings:string_feature",
],
query=[1.0, 2.0],
top_k=2,
distance_metric="L2",
).to_dict()
assert len(documents["embedding_float"]) == 2

assert len(documents["item_id"]) == 2
assert documents["item_id"] == [2, 3]
Loading