Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and improve retry mechanism for Compass client #84

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 114 additions & 44 deletions cohere/compass/clients/compass.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
password: Optional[str] = None,
bearer_token: Optional[str] = None,
http_session: Optional[requests.Session] = None,
default_max_retries: int = DEFAULT_MAX_RETRIES,
default_sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
):
"""
Initialize the Compass client.
Expand All @@ -110,6 +112,15 @@ def __init__(
self.session = http_session or requests.Session()
self.bearer_token = bearer_token

if default_max_retries < 0:
raise ValueError("default_max_retries must be a non-negative integer.")
if default_sleep_retry_seconds < 0:
raise ValueError(
"default_sleep_retry_seconds must be a non-negative integer."
)
self.default_max_retries = default_max_retries
self.default_sleep_retry_seconds = default_sleep_retry_seconds

self.api_method = {
"create_index": self.session.put,
"list_indexes": self.session.get,
Expand Down Expand Up @@ -154,7 +165,12 @@ def __init__(
}

def create_index(
self, *, index_name: str, index_config: Optional[IndexConfig] = None
self,
*,
index_name: str,
index_config: Optional[IndexConfig] = None,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Create an index in Compass.
Expand All @@ -165,13 +181,19 @@ def create_index(
"""
return self._send_request(
api_name="create_index",
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
index_name=index_name,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
data=index_config,
)

def refresh_index(self, *, index_name: str):
def refresh_index(
self,
*,
index_name: str,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Refresh index.

Expand All @@ -180,12 +202,18 @@ def refresh_index(self, *, index_name: str):
"""
return self._send_request(
api_name="refresh",
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name=index_name,
)

def delete_index(self, *, index_name: str):
def delete_index(
self,
*,
index_name: str,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Delete an index from Compass.

Expand All @@ -194,12 +222,19 @@ def delete_index(self, *, index_name: str):
"""
return self._send_request(
api_name="delete_index",
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name=index_name,
)

def delete_document(self, *, index_name: str, document_id: str):
def delete_document(
self,
*,
index_name: str,
document_id: str,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Delete a document from Compass.

Expand All @@ -211,12 +246,19 @@ def delete_document(self, *, index_name: str, document_id: str):
return self._send_request(
api_name="delete_document",
document_id=document_id,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name=index_name,
)

def get_document(self, *, index_name: str, document_id: str):
def get_document(
self,
*,
index_name: str,
document_id: str,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Get a document from Compass.

Expand All @@ -228,21 +270,25 @@ def get_document(self, *, index_name: str, document_id: str):
return self._send_request(
api_name="get_document",
document_id=document_id,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name=index_name,
)

def list_indexes(self):
def list_indexes(
self,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
List all indexes in Compass.

:returns: the response from the Compass API
"""
return self._send_request(
api_name="list_indexes",
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name="",
)

Expand All @@ -252,8 +298,8 @@ def add_attributes(
index_name: str,
document_id: str,
attributes: DocumentAttributes,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Optional[str]:
"""
Update the content field of an existing document with additional context.
Expand Down Expand Up @@ -284,10 +330,10 @@ def insert_doc(
*,
index_name: str,
doc: CompassDocument,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
authorized_groups: Optional[list[str]] = None,
merge_groups_on_conflict: bool = False,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Optional[list[dict[str, str]]]:
"""
Insert a parsed document into an index in Compass.
Expand Down Expand Up @@ -315,8 +361,8 @@ def upload_document(
content_type: str,
document_id: uuid.UUID,
attributes: DocumentAttributes = DocumentAttributes(),
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Optional[Union[str, dict[str, Any]]]:
"""
Parse and insert a document into an index in Compass.
Expand Down Expand Up @@ -367,13 +413,13 @@ def insert_docs(
docs: Iterator[CompassDocument],
max_chunks_per_request: int = DEFAULT_MAX_CHUNKS_PER_REQUEST,
max_error_rate: float = DEFAULT_MAX_ERROR_RATE,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
errors_sliding_window_size: Optional[int] = 10,
skip_first_n_docs: int = 0,
num_jobs: Optional[int] = None,
authorized_groups: Optional[list[str]] = None,
merge_groups_on_conflict: bool = False,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Optional[list[dict[str, str]]]:
"""
Insert multiple parsed documents into an index in Compass.
Expand Down Expand Up @@ -482,8 +528,8 @@ def create_datasource(
self,
*,
datasource: CreateDataSource,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Union[DataSource, str]:
"""
Create a new datasource in Compass.
Expand All @@ -506,8 +552,8 @@ def create_datasource(
def list_datasources(
self,
*,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Union[PaginatedList[DataSource], str]:
"""
List all datasources in Compass.
Expand All @@ -529,8 +575,8 @@ def get_datasource(
self,
*,
datasource_id: str,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Get a datasource in Compass.
Expand All @@ -554,8 +600,8 @@ def delete_datasource(
self,
*,
datasource_id: str,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Delete a datasource in Compass.
Expand All @@ -579,8 +625,8 @@ def sync_datasource(
self,
*,
datasource_id: str,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Sync a datasource in Compass.
Expand All @@ -606,8 +652,8 @@ def list_datasources_objects_states(
datasource_id: str,
skip: int = 0,
limit: int = 100,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Union[PaginatedList[DocumentStatus], str]:
"""
List all objects states in a datasource in Compass.
Expand Down Expand Up @@ -692,6 +738,8 @@ def _search(
query: str,
top_k: int = 10,
filters: Optional[list[SearchFilter]] = None,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
return self._send_request(
api_name=api_name,
Expand All @@ -708,6 +756,8 @@ def search_documents(
query: str,
top_k: int = 10,
filters: Optional[list[SearchFilter]] = None,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> SearchDocumentsResponse:
"""
Search documents in an index.
Expand All @@ -725,6 +775,8 @@ def search_documents(
query=query,
top_k=top_k,
filters=filters,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
)

if result.error:
Expand All @@ -739,6 +791,8 @@ def search_chunks(
query: str,
top_k: int = 10,
filters: Optional[list[SearchFilter]] = None,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> SearchChunksResponse:
"""
Search chunks in an index.
Expand All @@ -756,6 +810,8 @@ def search_chunks(
query=query,
top_k=top_k,
filters=filters,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
)

if result.error:
Expand All @@ -764,7 +820,12 @@ def search_chunks(
return SearchChunksResponse.model_validate(result.result)

def update_group_authorization(
self, *, index_name: str, group_auth_input: GroupAuthorizationInput
self,
*,
index_name: str,
group_auth_input: GroupAuthorizationInput,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> PutDocumentsResponse:
"""
Edit group authorization for an index.
Expand All @@ -776,18 +837,19 @@ def update_group_authorization(
api_name="update_group_authorization",
index_name=index_name,
data=group_auth_input,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
)
if result.error:
raise CompassError(result.error)
return PutDocumentsResponse.model_validate(result.result)

def _send_request(
# todo Simplify this method so we don't have to ignore the C901 complexity warning.
def _send_request( # noqa: C901
self,
api_name: str,
max_retries: int,
sleep_retry_seconds: int,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
data: Optional[BaseModel] = None,
**url_params: str,
) -> _RetryResult:
Expand All @@ -801,6 +863,14 @@ def _send_request(
:param data: the data to send
:returns: An error message if the request failed, otherwise None.
"""
if not max_retries:
max_retries = self.default_max_retries
if not sleep_retry_seconds:
sleep_retry_seconds = self.default_sleep_retry_seconds
if max_retries < 0:
raise ValueError("max_retries must be a non-negative integer.")
if sleep_retry_seconds < 0:
raise ValueError("sleep_retry_seconds must be a non-negative integer.")

@retry(
stop=stop_after_attempt(max_retries),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "compass-sdk"
version = "0.14.2"
version = "0.15.0"
authors = []
description = "Compass SDK"
readme = "README.md"
Expand Down