-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement grammar sampling for function and parameters names #74
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think basically, we need to handle to logic to detect if the current step is at generating tokens for function or parameters. Can we implement this more elegantly using the approach as I did for streaming ? we have a state (a dictionary containing: current_tokens, current_text, current_function, current_param) and update state --> output a new token.
Then we implement function:
def update_state(current_state, tokens_by_probs, sampled_token) --> new_state, sampled_token
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main points
- Stateful mechanism where the stage in
gen_state
consists of["pre-function", "function", "pre-parameter", "parameter-name", "parameter-value"]
- The mechanism works like a FSM, going through the different stages
- In
step_async
,self.sample
will be called first followed byself.update_gen_state
Additional pointers
- Refactored update_grammar_sampling_gen_state by using template method design pattern instead of abstract method to reduce code repetition. Will change back to abstract method if future prompt template versions do not allow template method pattern.
tokenizer=self.tokenizer, | ||
) | ||
|
||
def sample( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we implement this in prompt_template ? So in the future, we can re-use
this for other framework such as: llama_cpp, HF instead of binding to vllm
def sample_gramma_token(self, tokenizer, state, delta_token_ids, model_sampled_token_id)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha
output[i].samples[-1].output_token = grammar_sampled_token_id | ||
|
||
# Update gen_state | ||
self.update_gen_state( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also include update_gen_state inside def sample, I think it would be more convenient, no need to duplicate the this chunk of code:
if gen_state["stage"] in ["pre-function", "function"]:
options = [
tool_or_func["name"]
for tool_or_func in self.tools_or_functions[request_id]
]
else:
func_name = gen_state["func_name"]
for tool_or_func in self.tools_or_functions[request_id]:
if tool_or_func["name"] == func_name:
options = list(tool_or_func["parameters"]["properties"].keys())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will put this chunk into step_async
before calling prompt_template.grammar_sample()
.
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_stopping_token(self, stage: Literal["function", "parameter"]) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename get_stopping_token to avoid confusion with: def get_stop_tokens_for_generation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha
return self.recipient_token | ||
|
||
def get_stopping_token(self, stage: Literal["function", "parameter"]) -> int: | ||
if stage == "function": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we return string, then use tokenizer to get token_id, so we won't depend on the Model. In the future we might use another model, not Mistral
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same for v1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gtocha
tool_or_func["name"] | ||
for tool_or_func in self.tools_or_functions[request_id] | ||
] | ||
if self.prompt_templates[request_id].version not in ["v1"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add another method in prompt template for getting list of additional predefined function name with default implementation is return []
def get_predefined_function_names() --> List[str]
in base_template:
def get_predefined_function_names():
return []
in template_v2:
def get_predefined_function_names()
return ["all"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha
|
||
# Form the parameter name with the current sampled token id | ||
if len(wellformed_params) == 0: | ||
curr_text = gen_state["curr_text"][ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use curr_text with removing previous components, so we don't have to handle this.
For example, when we enter state=function; curr_text = function_name in progress
when we enter state=parameter --> curr_text=parameter_name in progress
when we enter state=parameter-value --> curr_text=parameter-value in progress
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha
gen_state = self.gen_states[request_id] | ||
|
||
# Form the functions/parameters options | ||
if gen_state["stage"] in ["pre-function", "function"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the chunk of code to get options should be inside: grammar_sample, right ?
we should replace: options
--> tools_or_functions
["request_id"]
So the logic is totally inside prompt template, inference engine only provides the sorted list of tokens
Then we pass options
to update_grammar_sampling_gen_state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in async_llm_engine.py because the prompt_template doesn't store the list of tools/functions mapped from request_id.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can pass: tools_or_funtions
to grammar_sample, replacing options
def grammar_sample(
self,
gen_state: Dict,
tools_or_funtions: List,
delta_token_ids: List,
model_sampled_token_id: int,
tokenizer: Any,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
|
||
def update_grammar_sampling_gen_state( | ||
self, | ||
gen_state: Dict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe should describe fields inside gen_state and their values. For example for stage: what are values and the meaning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will move the explanation from async_llm_engine to here.
wellformed_params = gen_state["param_names"] | ||
|
||
# Form the parameter name with the current sampled token id | ||
new_curr_tokens = gen_state["curr_text"] + tokenizer.decode( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not use this like line 203 ?
new_curr_tokens_id = gen_state["curr_tokens"] + [sampled_token_ind]
new_curr_tokens = tokenizer.decode(new_curr_tokens_id)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
if stage == "function": | ||
return ":" # 28747 | ||
else: | ||
return '":' # 1264 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think it is only: '"' instead of '":' ?
Because in your check:
sampled_token== self.get_stop_token_for_function_parameter(stage="parameter")
I assume that: '=:' is 2 tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for parameter names. The parameters are generated in json format so the models always generate '":' right after it completes a parameter name.
# Future versions are assumed to begin directly with "function" stage | ||
self.gen_states[request_id] = { | ||
"stage": "function" | ||
if self.prompt_templates[request_id].version not in ["v1"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think at first, it is always pre-function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, I think should initialize by an empty dict({}), by this vllm doesn't need to know the structure of state
we will check and initialize inside def grammar_sample
if len(gen_state) == 0:
gen_state = {...}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The start of function token <|recipient|>
is alr provided in v2 prompt so we basically start from "function" stage instead. I will move this initialization into prompt_template too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
_ = json.loads( | ||
'{"' | ||
+ gen_state["param_names"][-1] | ||
+ gen_state["curr_text"].removesuffix(', "') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wonder if:
{"x": 123456}
and 123456 is splitted into: "123", "456", then it will stop at: {"x": 123 ? and remove to new state ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha. Will add somemore checks here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if gen_state["curr_text"].endswith(', "'):
"""Conduct the json.loads operation"""
…ammar-sampling-parameters
): | ||
curr_text = gen_state["curr_text"].rstrip() | ||
while True: | ||
if any([curr_text == option for option in options]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this while loop can be forever? if the function name is not in the list even with grammar sampling? can this happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This while loop will not be forever because the grammar_sample
function already forces the function name to be built towards one of the provided function names including "all". This part just loops from the back to remove unnecessary suffixes until we have the complete function name to put into gen_state["func_name"]
pattern = stop_token + r".*$" | ||
match_res = re.search(pattern, latest_param_str) | ||
match_res = re.search(pattern, latest_param_str, re.DOTALL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is not popular but: stop_token='":' is not general enough; can also be '" :' ? right ?
for example:
{"a" : 10}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this part, it is ok because a parameter name will always be string. So it will always be "parameter-name": {parameter-value}
ONLY
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once "parameter-name":
is detected, we will go to parameter-value stage already.
gen_state["stage"] = "pre-function" | ||
except: | ||
pass | ||
|
||
# Check if the current state can be converted to json, it means the | ||
# new state is back to "parameter-name" stage | ||
pattern = r',.*"$' | ||
if bool(re.search(pattern, gen_state["curr_text"])): | ||
pattern = r',[\n\s]*"' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think \s already includes: "\n"
No description provided.