From 1f39f9258255d8285b5fe42c086faaa6fdc00f17 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 5 Feb 2025 20:34:25 -0800 Subject: [PATCH] Update PythonListCustomToolGenerator to support overriding system prompt (#271) Summary: Support user supplied prompt template in PythonListCustomToolGenerator. This is to allow user to provided their own system prompt without having to format function descirptions. Test Plan: python -m unittest llama_models.llama3.tests.prompt_templates.test_system_prompts --- .../llama3/prompt_templates/system_prompts.py | 35 ++++++++++---- .../prompt_templates/test_system_prompts.py | 46 +++++++++++++++++++ 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/models/llama3/prompt_templates/system_prompts.py b/models/llama3/prompt_templates/system_prompts.py index b1714412..eb5de8f9 100644 --- a/models/llama3/prompt_templates/system_prompts.py +++ b/models/llama3/prompt_templates/system_prompts.py @@ -7,7 +7,7 @@ import textwrap from datetime import datetime -from typing import Any, List +from typing import Any, List, Optional from llama_models.llama3.api.datatypes import ( BuiltinTool, @@ -215,14 +215,33 @@ def data_examples(self) -> List[List[ToolDefinition]]: class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 - def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + DEFAULT_PROMPT = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + {{ function_description }} + """.strip( + "\n" + ) + ) + + def gen( + self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None + ) -> PromptTemplate: + system_prompt = system_prompt or self.DEFAULT_PROMPT + return PromptTemplate( + system_prompt, + {"function_description": self._gen_function_description(custom_tools)}, + ) + + def _gen_function_description( + self, custom_tools: List[ToolDefinition] + ) -> PromptTemplate: template_str = textwrap.dedent( """ - You are an expert in composing functions. You are given a question and a set of possible functions. - Based on the question, you will need to make one or more function/tool calls to achieve the purpose. - If none of the function can be used, point it out. If the given question lacks the parameters required by the function, - also point it out. You should only return the function call in tools call sections. - If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] You SHOULD NOT include any other text in the response. @@ -263,7 +282,7 @@ def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: return PromptTemplate( template_str.strip("\n"), {"tools": [t.model_dump() for t in custom_tools]}, - ) + ).render() def data_examples(self) -> List[List[ToolDefinition]]: return [ diff --git a/models/llama3/tests/prompt_templates/test_system_prompts.py b/models/llama3/tests/prompt_templates/test_system_prompts.py index 615e3ea8..5f0ebcf9 100644 --- a/models/llama3/tests/prompt_templates/test_system_prompts.py +++ b/models/llama3/tests/prompt_templates/test_system_prompts.py @@ -145,3 +145,49 @@ def test_llama_3_2_system_zero_shot(self): """ ) self.check_generator_output(generator, expected_text.strip("\n")) + + def test_llama_3_2_provided_system_prompt(self): + generator = PythonListCustomToolGenerator() + expected_text = textwrap.dedent( + """ + Overriding message. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": ["city"], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ]""" + ) + user_system_prompt = textwrap.dedent( + """ + Overriding message. + + {{ function_description }} + """ + ) + example = generator.data_examples()[0] + + pt = generator.gen(example, user_system_prompt) + text = pt.render() + assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"