From efb03097398c0b9bdcab7cf0e3afceef35d90780 Mon Sep 17 00:00:00 2001 From: George Date: Mon, 15 Apr 2024 20:00:08 +0200 Subject: [PATCH] fix: prevent access to a closed instance methods in local mode (#594) * fix: prevent access to a closed method instance in local mode * fix: regen async --- qdrant_client/local/async_qdrant_local.py | 15 +++++++++++++++ qdrant_client/local/qdrant_local.py | 23 +++++++++++++++++++++++ tests/test_qdrant_client.py | 14 ++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/qdrant_client/local/async_qdrant_local.py b/qdrant_client/local/async_qdrant_local.py index 137b91c5..ad49c967 100644 --- a/qdrant_client/local/async_qdrant_local.py +++ b/qdrant_client/local/async_qdrant_local.py @@ -71,6 +71,7 @@ def __init__(self, location: str, force_disable_check_same_thread: bool = False) self._load() self._closed: bool = False + @property def closed(self) -> bool: return self._closed @@ -129,6 +130,8 @@ def _load(self) -> None: def _save(self) -> None: if not self.persistent: return + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") meta_path = os.path.join(self.location, META_INFO_FILENAME) with open(meta_path, "w") as f: f.write( @@ -144,6 +147,8 @@ def _save(self) -> None: ) def _get_collection(self, collection_name: str) -> LocalCollection: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") if collection_name in self.collections: return self.collections[collection_name] if collection_name in self.aliases: @@ -552,6 +557,8 @@ async def update_collection_aliases( async def get_collection_aliases( self, collection_name: str, **kwargs: Any ) -> types.CollectionsAliasesResponse: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") return types.CollectionsAliasesResponse( aliases=[ rest_models.AliasDescription(alias_name=alias_name, collection_name=name) @@ -561,6 +568,8 @@ async def get_collection_aliases( ) async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") return types.CollectionsAliasesResponse( aliases=[ rest_models.AliasDescription(alias_name=alias_name, collection_name=name) @@ -569,6 +578,8 @@ async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse: ) async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") return types.CollectionsResponse( collections=[ rest_models.CollectionDescription(name=name) @@ -598,6 +609,8 @@ def _collection_path(self, collection_name: str) -> Optional[str]: return None async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") _collection = self.collections.pop(collection_name, None) del _collection self.aliases = { @@ -619,6 +632,8 @@ async def create_collection( sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None, **kwargs: Any, ) -> bool: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") src_collection = None from_collection_name = None if init_from is not None: diff --git a/qdrant_client/local/qdrant_local.py b/qdrant_client/local/qdrant_local.py index 1ee32f86..113ae925 100644 --- a/qdrant_client/local/qdrant_local.py +++ b/qdrant_client/local/qdrant_local.py @@ -60,6 +60,7 @@ def __init__(self, location: str, force_disable_check_same_thread: bool = False) self._load() self._closed: bool = False + @property def closed(self) -> bool: return self._closed @@ -123,6 +124,10 @@ def _load(self) -> None: def _save(self) -> None: if not self.persistent: return + + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") + meta_path = os.path.join(self.location, META_INFO_FILENAME) with open(meta_path, "w") as f: f.write( @@ -138,6 +143,9 @@ def _save(self) -> None: ) def _get_collection(self, collection_name: str) -> LocalCollection: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") + if collection_name in self.collections: return self.collections[collection_name] if collection_name in self.aliases: @@ -567,6 +575,9 @@ def update_collection_aliases( def get_collection_aliases( self, collection_name: str, **kwargs: Any ) -> types.CollectionsAliasesResponse: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") + return types.CollectionsAliasesResponse( aliases=[ rest_models.AliasDescription( @@ -579,6 +590,9 @@ def get_collection_aliases( ) def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") + return types.CollectionsAliasesResponse( aliases=[ rest_models.AliasDescription( @@ -590,6 +604,9 @@ def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse: ) def get_collections(self, **kwargs: Any) -> types.CollectionsResponse: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") + return types.CollectionsResponse( collections=[ rest_models.CollectionDescription(name=name) @@ -619,6 +636,9 @@ def _collection_path(self, collection_name: str) -> Optional[str]: return None def delete_collection(self, collection_name: str, **kwargs: Any) -> bool: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") + _collection = self.collections.pop(collection_name, None) del _collection self.aliases = { @@ -640,6 +660,9 @@ def create_collection( sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None, **kwargs: Any, ) -> bool: + if self.closed: + raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.") + src_collection = None from_collection_name = None if init_from is not None: diff --git a/tests/test_qdrant_client.py b/tests/test_qdrant_client.py index 0ab66380..123cb835 100644 --- a/tests/test_qdrant_client.py +++ b/tests/test_qdrant_client.py @@ -1742,6 +1742,20 @@ def test_client_close(): "test", vectors_config=VectorParams(size=100, distance=Distance.COSINE) ) local_client_in_mem.close() + assert local_client_in_mem._client.closed is True + + with pytest.raises(RuntimeError): + local_client_in_mem.upsert( + "test", [PointStruct(id=1, vector=np.random.rand(100).tolist())] + ) + + with pytest.raises(RuntimeError): + local_client_in_mem.create_collection( + "test", vectors_config=VectorParams(size=100, distance=Distance.COSINE) + ) + + with pytest.raises(RuntimeError): + local_client_in_mem.delete_collection("test") with tempfile.TemporaryDirectory() as tmpdir: path = tmpdir + "/test.db"