Skip to content

Commit

Permalink
fix: prevent access to a closed instance methods in local mode (#594)
Browse files Browse the repository at this point in the history
* fix: prevent access to a closed method instance in local mode

* fix: regen async
  • Loading branch information
joein committed Apr 16, 2024
1 parent c6322f0 commit efb0309
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
15 changes: 15 additions & 0 deletions qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, location: str, force_disable_check_same_thread: bool = False)
self._load()
self._closed: bool = False

@property
def closed(self) -> bool:
return self._closed

Expand Down Expand Up @@ -129,6 +130,8 @@ def _load(self) -> None:
def _save(self) -> None:
if not self.persistent:
return
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
meta_path = os.path.join(self.location, META_INFO_FILENAME)
with open(meta_path, "w") as f:
f.write(
Expand All @@ -144,6 +147,8 @@ def _save(self) -> None:
)

def _get_collection(self, collection_name: str) -> LocalCollection:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
if collection_name in self.collections:
return self.collections[collection_name]
if collection_name in self.aliases:
Expand Down Expand Up @@ -552,6 +557,8 @@ async def update_collection_aliases(
async def get_collection_aliases(
self, collection_name: str, **kwargs: Any
) -> types.CollectionsAliasesResponse:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
Expand All @@ -561,6 +568,8 @@ async def get_collection_aliases(
)

async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
Expand All @@ -569,6 +578,8 @@ async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
)

async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
return types.CollectionsResponse(
collections=[
rest_models.CollectionDescription(name=name)
Expand Down Expand Up @@ -598,6 +609,8 @@ def _collection_path(self, collection_name: str) -> Optional[str]:
return None

async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
_collection = self.collections.pop(collection_name, None)
del _collection
self.aliases = {
Expand All @@ -619,6 +632,8 @@ async def create_collection(
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
**kwargs: Any,
) -> bool:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
src_collection = None
from_collection_name = None
if init_from is not None:
Expand Down
23 changes: 23 additions & 0 deletions qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, location: str, force_disable_check_same_thread: bool = False)
self._load()
self._closed: bool = False

@property
def closed(self) -> bool:
return self._closed

Expand Down Expand Up @@ -123,6 +124,10 @@ def _load(self) -> None:
def _save(self) -> None:
if not self.persistent:
return

if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")

meta_path = os.path.join(self.location, META_INFO_FILENAME)
with open(meta_path, "w") as f:
f.write(
Expand All @@ -138,6 +143,9 @@ def _save(self) -> None:
)

def _get_collection(self, collection_name: str) -> LocalCollection:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")

if collection_name in self.collections:
return self.collections[collection_name]
if collection_name in self.aliases:
Expand Down Expand Up @@ -567,6 +575,9 @@ def update_collection_aliases(
def get_collection_aliases(
self, collection_name: str, **kwargs: Any
) -> types.CollectionsAliasesResponse:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")

return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(
Expand All @@ -579,6 +590,9 @@ def get_collection_aliases(
)

def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")

return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(
Expand All @@ -590,6 +604,9 @@ def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
)

def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")

return types.CollectionsResponse(
collections=[
rest_models.CollectionDescription(name=name)
Expand Down Expand Up @@ -619,6 +636,9 @@ def _collection_path(self, collection_name: str) -> Optional[str]:
return None

def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")

_collection = self.collections.pop(collection_name, None)
del _collection
self.aliases = {
Expand All @@ -640,6 +660,9 @@ def create_collection(
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
**kwargs: Any,
) -> bool:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")

src_collection = None
from_collection_name = None
if init_from is not None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,20 @@ def test_client_close():
"test", vectors_config=VectorParams(size=100, distance=Distance.COSINE)
)
local_client_in_mem.close()
assert local_client_in_mem._client.closed is True

with pytest.raises(RuntimeError):
local_client_in_mem.upsert(
"test", [PointStruct(id=1, vector=np.random.rand(100).tolist())]
)

with pytest.raises(RuntimeError):
local_client_in_mem.create_collection(
"test", vectors_config=VectorParams(size=100, distance=Distance.COSINE)
)

with pytest.raises(RuntimeError):
local_client_in_mem.delete_collection("test")

with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir + "/test.db"
Expand Down

0 comments on commit efb0309

Please sign in to comment.