-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement grammar sampling for function and parameters names #74
Changes from 31 commits
9043291
05faca3
0909ae7
5aed290
8ffebd7
dd635f2
bd0d8b0
1c71f61
eced962
0e410c0
8439346
bbd0bb5
5a13a80
36cfaa2
2e0faab
d039411
5e39ae2
c824b17
6efcd19
3d5c5c9
c7a98d6
76b80ed
7b51c49
856e415
6274640
543508c
70ba0ef
762b7f1
1911833
174de56
58f49d0
26e71ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import json | ||
import random | ||
import string | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union | ||
|
||
from functionary.prompt_template.base_template import PromptTemplate | ||
|
||
|
@@ -10,6 +11,32 @@ class PromptTemplateV2(PromptTemplate): | |
recipient_token = "<|recipient|>" | ||
content_token = "<|content|>" | ||
stop_token = "<|stop|>" | ||
version = "v2" | ||
# This token splits between function name and parameters | ||
fn_param_sep_token = "\n<|content|> {" | ||
|
||
def get_start_of_function_call_token(self) -> str: | ||
return self.recipient_token | ||
|
||
def get_stop_token_for_function_parameter( | ||
self, stage: Literal["function", "parameter"] | ||
) -> int: | ||
if stage == "function": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we return string, then use tokenizer to get token_id, so we won't depend on the Model. In the future we might use another model, not Mistral There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same for v1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gtocha |
||
return "\n" # 13 | ||
else: | ||
return '":' # 1264 | ||
|
||
def get_predefined_function_names(self) -> List[str]: | ||
return ["all"] | ||
|
||
def initialize_grammar_sampling_gen_state(self) -> Dict: | ||
return { | ||
"stage": "function", | ||
"curr_tokens": [], | ||
"curr_text": "", | ||
"func_name": "", | ||
"param_names": [], | ||
} | ||
|
||
def get_additional_tokens(self) -> List[str]: | ||
return [ | ||
|
@@ -264,7 +291,7 @@ def update_response_state_from_delta_text( | |
"func_name": None, # function_name of the current tool, if the response requires to use tool | ||
"response_type": None, # response_type=text(text response)/function (using tool) | ||
"func_index": -1, # index of the tool in tool_calls | ||
"call_id": None, # call_id of the current tool | ||
"call_id": None, # call_id of the current tool | ||
# skip_until_reach we skip new tokens until we reach certain token. This is used when we hit special tokens | ||
"skip_until_reach": self.content_token, # at first we will skip until reach <|content|> | ||
"first_time": True, # if first_time we return an tempty delta with role=assistant | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think it is only: '"' instead of '":' ?
Because in your check:
sampled_token== self.get_stop_token_for_function_parameter(stage="parameter")
I assume that: '=:' is 2 tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for parameter names. The parameters are generated in json format so the models always generate '":' right after it completes a parameter name.