Skip to content

Commit

Permalink
Update llm C api (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 authored Nov 20, 2024
1 parent 004ea5c commit 62cef9c
Show file tree
Hide file tree
Showing 36 changed files with 913 additions and 1,945 deletions.
21 changes: 7 additions & 14 deletions go/bin/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"os"
"strings"

"github.com/ling0322/libllm/go/llm"
"github.com/ling0322/libllm/go/skill"
)

Expand Down Expand Up @@ -63,25 +62,19 @@ func (a *binArgs) addDeviceFlag() {
a.fs.StringVar(&a.device, "device", "auto", "inference device, either cpu, cuda or auto")
}

func (a *binArgs) getDevice() llm.Device {
var device llm.Device
if strings.ToLower(a.device) == "cpu" {
device = llm.Cpu
} else if strings.ToLower(a.device) == "cuda" {
device = llm.Cuda
} else if strings.ToLower(a.device) == "auto" {
device = llm.Auto
func (a *binArgs) getDevice() string {
device := strings.ToLower(a.device)
if device == "cpu" || device == "cuda" || device == "auto" {
return device
} else {
log.Fatalf("unexpected device %s", a.device)
slog.Error(`invalid device name: must be one of "cpu", "cuda" or "auto"`)
a.fs.Usage()
os.Exit(1)
}

return device
}

func (a *binArgs) getRawDevice() string {
return a.device
}

func (a *binArgs) addModelFlag() {
a.fs.Var(&a.models, "m", "the libllm model, it could be model name or model file,"+
" model files are with suffix \".llmpkg\". "+
Expand Down
27 changes: 11 additions & 16 deletions go/bin/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
"strings"
"time"

"github.com/ling0322/libllm/go/skill"
"github.com/ling0322/libllm/go/llm"
)

func printChatUsage(fs *flag.FlagSet) {
Expand Down Expand Up @@ -64,16 +64,11 @@ func chatMain(args []string) {
log.Fatal(err)
}

llmChat, err := skill.NewChat(model)
if err != nil {
log.Fatal(err)
}

fmt.Println(gLocalizer.Get(MsgInputQuestion))
fmt.Println(gLocalizer.Get(MsgInputQuestionNew))
fmt.Println(gLocalizer.Get(MsgInputQuestionSys))

history := []skill.Message{}
history := []llm.Message{}

// TODO: get system prompt for different models
systemPrompt := ""
Expand All @@ -91,22 +86,22 @@ func chatMain(args []string) {
question = strings.TrimSpace(question)
if len(question) > 5 && strings.ToLower(question)[0:5] == ":sys " {
systemPrompt = strings.TrimSpace(question[5:])
history = []skill.Message{}
history = []llm.Message{}
continue
} else if strings.ToLower(question) == ":new" {
fmt.Println(gLocalizer.Get(MsgNewSession))
history = []skill.Message{}
history = []llm.Message{}
continue
} else if question == "" {
continue
}

if len(history) == 0 && systemPrompt != "" {
history = append(history, skill.Message{Role: "system", Content: systemPrompt})
history = append(history, llm.Message{Role: "system", Content: systemPrompt})
}

history = append(history, skill.Message{Role: "user", Content: question})
comp, err := llmChat.Chat(history)
history = append(history, llm.Message{Role: "user", Content: question})
comp, err := model.Complete(history, llm.DefaultCompletionConfig())
if err != nil {
log.Fatal(err)
}
Expand All @@ -115,15 +110,15 @@ func chatMain(args []string) {
answer := ""
numToken := 0
for comp.Next() {
fmt.Print(comp.Text())
answer += comp.Text()
fmt.Print(comp.Chunk().Text)
answer += comp.Chunk().Text
numToken++
}
if err := comp.Error(); err != nil {
if err := comp.Err(); err != nil {
log.Fatal(err)
}

history = append(history, skill.Message{Role: "assistant", Content: answer})
history = append(history, llm.Message{Role: "assistant", Content: answer})
fmt.Println()

dur := time.Since(t0)
Expand Down
2 changes: 1 addition & 1 deletion go/bin/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func autoDownloadModel(nameOrPath string) (filename string, err error) {
return
}

func createModelAutoDownload(nameOrPath string, device llm.Device) (llm.Model, error) {
func createModelAutoDownload(nameOrPath, device string) (*llm.Model, error) {
modelPath, err := autoDownloadModel(nameOrPath)
if err != nil {
return nil, err
Expand Down
16 changes: 10 additions & 6 deletions go/bin/subtitle_translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
type transcriptionTranslator struct {
historySrc []string
historyTgt []string
translator skill.Translator
translator *skill.Translator
srcLang skill.Lang
tgtLang skill.Lang
}
Expand All @@ -53,7 +53,6 @@ func (t *transcriptionTranslator) translateOneInternal(
if err != nil {
return translationResult{}, err
}
slog.Info(fmt.Sprintf("translate \"%s\" to \"%s\"", text, tr.tgtText))
return tr, nil
}

Expand Down Expand Up @@ -113,16 +112,21 @@ func (t *transcriptionTranslator) translate(
return nil, err
}

transcription.Text = tr.tgtText
translated := tr.tgtText
translated = strings.TrimSpace(translated)
translated = strings.Replace(translated, "\n", " ", -1)
slog.Info(fmt.Sprintf("translate \"%s\" to \"%s\"", tr.srcText, translated))

transcription.Text = translated
translatedTrxn = append(translatedTrxn, transcription)
}

return translatedTrxn, nil
}

func newTranscripotionTranslator(
func newTranscriptionTranslator(
translationModel string,
device llm.Device,
device string,
srcLang, tgtLang skill.Lang) (*transcriptionTranslator, error) {
model, err := createModelAutoDownload(translationModel, device)
if err != nil {
Expand Down Expand Up @@ -167,7 +171,7 @@ func ReadSubtitleFile(filename string) ([]llm.RecognitionResult, error) {
}

func TranslateSubtitle(config translationConfig, inputFile, outputFile string) error {
tt, err := newTranscripotionTranslator(
tt, err := newTranscriptionTranslator(
config.modelName,
config.device,
config.srcLang,
Expand Down
4 changes: 2 additions & 2 deletions go/bin/transcribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type translationConfig struct {
srcLang skill.Lang
tgtLang skill.Lang
modelName string
device llm.Device
device string
}

func printTranscribeUsage(fs *flag.FlagSet) {
Expand Down Expand Up @@ -164,7 +164,7 @@ func transcribeMain(args []string) {
var tgtLang skill.Lang
modelFile := getTranscriotionModel(ba)
tgtLang = ba.getTargetLang()
device := ba.getRawDevice()
device := ba.getDevice()
inputFile := ba.getInput()
outputFile := getOutputFile(ba)

Expand Down
8 changes: 4 additions & 4 deletions go/bin/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type translationResult struct {
processingTime time.Duration
}

func translate(translator skill.Translator, req skill.TranslationRequest, onToken func(string)) (translationResult, error) {
func translate(translator *skill.Translator, req skill.TranslationRequest, onToken func(string)) (translationResult, error) {
text := strings.TrimSpace(req.Text)
if text == "" {
return translationResult{
Expand All @@ -70,12 +70,12 @@ func translate(translator skill.Translator, req skill.TranslationRequest, onToke
numToken := 0
for comp.Next() {
if onToken != nil {
onToken(comp.Text())
onToken(comp.Chunk().Text)
}
answer += comp.Text()
answer += comp.Chunk().Text
numToken++
}
if err := comp.Error(); err != nil {
if err := comp.Err(); err != nil {
return translationResult{}, err
}

Expand Down
64 changes: 8 additions & 56 deletions go/llm/asr.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,16 @@
package llm

// #include <stdlib.h>
// #include "llm_api.h"
// const int cLLM_ERROR_EOF = LLM_ERROR_EOF;
// #include "llm.h"
import "C"
import (
"encoding/json"
"errors"
"fmt"
"runtime"
"time"
"unsafe"
)

// A LLM.
// A ASR model.
type ASRModel struct {
model C.llm_asr_model_t
}
Expand All @@ -58,12 +55,8 @@ type Recognition struct {
json *llmJson
}

type llmJson struct {
json C.llm_json_t
}

func newRecognition() *Recognition {
r := &Recognition{}
r := new(Recognition)
r.json = newJson()
C.llm_asr_recognition_init(&r.recognition)
runtime.SetFinalizer(r, func(r *Recognition) {
Expand All @@ -73,47 +66,6 @@ func newRecognition() *Recognition {
return r
}

func newJson() *llmJson {
j := &llmJson{}
C.llm_json_init(&j.json)
runtime.SetFinalizer(j, func(j *llmJson) {
C.llm_json_destroy(&j.json)
})

return j
}

func (j *llmJson) marshal(v any) error {
jsonBytes, err := json.Marshal(v)
if err != nil {
return err
}

cJsonStr := C.CString(string(jsonBytes))
defer C.free(unsafe.Pointer(cJsonStr))

status := C.llm_json_parse(&j.json, cJsonStr)
if status != 0 {
return errors.New(C.GoString(C.llmGetLastErrorMessage()))
}

return nil
}

func (j *llmJson) unmarshal(v any) error {
bufSize := 2048
buf := C.malloc(C.size_t(bufSize))
defer C.free(buf)

status := C.llm_json_dump(&j.json, (*C.char)(buf), C.int64_t(bufSize))
if status != 0 {
return errors.New(C.GoString(C.llmGetLastErrorMessage()))
}

jsonStr := C.GoString((*C.char)(buf))
return json.Unmarshal([]byte(jsonStr), v)
}

func NewASRModel(filename, device string) (*ASRModel, error) {
err := initLlm()
if err != nil {
Expand All @@ -126,15 +78,15 @@ func NewASRModel(filename, device string) (*ASRModel, error) {
"device": device,
})

m := &ASRModel{}
m := new(ASRModel)
C.llm_asr_model_init(&m.model)
runtime.SetFinalizer(m, func(m *ASRModel) {
C.llm_asr_model_destroy(&m.model)
})

status := C.llm_asr_model_load(&m.model, &json.json)
if status != 0 {
return nil, errors.New(C.GoString(C.llmGetLastErrorMessage()))
return nil, errors.New(C.GoString(C.llm_get_last_error_message()))
}

return m, nil
Expand All @@ -150,7 +102,7 @@ func (m *ASRModel) Recognize(filename string) (*Recognition, error) {

status := C.llm_asr_recognize_media_file(&m.model, &json.json, &r.recognition)
if status != 0 {
return nil, errors.New(C.GoString(C.llmGetLastErrorMessage()))
return nil, errors.New(C.GoString(C.llm_get_last_error_message()))
}

return r, nil
Expand All @@ -176,10 +128,10 @@ func (r *Recognition) Dispose() {

func (r *Recognition) Next() bool {
status := C.llm_asr_recognition_get_next_result(&r.recognition, &r.json.json)
if status == C.cLLM_ERROR_EOF {
if status == LLM_ERROR_EOF {
return false
} else if status != 0 {
r.err = errors.New(C.GoString(C.llmGetLastErrorMessage()))
r.err = errors.New(C.GoString(C.llm_get_last_error_message()))
return false
}

Expand Down
Loading

0 comments on commit 62cef9c

Please sign in to comment.