Skip to content

Commit

Permalink
Merge pull request #272 from MeetKai/jinja-tojson-fix
Browse files Browse the repository at this point in the history
Fix jinja tojson and add edge cases to prompt creation unittests
  • Loading branch information
jeffreymeetkai authored Sep 24, 2024
2 parents b368a11 + 7df905d commit 9afb17b
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 54 deletions.
28 changes: 23 additions & 5 deletions functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,35 @@

from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils
from functionary.schema import generate_schema_from_functions


def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)


def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(
x,
ensure_ascii=ensure_ascii,
indent=indent,
separators=separators,
sort_keys=sort_keys,
)


class PromptTemplate:
_jinja_env = jinja2.Environment()
_jinja_env.policies["json.dumps_kwargs"] = {"sort_keys": False}
_jinja_env.filters["tojson"] = tojson
_jinja_env.globals["raise_exception"] = raise_exception
# Mapping from class --> instance to create singleton instance
_instances = {}

def __init__(self):
self._jinja_template = self._jinja_env.from_string(self.get_chat_template_jinja())
self._jinja_template = self._jinja_env.from_string(
self.get_chat_template_jinja()
)

@abstractmethod
def get_start_of_function_call_token(self) -> str:
Expand Down Expand Up @@ -341,7 +359,7 @@ def get_chat_template_jinja(self) -> str:
json_to_ts_schema = f.read()
with open(f"{path_prefix}{self.version}.txt", "r") as f:
template = f.read()

return (
template[: template.index("{%")]
+ json_to_ts_schema
Expand Down
4 changes: 2 additions & 2 deletions functionary/prompt_template/jinja_templates/v3-llama3.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
{{- "\nYou have access to the following functions:\n\n" }}
{%- for t in tools %}
{%- if "type" in t -%}
{{ "Use the function '"|safe + t["function"]["name"] + "' to '"|safe + t["function"]["description"] + "'\n"|safe + t["function"] | tojson() | safe }}
{{ "Use the function '" + t["function"]["name"] + "' to '" + t["function"]["description"] + "'\n" + t["function"] | tojson() }}
{%- else -%}
{{ "Use the function '"|safe + t["name"] + "' to '"|safe + t["description"] + "'\n"|safe + t | tojson() | safe }}
{{ "Use the function '" + t["name"] + "' to '" + t["description"] + "'\n" + t | tojson }}
{%- endif -%}
{{- "\n\n" }}
{%- endfor %}
Expand Down
13 changes: 6 additions & 7 deletions tests/prompt_test_v2.llama3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ type get_car_price = (_: {
car_name: string,
}) => any;

// get the weather of a location
// This function's purpose is to get the weather of a location
type get_weather = (_: {
// where to get weather.
location: string,
Expand Down Expand Up @@ -47,26 +47,25 @@ name=get_weather

The temperature in Hanoi is: 10 degree Celcious<|eot_id|><|start_header_id|>user<|end_header_id|>

Given the list: 0,1,2,3,4,5 remove the number in the list that is close to 3.6 the most<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I will use code interpreter to handle this<|reserved_special_token_249|>python
l=[0,1,2,3,4,5]
I'll use code interpreter to handle this<|reserved_special_token_249|>python
l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]
l.remove(3.6)<|eot_id|><|start_header_id|>tool<|end_header_id|>

name=python
ValueError: list.remove(x): x not in list<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I will fix the code<|reserved_special_token_249|>python
l=[0,1,2,3,4,5]
l.remove(4)
l.remove(round(3.6))
l<|eot_id|><|start_header_id|>tool<|end_header_id|>

name=python
[0,1,2,3,5,]<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The final list is: 0,1,2,3,5<|eot_id|><|start_header_id|>user<|end_header_id|>

Thanks! What is the weather in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Thanks! What's the weather in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

<|reserved_special_token_249|>get_weather
{"location": "San Francisco, CA"}<|eot_id|><|start_header_id|>tool<|end_header_id|>
Expand Down
13 changes: 6 additions & 7 deletions tests/prompt_test_v2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ type get_car_price = (_: {
car_name: string,
}) => any;

// get the weather of a location
// This function's purpose is to get the weather of a location
type get_weather = (_: {
// where to get weather.
location: string,
Expand Down Expand Up @@ -60,13 +60,13 @@ location: string,
<|content|>The temperature in Hanoi is: 10 degree Celcious<|stop|>
<|from|>user
<|recipient|>all
<|content|>Given the list: 0,1,2,3,4,5 remove the number in the list that is close to 3.6 the most
<|content|>Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most
<|from|>assistant
<|recipient|>all
<|content|>I will use code interpreter to handle this
<|content|>I'll use code interpreter to handle this
<|from|>assistant
<|recipient|>python
<|content|>l=[0,1,2,3,4,5]
<|content|>l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]
l.remove(3.6)<|stop|>
<|from|>python
<|recipient|>all
Expand All @@ -76,8 +76,7 @@ l.remove(3.6)<|stop|>
<|content|>I will fix the code
<|from|>assistant
<|recipient|>python
<|content|>l=[0,1,2,3,4,5]
l.remove(4)
<|content|>l.remove(round(3.6))
l<|stop|>
<|from|>python
<|recipient|>all
Expand All @@ -87,7 +86,7 @@ l<|stop|>
<|content|>The final list is: 0,1,2,3,5<|stop|>
<|from|>user
<|recipient|>all
<|content|>Thanks! What is the weather in San Francisco?
<|content|>Thanks! What's the weather in San Francisco?
<|from|>assistant
<|recipient|>get_weather
<|content|>{"location": "San Francisco, CA"}<|stop|>
Expand Down
13 changes: 6 additions & 7 deletions tests/prompt_test_v3-llama3.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ You have access to the following functions:
Use the function 'get_car_price' to 'Get the price of a particular car model'
{"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}

Use the function 'get_weather' to 'get the weather of a location'
{"name": "get_weather", "description": "get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}
Use the function 'get_weather' to 'This function's purpose is to get the weather of a location'
{"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}


Think very carefully before calling functions.
Expand Down Expand Up @@ -57,22 +57,21 @@ what's the weather like in Hanoi?<|eot_id|><|start_header_id|>assistant<|end_hea

The temperature in Hanoi is: 10 degree Celcious<|eot_id|><|start_header_id|>user<|end_header_id|>

Given the list: 0,1,2,3,4,5 remove the number in the list that is close to 3.6 the most<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I will use code interpreter to handle this<|python_tag|>l=[0,1,2,3,4,5]
I'll use code interpreter to handle this<|python_tag|>l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]
l.remove(3.6)<|eom_id|><|start_header_id|>ipython<|end_header_id|>

ValueError: list.remove(x): x not in list<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I will fix the code<|python_tag|>l=[0,1,2,3,4,5]
l.remove(4)
I will fix the code<|python_tag|>l.remove(round(3.6))
l<|eom_id|><|start_header_id|>ipython<|end_header_id|>

[0,1,2,3,5,]<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The final list is: 0,1,2,3,5<|eot_id|><|start_header_id|>user<|end_header_id|>

Thanks! What is the weather in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Thanks! What's the weather in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

<function=get_weather>{"location": "San Francisco, CA"}</function><|eom_id|><|start_header_id|>ipython<|end_header_id|>

Expand Down
13 changes: 6 additions & 7 deletions tests/prompt_test_v3.llama3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type get_car_price = (_: {
car_name: string,
}) => any;

// get the weather of a location
// This function's purpose is to get the weather of a location
type get_weather = (_: {
// where to get weather.
location: string,
Expand Down Expand Up @@ -56,27 +56,26 @@ what's the weather like in Hanoi?<|eot_id|><|start_header_id|>assistant<|end_hea
>>>all
The temperature in Hanoi is: 10 degree Celcious<|eot_id|><|start_header_id|>user<|end_header_id|>

Given the list: 0,1,2,3,4,5 remove the number in the list that is close to 3.6 the most<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|eot_id|><|start_header_id|>assistant<|end_header_id|>

>>>all
I will use code interpreter to handle this>>>python
l=[0,1,2,3,4,5]
I'll use code interpreter to handle this>>>python
l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]
l.remove(3.6)<|eot_id|><|start_header_id|>tool<|end_header_id|>

ValueError: list.remove(x): x not in list<|eot_id|><|start_header_id|>assistant<|end_header_id|>

>>>all
I will fix the code>>>python
l=[0,1,2,3,4,5]
l.remove(4)
l.remove(round(3.6))
l<|eot_id|><|start_header_id|>tool<|end_header_id|>

[0,1,2,3,5,]<|eot_id|><|start_header_id|>assistant<|end_header_id|>

>>>all
The final list is: 0,1,2,3,5<|eot_id|><|start_header_id|>user<|end_header_id|>

Thanks! What is the weather in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Thanks! What's the weather in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

>>>get_weather
{"location": "San Francisco, CA"}<|eot_id|><|start_header_id|>tool<|end_header_id|>
Expand Down
13 changes: 6 additions & 7 deletions tests/prompt_test_v3.llava_llama.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type get_car_price = (_: {
car_name: string,
}) => any;

// get the weather of a location
// This function's purpose is to get the weather of a location
type get_weather = (_: {
// where to get weather.
location: string,
Expand Down Expand Up @@ -56,27 +56,26 @@ No, the car Tang is less expensive than the car Song. The car Song is priced at
>>>all
The temperature in Hanoi is: 10 degree Celcious<|eot_id|><|start_header_id|>user<|end_header_id|>

Given the list: 0,1,2,3,4,5 remove the number in the list that is close to 3.6 the most<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|eot_id|><|start_header_id|>assistant<|end_header_id|>

>>>all
I will use code interpreter to handle this>>>python
l=[0,1,2,3,4,5]
I'll use code interpreter to handle this>>>python
l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]
l.remove(3.6)<|eot_id|><|start_header_id|>tool<|end_header_id|>

ValueError: list.remove(x): x not in list<|eot_id|><|start_header_id|>assistant<|end_header_id|>

>>>all
I will fix the code>>>python
l=[0,1,2,3,4,5]
l.remove(4)
l.remove(round(3.6))
l<|eot_id|><|start_header_id|>tool<|end_header_id|>

[0,1,2,3,5,]<|eot_id|><|start_header_id|>assistant<|end_header_id|>

>>>all
The final list is: 0,1,2,3,5<|eot_id|><|start_header_id|>user<|end_header_id|>

Thanks! What is the weather in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Thanks! What's the weather in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

>>>get_weather
{"location": "San Francisco, CA"}<|eot_id|><|start_header_id|>tool<|end_header_id|>
Expand Down
12 changes: 6 additions & 6 deletions tests/test_case.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"type": "function",
"function": {
"name": "get_weather",
"description": "get the weather of a location",
"description": "This function's purpose is to get the weather of a location",
"parameters": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -118,16 +118,16 @@
},
{
"role": "user",
"content": "Given the list: 0,1,2,3,4,5 remove the number in the list that is close to 3.6 the most"
"content": "Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most"
},
{
"role": "assistant",
"content": "I will use code interpreter to handle this",
"content": "I'll use code interpreter to handle this",
"tool_calls": [
{
"function": {
"name": "python",
"arguments": "l=[0,1,2,3,4,5]\nl.remove(3.6)"
"arguments": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"
}
}
],
Expand All @@ -147,7 +147,7 @@
{
"function": {
"name": "python",
"arguments": "l=[0,1,2,3,4,5]\nl.remove(4)\nl"
"arguments": "l.remove(round(3.6))\nl"
}
}
]
Expand All @@ -163,7 +163,7 @@
},
{
"role": "user",
"content": "Thanks! What is the weather in San Francisco?"
"content": "Thanks! What's the weather in San Francisco?"
},
{
"role": "assistant",
Expand Down
12 changes: 6 additions & 6 deletions tests/test_case_vision.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"type": "function",
"function": {
"name": "get_weather",
"description": "get the weather of a location",
"description": "This function's purpose is to get the weather of a location",
"parameters": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -125,16 +125,16 @@
},
{
"role": "user",
"content": "Given the list: 0,1,2,3,4,5 remove the number in the list that is close to 3.6 the most"
"content": "Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most"
},
{
"role": "assistant",
"content": "I will use code interpreter to handle this",
"content": "I'll use code interpreter to handle this",
"tool_calls": [
{
"function": {
"name": "python",
"arguments": "l=[0,1,2,3,4,5]\nl.remove(3.6)"
"arguments": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"
}
}
],
Expand All @@ -154,7 +154,7 @@
{
"function": {
"name": "python",
"arguments": "l=[0,1,2,3,4,5]\nl.remove(4)\nl"
"arguments": "l.remove(round(3.6))\nl"
}
}
]
Expand All @@ -170,7 +170,7 @@
},
{
"role": "user",
"content": "Thanks! What is the weather in San Francisco?"
"content": "Thanks! What's the weather in San Francisco?"
},
{
"role": "assistant",
Expand Down

0 comments on commit 9afb17b

Please sign in to comment.