Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement grammar sampling for function and parameters names #74

Merged
merged 32 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9043291
set up grammar sampling for functions, stress test batched inference
jeffreymeetkai Dec 7, 2023
05faca3
add detailed documentation
jeffreymeetkai Dec 7, 2023
0909ae7
make vllm mkpatch backward-compatible with prompt template v1
jeffreymeetkai Dec 7, 2023
5aed290
extend grammar sampling to params wip
jeffreymeetkai Dec 8, 2023
8ffebd7
fix check_to_sample for params
jeffreymeetkai Dec 9, 2023
dd635f2
extend grammar sampling to params
jeffreymeetkai Dec 11, 2023
bd0d8b0
pull from main
jeffreymeetkai Dec 11, 2023
1c71f61
generalize prompt_template versions to integrate new versions easier
jeffreymeetkai Dec 11, 2023
eced962
revamp implementation for template v2
jeffreymeetkai Dec 12, 2023
0e410c0
fixes to v2
jeffreymeetkai Dec 12, 2023
8439346
revamp implementation foor template v1
jeffreymeetkai Dec 12, 2023
bbd0bb5
Add documentation
jeffreymeetkai Dec 12, 2023
5a13a80
refactor update_grammar_sampling_gen_state
jeffreymeetkai Dec 12, 2023
36cfaa2
refactor based on comments
jeffreymeetkai Dec 13, 2023
2e0faab
refactor based on comments
jeffreymeetkai Dec 13, 2023
d039411
minor edit based on comment
jeffreymeetkai Dec 14, 2023
5e39ae2
fix delta_token_ids_by_logprobs
jeffreymeetkai Dec 14, 2023
c824b17
resolve merge conflict
jeffreymeetkai Dec 14, 2023
6efcd19
minor edit
jeffreymeetkai Dec 14, 2023
3d5c5c9
grammar sampling for pre-parameter stage
jeffreymeetkai Dec 15, 2023
c7a98d6
add no-function-call stage
jeffreymeetkai Dec 18, 2023
76b80ed
handle no args
jeffreymeetkai Dec 18, 2023
7b51c49
minor edit
jeffreymeetkai Dec 18, 2023
856e415
fixes
jeffreymeetkai Dec 19, 2023
6274640
Merge branch 'main' of https://github.com/MeetKai/functionary into gr…
jeffreymeetkai Dec 19, 2023
543508c
make parameter-name and value parsing compatible to all formats
jeffreymeetkai Dec 21, 2023
70ba0ef
resolve merge conflict
jeffreymeetkai Dec 21, 2023
762b7f1
edits
jeffreymeetkai Dec 21, 2023
1911833
upgrade vllm dependency
jeffreymeetkai Dec 26, 2023
174de56
resolve merge conflict
jeffreymeetkai Dec 26, 2023
58f49d0
make grammar sampling more flexible
jeffreymeetkai Dec 26, 2023
26e71ea
minor edit based on comments
jeffrey-fong Dec 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true
"source.organizeImports": "explicit"
},
},
"isort.args":["--profile", "black"],
Expand Down
17 changes: 11 additions & 6 deletions functionary/inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import List, Optional

import torch
from transformers import (LlamaForCausalLM, LlamaTokenizer, StoppingCriteria,
StoppingCriteriaList)
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)

from functionary.openai_types import ChatMessage, Function, FunctionCall, Tool
from functionary.prompt_template import (PromptTemplate,
get_prompt_template_from_tokenizer)
from functionary.prompt_template import (
PromptTemplate,
get_prompt_template_from_tokenizer,
)


class StopWordsCriteria(StoppingCriteria):
Expand Down Expand Up @@ -37,9 +43,8 @@ def prepare_messages_for_inference(
tools: Optional[List[Tool]] = None,
device="cuda:0",
) -> torch.Tensor:

prompt_template = get_prompt_template_from_tokenizer(tokenizer)

dic_messages = [mess.dict() for mess in messages]
dic_messages.append({"role": "assistant"})

Expand Down
392 changes: 391 additions & 1 deletion functionary/prompt_template/base_template.py

Large diffs are not rendered by default.

27 changes: 25 additions & 2 deletions functionary/prompt_template/prompt_template_v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from functionary.prompt_template.base_template import PromptTemplate

Expand All @@ -10,6 +11,9 @@ class PromptTemplateV1(PromptTemplate):
end_assistant = "<|END_OF_ASSISTANT|>"
end_function = "<|END_OF_FUNCTION_RESULT|>"
end_function_call = "<|END_OF_FUNCTION_CALL|>"
version = "v1"
# This token splits between function name and parameters
fn_param_sep_token = ":\n{"

def get_end_token_from_message(self, message: Dict) -> str:
"""this function is used for getting the end token for each message.
Expand All @@ -36,6 +40,26 @@ def get_end_token_from_message(self, message: Dict) -> str:
else:
return self.end_assistant

def get_start_of_function_call_token(self) -> str:
return self.start_function

def get_stop_token_for_function_parameter(
self, stage: Literal["function", "parameter"]
) -> int:
if stage == "function":
return ":" # 28747
else:
return '":' # 1264
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think it is only: '"' instead of '":' ?
Because in your check:
sampled_token== self.get_stop_token_for_function_parameter(stage="parameter")
I assume that: '=:' is 2 tokens?

Copy link
Collaborator Author

@jeffreymeetkai jeffreymeetkai Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for parameter names. The parameters are generated in json format so the models always generate '":' right after it completes a parameter name.


def initialize_grammar_sampling_gen_state(self) -> Dict:
return {
"stage": "pre-function",
"curr_tokens": [],
"curr_text": "",
"func_name": "",
"param_names": [],
}

def get_additional_tokens(self) -> List[str]:
return [
self.start_function,
Expand Down Expand Up @@ -234,4 +258,3 @@ def get_chat_template_jinja(self) -> str:
chat_template = chat_template.replace("<br>\n", "")
chat_template = chat_template.strip()
return chat_template

31 changes: 29 additions & 2 deletions functionary/prompt_template/prompt_template_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import random
import string
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from functionary.prompt_template.base_template import PromptTemplate

Expand All @@ -10,6 +11,32 @@ class PromptTemplateV2(PromptTemplate):
recipient_token = "<|recipient|>"
content_token = "<|content|>"
stop_token = "<|stop|>"
version = "v2"
# This token splits between function name and parameters
fn_param_sep_token = "\n<|content|> {"

def get_start_of_function_call_token(self) -> str:
return self.recipient_token

def get_stop_token_for_function_parameter(
self, stage: Literal["function", "parameter"]
) -> int:
if stage == "function":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we return string, then use tokenizer to get token_id, so we won't depend on the Model. In the future we might use another model, not Mistral

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same for v1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gtocha

return "\n" # 13
else:
return '":' # 1264

def get_predefined_function_names(self) -> List[str]:
return ["all"]

def initialize_grammar_sampling_gen_state(self) -> Dict:
return {
"stage": "function",
"curr_tokens": [],
"curr_text": "",
"func_name": "",
"param_names": [],
}

def get_additional_tokens(self) -> List[str]:
return [
Expand Down Expand Up @@ -264,7 +291,7 @@ def update_response_state_from_delta_text(
"func_name": None, # function_name of the current tool, if the response requires to use tool
"response_type": None, # response_type=text(text response)/function (using tool)
"func_index": -1, # index of the tool in tool_calls
"call_id": None, # call_id of the current tool
"call_id": None, # call_id of the current tool
# skip_until_reach we skip new tokens until we reach certain token. This is used when we hit special tokens
"skip_until_reach": self.content_token, # at first we will skip until reach <|content|>
"first_time": True, # if first_time we return an tempty delta with role=assistant
Expand Down
Loading