Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 committed Jun 23, 2024
1 parent d14f88a commit 1ac3ae3
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 23 deletions.
63 changes: 63 additions & 0 deletions go/chat/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// The MIT License (MIT)
//
// Copyright (c) 2024 Xiaoyang Chen
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
// and associated documentation files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package chat

import (
"github.com/ling0322/libllm/go/llm"
)

type Message struct {
Role string
Content string
}

type Chat struct {
model llm.Model
promptBuilder promptBuilder
compConfig llm.CompletionConfig
}

func NewChat(model llm.Model) (*Chat, error) {
modelName := model.GetName()
promptBuilder, err := newPromptBuilder(modelName)
if err != nil {
return nil, err
}

return &Chat{
model: model,
promptBuilder: promptBuilder,
compConfig: llm.NewCompletionConfig(),
}, nil
}

func (c *Chat) Chat(history []Message) (llm.Completion, error) {
prompt, err := c.promptBuilder.Build(history)
if err != nil {
return nil, err
}

comp, err := c.model.Complete(c.compConfig, prompt)
if err != nil {
return nil, err
}

return comp, nil
}
31 changes: 31 additions & 0 deletions go/chat/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// The MIT License (MIT)
//
// Copyright (c) 2024 Xiaoyang Chen
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
// and associated documentation files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package chat

type QA struct {
Question string
Answer string
}

type Context struct {
System string
History []QA
Question string
}
9 changes: 9 additions & 0 deletions go/chat/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module github.com/ling0322/libllm/go/chat

go 1.15

replace github.com/ling0322/libllm/go/llm => ../llm

require (
github.com/ling0322/libllm/go/llm v1.0.0
)
43 changes: 43 additions & 0 deletions go/chat/llama.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// The MIT License (MIT)
//
// Copyright (c) 2024 Xiaoyang Chen
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
// and associated documentation files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package chat

import "github.com/ling0322/libllm/go/llm"

type Llama struct {
}

func (l *Llama) Build(history []Message) (llm.Prompt, error) {
prompt := llm.NewPrompt()
prompt.AppendControlToken("<|begin_of_text|>")
for _, message := range history {
prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText(message.Role)
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\n" + message.Content)
prompt.AppendControlToken("<|eot_id|>")
}

prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText("assistant")
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\n")
return prompt, nil
}
38 changes: 38 additions & 0 deletions go/chat/prompt_builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// The MIT License (MIT)
//
// Copyright (c) 2024 Xiaoyang Chen
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
// and associated documentation files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package chat

import (
"fmt"

"github.com/ling0322/libllm/go/llm"
)

type promptBuilder interface {
Build(history []Message) (llm.Prompt, error)
}

func newPromptBuilder(modelName string) (promptBuilder, error) {
if modelName == "llama" {
return &Llama{}, nil
} else {
return nil, fmt.Errorf("unexpected model name %s", modelName)
}
}
5 changes: 4 additions & 1 deletion go/cmd/go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module main

go 1.22.4
go 1.15

replace github.com/ling0322/libllm/go/llm => ../llm

replace github.com/ling0322/libllm/go/chat => ../chat

require (
github.com/ling0322/libllm/go/llm v1.0.0
github.com/ling0322/libllm/go/chat v1.0.0
)
106 changes: 87 additions & 19 deletions go/cmd/main.go
Original file line number Diff line number Diff line change
@@ -1,42 +1,110 @@
package main

import (
"bufio"
"errors"
"flag"
"fmt"
"io"
"log"
"os"
"strings"
"time"

"github.com/ling0322/libllm/go/chat"
"github.com/ling0322/libllm/go/llm"
)

var gModelPath string
var gDevice string

func main() {
model, err := llm.NewModel("../../tools/llama.llmpkg", llm.Cuda)
flag.StringVar(&gModelPath, "model", "", "path of model file (.llmpkg)")
flag.StringVar(&gDevice, "device", "audo", "inference device (cpu|cuda|audo)")
flag.Parse()

var device llm.Device
if strings.ToLower(gDevice) == "cpu" {
device = llm.Cpu
} else if strings.ToLower(gDevice) == "cuda" {
device = llm.Cuda
} else if strings.ToLower(gDevice) == "audo" {
device = llm.Auto
} else {
log.Fatalf("unexpected device %s", gDevice)
}

model, err := llm.NewModel(gModelPath, device)
if err != nil {
log.Fatal(err)
}

prompt := llm.NewPrompt()
prompt.AppendControlToken("<|begin_of_text|>")
prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText("user")
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\nHi")
prompt.AppendControlToken("<|eot_id|>")
prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText("assistant")
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\n")

log.Println(model)

comp, err := model.Complete(llm.NewCompletionConfig(), prompt)
llmChat, err := chat.NewChat(model)
if err != nil {
log.Fatal(err)
}

for comp.IsActive() {
chunk, err := comp.GenerateNextChunk()
fmt.Println("Please input your question.")
fmt.Println(" Type ':new' to start a new session (clean history).")
fmt.Println(" Type ':sys' to input the system prompt and start a new session .")

history := []chat.Message{}
for {
reader := bufio.NewReader(os.Stdin)

fmt.Print("> ")
question, err := reader.ReadString('\n')
if errors.Is(err, io.EOF) {
fmt.Println()
break
} else if err != nil {
log.Fatal(err)
}
question = strings.TrimSpace(question)
if strings.ToLower(question) == ":sys" {
fmt.Print("SYSTEM> ")
system, err := reader.ReadString('\n')
if errors.Is(err, io.EOF) {
fmt.Println()
break
} else if err != nil {
log.Fatal(err)
}
history = []chat.Message{{Role: "system", Content: system}}
continue
} else if strings.ToLower(question) == ":new" {
fmt.Println("===== new session =====")
history = []chat.Message{}
continue
}

history = append(history, chat.Message{Role: "user", Content: question})
comp, err := llmChat.Chat(history)
if err != nil {
log.Fatal(err)
}
fmt.Printf(chunk.Text)

t0 := time.Now()
answer := ""
numToken := 0
for comp.IsActive() {
chunk, err := comp.GenerateNextChunk()
if err != nil {
log.Fatal(err)
}
fmt.Printf(chunk.Text)
answer += chunk.Text
numToken++
}
history = append(history, chat.Message{Role: "assistant", Content: answer})
fmt.Println()

dur := time.Since(t0)
fmt.Printf(
"(%d tokens, time=%.2fs, %.2fms per token)\n",
numToken,
dur.Seconds(),
dur.Seconds()*1000/float64(numToken),
)
}
}
2 changes: 1 addition & 1 deletion go/llm/go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/ling0322/libllm/go/llm

go 1.22.4
go 1.15
2 changes: 1 addition & 1 deletion go/llm/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (p *promptImpl) AppendText(text string) {
}

func (p *promptImpl) AppendControlToken(text string) {
p.elements = append(p.elements, &textPromptElem{text})
p.elements = append(p.elements, &controlTokenPromptElem{text})
}

func (p *promptImpl) updatePromptHandle(handle *promptHandle) error {
Expand Down
6 changes: 6 additions & 0 deletions src/libllm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,9 @@ endif()

enable_testing()
add_test(NAME unittest COMMAND $<TARGET_FILE:unittest>)

add_custom_target(llmbin
ALL
DEPENDS libllm
COMMAND go build -o ${CMAKE_CURRENT_BINARY_DIR}/llm_chat
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/go/cmd)
2 changes: 1 addition & 1 deletion src/libllm/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Generator::Generator(

void Generator::forwardPrompt(const std::vector<LongType> &prompt) {
for (LongType tokenId : prompt) {
LOG(DEBUG) << tokenId << " -> " << _tokenizer->getVocab()->getTokenString(tokenId);
LOG(INFO) << tokenId << " -> " << _tokenizer->getVocab()->getTokenString(tokenId);
}

Tensor inputs = _model->buildInput(prompt);
Expand Down

0 comments on commit 1ac3ae3

Please sign in to comment.