diff --git a/cohere/compass/clients/compass.py b/cohere/compass/clients/compass.py index 7121ae5..322145a 100644 --- a/cohere/compass/clients/compass.py +++ b/cohere/compass/clients/compass.py @@ -94,6 +94,8 @@ def __init__( password: Optional[str] = None, bearer_token: Optional[str] = None, http_session: Optional[requests.Session] = None, + default_max_retries: int = DEFAULT_MAX_RETRIES, + default_sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ): """ Initialize the Compass client. @@ -110,6 +112,15 @@ def __init__( self.session = http_session or requests.Session() self.bearer_token = bearer_token + if default_max_retries < 0: + raise ValueError("default_max_retries must be a non-negative integer.") + if default_sleep_retry_seconds < 0: + raise ValueError( + "default_sleep_retry_seconds must be a non-negative integer." + ) + self.default_max_retries = default_max_retries + self.default_sleep_retry_seconds = default_sleep_retry_seconds + self.api_method = { "create_index": self.session.put, "list_indexes": self.session.get, @@ -154,7 +165,12 @@ def __init__( } def create_index( - self, *, index_name: str, index_config: Optional[IndexConfig] = None + self, + *, + index_name: str, + index_config: Optional[IndexConfig] = None, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ): """ Create an index in Compass. @@ -165,13 +181,19 @@ def create_index( """ return self._send_request( api_name="create_index", - max_retries=DEFAULT_MAX_RETRIES, - sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, index_name=index_name, + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, data=index_config, ) - def refresh_index(self, *, index_name: str): + def refresh_index( + self, + *, + index_name: str, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, + ): """ Refresh index. @@ -180,12 +202,18 @@ def refresh_index(self, *, index_name: str): """ return self._send_request( api_name="refresh", - max_retries=DEFAULT_MAX_RETRIES, - sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, index_name=index_name, ) - def delete_index(self, *, index_name: str): + def delete_index( + self, + *, + index_name: str, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, + ): """ Delete an index from Compass. @@ -194,12 +222,19 @@ def delete_index(self, *, index_name: str): """ return self._send_request( api_name="delete_index", - max_retries=DEFAULT_MAX_RETRIES, - sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, index_name=index_name, ) - def delete_document(self, *, index_name: str, document_id: str): + def delete_document( + self, + *, + index_name: str, + document_id: str, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, + ): """ Delete a document from Compass. @@ -211,12 +246,19 @@ def delete_document(self, *, index_name: str, document_id: str): return self._send_request( api_name="delete_document", document_id=document_id, - max_retries=DEFAULT_MAX_RETRIES, - sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, index_name=index_name, ) - def get_document(self, *, index_name: str, document_id: str): + def get_document( + self, + *, + index_name: str, + document_id: str, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, + ): """ Get a document from Compass. @@ -228,12 +270,16 @@ def get_document(self, *, index_name: str, document_id: str): return self._send_request( api_name="get_document", document_id=document_id, - max_retries=DEFAULT_MAX_RETRIES, - sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, index_name=index_name, ) - def list_indexes(self): + def list_indexes( + self, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, + ): """ List all indexes in Compass. @@ -241,8 +287,8 @@ def list_indexes(self): """ return self._send_request( api_name="list_indexes", - max_retries=DEFAULT_MAX_RETRIES, - sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, index_name="", ) @@ -252,8 +298,8 @@ def add_attributes( index_name: str, document_id: str, attributes: DocumentAttributes, - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> Optional[str]: """ Update the content field of an existing document with additional context. @@ -284,10 +330,10 @@ def insert_doc( *, index_name: str, doc: CompassDocument, - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, authorized_groups: Optional[list[str]] = None, merge_groups_on_conflict: bool = False, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> Optional[list[dict[str, str]]]: """ Insert a parsed document into an index in Compass. @@ -315,8 +361,8 @@ def upload_document( content_type: str, document_id: uuid.UUID, attributes: DocumentAttributes = DocumentAttributes(), - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> Optional[Union[str, dict[str, Any]]]: """ Parse and insert a document into an index in Compass. @@ -367,13 +413,13 @@ def insert_docs( docs: Iterator[CompassDocument], max_chunks_per_request: int = DEFAULT_MAX_CHUNKS_PER_REQUEST, max_error_rate: float = DEFAULT_MAX_ERROR_RATE, - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, errors_sliding_window_size: Optional[int] = 10, skip_first_n_docs: int = 0, num_jobs: Optional[int] = None, authorized_groups: Optional[list[str]] = None, merge_groups_on_conflict: bool = False, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> Optional[list[dict[str, str]]]: """ Insert multiple parsed documents into an index in Compass. @@ -482,8 +528,8 @@ def create_datasource( self, *, datasource: CreateDataSource, - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> Union[DataSource, str]: """ Create a new datasource in Compass. @@ -506,8 +552,8 @@ def create_datasource( def list_datasources( self, *, - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> Union[PaginatedList[DataSource], str]: """ List all datasources in Compass. @@ -529,8 +575,8 @@ def get_datasource( self, *, datasource_id: str, - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ): """ Get a datasource in Compass. @@ -554,8 +600,8 @@ def delete_datasource( self, *, datasource_id: str, - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ): """ Delete a datasource in Compass. @@ -579,8 +625,8 @@ def sync_datasource( self, *, datasource_id: str, - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ): """ Sync a datasource in Compass. @@ -606,8 +652,8 @@ def list_datasources_objects_states( datasource_id: str, skip: int = 0, limit: int = 100, - max_retries: int = DEFAULT_MAX_RETRIES, - sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> Union[PaginatedList[DocumentStatus], str]: """ List all objects states in a datasource in Compass. @@ -692,6 +738,8 @@ def _search( query: str, top_k: int = 10, filters: Optional[list[SearchFilter]] = None, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ): return self._send_request( api_name=api_name, @@ -708,6 +756,8 @@ def search_documents( query: str, top_k: int = 10, filters: Optional[list[SearchFilter]] = None, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> SearchDocumentsResponse: """ Search documents in an index. @@ -725,6 +775,8 @@ def search_documents( query=query, top_k=top_k, filters=filters, + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, ) if result.error: @@ -739,6 +791,8 @@ def search_chunks( query: str, top_k: int = 10, filters: Optional[list[SearchFilter]] = None, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> SearchChunksResponse: """ Search chunks in an index. @@ -756,6 +810,8 @@ def search_chunks( query=query, top_k=top_k, filters=filters, + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, ) if result.error: @@ -764,7 +820,12 @@ def search_chunks( return SearchChunksResponse.model_validate(result.result) def update_group_authorization( - self, *, index_name: str, group_auth_input: GroupAuthorizationInput + self, + *, + index_name: str, + group_auth_input: GroupAuthorizationInput, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, ) -> PutDocumentsResponse: """ Edit group authorization for an index. @@ -776,18 +837,19 @@ def update_group_authorization( api_name="update_group_authorization", index_name=index_name, data=group_auth_input, - max_retries=DEFAULT_MAX_RETRIES, - sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, ) if result.error: raise CompassError(result.error) return PutDocumentsResponse.model_validate(result.result) - def _send_request( + # todo Simplify this method so we don't have to ignore the C901 complexity warning. + def _send_request( # noqa: C901 self, api_name: str, - max_retries: int, - sleep_retry_seconds: int, + max_retries: Optional[int] = None, + sleep_retry_seconds: Optional[int] = None, data: Optional[BaseModel] = None, **url_params: str, ) -> _RetryResult: @@ -801,6 +863,14 @@ def _send_request( :param data: the data to send :returns: An error message if the request failed, otherwise None. """ + if not max_retries: + max_retries = self.default_max_retries + if not sleep_retry_seconds: + sleep_retry_seconds = self.default_sleep_retry_seconds + if max_retries < 0: + raise ValueError("max_retries must be a non-negative integer.") + if sleep_retry_seconds < 0: + raise ValueError("sleep_retry_seconds must be a non-negative integer.") @retry( stop=stop_after_attempt(max_retries), diff --git a/pyproject.toml b/pyproject.toml index 4d6362a..f5128bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "compass-sdk" -version = "0.14.2" +version = "0.15.0" authors = [] description = "Compass SDK" readme = "README.md"