Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement grammar sampling for function and parameters names #74

Merged
merged 32 commits into from
Dec 27, 2023

Conversation

jeffreymeetkai
Copy link
Collaborator

No description provided.

Copy link
Collaborator

@khai-meetkai khai-meetkai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think basically, we need to handle to logic to detect if the current step is at generating tokens for function or parameters. Can we implement this more elegantly using the approach as I did for streaming ? we have a state (a dictionary containing: current_tokens, current_text, current_function, current_param) and update state --> output a new token.
Then we implement function:
def update_state(current_state, tokens_by_probs, sampled_token) --> new_state, sampled_token

@jeffreymeetkai jeffreymeetkai changed the title Extend grammar sampling to parameters names Implement grammar sampling for function and parameters names Dec 11, 2023
Copy link
Collaborator Author

@jeffreymeetkai jeffreymeetkai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main points

  • Stateful mechanism where the stage in gen_state consists of ["pre-function", "function", "pre-parameter", "parameter-name", "parameter-value"]
  • The mechanism works like a FSM, going through the different stages
  • In step_async, self.sample will be called first followed by self.update_gen_state

Additional pointers

  • Refactored update_grammar_sampling_gen_state by using template method design pattern instead of abstract method to reduce code repetition. Will change back to abstract method if future prompt template versions do not allow template method pattern.

tokenizer=self.tokenizer,
)

def sample(
Copy link
Collaborator

@khai-meetkai khai-meetkai Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we implement this in prompt_template ? So in the future, we can re-use this for other framework such as: llama_cpp, HF instead of binding to vllm
def sample_gramma_token(self, tokenizer, state, delta_token_ids, model_sampled_token_id)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha

output[i].samples[-1].output_token = grammar_sampled_token_id

# Update gen_state
self.update_gen_state(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also include update_gen_state inside def sample, I think it would be more convenient, no need to duplicate the this chunk of code:

if gen_state["stage"] in ["pre-function", "function"]:
            options = [
                tool_or_func["name"]
                for tool_or_func in self.tools_or_functions[request_id]
            ]
else:
    func_name = gen_state["func_name"]
    for tool_or_func in self.tools_or_functions[request_id]:
        if tool_or_func["name"] == func_name:
            options = list(tool_or_func["parameters"]["properties"].keys())

Copy link
Collaborator Author

@jeffreymeetkai jeffreymeetkai Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will put this chunk into step_async before calling prompt_template.grammar_sample().

raise NotImplementedError

@abstractmethod
def get_stopping_token(self, stage: Literal["function", "parameter"]) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename get_stopping_token to avoid confusion with: def get_stop_tokens_for_generation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha

return self.recipient_token

def get_stopping_token(self, stage: Literal["function", "parameter"]) -> int:
if stage == "function":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we return string, then use tokenizer to get token_id, so we won't depend on the Model. In the future we might use another model, not Mistral

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same for v1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gtocha

tool_or_func["name"]
for tool_or_func in self.tools_or_functions[request_id]
]
if self.prompt_templates[request_id].version not in ["v1"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add another method in prompt template for getting list of additional predefined function name with default implementation is return []

def get_predefined_function_names() --> List[str]
in base_template:
def get_predefined_function_names():
return []
in template_v2:
def get_predefined_function_names()
return ["all"]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha


# Form the parameter name with the current sampled token id
if len(wellformed_params) == 0:
curr_text = gen_state["curr_text"][
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use curr_text with removing previous components, so we don't have to handle this.
For example, when we enter state=function; curr_text = function_name in progress
when we enter state=parameter --> curr_text=parameter_name in progress
when we enter state=parameter-value --> curr_text=parameter-value in progress

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha

gen_state = self.gen_states[request_id]

# Form the functions/parameters options
if gen_state["stage"] in ["pre-function", "function"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the chunk of code to get options should be inside: grammar_sample, right ?
we should replace: options --> tools_or_functions["request_id"]
So the logic is totally inside prompt template, inference engine only provides the sorted list of tokens

Then we pass options to update_grammar_sampling_gen_state

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in async_llm_engine.py because the prompt_template doesn't store the list of tools/functions mapped from request_id.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can pass: tools_or_funtions to grammar_sample, replacing options

def grammar_sample(
        self,
        gen_state: Dict,
        tools_or_funtions: List,
        delta_token_ids: List,
        model_sampled_token_id: int,
        tokenizer: Any,
    )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok


def update_grammar_sampling_gen_state(
self,
gen_state: Dict,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe should describe fields inside gen_state and their values. For example for stage: what are values and the meaning

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will move the explanation from async_llm_engine to here.

wellformed_params = gen_state["param_names"]

# Form the parameter name with the current sampled token id
new_curr_tokens = gen_state["curr_text"] + tokenizer.decode(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use this like line 203 ?

new_curr_tokens_id = gen_state["curr_tokens"] + [sampled_token_ind]
new_curr_tokens = tokenizer.decode(new_curr_tokens_id)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

if stage == "function":
return ":" # 28747
else:
return '":' # 1264
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think it is only: '"' instead of '":' ?
Because in your check:
sampled_token== self.get_stop_token_for_function_parameter(stage="parameter")
I assume that: '=:' is 2 tokens?

Copy link
Collaborator Author

@jeffreymeetkai jeffreymeetkai Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for parameter names. The parameters are generated in json format so the models always generate '":' right after it completes a parameter name.

# Future versions are assumed to begin directly with "function" stage
self.gen_states[request_id] = {
"stage": "function"
if self.prompt_templates[request_id].version not in ["v1"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think at first, it is always pre-function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I think should initialize by an empty dict({}), by this vllm doesn't need to know the structure of state
we will check and initialize inside def grammar_sample
if len(gen_state) == 0:
gen_state = {...}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The start of function token <|recipient|> is alr provided in v2 prompt so we basically start from "function" stage instead. I will move this initialization into prompt_template too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

_ = json.loads(
'{"'
+ gen_state["param_names"][-1]
+ gen_state["curr_text"].removesuffix(', "')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wonder if:
{"x": 123456}
and 123456 is splitted into: "123", "456", then it will stop at: {"x": 123 ? and remove to new state ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha. Will add somemore checks here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if gen_state["curr_text"].endswith(', "'):
    """Conduct the json.loads operation"""

):
curr_text = gen_state["curr_text"].rstrip()
while True:
if any([curr_text == option for option in options]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this while loop can be forever? if the function name is not in the list even with grammar sampling? can this happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This while loop will not be forever because the grammar_sample function already forces the function name to be built towards one of the provided function names including "all". This part just loops from the back to remove unnecessary suffixes until we have the complete function name to put into gen_state["func_name"]

pattern = stop_token + r".*$"
match_res = re.search(pattern, latest_param_str)
match_res = re.search(pattern, latest_param_str, re.DOTALL)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is not popular but: stop_token='":' is not general enough; can also be '" :' ? right ?
for example:
{"a" : 10}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this part, it is ok because a parameter name will always be string. So it will always be "parameter-name": {parameter-value} ONLY

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once "parameter-name": is detected, we will go to parameter-value stage already.

gen_state["stage"] = "pre-function"
except:
pass

# Check if the current state can be converted to json, it means the
# new state is back to "parameter-name" stage
pattern = r',.*"$'
if bool(re.search(pattern, gen_state["curr_text"])):
pattern = r',[\n\s]*"'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think \s already includes: "\n"

@jeffreymeetkai jeffreymeetkai merged commit 1a4fd68 into main Dec 27, 2023
3 checks passed
@jeffreymeetkai jeffreymeetkai deleted the grammar-sampling-parameters branch November 6, 2024 03:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants