Skip to content

Commit

Permalink
Add the ability to submit system prompts to LLM in benchmarks
Browse files Browse the repository at this point in the history
Summary:
This allows LLMs to be queried with a system prompt if the option to do so is available, which is necessary to test prompt injection most effectively (which assumes that the LLM has been given a hidden system instruciton in most cases).

This should be completely backwards compatible and not affect any of the existing benchmarks or LLM integrations - a `query_with_system_prompt` can be optionally implemented for any LLM if desired and used to test the prompt injection benchmarks in the future.

Reviewed By: csahana95

Differential Revision: D53552178

fbshipit-source-id: 94e9cd43263a925bd9cab0ebe68b4c45fee528ae
  • Loading branch information
cynikolai authored and facebook-github-bot committed Feb 8, 2024
1 parent fadbeed commit c8146d4
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions CybersecurityBenchmarks/benchmark/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import time
from abc import ABC, abstractmethod
from typing import Callable

import openai
from langchain.llms import Together
Expand Down Expand Up @@ -45,17 +46,30 @@ def query(self, prompt: str) -> str:
"""
pass

def query_with_retries(
self, prompt: str, retries: int = NUM_LLM_RETRIES, backoff_factor: float = 0.5
) -> str:
def query_with_system_prompt(self, system_prompt: str, prompt: str) -> str:
"""
Wrapper around query that retries the query if it fails.
Implements exponential backoff.
Abstract method to query an LLM with a given prompt and system prompt and return the response.
Args:
system prompt (str): The system prompt to send to the LLM.
prompt (str): The prompt to send to the LLM.
Returns:
str: The response from the LLM.
"""
return self.query(system_prompt + "\n" + prompt)

def _query_with_retries(
self,
func: Callable[..., str],
*args: str,
retries: int = NUM_LLM_RETRIES,
backoff_factor: float = 0.5,
) -> str:
last_exception = None
for retry in range(retries):
try:
return self.query(prompt)
return func(*args)
except Exception as exception:
last_exception = exception
sleep_time = backoff_factor * (2**retry)
Expand All @@ -67,6 +81,16 @@ def query_with_retries(
f"Unable to query LLM after {retries} retries: {last_exception}"
)

def query_with_retries(self, prompt: str) -> str:
return self._query_with_retries(self.query, prompt)

def query_with_system_prompt_with_retries(
self, system_prompt: str, prompt: str
) -> str:
return self._query_with_retries(
self.query_with_system_prompt, system_prompt, prompt
)

def valid_models(self) -> list[str]:
"""List of valid model parameters, e.g. 'gpt-3.5-turbo' for GPT"""
return []
Expand Down

0 comments on commit c8146d4

Please sign in to comment.