Skip to content

Commit

Permalink
Merge pull request #60 from MeetKai/packing_for_flash_attention
Browse files Browse the repository at this point in the history
Packing for flash attention
  • Loading branch information
musab-mk authored Nov 27, 2023
2 parents d8189b0 + b9478f9 commit b76a6c1
Show file tree
Hide file tree
Showing 31 changed files with 3,898 additions and 1,151 deletions.
109 changes: 57 additions & 52 deletions functionary/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from transformers import (LlamaForCausalLM, LlamaTokenizer, StoppingCriteria,
StoppingCriteriaList)

from functionary.openai_types import ChatMessage, Function, FunctionCall
from functionary.prompt import (SYSTEM_MESSAGE, EndToken, StartToken,
get_prompt_from_messages)
from functionary.schema import generate_schema_from_functions
from functionary.openai_types import ChatMessage, Function, FunctionCall, Tool
from functionary.prompt_template import (PromptTemplate,
get_prompt_template_from_tokenizer)


class StopWordsCriteria(StoppingCriteria):
Expand All @@ -25,28 +24,43 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):

def tokenize(message: ChatMessage, tokenizer: LlamaTokenizer, device="cuda:0"):
text = str(message)
return tokenizer(text, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
return tokenizer(text, add_special_tokens=False, return_tensors="pt").input_ids.to(
device
)


def prepare_messages_for_inference(
*,
tokenizer: LlamaTokenizer,
messages: List[ChatMessage],
functions=None,
functions: Optional[List[Function]] = None,
tools: Optional[List[Tool]] = None,
device="cuda:0",
) -> torch.Tensor:

prompt_template = get_prompt_template_from_tokenizer(tokenizer)

dic_messages = [mess.dict() for mess in messages]
dic_messages.append({"role": "assistant"})
func_list = []
if functions is not None:
for item in functions:
func_list.append(item.dict())
final_prompt = get_prompt_from_messages(dic_messages, func_list)

tools_or_functions = []
if functions:
tools_or_functions = [item.dict() for item in functions]
elif tools:
tools_or_functions = [item.dict() for item in tools]

dic_messages = prompt_template.pre_process_messages_before_inference(dic_messages)
final_prompt = prompt_template.get_prompt_from_messages(
dic_messages, tools_or_functions=tools_or_functions
)
input_ids = tokenizer(final_prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device)
return input_ids


def remove_stop_tokens_from_end(token_ids: List[int], stop_sequences: List[List[int]]) -> List[int]:
def remove_stop_tokens_from_end(
token_ids: List[int], stop_sequences: List[List[int]]
) -> List[int]:
"""This function is used to remove the hitting stop-sequence of id at the end of generated token_ids
Args:
Expand All @@ -65,50 +79,31 @@ def remove_stop_tokens_from_end(token_ids: List[int], stop_sequences: List[List[
return token_ids


def parse_generated_content(generated_content: str) -> ChatMessage:
"""Parse LLM output into ChatMessage
Args:
generated_content (str): llm output
Returns:
ChatMessage: _description_
"""
# strip end_of_function_call and end_of_assistant
generated_content = generated_content.strip()
for endtoken in [EndToken.function_call, EndToken.assistant]:
if generated_content.endswith(endtoken):
generated_content = generated_content[: - len(endtoken)].strip()
# First we need to check if llm_output contains start_token or not
start_function_index = generated_content.find(StartToken.function.value)
text_content = generated_content
result = ChatMessage(role="assistant")
if start_function_index >= 0:
func_info = generated_content[start_function_index + len(StartToken.function.value): ].strip()
index = func_info.find(":")
func_name = func_info[: index].strip()
arguments = func_info[index + 1: ].strip()
text_content = generated_content[: start_function_index].strip()
result.function_call = FunctionCall(name=func_name, arguments=arguments)
if len(text_content) > 0:
result.content = text_content
return result


def generate_message(
*,
model: LlamaForCausalLM,
tokenizer: LlamaTokenizer,
messages: List[ChatMessage],
functions: Optional[List[Function]] = None,
tools: Optional[List[Tool]] = None,
temperature: float = 0.7,
max_new_tokens=256,
device="cuda:0",
**kwargs,
) -> ChatMessage:
inputs = prepare_messages_for_inference(tokenizer=tokenizer, messages=messages, functions=functions, device=device)
prompt_template = get_prompt_template_from_tokenizer(tokenizer)
inputs = prepare_messages_for_inference(
tokenizer=tokenizer,
messages=messages,
functions=functions,
tools=tools,
device=device,
)
stop_words_ids = []
# [EndToken.assistant, EndToken.function_call]
for stop in kwargs.get("stops", []) + [EndToken.assistant, EndToken.function_call]:
for stop in (
kwargs.get("stops", []) + prompt_template.get_stop_tokens_for_generation()
):
tok_ids = tokenizer.encode(stop, add_special_tokens=False)
if (
len(tok_ids) > 1 and tok_ids[0] == 29871
Expand All @@ -125,12 +120,14 @@ def generate_message(
)
token_ids = generate_ids[:, inputs.shape[1] :][0].tolist()

#token_ids = remove_stop_tokens_from_end(token_ids, stop_words_ids)

generated_content = tokenizer.decode(
token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, max_new_tokens=max_new_tokens
token_ids,
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
max_new_tokens=max_new_tokens,
).strip()
return parse_generated_content(generated_content)
result = prompt_template.parse_assistant_response(generated_content)
return ChatMessage(**result)


if __name__ == "__main__":
Expand Down Expand Up @@ -161,7 +158,9 @@ def generate_message(
ChatMessage(role="assistant", content="Hi there!"),
ChatMessage(role="user", content="How are you?"),
ChatMessage(role="assistant", content="I'm good thanks!"),
ChatMessage(role="user", content="What's the weather like today in san francisco?"),
ChatMessage(
role="user", content="What's the weather like today in san francisco?"
),
ChatMessage(
role="assistant",
content="I can help you find out! Lets call the get_current_weather function.",
Expand All @@ -170,13 +169,19 @@ def generate_message(
arguments='{"location": "San Francisco, CA", "format": "celsius"}',
),
),
ChatMessage(role="function", name="get_current_weather", content='{"value": 32}'),
ChatMessage(role="assistant", content="It's 32 degrees celsius in San Francisco today."),
ChatMessage(
role="function", name="get_current_weather", content='{"value": 32}'
),
ChatMessage(
role="assistant", content="It's 32 degrees celsius in San Francisco today."
),
ChatMessage(role="user", content="Thanks!"),
ChatMessage(role="assistant", content="No problem!"),
]

# Now Lets prepare the messages for inference
tokenizer = LlamaTokenizer.from_pretrained("musabgultekin/functionary-7b-v1")
inputs = prepare_messages_for_inference(tokenizer=tokenizer, messages=messages, functions=functions, device="cpu")
inputs = prepare_messages_for_inference(
tokenizer=tokenizer, messages=messages, functions=functions, device="cpu"
)
print(inputs.shape)
Loading

0 comments on commit b76a6c1

Please sign in to comment.