Skip to content

Commit

Permalink
support whisper models (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 authored Jul 30, 2024
1 parent bf6e1d1 commit a57a344
Show file tree
Hide file tree
Showing 85 changed files with 8,620 additions and 2,296 deletions.
2 changes: 2 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ AllowShortFunctionsOnASingleLine: false
PenaltyBreakAssignment: 2000
BreakConstructorInitializers: BeforeColon
PackConstructorInitializers: Never
ReturnTypeBreakingStyle: Automatic
PenaltyReturnTypeOnItsOwnLine: 8000
9 changes: 4 additions & 5 deletions go/bin/go.mod
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
module github.com/ling0322/libllm/go/bin

go 1.15
go 1.21

replace github.com/ling0322/libllm/go/llm => ../llm
replace github.com/ling0322/libllm/go/i18n => ../i18n
replace github.com/ling0322/libllm/go/chat => ../chat

replace github.com/ling0322/libllm/go/skill => ../skill

require (
github.com/ling0322/libllm/go/llm v1.0.0
github.com/ling0322/libllm/go/chat v1.0.0
github.com/ling0322/libllm/go/i18n v1.0.0
github.com/ling0322/libllm/go/skill v1.0.0
)
2 changes: 2 additions & 0 deletions go/bin/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/ling0322/libllm/go/i18n v0.0.0-20240626100001-bf6e1d13e2be h1:W/7pkPYLRaHvWNCezcFNzv/WzKSl7wZpkgKndToLh94=
github.com/ling0322/libllm/go/i18n v0.0.0-20240626100001-bf6e1d13e2be/go.mod h1:+AS95R9P+RH73A5IROS+q0l6NoozpyfJGhrBurYiokM=
2 changes: 1 addition & 1 deletion go/i18n/i18n.go → go/bin/i18n.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package i18n
package bin

import (
"errors"
Expand Down
105 changes: 105 additions & 0 deletions go/bin/llm/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// The MIT License (MIT)
//
// Copyright (c) 2024 Xiaoyang Chen
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
// and associated documentation files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package main

import (
"bufio"
"errors"
"fmt"
"io"
"log"
"os"
"strings"
"time"

"github.com/ling0322/libllm/go/llm"
"github.com/ling0322/libllm/go/skill"
)

func chatMain(model llm.Model) {
llmChat, err := skill.NewChat(model)
if err != nil {
log.Fatal(err)
}

fmt.Println(gLocalizer.Get(MsgInputQuestion))
fmt.Println(gLocalizer.Get(MsgInputQuestionNew))
fmt.Println(gLocalizer.Get(MsgInputQuestionSys))

history := []skill.Message{}
systemPrompt := ""
for {
reader := bufio.NewReader(os.Stdin)

fmt.Print("> ")
question, err := reader.ReadString('\n')
if errors.Is(err, io.EOF) {
fmt.Println()
break
} else if err != nil {
log.Fatal(err)
}
question = strings.TrimSpace(question)
if len(question) > 5 && strings.ToLower(question)[0:5] == ":sys " {
systemPrompt = strings.TrimSpace(question[5:])
history = []skill.Message{}
continue
} else if strings.ToLower(question) == ":new" {
fmt.Println(gLocalizer.Get(MsgNewSession))
history = []skill.Message{}
continue
} else if question == "" {
continue
}

if len(history) == 0 && systemPrompt != "" {
history = append(history, skill.Message{Role: "system", Content: systemPrompt})
}

history = append(history, skill.Message{Role: "user", Content: question})
comp, err := llmChat.Chat(history)
if err != nil {
log.Fatal(err)
}

t0 := time.Now()
answer := ""
numToken := 0
for comp.Next() {
fmt.Print(comp.Text())
answer += comp.Text()
numToken++
}
if err := comp.Error(); err != nil {
log.Fatal(err)
}

history = append(history, skill.Message{Role: "assistant", Content: answer})
fmt.Println()

dur := time.Since(t0)
fmt.Printf(
gLocalizer.Get(MsgStat),
numToken,
dur.Seconds(),
dur.Seconds()*1000/float64(numToken),
)
}
}
102 changes: 13 additions & 89 deletions go/bin/llm/main.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
package main

import (
"bufio"
"errors"
"flag"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
"time"

"github.com/ling0322/libllm/go/chat"
"github.com/ling0322/libllm/go/i18n"
"github.com/ling0322/libllm/go/bin"
"github.com/ling0322/libllm/go/llm"
)

var gModelPath string
var gDevice string
var gTask string
var gAudioFile string
var gLocalizer *bin.Localizer

const (
MsgInputQuestion = iota
Expand Down Expand Up @@ -47,9 +42,12 @@ var gMsgEnUs = map[int]string{
func main() {
flag.StringVar(&gModelPath, "model", "", "path of model file (.llmpkg)")
flag.StringVar(&gDevice, "device", "audo", "inference device (cpu|cuda|audo)")
flag.StringVar(&gAudioFile, "audio", "", "the input audio file, only used in transcribe task")
flag.StringVar(&gTask, "task", "chat", "the task to run (chat|transcribe)")
flag.Parse()

localizer, err := i18n.NewLocalizer(map[string]map[int]string{
var err error
gLocalizer, err = bin.NewLocalizer(map[string]map[int]string{
"en_us": gMsgEnUs,
"zh_cn": gMsgZhCn,
})
Expand All @@ -68,92 +66,18 @@ func main() {
log.Fatalf("unexpected device %s", gDevice)
}

// if model is empty, automatically choose a *.llmpkg file in working directory.
if gModelPath == "" {
llmpkgFiles, err := filepath.Glob("*.llmpkg")
if err != nil {
log.Fatal(err)
}

if len(llmpkgFiles) != 1 {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
flag.PrintDefaults()
} else {
gModelPath = llmpkgFiles[0]
}
log.Fatal("argument -model is required")
}

model, err := llm.NewModel(gModelPath, device)
if err != nil {
log.Fatal(err)
}

llmChat, err := chat.NewChat(model)
if err != nil {
log.Fatal(err)
}

fmt.Println(localizer.Get(MsgInputQuestion))
fmt.Println(localizer.Get(MsgInputQuestionNew))
fmt.Println(localizer.Get(MsgInputQuestionSys))

history := []chat.Message{}
systemPrompt := ""
for {
reader := bufio.NewReader(os.Stdin)

fmt.Print("> ")
question, err := reader.ReadString('\n')
if errors.Is(err, io.EOF) {
fmt.Println()
break
} else if err != nil {
log.Fatal(err)
}
question = strings.TrimSpace(question)
if len(question) > 5 && strings.ToLower(question)[0:5] == ":sys " {
systemPrompt = strings.TrimSpace(question[5:])
history = []chat.Message{}
continue
} else if strings.ToLower(question) == ":new" {
fmt.Println(localizer.Get(MsgNewSession))
history = []chat.Message{}
continue
} else if question == "" {
continue
}

if len(history) == 0 && systemPrompt != "" {
history = append(history, chat.Message{Role: "system", Content: systemPrompt})
}

history = append(history, chat.Message{Role: "user", Content: question})
comp, err := llmChat.Chat(history)
if err != nil {
log.Fatal(err)
}

t0 := time.Now()
answer := ""
numToken := 0
for comp.IsActive() {
chunk, err := comp.GenerateNextChunk()
if err != nil {
log.Fatal(err)
}
fmt.Printf(chunk.Text)
answer += chunk.Text
numToken++
}
history = append(history, chat.Message{Role: "assistant", Content: answer})
fmt.Println()

dur := time.Since(t0)
fmt.Printf(
localizer.Get(MsgStat),
numToken,
dur.Seconds(),
dur.Seconds()*1000/float64(numToken),
)
if gTask == "chat" {
chatMain(model)
} else {
log.Fatalf("unexpected task: %s", gTask)
}
}
78 changes: 78 additions & 0 deletions go/bin/llm_transcribe/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package main

import (
"flag"
"log"
"strings"

"github.com/ling0322/libllm/go/bin"
"github.com/ling0322/libllm/go/llm"
)

var gModelPath string
var gDevice string
var gInputFile string
var gLocalizer *bin.Localizer
var gFfmpegBin string

const (
MsgInputQuestion = iota
MsgInputQuestionNew
MsgInputQuestionSys
MsgNewSession
MsgStat
)

var gMsgZhCn = map[int]string{
MsgInputQuestion: "请输入问题:",
MsgInputQuestionNew: " 输入 ':new' 重新开始一个新的对话 (清除历史).",
MsgInputQuestionSys: " 输入 ':sys <系统指令>' 设置对话的系统指令,并重新开始一个新的对话.",
MsgNewSession: "===== 新的对话 =====",
MsgStat: "(%d个Token, 总共耗时%.2f秒, 平均每个Token耗时%.2f毫秒)\n",
}

var gMsgEnUs = map[int]string{
MsgInputQuestion: "Please input your question.",
MsgInputQuestionNew: " Type ':new' to start a new session (clean history).",
MsgInputQuestionSys: " Type ':sys <system_prompt>' to set the system prompt and start a new session .",
MsgNewSession: "===== new session =====",
MsgStat: "(%d tokens, time=%.2fs, %.2fms per token)\n",
}

func main() {
flag.StringVar(&gModelPath, "model", "", "path of model file (.llmpkg)")
flag.StringVar(&gDevice, "device", "audo", "inference device (cpu|cuda|audo)")
flag.StringVar(&gInputFile, "input", "", "the input audio file, only used in transcribe task")
flag.Parse()

var err error
gLocalizer, err = bin.NewLocalizer(map[string]map[int]string{
"en_us": gMsgEnUs,
"zh_cn": gMsgZhCn,
})
if err != nil {
log.Fatal(err)
}

var device llm.Device
if strings.ToLower(gDevice) == "cpu" {
device = llm.Cpu
} else if strings.ToLower(gDevice) == "cuda" {
device = llm.Cuda
} else if strings.ToLower(gDevice) == "audo" {
device = llm.Auto
} else {
log.Fatalf("unexpected device %s", gDevice)
}

if gModelPath == "" {
log.Fatal("argument -model is required")
}

model, err := llm.NewModel(gModelPath, device)
if err != nil {
log.Fatal(err)
}

transcribeMain(model)
}
Loading

0 comments on commit a57a344

Please sign in to comment.