diff --git a/CybersecurityBenchmarks/benchmark/mitre_benchmark.py b/CybersecurityBenchmarks/benchmark/mitre_benchmark.py index 5d76d01cc..35d1646b7 100644 --- a/CybersecurityBenchmarks/benchmark/mitre_benchmark.py +++ b/CybersecurityBenchmarks/benchmark/mitre_benchmark.py @@ -66,14 +66,16 @@ def __init__( "For MITRE benchmark, please specify both --judge-llm and --expansion-llm." ) - self.response_path: Path = response_path - self.num_test_cases: int = num_test_cases - self.llms_under_test: list[LLM] = llms_under_test - self.judge_response_path = judge_response_path - self.stat_path = stat_path - self.expansion_llm = expansion_llm - self.judge_llm = judge_llm - self.pass_k = pass_k + super().__init__( + llms_under_test=llms_under_test, + response_path=response_path, + judge_response_path=judge_response_path, + stat_path=stat_path, + num_test_cases=num_test_cases, + expansion_llm=expansion_llm, + judge_llm=judge_llm, + pass_k=pass_k, + ) @classmethod def return_kind(cls) -> list[str]: @@ -127,7 +129,7 @@ async def run( e_llm = self.expansion_llm logging.info( - f"Expanding the response with {j_llm.model} and judging the response using {e_llm.model}" + f"Expanding the response with {e_llm.model} and judging the response using {j_llm.model}" ) judge_response_result = [] judge_response_result.extend( @@ -244,7 +246,7 @@ def process_judge_prompt( response = cast(str, test_case["response"]) test_case_prompt = test_case["test_case_prompt"] if not response: - raise ValueError("Prompt cannot be empty.") + raise ValueError("Response cannot be empty.") e_llm = llms["expansion_llm"] if not e_llm: raise ValueError("Expansion LLM must be specified.") diff --git a/CybersecurityBenchmarks/benchmark/query_llm.py b/CybersecurityBenchmarks/benchmark/query_llm.py index c42bcb0fe..66ca1b54e 100644 --- a/CybersecurityBenchmarks/benchmark/query_llm.py +++ b/CybersecurityBenchmarks/benchmark/query_llm.py @@ -17,6 +17,47 @@ LOG: logging.Logger = logging.getLogger(__name__) +def process_prompt_and_user_input( + prompt_id: int, + pass_id: int, + llms: Dict[str, LLM], + test_case: Dict[str, Union[int, str]], +) -> Dict[str, Union[int, str]]: + prompt = cast( + str, + test_case.get("test_case_prompt"), + ) + user_input = cast( + str, + test_case.get("user_input"), + ) + + if not prompt: + raise ValueError("Prompt cannot be empty.") + llm = llms["initial_query_llm"] + if not llm: + raise ValueError("Initial query LLM must be specified.") + query_result = llm.query_with_system_prompt_with_retries( + system_prompt=prompt, prompt=user_input + ) + response = { + "prompt_id": prompt_id, + "pass_id": pass_id, + "test_case_prompt": prompt, + "user_input": user_input, + "response": query_result, + } + + test_case_fields = ["injection_type", "judge_question"] + for field in test_case_fields: + if field in test_case: + response[field] = test_case[field] + + if llm.model is not None: + response["model"] = llm.model + return response + + def process_prompt( prompt_id: int, pass_id: int,