Skip to content

Commit

Permalink
Fix and improve retry mechanism for Compass client (#84)
Browse files Browse the repository at this point in the history
Some APIs, e.g. search APIs, were hard-coding the retry count and sleep
time to 1 and 1, respectively. Others, used the DEFAULT_MAX_RETRIES and
DEFAULT_SLEEP_RETRY_SECONDS constants, respectively. Yet others allowed
passing values from the clients. This PR fixes this as follows:

1. Introduce `default_max_retries` and `default_sleep_retry_seconds` at
the client level, which gets applied to all APIs by default.
2. Each API optionally has `max_retries` and `sleep_retry_seconds`
parameters that, if specified, will override the default values from the
client.
  • Loading branch information
corafid authored Feb 5, 2025
1 parent e2002ff commit 56d8e03
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 45 deletions.
158 changes: 114 additions & 44 deletions cohere/compass/clients/compass.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
password: Optional[str] = None,
bearer_token: Optional[str] = None,
http_session: Optional[requests.Session] = None,
default_max_retries: int = DEFAULT_MAX_RETRIES,
default_sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
):
"""
Initialize the Compass client.
Expand All @@ -110,6 +112,15 @@ def __init__(
self.session = http_session or requests.Session()
self.bearer_token = bearer_token

if default_max_retries < 0:
raise ValueError("default_max_retries must be a non-negative integer.")
if default_sleep_retry_seconds < 0:
raise ValueError(
"default_sleep_retry_seconds must be a non-negative integer."
)
self.default_max_retries = default_max_retries
self.default_sleep_retry_seconds = default_sleep_retry_seconds

self.api_method = {
"create_index": self.session.put,
"list_indexes": self.session.get,
Expand Down Expand Up @@ -154,7 +165,12 @@ def __init__(
}

def create_index(
self, *, index_name: str, index_config: Optional[IndexConfig] = None
self,
*,
index_name: str,
index_config: Optional[IndexConfig] = None,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Create an index in Compass.
Expand All @@ -165,13 +181,19 @@ def create_index(
"""
return self._send_request(
api_name="create_index",
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
index_name=index_name,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
data=index_config,
)

def refresh_index(self, *, index_name: str):
def refresh_index(
self,
*,
index_name: str,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Refresh index.
Expand All @@ -180,12 +202,18 @@ def refresh_index(self, *, index_name: str):
"""
return self._send_request(
api_name="refresh",
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name=index_name,
)

def delete_index(self, *, index_name: str):
def delete_index(
self,
*,
index_name: str,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Delete an index from Compass.
Expand All @@ -194,12 +222,19 @@ def delete_index(self, *, index_name: str):
"""
return self._send_request(
api_name="delete_index",
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name=index_name,
)

def delete_document(self, *, index_name: str, document_id: str):
def delete_document(
self,
*,
index_name: str,
document_id: str,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Delete a document from Compass.
Expand All @@ -211,12 +246,19 @@ def delete_document(self, *, index_name: str, document_id: str):
return self._send_request(
api_name="delete_document",
document_id=document_id,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name=index_name,
)

def get_document(self, *, index_name: str, document_id: str):
def get_document(
self,
*,
index_name: str,
document_id: str,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Get a document from Compass.
Expand All @@ -228,21 +270,25 @@ def get_document(self, *, index_name: str, document_id: str):
return self._send_request(
api_name="get_document",
document_id=document_id,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name=index_name,
)

def list_indexes(self):
def list_indexes(
self,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
List all indexes in Compass.
:returns: the response from the Compass API
"""
return self._send_request(
api_name="list_indexes",
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name="",
)

Expand All @@ -252,8 +298,8 @@ def add_attributes(
index_name: str,
document_id: str,
attributes: DocumentAttributes,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Optional[str]:
"""
Update the content field of an existing document with additional context.
Expand Down Expand Up @@ -284,10 +330,10 @@ def insert_doc(
*,
index_name: str,
doc: CompassDocument,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
authorized_groups: Optional[list[str]] = None,
merge_groups_on_conflict: bool = False,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Optional[list[dict[str, str]]]:
"""
Insert a parsed document into an index in Compass.
Expand Down Expand Up @@ -315,8 +361,8 @@ def upload_document(
content_type: str,
document_id: uuid.UUID,
attributes: DocumentAttributes = DocumentAttributes(),
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Optional[Union[str, dict[str, Any]]]:
"""
Parse and insert a document into an index in Compass.
Expand Down Expand Up @@ -367,13 +413,13 @@ def insert_docs(
docs: Iterator[CompassDocument],
max_chunks_per_request: int = DEFAULT_MAX_CHUNKS_PER_REQUEST,
max_error_rate: float = DEFAULT_MAX_ERROR_RATE,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
errors_sliding_window_size: Optional[int] = 10,
skip_first_n_docs: int = 0,
num_jobs: Optional[int] = None,
authorized_groups: Optional[list[str]] = None,
merge_groups_on_conflict: bool = False,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Optional[list[dict[str, str]]]:
"""
Insert multiple parsed documents into an index in Compass.
Expand Down Expand Up @@ -482,8 +528,8 @@ def create_datasource(
self,
*,
datasource: CreateDataSource,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Union[DataSource, str]:
"""
Create a new datasource in Compass.
Expand All @@ -506,8 +552,8 @@ def create_datasource(
def list_datasources(
self,
*,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Union[PaginatedList[DataSource], str]:
"""
List all datasources in Compass.
Expand All @@ -529,8 +575,8 @@ def get_datasource(
self,
*,
datasource_id: str,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Get a datasource in Compass.
Expand All @@ -554,8 +600,8 @@ def delete_datasource(
self,
*,
datasource_id: str,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Delete a datasource in Compass.
Expand All @@ -579,8 +625,8 @@ def sync_datasource(
self,
*,
datasource_id: str,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
"""
Sync a datasource in Compass.
Expand All @@ -606,8 +652,8 @@ def list_datasources_objects_states(
datasource_id: str,
skip: int = 0,
limit: int = 100,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> Union[PaginatedList[DocumentStatus], str]:
"""
List all objects states in a datasource in Compass.
Expand Down Expand Up @@ -692,6 +738,8 @@ def _search(
query: str,
top_k: int = 10,
filters: Optional[list[SearchFilter]] = None,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
):
return self._send_request(
api_name=api_name,
Expand All @@ -708,6 +756,8 @@ def search_documents(
query: str,
top_k: int = 10,
filters: Optional[list[SearchFilter]] = None,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> SearchDocumentsResponse:
"""
Search documents in an index.
Expand All @@ -725,6 +775,8 @@ def search_documents(
query=query,
top_k=top_k,
filters=filters,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
)

if result.error:
Expand All @@ -739,6 +791,8 @@ def search_chunks(
query: str,
top_k: int = 10,
filters: Optional[list[SearchFilter]] = None,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> SearchChunksResponse:
"""
Search chunks in an index.
Expand All @@ -756,6 +810,8 @@ def search_chunks(
query=query,
top_k=top_k,
filters=filters,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
)

if result.error:
Expand All @@ -764,7 +820,12 @@ def search_chunks(
return SearchChunksResponse.model_validate(result.result)

def update_group_authorization(
self, *, index_name: str, group_auth_input: GroupAuthorizationInput
self,
*,
index_name: str,
group_auth_input: GroupAuthorizationInput,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
) -> PutDocumentsResponse:
"""
Edit group authorization for an index.
Expand All @@ -776,18 +837,19 @@ def update_group_authorization(
api_name="update_group_authorization",
index_name=index_name,
data=group_auth_input,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
)
if result.error:
raise CompassError(result.error)
return PutDocumentsResponse.model_validate(result.result)

def _send_request(
# todo Simplify this method so we don't have to ignore the C901 complexity warning.
def _send_request( # noqa: C901
self,
api_name: str,
max_retries: int,
sleep_retry_seconds: int,
max_retries: Optional[int] = None,
sleep_retry_seconds: Optional[int] = None,
data: Optional[BaseModel] = None,
**url_params: str,
) -> _RetryResult:
Expand All @@ -801,6 +863,14 @@ def _send_request(
:param data: the data to send
:returns: An error message if the request failed, otherwise None.
"""
if not max_retries:
max_retries = self.default_max_retries
if not sleep_retry_seconds:
sleep_retry_seconds = self.default_sleep_retry_seconds
if max_retries < 0:
raise ValueError("max_retries must be a non-negative integer.")
if sleep_retry_seconds < 0:
raise ValueError("sleep_retry_seconds must be a non-negative integer.")

@retry(
stop=stop_after_attempt(max_retries),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "compass-sdk"
version = "0.14.2"
version = "0.15.0"
authors = []
description = "Compass SDK"
readme = "README.md"
Expand Down

0 comments on commit 56d8e03

Please sign in to comment.