diff --git a/CybersecurityBenchmarks/benchmark/llm.py b/CybersecurityBenchmarks/benchmark/llm.py index 68ef4c6bc..492d9de57 100644 --- a/CybersecurityBenchmarks/benchmark/llm.py +++ b/CybersecurityBenchmarks/benchmark/llm.py @@ -9,6 +9,7 @@ import time from abc import ABC, abstractmethod +from typing import Callable import openai from langchain.llms import Together @@ -45,17 +46,30 @@ def query(self, prompt: str) -> str: """ pass - def query_with_retries( - self, prompt: str, retries: int = NUM_LLM_RETRIES, backoff_factor: float = 0.5 - ) -> str: + def query_with_system_prompt(self, system_prompt: str, prompt: str) -> str: """ - Wrapper around query that retries the query if it fails. - Implements exponential backoff. + Abstract method to query an LLM with a given prompt and system prompt and return the response. + + Args: + system prompt (str): The system prompt to send to the LLM. + prompt (str): The prompt to send to the LLM. + + Returns: + str: The response from the LLM. """ + return self.query(system_prompt + "\n" + prompt) + + def _query_with_retries( + self, + func: Callable[..., str], + *args: str, + retries: int = NUM_LLM_RETRIES, + backoff_factor: float = 0.5, + ) -> str: last_exception = None for retry in range(retries): try: - return self.query(prompt) + return func(*args) except Exception as exception: last_exception = exception sleep_time = backoff_factor * (2**retry) @@ -67,6 +81,16 @@ def query_with_retries( f"Unable to query LLM after {retries} retries: {last_exception}" ) + def query_with_retries(self, prompt: str) -> str: + return self._query_with_retries(self.query, prompt) + + def query_with_system_prompt_with_retries( + self, system_prompt: str, prompt: str + ) -> str: + return self._query_with_retries( + self.query_with_system_prompt, system_prompt, prompt + ) + def valid_models(self) -> list[str]: """List of valid model parameters, e.g. 'gpt-3.5-turbo' for GPT""" return []