diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py index 003812a914..59bd0429e6 100644 --- a/haystack/nodes/answer_generator/openai.py +++ b/haystack/nodes/answer_generator/openai.py @@ -37,8 +37,6 @@ OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30)) -OPENAI_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10)) -OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)) class OpenAIAnswerGenerator(BaseGenerator): @@ -55,12 +53,14 @@ def __init__( api_key: str, model: str = "text-davinci-003", max_tokens: int = 50, - top_k: int = 5, + top_k: int = 1, temperature: float = 0.2, presence_penalty: float = 0.1, frequency_penalty: float = 0.1, examples_context: Optional[str] = None, examples: Optional[List[List[str]]] = None, + instructions: Optional[str] = None, + add_runtime_instructions: bool = False, stop_words: Optional[List[str]] = None, progress_bar: bool = True, prompt_template: Optional[PromptTemplate] = None, @@ -87,13 +87,16 @@ def __init__( [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) :param examples_context: A text snippet containing the contextual information used to generate the Answers for the examples you provide. - If not supplied, the default from OpenAI API docs is used: - `"In 2017, U.S. life expectancy was 78.6 years."` :param examples: List of (question, answer) pairs that helps steer the model towards the tone and answer - format you'd like. We recommend adding 2 to 3 examples. - If not supplied, the default from OpenAI API docs is used: - `[["Q: What is human life expectancy in the United States?", "A: 78 years."]]` - :param stop_words: Up to four sequences where the API stops generating further tokens. The returned text does + format you'd like. + :param instructions: Here you can initialize custom instructions as prompt. Defaults to 'Create a concise and informative answer...' + :param add_runtime_instructions: If you like to add the prompt instructions (the instructions around the question) + during querying or not. Defaults to using predefined prompt instructions. + If you do add instructions at runtime separate instructions and question like: + "... ... [SEPARATOR] " + Also make sure to mention "$documents" and "$query" in the , such + that those will be replaced in correctly. + :param stop_words: Up to 4 sequences where the API stops generating further tokens. The returned text does not contain the stop sequence. If you don't provide any stop words, the default value from OpenAI API docs is used: `["\n", "<|endoftext|>"]`. :param prompt_template: A PromptTemplate that tells the model how to generate answers given a @@ -128,13 +131,28 @@ def __init__( if stop_words is None: stop_words = ["\n", "<|endoftext|>"] if prompt_template is None: - prompt_template = PromptTemplate( - name="question-answering-with-examples", - prompt_text="Please answer the question according to the above context." - "\n===\nContext: $examples_context\n===\n$examples\n\n" - "===\nContext: $context\n===\n$query", - prompt_params=["examples_context", "examples", "context", "query"], - ) + if instructions: + prompt_template = PromptTemplate( + name="custom", + prompt_text=f"{instructions}\n" + "\n===\nContext: $examples_context\n===\n$examples\n\n" + "===\nContext: $context\n===\n$query", + prompt_params=["examples_context", "examples", "context", "query"], + ) + else: + prompt_template = PromptTemplate( + name="question-answering-with-examples", + prompt_text=f"Create a concise and informative answer (no more than {max_tokens} words) for a given " + f"question based solely on the given documents. You must only use information from the given " + f"documents. Use an unbiased and journalistic tone. Do not repeat text. Cite the documents " + f"using Document[$number] notation. If multiple documents contain the answer, cite " + f"each document like Document[$number], Document[$number], Document[$number] ... If " + f"the documents do not contain the answer to the question, say that " + f"'answering is not possible given the available information.'\n " + "\n===\nContext: $examples_context\n===\n$examples\n\n" + "===\nContext: $context\n===\n$query", + prompt_params=["examples_context", "examples", "context", "query"], + ) else: # Check for required prompts required_params = ["context", "query"] @@ -167,6 +185,7 @@ def __init__( self.frequency_penalty = frequency_penalty self.examples_context = examples_context self.examples = examples + self.add_runtime_instructions = add_runtime_instructions self.stop_words = stop_words self.prompt_template = prompt_template self.context_join_str = context_join_str @@ -186,7 +205,11 @@ def __init__( logger.debug("Using GPT2TokenizerFast") self._hf_tokenizer: PreTrainedTokenizerFast = GPT2TokenizerFast.from_pretrained(tokenizer) - @retry_with_exponential_backoff(backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES) + @retry_with_exponential_backoff( + backoff_in_seconds=int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 1)), + max_retries=int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)), + errors=(OpenAIRateLimitError, OpenAIError), + ) def predict( self, query: str, @@ -275,7 +298,7 @@ def predict( @staticmethod def _create_context(documents: List[Document], join_str: str = " ") -> str: """Join the documents to create a single context to be used in the PromptTemplate.""" - doc_contents = [doc.content for doc in documents] + doc_contents = [OpenAIAnswerGenerator._clean_documents(doc.content) for doc in documents] # We reverse the docs to put the most relevant documents at the bottom of the context context = join_str.join(reversed(doc_contents)) return context @@ -292,7 +315,27 @@ def _fill_prompt(self, query: str, documents: List[Document]) -> str: ): kwargs["examples_context"] = self.examples_context kwargs["examples"] = example_prompts - full_prompt = next(self.prompt_template.fill(**kwargs)) + prompt_template = self.prompt_template + # Switch for adding the prompt instructions at runtime. + if self.add_runtime_instructions: + temp = query.split("[SEPARATOR]") + if len(temp) != 2: + logger.error( + "Instructions given to the OpenAIAnswerGenerator were not correct, please follow the structure " + "from the docstrings. You supplied: %s", + query, + ) + prompt_template = PromptTemplate(name="custom", prompt_text="$query", prompt_params=["query"]) + kwargs["query"] = "Say: incorrect prompt." + else: + current_prompt = temp[0].strip() + prompt_template = PromptTemplate( + name="custom", + prompt_text=f"{current_prompt}\n" "\n===\nContext: $context\n===\n$query", + prompt_params=["context", "query"], + ) + kwargs["query"] = temp[1].strip() + full_prompt = next(prompt_template.fill(**kwargs)) return full_prompt def _build_prompt_within_max_length(self, query: str, documents: List[Document]) -> Tuple[str, List[Document]]: @@ -352,3 +395,10 @@ def _count_tokens(self, text: str) -> int: return len(self._tk_tokenizer.encode(text)) else: return len(self._hf_tokenizer.tokenize(text)) + + @staticmethod + def _clean_documents(text: str) -> str: + to_remove = {"$documents": "#documents", "$query": "#query", "\n": " "} + for key, val in to_remove.items(): + text = text.replace(key, val) + return text diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index 3cff64f1fd..101119ab96 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -437,6 +437,7 @@ def __init__( @retry_with_exponential_backoff( backoff_in_seconds=int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 5)), max_retries=int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)), + errors=(OpenAIRateLimitError, OpenAIError), ) def invoke(self, *args, **kwargs): """