Skip to content

Commit

Permalink
update chat_template and add chat_template to training
Browse files Browse the repository at this point in the history
  • Loading branch information
khai-meetkai committed Nov 24, 2023
1 parent b1d26b4 commit 4a8845f
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 29 deletions.
153 changes: 135 additions & 18 deletions functionary/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,69 @@ class PromptTemplate:

@abstractmethod
def get_additional_tokens(self) -> List[str]:
"""return list of added tokens if using this template
Returns:
List[str]: list of tokens, each token is a string
"""
raise NotImplementedError

@abstractmethod
def get_text_from_message(self, message: Dict) -> str:
"""Return the prompt of this message
Args:
message (Dict): Dictionary of openAI format
Returns:
str: prompt of this message
"""
raise NotImplementedError

@abstractmethod
def get_stop_tokens_for_generation(self) -> List[str]:
"""Function to get list of stop tokens in generation
Returns:
List[str]: list of stop tokens
"""
raise NotImplementedError

@abstractmethod
def get_assistant_prefixes(self) -> List[str]:
"""Return the assistant prefixs in the final prompt, this is used for masking the labels
in unmasking labels, the system will unmask chunks that start with assistant prefixs and end with stop tokens.
For example, assistant_prefixes might be: "<|from|>assistant\n<|recipient|>"
In this case unmasked chunks in labels would be tokens in ... of: <|from|>assistant\n<|recipient|> ... <|stop|>
Returns:
List[str]: list of possible assistant prefixs
"""
raise NotImplementedError

def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]:
"""This function is used if we need to process messages before doing inference.
This is used when the messages in training and inference are different.
For example, in training we have no: tool_call_id, but in inference, we have tool_call_id to know the order of function calls.
This function woule be called to convert inference messages to the format of training messages.
Args:
messages (List[Dict]): list of input messages
Returns:
List[Dict]: list of output messages
"""
return messages

def get_prompt_from_messages(
self, messages: List[Dict], tools_or_functions: Optional[List[Dict]] = None
) -> str:
"""This function is used to get the complete prompt for list of messages
Args:
messages (List[Dict]): List of messages
tools_or_functions (Optional[List[Dict]], optional): List of tools or functions. Defaults to None.
Returns:
str: the prompt for inference/training
"""
messages_clone = messages.copy() # To avoid modifying the original list

functions = []
Expand All @@ -62,7 +105,6 @@ def get_prompt_from_messages(

def get_end_token_to_token_id(self, tokenizer: Any) -> Dict[str, int]:
"""return a dictionary mapping from end_token --> token_id
Args:
tokenizer (Any): tokenizer in transformers
Expand All @@ -82,6 +124,14 @@ def get_end_token_to_token_id(self, tokenizer: Any) -> Dict[str, int]:

@abstractmethod
def parse_assistant_response(self, llm_ouput: str) -> Dict:
"""This function is used to parse llm_output to the Message of OpenAI ({"role": xxx, "content": xxx, ...})
this is used in inference.
Args:
llm_ouput (str): The generated content from Model
Returns:
Dict: Dictionary of OpenAI message format
"""
raise NotImplementedError

@abstractmethod
Expand All @@ -95,18 +145,20 @@ def update_response_state_from_delta_text(
"""This function is used for streaming
Args:
current_state (Dict[str, Any]): a dictionary:
+ func_name: Optional[str],
+ response_type: Optional[str] text/function
current_text: the llm_output until now
current_state (Dict[str, Any]): a dictionary containing the state of the streaming: such as current function_name,
delta_text: new token generated
finish_reason: if finished or not
Returns:
Tuple[Dict[str, Any], Optional[Dict]]: {func_name, response_type}, response
Tuple[Dict[str, Any], Optional[Dict]]: updated state, response: can be None, a dictionary: {} or a list of dictionary: [{}, ..., {}]
"""
raise NotImplementedError

@abstractmethod
def get_chat_template(self):
"""Return chat_template in jinja format"""
raise NotImplementedError

@classmethod
def get_prompt_template(cls):
if cls._instance is None:
Expand Down Expand Up @@ -315,6 +367,33 @@ def update_response_state_from_delta_text(
} # format of openAI at the end, delta must be empty
return current_state, response

def get_chat_template(self) -> str:
chat_template = """{% for message in messages %}
{% if message['role'] == 'user' %}
{{ message['role'] + ':\n' + message['content'] + '<|END_OF_USER|>' + '\n' }}<br>
{% elif message['role'] == 'system' %}
{{ message['role'] + ':\n' + message['content'] + '<|END_OF_SYSTEM|>' + '\n' }}<br>
{% elif message['role'] == 'function' %}
{{ 'function name=' + message['name'] + ':\n' + message['content']+ '<|END_OF_FUNCTION_RESULT|>\n' }}<br>
{% elif message['role'] == 'assistant' %}
{% if 'function_call' in message and message['function_call'] is not none %}
{% if message['content'] is not none %}
{{ 'assistant:\n' + message['content'] + '\n<|START_OF_FUNCTION_CALL|>' + message['function_call']['name'] + ':\n' + message['function_call']['arguments'] + '<|END_OF_FUNCTION_CALL|>\n' }}<br>
{% else %}
{{ 'assistant:\n<|START_OF_FUNCTION_CALL|>' + message['function_call']['name'] + ':\n' + message['function_call']['arguments'] + '<|END_OF_FUNCTION_CALL|>\n' }}<br>
{% endif %}
{% else %}
{{ 'assistant:\n' + message['content'] + '<|END_OF_ASSISTANT|>' + '\n' }}<br>
{% endif %}
{% endif %}
{% endfor %}
{% if add_generation_prompt %}{{ 'assistant:' }}{% endif %}
"""
chat_template = chat_template.replace(" ", "")
chat_template = chat_template.replace("<br>\n", "")
chat_template = chat_template.strip()
return chat_template


class PromptTemplateV2(PromptTemplate):
from_token = "<|from|>"
Expand Down Expand Up @@ -379,11 +458,11 @@ def parse_assistant_response(self, llm_ouput: str) -> Dict:
for stop in self.get_stop_tokens_for_generation():
if llm_ouput.endswith(stop):
llm_ouput = llm_ouput[: -len(stop)]
print("---------------------------")

llm_ouput = f"{self.from_token}assistant\n{self.recipient_token}" + llm_ouput
print(llm_ouput)
responses = llm_ouput.split(self.from_token)
responses = [response.strip() for response in responses]

functions = []
text_response = None
for response in responses:
Expand Down Expand Up @@ -422,19 +501,25 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di
while index < len(messages):
message = messages[index]
tool_calls = message.get("tool_calls", None)

result.append(message)
if message["role"] == "assistant" and tool_calls:
num_calls = len(tool_calls)
tool_call_ids = [item["id"] for item in tool_calls]
if (
tool_calls[0].get("id", None) is not None
): # if tool_call contains "id" for mapping
tool_call_ids = [item["id"] for item in tool_calls]

tool_messages = [messages[index + 1 + j] for j in range(num_calls)]
id_2_tool_messages = {
item["tool_call_id"]: item for item in tool_messages
}
new_messages = [id_2_tool_messages[cid] for cid in tool_call_ids]
tool_messages = [messages[index + 1 + j] for j in range(num_calls)]
id_2_tool_messages = {
item["tool_call_id"]: item for item in tool_messages
}
new_messages = [id_2_tool_messages[cid] for cid in tool_call_ids]

result.extend(new_messages)
index += num_calls + 1
result.extend(new_messages)
index += num_calls + 1
else:
index += 1
else:
index += 1
return result
Expand Down Expand Up @@ -492,6 +577,38 @@ def get_recipient(self, current_text: str) -> str:
end_index = current_text.find(f"\n{self.content_token}")
return current_text[start_index:end_index].strip()

def get_chat_template(self) -> str:
chat_template = """{% for message in messages %}
{% if message['role'] == 'user' or message['role'] == 'system' %}
{{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}<br>
{% elif message['role'] == 'tool' %}
{{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}<br>
{% else %}
{% set contain_content='no'%}
{% if message['content'] is not none %}
{{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}<br>
{% set contain_content='yes'%}
{% endif %}
{% if 'tool_calls' in message and message['tool_calls'] is not none %}
{% for tool_call in message['tool_calls'] %}
{% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %}
{% if loop.index == 1 and contain_content == "no" %}
{{ prompt }}<br>
{% else %}
{{ '\n' + prompt}}<br>
{% endif %}
{% endfor %}
{% endif %}
{{ '<|stop|>\n' }}<br>
{% endif %}
{% endfor %}
{% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %}
"""
chat_template = chat_template.replace(" ", "")
chat_template = chat_template.replace("<br>\n", "")
chat_template = chat_template.strip()
return chat_template

def update_response_state_from_delta_text(
self,
*,
Expand Down Expand Up @@ -592,12 +709,12 @@ def get_prompt_template(version: int) -> PromptTemplate:
return PromptTemplateV2.get_prompt_template()


def get_prompt_template_from_tokenizer(tokenizer: Any):
def get_prompt_template_from_tokenizer(tokenizer: Any) -> PromptTemplate:
"""This function will determine the prompt template based on tokenizer.
Under the hood, this function will check if tokenizer contains some special tokens from template or not
Args:
tokenizer (Any): _description_
tokenizer (Any): Tokenizer
Returns:
_type_: _description_
Expand Down
3 changes: 3 additions & 0 deletions functionary/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def initialize_tokenizer(
added_tokens = prompt_template.get_additional_tokens()
special_tokens = {"additional_special_tokens": added_tokens}
num_new_tokens = tokenizer.add_special_tokens(special_tokens)

# add chat_template for tokenizer
tokenizer.chat_template = prompt_template.get_chat_template()
print("tokenizer: ", tokenizer)

# Resize embedding
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ typer==0.9.0
protobuf==3.20.0
tokenizers==0.14.1
vllm==0.2.1
openai==0.28.0
openai==1.3.5
1 change: 0 additions & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ async def chat_endpoint(chat_input: ChatInput):
def get_response_stream():
for response in response_generator:
chunk = StreamChoice(**response)
print("chunk: ", chunk)
result = ChatCompletionChunk(id=request_id, choices=[chunk])
chunk_dic = result.dict(exclude_unset=True)
chunk_data = json.dumps(chunk_dic, ensure_ascii=False)
Expand Down
59 changes: 50 additions & 9 deletions tests/test_prompt_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_default_prompt_template,
PromptTemplateV1,
PromptTemplateV2,
SYSTEM_MESSAGE,
)
from functionary.schema import generate_schema_from_functions
from functionary.train.custom_datasets import prepare_training_inputs
Expand Down Expand Up @@ -73,23 +74,25 @@ def test_final_prompt_generation(self):

def test_prepare_training_inputs_fast_tokenizer(self):
print("start testing fast tokenizer")
for keep_assistant_prefix in [False, True]:
for keep_assistant_prefix in [False]:
self.run_prepare_training_inputs(
use_fast=True,
use_fast=True,
pretrained="mistralai/Mistral-7B-v0.1",
keep_assistant_prefix=keep_assistant_prefix
keep_assistant_prefix=keep_assistant_prefix,
)

def test_prepare_training_inputs_normal_tokenizer(self):
print("start testing normal tokenizer")
for keep_assistant_prefix in [False, True]:
for keep_assistant_prefix in [False]:
self.run_prepare_training_inputs(
use_fast=False,
use_fast=False,
pretrained="mistralai/Mistral-7B-v0.1",
keep_assistant_prefix=keep_assistant_prefix
keep_assistant_prefix=keep_assistant_prefix,
)

def run_prepare_training_inputs(self, use_fast: bool, pretrained: str, keep_assistant_prefix: bool=False):
def run_prepare_training_inputs(
self, use_fast: bool, pretrained: str, keep_assistant_prefix: bool = False
):
"""this function is used to test function: prepare_training_inputs"""
# note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176
tokenizer_class = LlamaTokenizer
Expand All @@ -112,8 +115,8 @@ def run_prepare_training_inputs(self, use_fast: bool, pretrained: str, keep_assi
padding="longest",
max_length=1024,
return_tensor=False,
verbose=True,
keep_assistant_prefix=keep_assistant_prefix
verbose=False,
keep_assistant_prefix=keep_assistant_prefix,
)
input_ids = inputs["inputs"]["input_ids"]
labels = inputs["inputs"]["labels"]
Expand Down Expand Up @@ -158,6 +161,44 @@ def run_prepare_training_inputs(self, use_fast: bool, pretrained: str, keep_assi
"decoded content is different from original content",
)

def test_chat_template(self):
messages = self.test_case["messages"]
if "functions" in self.test_case:
functions = self.test_case["functions"]
else:
functions = []
for tool in self.test_case["tools"]:
functions.append(tool["function"])

messages.insert(
0, {"role": "system", "content": generate_schema_from_functions(functions)}
)
messages.insert(1, {"role": "system", "content": SYSTEM_MESSAGE})

chat_template = self.prompt_template.get_chat_template()
tokenizer = LlamaTokenizer.from_pretrained(
"meetkai/functionary-7b-v1.1", legacy=True
)
tokenizer.chat_template = chat_template

final_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
self.assertEqual(
final_prompt.strip(),
self.final_prompt,
"wrong final prompt for chat template",
)

prompt_gen = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
self.assertEqual(
prompt_gen,
self.final_prompt
+ "\n"
+ self.prompt_template.get_text_from_message({"role": "assistant"}),
"wrong prompt for generation",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4a8845f

Please sign in to comment.