Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 committed Sep 27, 2024
1 parent d62435e commit 2326238
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
32 changes: 26 additions & 6 deletions go/skill/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,45 @@

package skill

import "github.com/ling0322/libllm/go/llm"
import (
"errors"

"github.com/ling0322/libllm/go/llm"
)

type Llama struct {
}

func (l *Llama) Build(history []Message) (llm.Prompt, error) {
prompt := llm.NewPrompt()
prompt.AppendControlToken("<|begin_of_text|>")
for _, message := range history {
for _, message := range history[:len(history)-1] {
prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText(message.Role)
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\n" + message.Content)
prompt.AppendControlToken("<|eot_id|>")
}

prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText("assistant")
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\n")
lastMessage := history[len(history)-1]
if lastMessage.Role == "user" {
prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText(lastMessage.Role)
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\n" + lastMessage.Content)
prompt.AppendControlToken("<|eot_id|>")
prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText("assistant")
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\n")
} else if lastMessage.Role == "assistant" {
prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText(lastMessage.Role)
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\n" + lastMessage.Content)
} else {
return nil, errors.New("last message should be either user or assistant")
}

return prompt, nil
}
14 changes: 9 additions & 5 deletions tools/llama_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ def _export_block(self, ctx: Context, model_block):
self._export_mlp(ctx.with_subname("mlp"), model_block.mlp)

def _export_rope(self, ctx: Context, rope_embd):
cos_cached = torch.squeeze(rope_embd.cos_cached)
sin_cached = torch.squeeze(rope_embd.sin_cached)
original_max_seq_len = rope_embd.original_max_seq_len
position_ids = torch.arange(original_max_seq_len).unsqueeze(0)
x = torch.ones(1)

cos_cached, sin_cached = rope_embd(x, position_ids)

rope = torch.stack((cos_cached, sin_cached))
self._write(ctx.with_quant(Quant.NONE), rope)
Expand All @@ -79,7 +82,8 @@ def generate_config(cls, llama_config) -> configparser.ConfigParser:
config = configparser.ConfigParser()
config["llama"] = {}

assert llama_config.rope_scaling is None
print("llama_config.rope_scaling =", llama_config.rope_scaling)

assert llama_config.pretraining_tp == 1
assert llama_config.hidden_act == "silu"

Expand Down Expand Up @@ -137,7 +141,7 @@ def run_llama_chat(huggingface_name):
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_BIN = "model.bin"
MODEL_INI = "model.ini"
TOKENIZER_BIN = "tokenizer.bin"
Expand Down Expand Up @@ -171,7 +175,7 @@ def run_llama_chat(huggingface_name):
libllm_tokenizer = read_spm_model(args.huggingface_name)

with package.open(MODEL_BIN, "w", force_zip64=True) as fp:
config = LlamaExporter.export(model, fp, args.quantization)
config = LlamaExporter.export(model, fp, args.quant)

if args.llama_version == 3:
config["llama"]["bot_token_id"] = str(tokenizer.bos_token_id)
Expand Down

0 comments on commit 2326238

Please sign in to comment.