Skip to content

Commit

Permalink
feat: Add OpenAIError to retry mechanism (#4178)
Browse files Browse the repository at this point in the history
* Add OpenAIError to retry mechanism. Use env variable for timeout for OpenAI request in PromptNode.

* Updated retry in OpenAI embedding encoder as well.

* Empty commit
  • Loading branch information
sjrl authored Feb 17, 2023
1 parent 7eeb3e0 commit 44509cd
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
4 changes: 3 additions & 1 deletion haystack/nodes/answer_generator/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def __init__(
logger.debug("Using GPT2TokenizerFast")
self._hf_tokenizer: PreTrainedTokenizerFast = GPT2TokenizerFast.from_pretrained(tokenizer)

@retry_with_exponential_backoff(backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES)
@retry_with_exponential_backoff(
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
)
def predict(
self,
query: str,
Expand Down
17 changes: 14 additions & 3 deletions haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES

from haystack import MultiLabel
from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
from haystack.environment import (
HAYSTACK_REMOTE_API_BACKOFF_SEC,
HAYSTACK_REMOTE_API_MAX_RETRIES,
HAYSTACK_REMOTE_API_TIMEOUT_SEC,
)
from haystack.errors import OpenAIError, OpenAIRateLimitError
from haystack.modeling.utils import initialize_device_settings
from haystack.nodes.base import BaseComponent
Expand Down Expand Up @@ -435,8 +439,9 @@ def __init__(
}

@retry_with_exponential_backoff(
backoff_in_seconds=int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 5)),
backoff_in_seconds=float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 5)),
max_retries=int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)),
errors=(OpenAIRateLimitError, OpenAIError),
)
def invoke(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -478,7 +483,13 @@ def invoke(self, *args, **kwargs):
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30)
response = requests.request(
"POST",
self.url,
headers=headers,
data=json.dumps(payload),
timeout=float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30)),
)
res = json.loads(response.text)

if response.status_code != 200:
Expand Down
4 changes: 3 additions & 1 deletion haystack/nodes/retriever/_openai_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def _ensure_text_limit(self, text: str) -> str:

return decoded_string

@retry_with_exponential_backoff(backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES)
@retry_with_exponential_backoff(
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
)
def embed(self, model: str, text: List[str]) -> np.ndarray:
payload = {"model": model, "input": text}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
Expand Down

0 comments on commit 44509cd

Please sign in to comment.