diff --git a/go/bin/args.go b/go/bin/args.go index 88292a70..944e744d 100644 --- a/go/bin/args.go +++ b/go/bin/args.go @@ -27,7 +27,6 @@ import ( "os" "strings" - "github.com/ling0322/libllm/go/llm" "github.com/ling0322/libllm/go/skill" ) @@ -63,25 +62,19 @@ func (a *binArgs) addDeviceFlag() { a.fs.StringVar(&a.device, "device", "auto", "inference device, either cpu, cuda or auto") } -func (a *binArgs) getDevice() llm.Device { - var device llm.Device - if strings.ToLower(a.device) == "cpu" { - device = llm.Cpu - } else if strings.ToLower(a.device) == "cuda" { - device = llm.Cuda - } else if strings.ToLower(a.device) == "auto" { - device = llm.Auto +func (a *binArgs) getDevice() string { + device := strings.ToLower(a.device) + if device == "cpu" || device == "cuda" || device == "auto" { + return device } else { - log.Fatalf("unexpected device %s", a.device) + slog.Error(`invalid device name: must be one of "cpu", "cuda" or "auto"`) + a.fs.Usage() + os.Exit(1) } return device } -func (a *binArgs) getRawDevice() string { - return a.device -} - func (a *binArgs) addModelFlag() { a.fs.Var(&a.models, "m", "the libllm model, it could be model name or model file,"+ " model files are with suffix \".llmpkg\". "+ diff --git a/go/bin/chat.go b/go/bin/chat.go index 05b9d485..f12feef7 100644 --- a/go/bin/chat.go +++ b/go/bin/chat.go @@ -30,7 +30,7 @@ import ( "strings" "time" - "github.com/ling0322/libllm/go/skill" + "github.com/ling0322/libllm/go/llm" ) func printChatUsage(fs *flag.FlagSet) { @@ -64,16 +64,11 @@ func chatMain(args []string) { log.Fatal(err) } - llmChat, err := skill.NewChat(model) - if err != nil { - log.Fatal(err) - } - fmt.Println(gLocalizer.Get(MsgInputQuestion)) fmt.Println(gLocalizer.Get(MsgInputQuestionNew)) fmt.Println(gLocalizer.Get(MsgInputQuestionSys)) - history := []skill.Message{} + history := []llm.Message{} // TODO: get system prompt for different models systemPrompt := "" @@ -91,22 +86,22 @@ func chatMain(args []string) { question = strings.TrimSpace(question) if len(question) > 5 && strings.ToLower(question)[0:5] == ":sys " { systemPrompt = strings.TrimSpace(question[5:]) - history = []skill.Message{} + history = []llm.Message{} continue } else if strings.ToLower(question) == ":new" { fmt.Println(gLocalizer.Get(MsgNewSession)) - history = []skill.Message{} + history = []llm.Message{} continue } else if question == "" { continue } if len(history) == 0 && systemPrompt != "" { - history = append(history, skill.Message{Role: "system", Content: systemPrompt}) + history = append(history, llm.Message{Role: "system", Content: systemPrompt}) } - history = append(history, skill.Message{Role: "user", Content: question}) - comp, err := llmChat.Chat(history) + history = append(history, llm.Message{Role: "user", Content: question}) + comp, err := model.Complete(history, llm.DefaultCompletionConfig()) if err != nil { log.Fatal(err) } @@ -115,15 +110,15 @@ func chatMain(args []string) { answer := "" numToken := 0 for comp.Next() { - fmt.Print(comp.Text()) - answer += comp.Text() + fmt.Print(comp.Chunk().Text) + answer += comp.Chunk().Text numToken++ } - if err := comp.Error(); err != nil { + if err := comp.Err(); err != nil { log.Fatal(err) } - history = append(history, skill.Message{Role: "assistant", Content: answer}) + history = append(history, llm.Message{Role: "assistant", Content: answer}) fmt.Println() dur := time.Since(t0) diff --git a/go/bin/download.go b/go/bin/download.go index ebb61b58..ea6afea8 100644 --- a/go/bin/download.go +++ b/go/bin/download.go @@ -243,7 +243,7 @@ func autoDownloadModel(nameOrPath string) (filename string, err error) { return } -func createModelAutoDownload(nameOrPath string, device llm.Device) (llm.Model, error) { +func createModelAutoDownload(nameOrPath, device string) (*llm.Model, error) { modelPath, err := autoDownloadModel(nameOrPath) if err != nil { return nil, err diff --git a/go/bin/subtitle_translator.go b/go/bin/subtitle_translator.go index 2431bb88..9d4fcb72 100644 --- a/go/bin/subtitle_translator.go +++ b/go/bin/subtitle_translator.go @@ -33,7 +33,7 @@ import ( type transcriptionTranslator struct { historySrc []string historyTgt []string - translator skill.Translator + translator *skill.Translator srcLang skill.Lang tgtLang skill.Lang } @@ -53,7 +53,6 @@ func (t *transcriptionTranslator) translateOneInternal( if err != nil { return translationResult{}, err } - slog.Info(fmt.Sprintf("translate \"%s\" to \"%s\"", text, tr.tgtText)) return tr, nil } @@ -113,16 +112,21 @@ func (t *transcriptionTranslator) translate( return nil, err } - transcription.Text = tr.tgtText + translated := tr.tgtText + translated = strings.TrimSpace(translated) + translated = strings.Replace(translated, "\n", " ", -1) + slog.Info(fmt.Sprintf("translate \"%s\" to \"%s\"", tr.srcText, translated)) + + transcription.Text = translated translatedTrxn = append(translatedTrxn, transcription) } return translatedTrxn, nil } -func newTranscripotionTranslator( +func newTranscriptionTranslator( translationModel string, - device llm.Device, + device string, srcLang, tgtLang skill.Lang) (*transcriptionTranslator, error) { model, err := createModelAutoDownload(translationModel, device) if err != nil { @@ -167,7 +171,7 @@ func ReadSubtitleFile(filename string) ([]llm.RecognitionResult, error) { } func TranslateSubtitle(config translationConfig, inputFile, outputFile string) error { - tt, err := newTranscripotionTranslator( + tt, err := newTranscriptionTranslator( config.modelName, config.device, config.srcLang, diff --git a/go/bin/transcribe.go b/go/bin/transcribe.go index c0840260..49e10d28 100644 --- a/go/bin/transcribe.go +++ b/go/bin/transcribe.go @@ -37,7 +37,7 @@ type translationConfig struct { srcLang skill.Lang tgtLang skill.Lang modelName string - device llm.Device + device string } func printTranscribeUsage(fs *flag.FlagSet) { @@ -164,7 +164,7 @@ func transcribeMain(args []string) { var tgtLang skill.Lang modelFile := getTranscriotionModel(ba) tgtLang = ba.getTargetLang() - device := ba.getRawDevice() + device := ba.getDevice() inputFile := ba.getInput() outputFile := getOutputFile(ba) diff --git a/go/bin/translate.go b/go/bin/translate.go index 98fde4d9..4f1605ec 100644 --- a/go/bin/translate.go +++ b/go/bin/translate.go @@ -49,7 +49,7 @@ type translationResult struct { processingTime time.Duration } -func translate(translator skill.Translator, req skill.TranslationRequest, onToken func(string)) (translationResult, error) { +func translate(translator *skill.Translator, req skill.TranslationRequest, onToken func(string)) (translationResult, error) { text := strings.TrimSpace(req.Text) if text == "" { return translationResult{ @@ -70,12 +70,12 @@ func translate(translator skill.Translator, req skill.TranslationRequest, onToke numToken := 0 for comp.Next() { if onToken != nil { - onToken(comp.Text()) + onToken(comp.Chunk().Text) } - answer += comp.Text() + answer += comp.Chunk().Text numToken++ } - if err := comp.Error(); err != nil { + if err := comp.Err(); err != nil { return translationResult{}, err } diff --git a/go/llm/asr.go b/go/llm/asr.go index 71146376..cca809fa 100644 --- a/go/llm/asr.go +++ b/go/llm/asr.go @@ -20,19 +20,16 @@ package llm // #include -// #include "llm_api.h" -// const int cLLM_ERROR_EOF = LLM_ERROR_EOF; +// #include "llm.h" import "C" import ( - "encoding/json" "errors" "fmt" "runtime" "time" - "unsafe" ) -// A LLM. +// A ASR model. type ASRModel struct { model C.llm_asr_model_t } @@ -58,12 +55,8 @@ type Recognition struct { json *llmJson } -type llmJson struct { - json C.llm_json_t -} - func newRecognition() *Recognition { - r := &Recognition{} + r := new(Recognition) r.json = newJson() C.llm_asr_recognition_init(&r.recognition) runtime.SetFinalizer(r, func(r *Recognition) { @@ -73,47 +66,6 @@ func newRecognition() *Recognition { return r } -func newJson() *llmJson { - j := &llmJson{} - C.llm_json_init(&j.json) - runtime.SetFinalizer(j, func(j *llmJson) { - C.llm_json_destroy(&j.json) - }) - - return j -} - -func (j *llmJson) marshal(v any) error { - jsonBytes, err := json.Marshal(v) - if err != nil { - return err - } - - cJsonStr := C.CString(string(jsonBytes)) - defer C.free(unsafe.Pointer(cJsonStr)) - - status := C.llm_json_parse(&j.json, cJsonStr) - if status != 0 { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - return nil -} - -func (j *llmJson) unmarshal(v any) error { - bufSize := 2048 - buf := C.malloc(C.size_t(bufSize)) - defer C.free(buf) - - status := C.llm_json_dump(&j.json, (*C.char)(buf), C.int64_t(bufSize)) - if status != 0 { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - jsonStr := C.GoString((*C.char)(buf)) - return json.Unmarshal([]byte(jsonStr), v) -} - func NewASRModel(filename, device string) (*ASRModel, error) { err := initLlm() if err != nil { @@ -126,7 +78,7 @@ func NewASRModel(filename, device string) (*ASRModel, error) { "device": device, }) - m := &ASRModel{} + m := new(ASRModel) C.llm_asr_model_init(&m.model) runtime.SetFinalizer(m, func(m *ASRModel) { C.llm_asr_model_destroy(&m.model) @@ -134,7 +86,7 @@ func NewASRModel(filename, device string) (*ASRModel, error) { status := C.llm_asr_model_load(&m.model, &json.json) if status != 0 { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) + return nil, errors.New(C.GoString(C.llm_get_last_error_message())) } return m, nil @@ -150,7 +102,7 @@ func (m *ASRModel) Recognize(filename string) (*Recognition, error) { status := C.llm_asr_recognize_media_file(&m.model, &json.json, &r.recognition) if status != 0 { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) + return nil, errors.New(C.GoString(C.llm_get_last_error_message())) } return r, nil @@ -176,10 +128,10 @@ func (r *Recognition) Dispose() { func (r *Recognition) Next() bool { status := C.llm_asr_recognition_get_next_result(&r.recognition, &r.json.json) - if status == C.cLLM_ERROR_EOF { + if status == LLM_ERROR_EOF { return false } else if status != 0 { - r.err = errors.New(C.GoString(C.llmGetLastErrorMessage())) + r.err = errors.New(C.GoString(C.llm_get_last_error_message())) return false } diff --git a/go/llm/common.go b/go/llm/common.go new file mode 100644 index 00000000..004c482e --- /dev/null +++ b/go/llm/common.go @@ -0,0 +1,87 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package llm + +// #include +// #include "llm.h" +import "C" +import ( + "errors" + "os" + "path/filepath" + "runtime" + "sync/atomic" + "unsafe" +) + +const LLM_ERROR_EOF = 0x0103 + +var gInit atomic.Bool +var gDll unsafe.Pointer + +// Initialize the libllm. +func initLlm() error { + if !gInit.Swap(true) { + // load the shared library. + binPath, err := os.Executable() + if err != nil { + gInit.Store(false) + return err + } + + var libname string + if runtime.GOOS == "windows" { + libname = "llm.dll" + } else if runtime.GOOS == "linux" { + libname = "libllm.so" + } else if runtime.GOOS == "darwin" { + libname = "libllm.dylib" + } + + binDir := filepath.Dir(binPath) + dllPath := C.CString(filepath.Join(binDir, libname)) + defer C.free(unsafe.Pointer(dllPath)) + + gDll = C.llm_load_library(dllPath) + if gDll == nil { + dllPath := C.CString("libllm.so") + defer C.free(unsafe.Pointer(dllPath)) + + gDll = C.llm_load_library(dllPath) + } + + if gDll == nil { + gInit.Store(false) + return errors.New("failed to load the libllm dynamic library") + } + + // initialize the symbols. + status := C.llm_load_symbols(gDll) + if status != 0 { + gInit.Store(false) + return errors.New("failed to load libllm api symbols") + } + + // initialize libllm inference engine. + C.llm_init() + } + + return nil +} diff --git a/go/llm/completion.go b/go/llm/completion.go deleted file mode 100644 index fe3dbd14..00000000 --- a/go/llm/completion.go +++ /dev/null @@ -1,152 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2024 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package llm - -// #include -// #include "llm_api.h" -import "C" -import ( - "errors" - "fmt" - "log/slog" - "os" - "runtime" -) - -// Config for LLM completion. -type Completion interface { - Next() bool - Error() error - Text() string - - // Get the name of current token. - // For a normal token Text() will return its byte piece, for example, "foo "; Token() will - // return its name in model, for example, "hello_". - // For a control token, Text() will return an empty string ("") and Token() will return its - // name, for example, "<|endoftext|>". - Token() string - - Dispose() -} - -type completionHandle struct { - handle *C.llmCompletion_t -} - -type completionImpl struct { - handle *completionHandle - chunkText string - chunkToken string - err error -} - -func (c *completionImpl) Next() bool { - ok := C.llmCompletion_Next(c.handle.handle) != 0 - if !ok { - return false - } - - cText := C.llmCompletion_GetText(c.handle.handle) - if cText == nil { - c.err = errors.New(C.GoString(C.llmGetLastErrorMessage())) - return false - } - c.chunkText = C.GoString(cText) - - cToken := C.llmCompletion_GetToken(c.handle.handle) - if cToken == nil { - c.err = errors.New(C.GoString(C.llmGetLastErrorMessage())) - return false - } - c.chunkToken = C.GoString(cToken) - - return true -} - -func (c *completionImpl) Error() error { - if c.err != nil { - return c.err - } - - if C.llmCompletion_GetError(c.handle.handle) != C.LLM_OK { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - return nil -} - -func (c *completionImpl) Text() string { - return c.chunkText -} - -func (c *completionImpl) Token() string { - return c.chunkToken -} - -func (c *completionImpl) Dispose() { - c.handle.dispose() - c.handle = nil -} - -func newCompletionImpl(modelHandle *modelHandle) (*completionImpl, error) { - handle, err := newCompletionHandle(modelHandle) - if err != nil { - return nil, err - } - - return &completionImpl{ - handle: handle, - }, nil -} - -func (h *completionHandle) dispose() error { - if h.handle == nil { - return nil - } - - status := C.llmCompletion_Delete(h.handle) - if status != C.LLM_OK { - slog.Error( - "failed to call llmCompletion_Delete()", - "message", C.GoString(C.llmGetLastErrorMessage())) - } - - h.handle = nil - return nil -} - -func newCompletionHandle(m *modelHandle) (*completionHandle, error) { - cHandle := C.llmCompletion_New(m.handle) - if cHandle == nil { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - handle := &completionHandle{ - cHandle, - } - runtime.SetFinalizer(handle, func(h *completionHandle) { - status := C.llmCompletion_Delete(h.handle) - if status != C.LLM_OK { - fmt.Fprintln(os.Stderr, "failed to call llmCompletion_Delete()") - } - }) - - return handle, nil -} diff --git a/go/llm/completion_config.go b/go/llm/completion_config.go deleted file mode 100644 index 86e47db8..00000000 --- a/go/llm/completion_config.go +++ /dev/null @@ -1,117 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2024 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package llm - -// #include -// #include "llm_api.h" -import "C" -import ( - "errors" - "unsafe" -) - -// Config for LLM completion. -type CompletionConfig interface { - SetTopP(topP float32) - GetTopP() float32 - - SetTopK(topK int) - GetTopK() int - - SetTemperature(temperature float32) - GetTemperature() float32 - - SetConfig(key, value string) - - // update the llmCompletion_t according to the config. - updateCompHandle(compHandle *completionHandle) error -} - -type completionConfigImpl struct { - topP float32 - topK int - temperature float32 - - kvConfig map[string]string -} - -func NewCompletionConfig() CompletionConfig { - return &completionConfigImpl{ - topP: 0.8, - topK: 50, - temperature: 1.0, - kvConfig: map[string]string{}, - } -} - -func (c *completionConfigImpl) SetTopP(topP float32) { - c.topP = topP -} - -func (c *completionConfigImpl) GetTopP() float32 { - return c.topP -} - -func (c *completionConfigImpl) SetTopK(topK int) { - c.topK = topK -} - -func (c *completionConfigImpl) GetTopK() int { - return c.topK -} - -func (c *completionConfigImpl) SetTemperature(temperature float32) { - c.temperature = temperature -} - -func (c *completionConfigImpl) GetTemperature() float32 { - return c.temperature -} - -func (c *completionConfigImpl) SetConfig(key, value string) { - c.kvConfig[key] = value -} - -func (c *completionConfigImpl) updateCompHandle(compHandle *completionHandle) error { - if C.llmCompletion_SetTopP(compHandle.handle, C.float(c.topP)) != C.LLM_OK { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - if C.llmCompletion_SetTopK(compHandle.handle, C.int(c.topK)) != C.LLM_OK { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - if C.llmCompletion_SetTemperature(compHandle.handle, C.float(c.temperature)) != C.LLM_OK { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - for key, value := range c.kvConfig { - cKey := C.CString(key) - cValue := C.CString(value) - ok := C.llmCompletion_SetConfig(compHandle.handle, cKey, cValue) - C.free(unsafe.Pointer(cKey)) - C.free(unsafe.Pointer(cValue)) - if ok != C.LLM_OK { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - } - - return nil -} diff --git a/go/skill/chat.go b/go/llm/json.go similarity index 54% rename from go/skill/chat.go rename to go/llm/json.go index b9b4137f..db57ff99 100644 --- a/go/skill/chat.go +++ b/go/llm/json.go @@ -17,51 +17,59 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package skill +package llm +// #include +// #include "llm.h" +import "C" import ( - "github.com/ling0322/libllm/go/llm" + "encoding/json" + "errors" + "runtime" + "unsafe" ) -type Message struct { - Role string - Content string +type llmJson struct { + json C.llm_json_t } -type Chat struct { - model llm.Model - promptBuilder promptBuilder - compConfig llm.CompletionConfig +func newJson() *llmJson { + j := new(llmJson) + C.llm_json_init(&j.json) + runtime.SetFinalizer(j, func(j *llmJson) { + C.llm_json_destroy(&j.json) + }) + + return j } -func NewChat(model llm.Model) (*Chat, error) { - modelName := model.GetName() - promptBuilder, err := newPromptBuilder(modelName) +func (j *llmJson) marshal(v any) error { + jsonBytes, err := json.Marshal(v) if err != nil { - return nil, err + return err } - return &Chat{ - model: model, - promptBuilder: promptBuilder, - compConfig: llm.NewCompletionConfig(), - }, nil -} + cJsonStr := C.CString(string(jsonBytes)) + defer C.free(unsafe.Pointer(cJsonStr)) + + status := C.llm_json_parse(&j.json, cJsonStr) + if status != 0 { + return errors.New(C.GoString(C.llm_get_last_error_message())) + } -func (c *Chat) SetTemperature(temperature float32) { - c.compConfig.SetTemperature(temperature) + return nil } -func (c *Chat) Chat(history []Message) (llm.Completion, error) { - prompt, err := c.promptBuilder.Build(history) - if err != nil { - return nil, err - } +func (j *llmJson) unmarshal(v any) error { + bufSize := 2048 + buf := C.malloc(C.size_t(bufSize)) + defer C.free(buf) - comp, err := c.model.Complete(c.compConfig, prompt) - if err != nil { - return nil, err + status := C.llm_json_dump(&j.json, (*C.char)(buf), C.int64_t(bufSize)) + if status != 0 { + return errors.New(C.GoString(C.llm_get_last_error_message())) } - return comp, nil + jsonStr := C.GoString((*C.char)(buf)) + return json.Unmarshal([]byte(jsonStr), v) } diff --git a/go/llm/llm.c b/go/llm/llm.c new file mode 100644 index 00000000..16327262 --- /dev/null +++ b/go/llm/llm.c @@ -0,0 +1,215 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifdef __APPLE__ +#define LUT_PLATFORM_APPLE +#elif defined(linux) || defined(__linux) || defined(__linux__) +#define LUT_PLATFORM_LINUX +#elif defined(WIN32) || defined(__WIN32__) || defined(_MSC_VER) || defined(_WIN32) || \ + defined(__MINGW32__) +#define LUT_PLATFORM_WINDOWS +#else +#error unknown platform +#endif + +#if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) +#include +typedef void *LLM_HMODULE; +#elif defined(LUT_PLATFORM_WINDOWS) +#include +typedef HMODULE LLM_HMODULE; +#endif + +#include +#include + +#include "llm.h" + +// global state +void (*p_llm_init)(); +const char *(*p_llm_get_last_error_message)(); + +// llm + +int32_t (*p_llm_model_init)(llm_model_t *m); +int32_t (*p_llm_model_destroy)(llm_model_t *m); +int32_t (*p_llm_model_load)(llm_model_t *m, llm_json_t *kwargs); +int32_t (*p_llm_model_get_info)(llm_model_t *m, llm_json_t *info); +int32_t (*p_llm_model_complete)(llm_model_t *m, llm_json_t *kwargs, llm_completion_t *comp); + +int32_t (*p_llm_completion_init)(llm_completion_t *c); +int32_t (*p_llm_completion_destroy)(llm_completion_t *c); +int32_t (*p_llm_completion_get_next_chunk)(llm_completion_t *c, llm_json_t *chunk); + +// json +int32_t (*p_llm_json_init)(llm_json_t *j); +int32_t (*p_llm_json_destroy)(llm_json_t *j); +int32_t (*p_llm_json_parse)(llm_json_t *j, const char *json_str); +int32_t (*p_llm_json_dump)(llm_json_t *j, char *buf, int64_t buf_size); + +// asr + +int32_t (*p_llm_asr_model_init)(llm_asr_model_t *m); +int32_t (*p_llm_asr_model_load)(llm_asr_model_t *m, llm_json_t *options); +int32_t (*p_llm_asr_model_destroy)(llm_asr_model_t *m); +int32_t (*p_llm_asr_recognition_init)(llm_asr_recognition_t *r); +int32_t (*p_llm_asr_recognition_destroy)(llm_asr_recognition_t *r); +int32_t (*p_llm_asr_recognition_get_next_result)(llm_asr_recognition_t *r, llm_json_t *result); +int32_t (*p_llm_asr_recognize_media_file)( + llm_asr_model_t *model, + llm_json_t *options, + llm_asr_recognition_t *recognition); + +// load the libllm shared library. +void *llm_load_library(const char *library_path) { + // first try to load the dll from same folder as current module. +#if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) + return dlopen(library_path, RTLD_NOW); +#elif defined(LUT_PLATFORM_WINDOWS) + return LoadLibraryA(library_path); +#endif +} + +#if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) +#define GET_PROC_ADDR dlsym +#elif defined(LUT_PLATFORM_WINDOWS) +#define GET_PROC_ADDR (void *)GetProcAddress +#endif + +#define LOAD_SYMBOL(hDll, symbol) \ + p_##symbol = GET_PROC_ADDR(hDll, #symbol); \ + if (!p_##symbol) { \ + fprintf(stderr, "llm.go: unable to load symbol: %s\n", #symbol); \ + return LLM_ERROR_ABORTED; \ + } + +int32_t llm_load_symbols(void *pDll) { + LLM_HMODULE hDll = (LLM_HMODULE)pDll; + + LOAD_SYMBOL(hDll, llm_init); + LOAD_SYMBOL(hDll, llm_get_last_error_message); + LOAD_SYMBOL(hDll, llm_model_init); + LOAD_SYMBOL(hDll, llm_model_destroy); + LOAD_SYMBOL(hDll, llm_model_load); + LOAD_SYMBOL(hDll, llm_model_get_info); + LOAD_SYMBOL(hDll, llm_model_complete); + LOAD_SYMBOL(hDll, llm_completion_init); + LOAD_SYMBOL(hDll, llm_completion_destroy); + LOAD_SYMBOL(hDll, llm_completion_get_next_chunk); + LOAD_SYMBOL(hDll, llm_json_init); + LOAD_SYMBOL(hDll, llm_json_destroy); + LOAD_SYMBOL(hDll, llm_json_parse); + LOAD_SYMBOL(hDll, llm_json_dump); + LOAD_SYMBOL(hDll, llm_asr_model_init); + LOAD_SYMBOL(hDll, llm_asr_model_load); + LOAD_SYMBOL(hDll, llm_asr_model_destroy); + LOAD_SYMBOL(hDll, llm_asr_recognition_init); + LOAD_SYMBOL(hDll, llm_asr_recognition_destroy); + LOAD_SYMBOL(hDll, llm_asr_recognition_get_next_result); + LOAD_SYMBOL(hDll, llm_asr_recognize_media_file); + + return 0; +} + +void llm_init() { + return p_llm_init(); +} + +const char *llm_get_last_error_message() { + return p_llm_get_last_error_message(); +} + +int32_t llm_model_init(llm_model_t *m) { + return p_llm_model_init(m); +} + +int32_t llm_model_destroy(llm_model_t *m) { + return p_llm_model_destroy(m); +} + +int32_t llm_model_load(llm_model_t *m, llm_json_t *kwargs) { + return p_llm_model_load(m, kwargs); +} + +int32_t llm_model_get_info(llm_model_t *m, llm_json_t *info) { + return p_llm_model_get_info(m, info); +} + +int32_t llm_model_complete(llm_model_t *m, llm_json_t *kwargs, llm_completion_t *comp) { + return p_llm_model_complete(m, kwargs, comp); +} + +int32_t llm_completion_init(llm_completion_t *c) { + return p_llm_completion_init(c); +} + +int32_t llm_completion_destroy(llm_completion_t *c) { + return p_llm_completion_destroy(c); +} + +int32_t llm_completion_get_next_chunk(llm_completion_t *c, llm_json_t *chunk) { + return p_llm_completion_get_next_chunk(c, chunk); +} + +int32_t llm_json_init(llm_json_t *j) { + return p_llm_json_init(j); +} + +int32_t llm_json_destroy(llm_json_t *j) { + return p_llm_json_destroy(j); +} + +int32_t llm_json_parse(llm_json_t *j, const char *json_str) { + return p_llm_json_parse(j, json_str); +} + +int32_t llm_json_dump(llm_json_t *j, char *buf, int64_t buf_size) { + return p_llm_json_dump(j, buf, buf_size); +} + +int32_t llm_asr_model_init(llm_asr_model_t *m) { + return p_llm_asr_model_init(m); +} + +int32_t llm_asr_model_load(llm_asr_model_t *m, llm_json_t *options) { + return p_llm_asr_model_load(m, options); +} + +int32_t llm_asr_model_destroy(llm_asr_model_t *m) { + return p_llm_asr_model_destroy(m); +} + +int32_t llm_asr_recognition_init(llm_asr_recognition_t *r) { + return p_llm_asr_recognition_init(r); +} + +int32_t llm_asr_recognition_destroy(llm_asr_recognition_t *r) { + return p_llm_asr_recognition_destroy(r); +} + +int32_t llm_asr_recognition_get_next_result(llm_asr_recognition_t *r, llm_json_t *result) { + return p_llm_asr_recognition_get_next_result(r, result); +} + +int32_t llm_asr_recognize_media_file( + llm_asr_model_t *model, + llm_json_t *options, + llm_asr_recognition_t *recognition) { + return p_llm_asr_recognize_media_file(model, options, recognition); +} diff --git a/go/llm/llm.go b/go/llm/llm.go index 2b57e3dc..eab141d6 100644 --- a/go/llm/llm.go +++ b/go/llm/llm.go @@ -22,100 +22,119 @@ package llm // #cgo linux LDFLAGS: -ldl // #cgo darwin LDFLAGS: -ldl // #include -// #include "llm_api.h" +// #include "llm.h" import "C" import ( "errors" - "fmt" - "os" - "path/filepath" "runtime" - "sync/atomic" - "unsafe" ) -type Device int32 -type AudioFormat int32 +type Model struct { + model C.llm_model_t +} -const ( - Cpu = Device(0x0000) - Cuda = Device(0x0100) - Auto = Device(0x1f00) - Pcm16kHz16BitMono = AudioFormat(0x0001) -) +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type CompletionConfig struct { + Temperature float32 + TopK int + TopP float32 +} + +type Completion struct { + comp C.llm_completion_t + json *llmJson + err error + chunk Chunk +} -var gInit atomic.Bool -var gDll unsafe.Pointer - -// Initialize the libllm. -func initLlm() error { - if !gInit.Swap(true) { - // load the shared library. - binPath, err := os.Executable() - if err != nil { - gInit.Store(false) - return err - } - - var libname string - if runtime.GOOS == "windows" { - libname = "llm.dll" - } else if runtime.GOOS == "linux" { - libname = "libllm.so" - } else if runtime.GOOS == "darwin" { - libname = "libllm.dylib" - } - - binDir := filepath.Dir(binPath) - dllPath := C.CString(filepath.Join(binDir, libname)) - defer C.free(unsafe.Pointer(dllPath)) - - gDll = C.llmLoadLibrary(dllPath) - if gDll == nil { - dllPath := C.CString("libllm.so") - defer C.free(unsafe.Pointer(dllPath)) - - gDll = C.llmLoadLibrary(dllPath) - } - - if gDll == nil { - gInit.Store(false) - return errors.New("failed to load the libllm dynamic library") - } - - // initialize the symbols. - status := C.llmLoadSymbols(gDll) - if status != C.LLM_OK { - gInit.Store(false) - return errors.New("failed to load libllm api symbols") - } - - // initialize libllm inference engine. - status = C.llmInit(C.LLM_API_VERSION) - if status != C.LLM_OK { - gInit.Store(false) - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } +type Chunk struct { + Text string `json:"text"` +} + +func NewModel(filename, device string) (*Model, error) { + err := initLlm() + if err != nil { + return nil, err + } + + json := newJson() + json.marshal(map[string]any{ + "filename": filename, + "device": device, + }) + + m := new(Model) + C.llm_model_init(&m.model) + runtime.SetFinalizer(m, func(m *Model) { + C.llm_model_destroy(&m.model) + }) + + status := C.llm_model_load(&m.model, &json.json) + if status != 0 { + return nil, errors.New(C.GoString(C.llm_get_last_error_message())) } - return nil + return m, nil } -// Release all the resources allocated in libllm library. -func Release() { - if gInit.Swap(false) { - status := C.llmDestroy() - if status != C.LLM_OK { - fmt.Fprintf( - os.Stderr, - "failed to destroy libllm: %s\n", - C.GoString(C.llmGetLastErrorMessage())) - } - - // release the dynamic library itself. - status = C.llmDestroyLibrary(gDll) - if status != C.LLM_OK { - fmt.Fprintf(os.Stderr, "failed to close dynamic library of libllm\n") - } +func DefaultCompletionConfig() CompletionConfig { + config := CompletionConfig{} + config.Temperature = 1.0 + config.TopK = 50 + config.TopP = 0.8 + + return config +} + +func (m *Model) Complete(history []Message, config CompletionConfig) (*Completion, error) { + if config.Temperature <= 0 || config.TopK <= 0 || config.TopP <= 0 { + return nil, errors.New("invalid completion config") + } + + json := newJson() + json.marshal(map[string]any{ + "temperature": config.Temperature, + "top_k": config.TopK, + "top_p": config.TopP, + "messages": history, + }) + + comp := new(Completion) + comp.json = newJson() + C.llm_completion_init(&comp.comp) + runtime.SetFinalizer(comp, func(c *Completion) { + C.llm_completion_destroy(&c.comp) + }) + + status := C.llm_model_complete(&m.model, &json.json, &comp.comp) + if status != 0 { + return nil, errors.New(C.GoString(C.llm_get_last_error_message())) + } + + return comp, nil +} + +func (c *Completion) Err() error { + return c.err +} + +func (c *Completion) Chunk() Chunk { + return c.chunk +} + +func (c *Completion) Next() bool { + status := C.llm_completion_get_next_chunk(&c.comp, &c.json.json) + if status == LLM_ERROR_EOF { + return false + } else if status != 0 { + c.err = errors.New(C.GoString(C.llm_get_last_error_message())) + return false } + + c.err = c.json.unmarshal(&c.chunk) + return c.err == nil } diff --git a/go/llm/llm_api.h b/go/llm/llm.h similarity index 51% rename from go/llm/llm_api.h rename to go/llm/llm.h index 8c4b92fa..645506fc 100644 --- a/go/llm/llm_api.h +++ b/go/llm/llm.h @@ -22,80 +22,53 @@ #include -#define LLM_DEVICE_CPU 0x0000 -#define LLM_DEVICE_CUDA 0x0100 -#define LLM_DEVICE_AUTO 0x1f00 -#define LLM_API_VERSION 20240101 -#define LLM_WAVE_FORMAT_PCM16KHZ16BITMONO 0x0001 -#define LLM_OK 0 -#define LLM_ABORTED 1 - #define LLM_ERROR_INVALID_ARG 0x0100 #define LLM_ERROR_INSUFFICIENT_BUFFER 0x0101 #define LLM_ERROR_ABORTED 0x0102 #define LLM_ERROR_EOF 0x0103 -typedef int32_t llmStatus_t; -typedef struct llmModel_t llmModel_t; -typedef struct llmChunk_t llmChunk_t; -typedef struct llmPrompt_t llmPrompt_t; -typedef struct llmCompletion_t llmCompletion_t; -typedef int32_t llmBool_t; -typedef int8_t llmByte_t; +typedef struct llm_model_impl_t *llm_model_t; +typedef struct llm_completion_impl_t *llm_completion_t; +typedef struct llm_json_impl_t *llm_json_t; +typedef struct llm_asr_recognition_impl_t *llm_asr_recognition_t; +typedef struct llm_asr_model_impl_t *llm_asr_model_t; + +// library -void *llmLoadLibrary(const char *libraryPath); -llmStatus_t llmLoadSymbols(void *hDll); -llmStatus_t llmDestroyLibrary(void *handle); +void *llm_load_library(const char *library_path); +int32_t llm_load_symbols(void *pDll); // global state -llmStatus_t llmInit(int32_t apiVersion); -llmStatus_t llmDestroy(); -const char *llmGetLastErrorMessage(); - -// llmModel_t -llmModel_t *llmModel_New(); -llmStatus_t llmModel_Delete(llmModel_t *model); -llmStatus_t llmModel_SetFile(llmModel_t *model, const char *filename); -llmStatus_t llmModel_SetDevice(llmModel_t *model, int32_t device); -llmStatus_t llmModel_Load(llmModel_t *model); -const char *llmModel_GetName(llmModel_t *model); - -// llmPrompt_t -llmPrompt_t *llmPrompt_New(); -llmStatus_t llmPrompt_Delete(llmPrompt_t *prompt); -llmStatus_t llmPrompt_AppendText(llmPrompt_t *prompt, const char *text); -llmStatus_t llmPrompt_AppendControlToken(llmPrompt_t *prompt, const char *token); - -llmStatus_t llmPrompt_AppendAudio( - llmPrompt_t *prompt, - const llmByte_t *audio, - int64_t size, - int32_t format); - -// llmCompletion_t -llmCompletion_t *llmCompletion_New(llmModel_t *model); -llmStatus_t llmCompletion_Delete(llmCompletion_t *comp); -llmStatus_t llmCompletion_SetConfig(llmCompletion_t *comp, const char *key, const char *value); -llmStatus_t llmCompletion_SetPrompt(llmCompletion_t *comp, llmPrompt_t *prompt); -llmStatus_t llmCompletion_SetTopP(llmCompletion_t *comp, float topP); -llmStatus_t llmCompletion_SetTopK(llmCompletion_t *comp, int32_t topK); -llmStatus_t llmCompletion_SetTemperature(llmCompletion_t *comp, float temperature); -llmBool_t llmCompletion_Next(llmCompletion_t *comp); -llmStatus_t llmCompletion_GetError(llmCompletion_t *comp); -const char *llmCompletion_GetText(llmCompletion_t *comp); -const char *llmCompletion_GetToken(llmCompletion_t *comp); +void llm_init(); +const char *llm_get_last_error_message(); -typedef struct llm_json_impl_t *llm_json_t; +// JSON int32_t llm_json_init(llm_json_t *j); int32_t llm_json_destroy(llm_json_t *j); int32_t llm_json_parse(llm_json_t *j, const char *json_str); int32_t llm_json_dump(llm_json_t *j, char *buf, int64_t buf_size); -// ASR +// LLM -typedef struct llm_asr_recognition_impl_t *llm_asr_recognition_t; -typedef struct llm_asr_model_impl_t *llm_asr_model_t; +int32_t llm_model_init(llm_model_t *m); +int32_t llm_model_destroy(llm_model_t *m); +int32_t llm_model_load(llm_model_t *m, llm_json_t *kwargs); +int32_t llm_model_get_info(llm_model_t *m, llm_json_t *info); +int32_t llm_model_complete(llm_model_t *m, llm_json_t *kwargs, llm_completion_t *comp); + +int32_t llm_completion_init(llm_completion_t *c); +int32_t llm_completion_destroy(llm_completion_t *c); +int32_t llm_completion_get_next_chunk(llm_completion_t *c, llm_json_t *chunk); + +// JSON + +int32_t llm_json_init(llm_json_t *j); +int32_t llm_json_destroy(llm_json_t *j); +int32_t llm_json_parse(llm_json_t *j, const char *json_str); +int32_t llm_json_dump(llm_json_t *j, char *buf, int64_t buf_size); + +// ASR int32_t llm_asr_model_init(llm_asr_model_t *m); int32_t llm_asr_model_load(llm_asr_model_t *m, llm_json_t *options); diff --git a/go/llm/llm_api.c b/go/llm/llm_api.c deleted file mode 100644 index e1aee4cb..00000000 --- a/go/llm/llm_api.c +++ /dev/null @@ -1,373 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2024 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#ifdef __APPLE__ -#define LUT_PLATFORM_APPLE -#elif defined(linux) || defined(__linux) || defined(__linux__) -#define LUT_PLATFORM_LINUX -#elif defined(WIN32) || defined(__WIN32__) || defined(_MSC_VER) || defined(_WIN32) || \ - defined(__MINGW32__) -#define LUT_PLATFORM_WINDOWS -#else -#error unknown platform -#endif - -#if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) -#include -typedef void *LLM_HMODULE; -#elif defined(LUT_PLATFORM_WINDOWS) -#include -typedef HMODULE LLM_HMODULE; -#endif - -#include -#include - -#include "llm_api.h" - -// global state -llmStatus_t (*p_llmInit)(int32_t apiVersion); -llmStatus_t (*p_llmDestroy)(); -const char *(*p_llmGetLastErrorMessage)(); - -// llmModel_t -llmModel_t *(*p_llmModel_New)(); -llmStatus_t (*p_llmModel_Delete)(llmModel_t *model); -llmStatus_t (*p_llmModel_SetFile)(llmModel_t *model, const char *filename); -llmStatus_t (*p_llmModel_SetDevice)(llmModel_t *model, int32_t device); -llmStatus_t (*p_llmModel_Load)(llmModel_t *model); -const char *(*p_llmModel_GetName)(llmModel_t *model); - -// llmPrompt_t -llmPrompt_t *(*p_llmPrompt_New)(); -llmStatus_t (*p_llmPrompt_Delete)(llmPrompt_t *prompt); -llmStatus_t (*p_llmPrompt_AppendText)(llmPrompt_t *prompt, const char *text); -llmStatus_t (*p_llmPrompt_AppendControlToken)(llmPrompt_t *prompt, const char *token); -llmStatus_t (*p_llmPrompt_AppendAudio)( - llmPrompt_t *prompt, - const llmByte_t *audio, - int64_t size, - int32_t format); - -// llmCompletion_t -llmCompletion_t *(*p_llmCompletion_New)(llmModel_t *model); -llmStatus_t (*p_llmCompletion_Delete)(llmCompletion_t *comp); -llmStatus_t (*p_llmCompletion_SetPrompt)(llmCompletion_t *comp, llmPrompt_t *prompt); -llmStatus_t (*p_llmCompletion_SetTopP)(llmCompletion_t *comp, float topP); -llmStatus_t (*p_llmCompletion_SetTopK)(llmCompletion_t *comp, int32_t topK); -llmStatus_t (*p_llmCompletion_SetTemperature)(llmCompletion_t *comp, float temperature); -llmStatus_t (*p_llmCompletion_SetConfig)(llmCompletion_t *comp, const char *key, const char *value); -llmBool_t (*p_llmCompletion_Next)(llmCompletion_t *comp); -llmStatus_t (*p_llmCompletion_GetError)(llmCompletion_t *comp); -const char *(*p_llmCompletion_GetText)(llmCompletion_t *comp); -const char *(*p_llmCompletion_GetToken)(llmCompletion_t *comp); - -// json -int32_t (*p_llm_json_init)(llm_json_t *j); -int32_t (*p_llm_json_destroy)(llm_json_t *j); -int32_t (*p_llm_json_parse)(llm_json_t *j, const char *json_str); -int32_t (*p_llm_json_dump)(llm_json_t *j, char *buf, int64_t buf_size); - -// asr - -int32_t (*p_llm_asr_model_init)(llm_asr_model_t *m); -int32_t (*p_llm_asr_model_load)(llm_asr_model_t *m, llm_json_t *options); -int32_t (*p_llm_asr_model_destroy)(llm_asr_model_t *m); -int32_t (*p_llm_asr_recognition_init)(llm_asr_recognition_t *r); -int32_t (*p_llm_asr_recognition_destroy)(llm_asr_recognition_t *r); -int32_t (*p_llm_asr_recognition_get_next_result)(llm_asr_recognition_t *r, llm_json_t *result); -int32_t (*p_llm_asr_recognize_media_file)( - llm_asr_model_t *model, - llm_json_t *options, - llm_asr_recognition_t *recognition); - -// load the libllm shared library. -void *llmLoadLibrary(const char *libraryPath) { - // first try to load the dll from same folder as current module. -#if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) - return dlopen(libraryPath, RTLD_NOW); -#elif defined(LUT_PLATFORM_WINDOWS) - return LoadLibraryA(libraryPath); -#endif -} - -#if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) -#define GET_PROC_ADDR dlsym -#elif defined(LUT_PLATFORM_WINDOWS) -#define GET_PROC_ADDR (void *)GetProcAddress -#endif - -#define LOAD_SYMBOL(hDll, symbol) \ - p_##symbol = GET_PROC_ADDR(hDll, #symbol); \ - if (!p_##symbol) { \ - fprintf(stderr, "llm.go: unable to load symbol: %s\n", #symbol); \ - return LLM_ABORTED; \ - } - -llmStatus_t llmLoadSymbols(void *pDll) { - LLM_HMODULE hDll = (LLM_HMODULE)pDll; - - LOAD_SYMBOL(hDll, llmInit); - LOAD_SYMBOL(hDll, llmDestroy); - LOAD_SYMBOL(hDll, llmGetLastErrorMessage); - LOAD_SYMBOL(hDll, llmModel_New); - LOAD_SYMBOL(hDll, llmModel_Delete); - LOAD_SYMBOL(hDll, llmModel_SetFile); - LOAD_SYMBOL(hDll, llmModel_SetDevice); - LOAD_SYMBOL(hDll, llmModel_Load); - LOAD_SYMBOL(hDll, llmModel_GetName); - LOAD_SYMBOL(hDll, llmPrompt_New); - LOAD_SYMBOL(hDll, llmPrompt_Delete); - LOAD_SYMBOL(hDll, llmPrompt_AppendText); - LOAD_SYMBOL(hDll, llmPrompt_AppendControlToken); - LOAD_SYMBOL(hDll, llmPrompt_AppendAudio); - LOAD_SYMBOL(hDll, llmCompletion_New); - LOAD_SYMBOL(hDll, llmCompletion_Delete); - LOAD_SYMBOL(hDll, llmCompletion_SetPrompt); - LOAD_SYMBOL(hDll, llmCompletion_SetTopP); - LOAD_SYMBOL(hDll, llmCompletion_SetTopK); - LOAD_SYMBOL(hDll, llmCompletion_SetTemperature); - LOAD_SYMBOL(hDll, llmCompletion_SetConfig); - LOAD_SYMBOL(hDll, llmCompletion_Next); - LOAD_SYMBOL(hDll, llmCompletion_GetError); - LOAD_SYMBOL(hDll, llmCompletion_GetText); - LOAD_SYMBOL(hDll, llmCompletion_GetToken); - LOAD_SYMBOL(hDll, llm_json_init); - LOAD_SYMBOL(hDll, llm_json_destroy); - LOAD_SYMBOL(hDll, llm_json_parse); - LOAD_SYMBOL(hDll, llm_json_dump); - LOAD_SYMBOL(hDll, llm_asr_model_init); - LOAD_SYMBOL(hDll, llm_asr_model_load); - LOAD_SYMBOL(hDll, llm_asr_model_destroy); - LOAD_SYMBOL(hDll, llm_asr_recognition_init); - LOAD_SYMBOL(hDll, llm_asr_recognition_destroy); - LOAD_SYMBOL(hDll, llm_asr_recognition_get_next_result); - LOAD_SYMBOL(hDll, llm_asr_recognize_media_file); - - return LLM_OK; -} - -// load the libllm shared library. -llmStatus_t llmDestroyLibrary(void *handle) { - p_llmInit = NULL; - p_llmDestroy = NULL; - p_llmGetLastErrorMessage = NULL; - p_llmModel_New = NULL; - p_llmModel_Delete = NULL; - p_llmModel_SetFile = NULL; - p_llmModel_SetDevice = NULL; - p_llmModel_Load = NULL; - p_llmModel_GetName = NULL; - p_llmPrompt_New = NULL; - p_llmPrompt_Delete = NULL; - p_llmPrompt_AppendText = NULL; - p_llmPrompt_AppendControlToken = NULL; - p_llmPrompt_AppendAudio = NULL; - p_llmCompletion_New = NULL; - p_llmCompletion_Delete = NULL; - p_llmCompletion_SetPrompt = NULL; - p_llmCompletion_SetTopP = NULL; - p_llmCompletion_SetTopK = NULL; - p_llmCompletion_SetTemperature = NULL; - p_llmCompletion_SetConfig = NULL; - p_llmCompletion_Next = NULL; - p_llmCompletion_GetError = NULL; - p_llmCompletion_GetText = NULL; - p_llmCompletion_GetToken = NULL; - p_llm_json_init = NULL; - p_llm_json_destroy = NULL; - p_llm_json_parse = NULL; - p_llm_json_dump = NULL; - p_llm_asr_model_init = NULL; - p_llm_asr_model_load = NULL; - p_llm_asr_model_destroy = NULL; - p_llm_asr_recognition_init = NULL; - p_llm_asr_recognition_destroy = NULL; - p_llm_asr_recognition_get_next_result = NULL; - p_llm_asr_recognize_media_file = NULL; - - // first try to load the dll from same folder as current module. -#if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) - int ret = dlclose(handle); - if (ret != 0) { - return LLM_ABORTED; - } -#elif defined(LUT_PLATFORM_WINDOWS) - BOOL success = FreeLibrary((LLM_HMODULE)handle); - if (FALSE == success) { - return LLM_ABORTED; - } -#endif - - return LLM_OK; -} - -llmStatus_t llmInit(int32_t apiVersion) { - return p_llmInit(apiVersion); -} - -llmStatus_t llmDestroy() { - return p_llmDestroy(); -} - -const char *llmGetLastErrorMessage() { - return p_llmGetLastErrorMessage(); -} - -// llmModel_t -llmModel_t *llmModel_New() { - return p_llmModel_New(); -} - -llmStatus_t llmModel_Delete(llmModel_t *model) { - return p_llmModel_Delete(model); -} - -llmStatus_t llmModel_SetFile(llmModel_t *model, const char *filename) { - return p_llmModel_SetFile(model, filename); -} - -llmStatus_t llmModel_SetDevice(llmModel_t *model, int32_t device) { - return p_llmModel_SetDevice(model, device); -} - -llmStatus_t llmModel_Load(llmModel_t *model) { - return p_llmModel_Load(model); -} - -const char *llmModel_GetName(llmModel_t *model) { - return p_llmModel_GetName(model); -} - -// llmPrompt_t -llmPrompt_t *llmPrompt_New() { - return p_llmPrompt_New(); -} - -llmStatus_t llmPrompt_Delete(llmPrompt_t *prompt) { - return p_llmPrompt_Delete(prompt); -} - -llmStatus_t llmPrompt_AppendText(llmPrompt_t *prompt, const char *text) { - return p_llmPrompt_AppendText(prompt, text); -} - -llmStatus_t llmPrompt_AppendControlToken(llmPrompt_t *prompt, const char *token) { - return p_llmPrompt_AppendControlToken(prompt, token); -} - -llmStatus_t llmPrompt_AppendAudio( - llmPrompt_t *prompt, - const llmByte_t *audio, - int64_t size, - int32_t format) { - return p_llmPrompt_AppendAudio(prompt, audio, size, format); -} - -// llmCompletion_t -llmCompletion_t *llmCompletion_New(llmModel_t *model) { - return p_llmCompletion_New(model); -} - -llmStatus_t llmCompletion_Delete(llmCompletion_t *comp) { - return p_llmCompletion_Delete(comp); -} - -llmStatus_t llmCompletion_SetPrompt(llmCompletion_t *comp, llmPrompt_t *prompt) { - return p_llmCompletion_SetPrompt(comp, prompt); -} - -llmStatus_t llmCompletion_SetTopP(llmCompletion_t *comp, float topP) { - return p_llmCompletion_SetTopP(comp, topP); -} - -llmStatus_t llmCompletion_SetTopK(llmCompletion_t *comp, int32_t topK) { - return p_llmCompletion_SetTopK(comp, topK); -} - -llmStatus_t llmCompletion_SetTemperature(llmCompletion_t *comp, float temperature) { - return p_llmCompletion_SetTemperature(comp, temperature); -} - -llmStatus_t llmCompletion_SetConfig(llmCompletion_t *comp, const char *key, const char *value) { - return p_llmCompletion_SetConfig(comp, key, value); -} - -llmBool_t llmCompletion_Next(llmCompletion_t *comp) { - return p_llmCompletion_Next(comp); -} - -llmStatus_t llmCompletion_GetError(llmCompletion_t *comp) { - return p_llmCompletion_GetError(comp); -} - -const char *llmCompletion_GetText(llmCompletion_t *comp) { - return p_llmCompletion_GetText(comp); -} - -int32_t llm_json_init(llm_json_t *j) { - return p_llm_json_init(j); -} - -int32_t llm_json_destroy(llm_json_t *j) { - return p_llm_json_destroy(j); -} - -int32_t llm_json_parse(llm_json_t *j, const char *json_str) { - return p_llm_json_parse(j, json_str); -} - -int32_t llm_json_dump(llm_json_t *j, char *buf, int64_t buf_size) { - return p_llm_json_dump(j, buf, buf_size); -} - -const char *llmCompletion_GetToken(llmCompletion_t *comp) { - return p_llmCompletion_GetToken(comp); -} - -int32_t llm_asr_model_init(llm_asr_model_t *m) { - return p_llm_asr_model_init(m); -} - -int32_t llm_asr_model_load(llm_asr_model_t *m, llm_json_t *options) { - return p_llm_asr_model_load(m, options); -} - -int32_t llm_asr_model_destroy(llm_asr_model_t *m) { - return p_llm_asr_model_destroy(m); -} - -int32_t llm_asr_recognition_init(llm_asr_recognition_t *r) { - return p_llm_asr_recognition_init(r); -} - -int32_t llm_asr_recognition_destroy(llm_asr_recognition_t *r) { - return p_llm_asr_recognition_destroy(r); -} - -int32_t llm_asr_recognition_get_next_result(llm_asr_recognition_t *r, llm_json_t *result) { - return p_llm_asr_recognition_get_next_result(r, result); -} - -int32_t llm_asr_recognize_media_file( - llm_asr_model_t *model, - llm_json_t *options, - llm_asr_recognition_t *recognition) { - return p_llm_asr_recognize_media_file(model, options, recognition); -} diff --git a/go/llm/model.go b/go/llm/model.go deleted file mode 100644 index d78d5f00..00000000 --- a/go/llm/model.go +++ /dev/null @@ -1,151 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2024 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package llm - -// #include -// #include "llm_api.h" -import "C" -import ( - "errors" - "log/slog" - "runtime" - "unsafe" -) - -// A LLM. -type Model interface { - GetName() string - Complete(config CompletionConfig, prompt Prompt) (Completion, error) - Dispose() -} - -type modelHandle struct { - handle *C.llmModel_t -} - -type modelImpl struct { - handle *modelHandle -} - -// Load a LLM model from `modelPath`, then save it to the specified device. -func NewModel(modelPath string, device Device) (Model, error) { - err := initLlm() - if err != nil { - return nil, err - } - - handle, err := newModelHandle() - if err != nil { - return nil, err - } - - cPath := C.CString(modelPath) - defer C.free(unsafe.Pointer(cPath)) - if C.llmModel_SetFile(handle.handle, cPath) != C.LLM_OK { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - if C.llmModel_SetDevice(handle.handle, C.int32_t(device)) != C.LLM_OK { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - if C.llmModel_Load(handle.handle) != C.LLM_OK { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - model := &modelImpl{ - handle: handle, - } - return model, nil -} - -func (m *modelImpl) Complete(config CompletionConfig, prompt Prompt) (Completion, error) { - comp, err := newCompletionImpl(m.handle) - if err != nil { - return nil, err - } - - err = config.updateCompHandle(comp.handle) - if err != nil { - return nil, err - } - - promptHandle, err := newPromptHandle() - if err != nil { - return nil, err - } - - err = prompt.updatePromptHandle(promptHandle) - if err != nil { - return nil, err - } - - ok := C.llmCompletion_SetPrompt(comp.handle.handle, promptHandle.handle) - if ok != C.LLM_OK { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - return comp, nil -} - -// Get the name of model. -func (m *modelImpl) GetName() string { - name := C.llmModel_GetName(m.handle.handle) - if name == nil { - return "" - } else { - return C.GoString(name) - } -} - -func (m *modelImpl) Dispose() { - m.handle.dispose() - m.handle = nil -} - -func (h *modelHandle) dispose() { - if h.handle == nil { - return - } - status := C.llmModel_Delete(h.handle) - if status != C.LLM_OK { - slog.Error( - "failed to call llmModel_Delete()", - "message", C.GoString(C.llmGetLastErrorMessage())) - } - - h.handle = nil -} - -func newModelHandle() (*modelHandle, error) { - cHandle := C.llmModel_New() - if cHandle == nil { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - handle := &modelHandle{ - cHandle, - } - runtime.SetFinalizer(handle, func(h *modelHandle) { - h.dispose() - }) - - return handle, nil -} diff --git a/go/llm/prompt.go b/go/llm/prompt.go deleted file mode 100644 index 07708c55..00000000 --- a/go/llm/prompt.go +++ /dev/null @@ -1,156 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2024 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package llm - -// #include -// #include "llm_api.h" -import "C" -import ( - "errors" - "fmt" - "os" - "runtime" - "unsafe" -) - -// The input of LLM. -type Prompt interface { - AppendText(text string) - AppendControlToken(text string) - AppendAudio(data []byte, format AudioFormat) - - // Update the llmPrompt_t instance according to the current prompt. - updatePromptHandle(handle *promptHandle) error -} - -type promptImpl struct { - elements []promptElem -} - -type promptElem interface { - AppendTo(handle *promptHandle) error -} - -type textPromptElem struct { - text string -} - -type controlTokenPromptElem struct { - token string -} - -type audioPromptElem struct { - payload []byte - format AudioFormat -} - -type promptHandle struct { - handle *C.llmPrompt_t -} - -func NewPrompt() Prompt { - return &promptImpl{} -} - -func (p *promptImpl) AppendText(text string) { - p.elements = append(p.elements, &textPromptElem{text}) -} - -func (p *promptImpl) AppendControlToken(text string) { - p.elements = append(p.elements, &controlTokenPromptElem{text}) -} - -func (p *promptImpl) AppendAudio(audio []byte, format AudioFormat) { - p.elements = append(p.elements, &audioPromptElem{audio, format}) -} - -func (p *promptImpl) updatePromptHandle(handle *promptHandle) error { - if len(p.elements) == 0 { - return errors.New("prompt is empty") - } - - for _, elem := range p.elements { - err := elem.AppendTo(handle) - if err != nil { - return err - } - } - - return nil -} - -func (e *textPromptElem) AppendTo(handle *promptHandle) error { - cText := C.CString(e.text) - defer C.free(unsafe.Pointer(cText)) - - status := C.llmPrompt_AppendText(handle.handle, cText) - if status != C.LLM_OK { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - return nil -} - -func (e *controlTokenPromptElem) AppendTo(handle *promptHandle) error { - cToken := C.CString(e.token) - defer C.free(unsafe.Pointer(cToken)) - - status := C.llmPrompt_AppendControlToken(handle.handle, cToken) - if status != C.LLM_OK { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - return nil -} - -func (e *audioPromptElem) AppendTo(handle *promptHandle) error { - cPayload := C.CBytes(e.payload) - defer C.free(unsafe.Pointer(cPayload)) - - status := C.llmPrompt_AppendAudio( - handle.handle, - (*C.llmByte_t)(cPayload), - C.int64_t(len(e.payload)), - C.int32_t(Pcm16kHz16BitMono)) - if status != C.LLM_OK { - return errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - return nil -} - -func newPromptHandle() (*promptHandle, error) { - cHandle := C.llmPrompt_New() - if cHandle == nil { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - - handle := &promptHandle{ - cHandle, - } - runtime.SetFinalizer(handle, func(h *promptHandle) { - status := C.llmPrompt_Delete(h.handle) - if status != C.LLM_OK { - fmt.Fprintln(os.Stderr, "failed to call llmPrompt_Delete()") - } - }) - - return handle, nil -} diff --git a/go/skill/llama.go b/go/skill/llama.go deleted file mode 100644 index 4afa08bf..00000000 --- a/go/skill/llama.go +++ /dev/null @@ -1,63 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2024 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package skill - -import ( - "errors" - - "github.com/ling0322/libllm/go/llm" -) - -type Llama struct { -} - -func (l *Llama) Build(history []Message) (llm.Prompt, error) { - prompt := llm.NewPrompt() - prompt.AppendControlToken("<|begin_of_text|>") - for _, message := range history[:len(history)-1] { - prompt.AppendControlToken("<|start_header_id|>") - prompt.AppendText(message.Role) - prompt.AppendControlToken("<|end_header_id|>") - prompt.AppendText("\n\n" + message.Content) - prompt.AppendControlToken("<|eot_id|>") - } - - lastMessage := history[len(history)-1] - if lastMessage.Role == "user" { - prompt.AppendControlToken("<|start_header_id|>") - prompt.AppendText(lastMessage.Role) - prompt.AppendControlToken("<|end_header_id|>") - prompt.AppendText("\n\n" + lastMessage.Content) - prompt.AppendControlToken("<|eot_id|>") - prompt.AppendControlToken("<|start_header_id|>") - prompt.AppendText("assistant") - prompt.AppendControlToken("<|end_header_id|>") - prompt.AppendText("\n\n") - } else if lastMessage.Role == "assistant" { - prompt.AppendControlToken("<|start_header_id|>") - prompt.AppendText(lastMessage.Role) - prompt.AppendControlToken("<|end_header_id|>") - prompt.AppendText("\n\n" + lastMessage.Content) - } else { - return nil, errors.New("last message should be either user or assistant") - } - - return prompt, nil -} diff --git a/go/skill/qwen.go b/go/skill/qwen.go deleted file mode 100644 index b4729cf8..00000000 --- a/go/skill/qwen.go +++ /dev/null @@ -1,61 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2024 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package skill - -import ( - "errors" - "fmt" - - "github.com/ling0322/libllm/go/llm" -) - -type Qwen struct { -} - -func (l *Qwen) Build(history []Message) (llm.Prompt, error) { - if len(history) == 0 { - return nil, errors.New("history is empty") - } - - prompt := llm.NewPrompt() - for _, message := range history[:len(history)-1] { - prompt.AppendControlToken("<|im_start|>") - prompt.AppendText(fmt.Sprintf("%s\n%s", message.Role, message.Content)) - prompt.AppendControlToken("<|im_end|>") - prompt.AppendText("\n") - } - - lastMessage := history[len(history)-1] - if lastMessage.Role == "user" { - prompt.AppendControlToken("<|im_start|>") - prompt.AppendText(fmt.Sprintf("%s\n%s", lastMessage.Role, lastMessage.Content)) - prompt.AppendControlToken("<|im_end|>") - prompt.AppendText("\n") - prompt.AppendControlToken("<|im_start|>") - prompt.AppendText("assistant\n") - } else if lastMessage.Role == "assistant" { - prompt.AppendControlToken("<|im_start|>") - prompt.AppendText(fmt.Sprintf("%s\n%s", lastMessage.Role, lastMessage.Content)) - } else { - return nil, errors.New("last message should be either user or assistant") - } - - return prompt, nil -} diff --git a/go/skill/transcriber.go b/go/skill/transcriber.go deleted file mode 100644 index e2ccdff6..00000000 --- a/go/skill/transcriber.go +++ /dev/null @@ -1,87 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2024 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package skill - -import ( - "fmt" - "time" - - "github.com/ling0322/libllm/go/llm" -) - -type TranscriptionResult struct { - Begin time.Duration - End time.Duration - Language string - Text string -} - -type Transcriber interface { - Transcribe() bool - Result() TranscriptionResult - Err() error - Dispose() -} - -type llmTranscriber struct { - recognition *llm.Recognition -} - -func (r *TranscriptionResult) String() string { - return fmt.Sprintf("%8s - %8s: %s", r.Begin.String(), r.End.String(), r.Text) -} - -func (r *TranscriptionResult) Duration() time.Duration { - return r.End - r.Begin -} - -func (t *llmTranscriber) Transcribe() bool { - return t.recognition.Next() -} - -func (t *llmTranscriber) Result() TranscriptionResult { - recoResult := t.recognition.Result() - result := TranscriptionResult{} - result.Text = recoResult.Text - result.Language = recoResult.Language - result.Begin = recoResult.Begin - result.End = recoResult.End - - return result -} - -func (t *llmTranscriber) Err() error { - return t.recognition.Err() -} - -func (t *llmTranscriber) Dispose() { - t.recognition = nil -} - -func NewASRTranscriber(model *llm.ASRModel, inputFile string) (Transcriber, error) { - r, err := model.Recognize(inputFile) - if err != nil { - return nil, err - } - - return &llmTranscriber{ - recognition: r, - }, nil -} diff --git a/go/skill/translation.go b/go/skill/translation.go index 960285d1..ff6d0cd3 100644 --- a/go/skill/translation.go +++ b/go/skill/translation.go @@ -35,35 +35,20 @@ type TranslationRequest struct { Temperature float32 } -type Translator interface { - Translate(request TranslationRequest) (llm.Completion, error) - - // return true if the translator supports the source and target language pairs. - IsSupport(source, target Lang) bool -} - // translator implemented by a chat model -type chatTranslator struct { - model llm.Model +type Translator struct { + model *llm.Model } -func NewTranslator(model llm.Model) (Translator, error) { +func NewTranslator(model *llm.Model) (*Translator, error) { if model == nil { return nil, ErrModelIsNil } - modelName := model.GetName() - switch modelName { - case "index": - return &chatTranslator{model}, nil - case "qwen": - return &chatTranslator{model}, nil - default: - return nil, ErrInvalidModelForTranslation - } + return &Translator{model}, nil } -var sysPromptIndexTranslation = "翻译%s到%s,不能有换行符,回复请以\"翻译结果:\"开头。" +var sysPromptIndexTranslation = "Translate from %s to %s. The result should begin with TRANSLATION:" var translationExamples = []map[Lang]string{ { @@ -83,7 +68,7 @@ var translationExamples = []map[Lang]string{ }, } -func (l *chatTranslator) IsSupport(source, target Lang) bool { +func (l *Translator) IsSupport(source, target Lang) bool { var sourceOk, targetOk bool switch source { @@ -105,20 +90,20 @@ func (l *chatTranslator) IsSupport(source, target Lang) bool { return sourceOk && targetOk } -func (l *chatTranslator) getLangString(lang Lang) (name string, err error) { +func (l *Translator) getLangString(lang Lang) (name string, err error) { switch lang { case Chinese: - return "中文", nil + return "Chinese", nil case English: - return "英语", nil + return "English", nil case Japanese: - return "日语", nil + return "Japanese", nil default: return "", ErrUnexpectedLanguage } } -func (l *chatTranslator) getSysPrompt(source, target Lang) (prompt string, err error) { +func (l *Translator) getSysPrompt(source, target Lang) (prompt string, err error) { srcLang, err := l.getLangString(source) if err != nil { return @@ -132,12 +117,7 @@ func (l *chatTranslator) getSysPrompt(source, target Lang) (prompt string, err e return fmt.Sprintf(sysPromptIndexTranslation, srcLang, tgtLang), nil } -func (l *chatTranslator) Translate(request TranslationRequest) (llm.Completion, error) { - chat, err := NewChat(l.model) - if err != nil { - return nil, err - } - +func (l *Translator) Translate(request TranslationRequest) (*llm.Completion, error) { sysPrompt, err := l.getSysPrompt(request.SourceLang, request.TargetLang) if err != nil { return nil, err @@ -151,15 +131,11 @@ func (l *chatTranslator) Translate(request TranslationRequest) (llm.Completion, leftCtxTgt = request.LeftContextTarget } - messages := []Message{ - {"system", sysPrompt}, - {"user", leftCtxSrc + request.Text}, - {"assistant", "翻译结果:" + leftCtxTgt}, - } - - if request.Temperature > 0 { - chat.SetTemperature(request.Temperature) + messages := []llm.Message{ + {Role: "system", Content: sysPrompt}, + {Role: "user", Content: leftCtxSrc + request.Text}, + {Role: "assistant", Content: "TRANSLATION: " + leftCtxTgt}, } - return chat.Chat(messages) + return l.model.Complete(messages, llm.DefaultCompletionConfig()) } diff --git a/src/libllm/CMakeLists.txt b/src/libllm/CMakeLists.txt index 5d5305b4..cdd06afe 100644 --- a/src/libllm/CMakeLists.txt +++ b/src/libllm/CMakeLists.txt @@ -32,6 +32,7 @@ set(libllm_SOURCES "cpu/transform.cc" "cpu/unfold.cc" "cpu/view.cc" + "bilibili_index.cc" "bpe_config.cc" "bpe_encoder.cc" "bpe_model.cc" diff --git a/src/libllm/bilibili_index.cc b/src/libllm/bilibili_index.cc new file mode 100644 index 00000000..2201196e --- /dev/null +++ b/src/libllm/bilibili_index.cc @@ -0,0 +1,78 @@ +// The MIT License (MIT) +// +// Copyright (c) 2023 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/bilibili_index.h" + +namespace libllm { +namespace index { + +std::shared_ptr IndexModelForGeneration::fromPackage( + const Context &ctx, + lut::ZipFile *package) { + std::shared_ptr reader = package->open(ModelConfig); + std::shared_ptr ini = lut::IniConfig::fromStream(reader.get()); + + std::string modelFile = ini->getSection(ModelSection).getString(ModelFileField); + std::string modelType = ini->getSection(ModelSection).getString(ModelTypeField); + CHECK(modelType == "index"); + + const lut::IniSection &indexIni = ini->getSection(modelType); + + std::shared_ptr model{new IndexModelForGeneration()}; + llama::LlamaConfig llamaConfig = llama::LlamaConfig::loadConfig(indexIni); + + StateMap stateMap; + stateMap.read(package->open(modelFile).get()); + + model->_model = llama::LlamaModel::create(ctx, llamaConfig); + model->_model->initParameters(stateMap); + model->_modelName = modelType; + model->_eotId = indexIni.getInt("eot_token_id"); + + model->initTokenizer(package); + return model; +} + +Prompt IndexModelForGeneration::buildPrompt(lut::Span history) const { + CHECK(!history.empty()) << "history is empty"; + + Prompt prompt; + if (history.front().role == "system") { + prompt.appendControlToken(""); + prompt.appendText(history.front().content); + history = history.subspan(1); + } + + for (const Message &message : history) { + if (message.role == "user") { + prompt.appendControlToken("<|reserved_0|>"); + prompt.appendText(message.content); + prompt.appendControlToken("<|reserved_1|>"); + } else if (message.role == "assistant") { + prompt.appendText(message.content); + } else { + throw lut::AbortedError("unexpected role"); + } + } + + return prompt; +} + +} // namespace index +} // namespace libllm diff --git a/go/skill/bilibili_index.go b/src/libllm/bilibili_index.h similarity index 57% rename from go/skill/bilibili_index.go rename to src/libllm/bilibili_index.h index fed92048..83b5119e 100644 --- a/go/skill/bilibili_index.go +++ b/src/libllm/bilibili_index.h @@ -17,36 +17,35 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package skill - -import ( - "errors" - - "github.com/ling0322/libllm/go/llm" -) - -type BilibiliIndex struct { -} - -func (l *BilibiliIndex) Build(history []Message) (llm.Prompt, error) { - prompt := llm.NewPrompt() - if len(history) > 0 && history[0].Role == "system" { - prompt.AppendControlToken("") - prompt.AppendText(history[0].Content) - history = history[1:] - } - - for _, message := range history { - if message.Role == "user" { - prompt.AppendControlToken("<|reserved_0|>") - prompt.AppendText(message.Content) - prompt.AppendControlToken("<|reserved_1|>") - } else if message.Role == "assistant" { - prompt.AppendText(message.Content) - } else { - return nil, errors.New("unexpected role") - } - } - - return prompt, nil -} +#pragma once + +#include + +#include "libllm/llama.h" +#include "libllm/model_for_generation.h" +#include "lutil/ini_config.h" + +namespace libllm { +namespace index { + +/// @brief The Bilibili index model. Model structure of index is the same as llama, the only +/// difference is the prompt. +class IndexModelForGeneration : public llama::LlamaModelForGeneration { + public: + static std::shared_ptr fromPackage( + const Context &ctx, + lut::ZipFile *package); + + // noncopyable + IndexModelForGeneration(IndexModelForGeneration &) = delete; + IndexModelForGeneration &operator=(IndexModelForGeneration &) = delete; + + // override LlamaModelForGeneration + Prompt buildPrompt(lut::Span history) const override; + + private: + IndexModelForGeneration() = default; +}; + +} // namespace index +} // namespace libllm diff --git a/src/libllm/generator.h b/src/libllm/generator.h index 49add20d..89bebfde 100644 --- a/src/libllm/generator.h +++ b/src/libllm/generator.h @@ -33,7 +33,6 @@ struct GenerationConfig { int topK; float topP; float temperature; - std::unordered_map kvConfig; GenerationConfig(); }; diff --git a/src/libllm/llama.cc b/src/libllm/llama.cc index ae8ecf80..22d480c1 100644 --- a/src/libllm/llama.cc +++ b/src/libllm/llama.cc @@ -484,5 +484,42 @@ int LlamaModelForGeneration::getOutputDim() const { return _model->getOutputDim(); } +Prompt LlamaModelForGeneration::buildPrompt(lut::Span history) const { + CHECK(!history.empty()) << "history is empty"; + + Prompt prompt; + prompt.appendControlToken("<|begin_of_text|>"); + for (const Message &message : history.subspan(0, history.size() - 1)) { + prompt.appendControlToken("<|start_header_id|>"); + prompt.appendText(message.role); + prompt.appendControlToken("<|end_header_id|>"); + prompt.appendText("\n\n" + message.content); + prompt.appendControlToken("<|eot_id|>"); + } + + const Message &message = history.back(); + if (message.role == "user") { + prompt.appendControlToken("<|start_header_id|>"); + prompt.appendText(message.role); + prompt.appendControlToken("<|end_header_id|>"); + prompt.appendText("\n\n" + message.content); + prompt.appendControlToken("<|eot_id|>"); + prompt.appendControlToken("<|start_header_id|>"); + prompt.appendText("assistant"); + prompt.appendControlToken("<|end_header_id|>"); + prompt.appendText("\n\n"); + } else if (message.role == "assistant") { + prompt.appendControlToken("<|start_header_id|>"); + prompt.appendText(message.role); + prompt.appendControlToken("<|end_header_id|>"); + prompt.appendText("\n\n" + message.content); + } else { + throw lut::AbortedError( + "invalid messages: role of last message should be either user or assistant"); + } + + return prompt; +} + } // namespace llama } // namespace libllm diff --git a/src/libllm/llama.h b/src/libllm/llama.h index cb9b8251..ba689531 100644 --- a/src/libllm/llama.h +++ b/src/libllm/llama.h @@ -149,6 +149,7 @@ class LlamaModelForGeneration : public ModelForGeneration { const char *getName() const override; Device getDevice() const override; int getOutputDim() const override; + Prompt buildPrompt(lut::Span history) const override; protected: std::shared_ptr _model; diff --git a/src/libllm/llm.cc b/src/libllm/llm.cc index 092ceb68..e37991be 100644 --- a/src/libllm/llm.cc +++ b/src/libllm/llm.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "../../third_party/nlohmann/json.hpp" @@ -38,6 +39,8 @@ using libllm::whisper::WhisperModel; using lut::IniConfig; using json = nlohmann::json; +std::once_flag gLlmInitOnce; + thread_local std::string gJsonString; thread_local std::string gJsonErrorMessage; @@ -46,32 +49,14 @@ constexpr char LlmConfigKey_WhisperLang[] = "whisper.language"; constexpr char LlmConfigValue_Sampler[] = "sampler"; constexpr char LlmConfigValue_Whisper[] = "whisper"; -struct llmModel_t { - Context ctx; - std::shared_ptr model_for_generation; +struct llm_model_impl_t { + std::shared_ptr model; std::shared_ptr tokenizer; - std::string configFile; - int device; }; -struct llmCompletion_t { - int top_k; - float top_p; - float temperature; - std::shared_ptr prompt; +struct llm_completion_impl_t { std::weak_ptr model_for_generation; std::shared_ptr generator; - lut::Error error; - std::string chunkText; - std::unordered_map kvConfig; -}; - -struct llmChunk_t { - std::string text; -}; - -struct llmPrompt_t { - std::shared_ptr prompt; }; struct llm_json_impl_t { @@ -102,56 +87,65 @@ void llmSetErrorMessage(const std::string &message) { snprintf(gErrorMessage, sizeof(gErrorMessage), "%s", what.c_str()); } -void setErrorCodeAndMessage(const lut::Error &e) { - gErrorCode = static_cast(e.getCode()); - llmSetErrorMessage(e.what()); -} +void checkJsonKeys( + const json &json, + std::initializer_list> schema) { + std::set keys; + for (auto &[key, value] : json.items()) { + keys.emplace(key); + } -llmStatus_t runAndCatch(std::function &&f) { - try { - f(); - return LLM_OK; - } catch (const lut::Error &e) { - setErrorCodeAndMessage(e); - return static_cast(e.getCode()); + for (const auto &entry : schema) { + std::string_view key = entry.first; + bool required = entry.second; + + auto it = keys.find(key); + if (required && it == keys.end()) { + throw lut::AbortedError(lut::sprintf("json: required key \"%s\" not found", key)); + } + + if (it != keys.end()) keys.erase(it); } -} -template -T runAndCatch(std::function &&c, T default_value) { - try { - return c(); - } catch (const lut::Error &e) { - setErrorCodeAndMessage(e); - return default_value; + if (!keys.empty()) { + throw lut::AbortedError(lut::sprintf("json: unexpected key \"%s\"", *keys.begin())); } } -Device getDeviceFromApi(int apiDevice) { - switch (apiDevice) { - case LLM_DEVICE_CPU: - return Device::getCpu(); - case LLM_DEVICE_CUDA: - return Device::getCuda(); - case LLM_DEVICE_AUTO: - if (Device::isCudaAvailable()) { - return Device::getCuda(); - } else { - return Device::getCpu(); - } - default: - throw lut::InvalidArgError("invalid device type"); +Prompt buildPromptFromJson(const ModelForGeneration *model, const json &kwargsJson) { + const json &messageJsons = kwargsJson["messages"]; + std::vector messages; + for (const json &messageJson : messageJsons) { + Message message; + message.role = messageJson["role"]; + message.content = messageJson["content"]; + + messages.emplace_back(message); + if (messageJson.size() != 2) { + throw lut::AbortedError("invalid json for message"); + } } + + return model->buildPrompt(messages); } -int parseGeneratorType(const std::string &name) { - if (name == LlmConfigValue_Sampler) { - return Generator::Sampling; - } else if (name == LlmConfigValue_Whisper) { - return Generator::Whisper; - } else { - throw lut::AbortedError("invalid generator type: " + name); +template +T getValueFromJson(const json &j, std::string_view key, T defaultVal) { + T val = defaultVal; + if (j.contains(key)) { + val = j[key]; } + + return val; +} + +GenerationConfig parseGenerationConfig(const json &kwargsJson) { + GenerationConfig config; + config.temperature = getValueFromJson(kwargsJson, "temperature", 1.0); + config.topK = getValueFromJson(kwargsJson, "top_k", 50); + config.topP = getValueFromJson(kwargsJson, "top_p", 0.8); + + return config; } int32_t llmErrorSetInvalidArg(const std::string &argName) { @@ -198,313 +192,137 @@ libllm::Device parseDevice(const std::string &device) { using namespace libllm; using namespace libllm::api; -llmStatus_t llmInit(int32_t apiVersion) { - if (!gInitialized.exchange(true)) { +void llm_init() { + std::call_once(gLlmInitOnce, []() { try { - if (apiVersion != LLM_API_VERSION) throw lut::InvalidArgError("api version mismatch."); lut::setLogLevel(lut::LogSeverity::kINFO); libllm::initOperators(); - - return LLM_OK; } catch (const lut::Error &e) { - gInitialized = false; - setErrorCodeAndMessage(e); - return static_cast(e.getCode()); - ; + LOG(ERROR) << "initialize libllm failed: " << e.what(); } - } - - return LLM_OK; -} - -llmStatus_t llmDestroy() { - if (gInitialized.exchange(false)) { - libllm::destroyOperators(); - } - - return LLM_OK; + }); } -const char *llmGetLastErrorMessage() { +const char *llm_get_last_error_message() { return gErrorMessage; } -llmModel_t *llmModel_New() { - llmModel_t *model = new llmModel_t(); - model->device = LLM_DEVICE_AUTO; - return model; -} - -llmStatus_t llmModel_Delete(llmModel_t *model) { - delete model; - return LLM_OK; -} - -llmStatus_t llmModel_SetFile(llmModel_t *model, const char *filename) { - return runAndCatch([model, filename]() { - if (!model) throw lut::InvalidArgError("model"); - if (!filename) throw lut::InvalidArgError("filename"); - - model->configFile = filename; - return LLM_OK; - }); -} - -llmStatus_t llmModel_SetDevice(llmModel_t *model, int32_t device) { - return runAndCatch([model, device]() { - if (!model) throw lut::InvalidArgError("model"); - - model->device = device; - return LLM_OK; - }); +int32_t llm_model_init(llm_model_t *m) { + *m = new llm_model_impl_t(); + return 0; } -llmStatus_t llmModel_Load(llmModel_t *model) { - return runAndCatch([model]() { - if (!model) throw lut::InvalidArgError("model"); - if (model->configFile.empty()) throw lut::InvalidArgError("model file not set."); - - LOG(INFO) << "read model package: " << model->configFile; - std::shared_ptr package = lut::ZipFile::fromFile(model->configFile); - - model->ctx.setDevice(getDeviceFromApi(model->device)); - model->ctx.setFloatDType(F::getDefaultFloatType(model->ctx.getDevice())); - model->tokenizer = Tokenizer::fromPackage(package.get()); - model->model_for_generation = ModelForGeneration::fromPackage(model->ctx, package.get()); +int32_t llm_model_destroy(llm_model_t *m) { + if (!m) return llmErrorSetInvalidArg("m"); - return LLM_OK; - }); + delete *m; + *m = nullptr; + return 0; } -const char *llmModel_GetName(llmModel_t *model) { - return runAndCatch( - [model]() { - if (!model) throw lut::InvalidArgError("m"); - if (!model->model_for_generation) throw lut::InvalidArgError("model"); - - return model->model_for_generation->getName(); - }, - nullptr); -} +int32_t llm_model_load(llm_model_t *m, llm_json_t *kwargs) { + try { + libllm::Device device; + std::shared_ptr package; + json object = (*kwargs)->jsonObject; + for (auto &[key, value] : object.items()) { + if (key == "filename") { + package = lut::ZipFile::fromFile(value); + } else if (key == "device") { + device = parseDevice(value); + } else { + return llmErrorSetAborted("invalid key in options: " + key); + } + } -llmPrompt_t *llmPrompt_New() { - return runAndCatch( - []() { - llmPrompt_t *prompt = new llmPrompt_t(); - prompt->prompt = std::make_shared(); - return prompt; - }, - nullptr); -} + if (!package) return llmErrorSetAborted("options.filename undefined"); + if (device.getType() == libllm::Device::kUnknown) { + return llmErrorSetAborted("options.device undefined"); + } -llmStatus_t llmPrompt_Delete(llmPrompt_t *prompt) { - delete prompt; - return LLM_OK; -} + Context ctx; + ctx.setDevice(device); + ctx.setFloatDType(F::getDefaultFloatType(device)); + std::shared_ptr model = ModelForGeneration::fromPackage(ctx, package.get()); -llmStatus_t llmPrompt_AppendText(llmPrompt_t *prompt, const char *text) { - return runAndCatch([prompt, text]() { - if (!prompt) throw lut::InvalidArgError("prompt"); - if (!text) throw lut::InvalidArgError("text"); + (*m)->model = model; + } catch (std::exception &e) { + return llmErrorSetAborted(e.what()); + } - prompt->prompt->appendText(text); - return LLM_OK; - }); + return 0; } -llmStatus_t llmPrompt_AppendControlToken(llmPrompt_t *prompt, const char *name) { - return runAndCatch([prompt, name]() { - if (!prompt) throw lut::InvalidArgError("prompt"); - if (!name) throw lut::InvalidArgError("name"); - - prompt->prompt->appendControlToken(name); - return LLM_OK; - }); -} +int32_t llm_model_get_info(llm_model_t *m, llm_json_t *info) { + if (!m) return llmErrorSetInvalidArg("m"); + if (!info) return llmErrorSetInvalidArg("info"); -llmStatus_t llmPrompt_AppendAudio( - llmPrompt_t *prompt, - const llmByte_t *audio, - int64_t size, - int32_t format) { - return runAndCatch([prompt, audio, size, format]() { - if (!prompt) throw lut::InvalidArgError("prompt"); - if (!audio) throw lut::InvalidArgError("audio"); - if (size <= 0 || size > 1024 * 1024 * 1024) - throw lut::AbortedError("invalid size, [1, 1G) expected"); - if (format != LLM_WAVE_FORMAT_PCM16KHZ16BITMONO) throw lut::AbortedError("invalid format"); - - prompt->prompt->appendWave( - lut::Span(reinterpret_cast(audio), size), - WaveFormat::Wave16kHz16bitMonoPCM); - return LLM_OK; - }); -} + try { + json infoJson; + infoJson["name"] = (*m)->model->getName(); + (*info)->jsonObject = infoJson; + } catch (std::exception &e) { + return llmErrorSetAborted(e.what()); + } -llmCompletion_t *llmCompletion_New(llmModel_t *model) { - return runAndCatch( - [model]() { - if (!model) throw lut::InvalidArgError("model"); - if (!model->model_for_generation) throw lut::InvalidArgError("model not initialized"); - - std::unique_ptr comp = std::make_unique(); - comp->model_for_generation = model->model_for_generation; - comp->temperature = 1.0f; - comp->top_k = 50; - comp->top_p = 0.8f; - - return comp.release(); - }, - nullptr); + return 0; } -llmStatus_t llmCompletion_Delete(llmCompletion_t *comp) { - delete comp; - return LLM_OK; +int32_t llm_completion_init(llm_completion_t *c) { + *c = new llm_completion_impl_t(); + return 0; } -llmStatus_t llmCompletion_SetConfig(llmCompletion_t *comp, const char *key, const char *value) { - return runAndCatch([comp, key, value]() { - if (!comp) throw lut::InvalidArgError("comp"); - if (!key) throw lut::InvalidArgError("key"); - if (!value) throw lut::InvalidArgError("value"); +int32_t llm_completion_destroy(llm_completion_t *c) { + if (!c) return llmErrorSetInvalidArg("c"); + delete *c; + *c = nullptr; - comp->kvConfig[key] = value; - return LLM_OK; - }); + return 0; } -llmStatus_t llmCompletion_SetPrompt(llmCompletion_t *comp, llmPrompt_t *prompt) { - return runAndCatch([comp, prompt]() { - if (!comp) throw lut::InvalidArgError("comp"); - if (!prompt) throw lut::InvalidArgError("prompt"); - if (comp->generator) throw lut::InvalidArgError("completion already started"); - if (prompt->prompt->empty()) throw lut::InvalidArgError("prompt is empty"); +int32_t llm_model_complete(llm_model_t *m, llm_json_t *kwargs, llm_completion_t *comp) { + if (!m) return llmErrorSetInvalidArg("m"); + if (!kwargs) return llmErrorSetInvalidArg("kwargs"); + if (!comp) return llmErrorSetInvalidArg("comp"); - comp->prompt = prompt->prompt; - return LLM_OK; - }); -} + try { + const json &kwargsJson = (*kwargs)->jsonObject; + checkJsonKeys( + kwargsJson, + {{"temperature", false}, {"top_k", false}, {"top_p", false}, {"messages", true}}); -llmStatus_t llmCompletion_SetTopP(llmCompletion_t *comp, float topP) { - return runAndCatch([comp, topP]() { - if (!comp) throw lut::InvalidArgError("comp"); - if (comp->generator) throw lut::InvalidArgError("completion already started"); + GenerationConfig config = parseGenerationConfig(kwargsJson); + (*comp)->generator = SamplingGenerator::newGenerator(config, (*m)->model); + (*comp)->model_for_generation = (*m)->model; - comp->top_p = topP; - return LLM_OK; - }); -} - -llmStatus_t llmCompletion_SetTopK(llmCompletion_t *comp, int32_t topK) { - return runAndCatch([comp, topK]() { - if (!comp) throw lut::InvalidArgError("comp"); - if (comp->generator) throw lut::InvalidArgError("completion already started"); + Prompt prompt = buildPromptFromJson((*m)->model.get(), kwargsJson); + (*comp)->generator->setPrompt(prompt); + } catch (std::exception &e) { + return llmErrorSetAborted(e.what()); + } - comp->top_k = topK; - return LLM_OK; - }); + return 0; } -llmStatus_t llmCompletion_SetTemperature(llmCompletion_t *comp, float temperature) { - return runAndCatch([comp, temperature]() { - if (!comp) throw lut::InvalidArgError("comp"); - if (comp->generator) throw lut::InvalidArgError("completion already started"); +int32_t llm_completion_get_next_chunk(llm_completion_t *c, llm_json_t *chunk) { + if (!c) return llmErrorSetInvalidArg("c"); + if (!chunk) return llmErrorSetInvalidArg("chunk"); - comp->temperature = temperature; - return LLM_OK; - }); -} - -llmBool_t llmCompletion_Next(llmCompletion_t *comp) { try { - if (!comp) throw lut::InvalidArgError("comp"); - if (comp->prompt->empty()) throw lut::InvalidArgError("prompt is empty"); - - if (comp->error.getCode() != lut::ErrorCode::OK) { - return LLM_FALSE; - } - - if (!comp->generator) { - // prefill - std::shared_ptr model = comp->model_for_generation.lock(); - if (!model) throw lut::InvalidArgError("model had been destroyed"); - - GenerationConfig config; - config.temperature = comp->temperature; - config.topK = comp->top_k; - config.topP = comp->top_p; - - int generatorType = Generator::Sampling; - std::string whisperLang; - for (const auto &kv : comp->kvConfig) { - if (kv.first == LlmConfigKey_GeneratorType) { - generatorType = parseGeneratorType(kv.second); - } else if (kv.first == LlmConfigKey_WhisperLang) { - whisperLang = lut::trim(kv.second); - } else { - throw lut::AbortedError("invalid configuration key: " + kv.first); - } - } - - if (generatorType == Generator::Sampling) { - comp->generator = SamplingGenerator::newGenerator(config, model); - } else { - NOT_IMPL(); - } - - comp->generator->setPrompt(*comp->prompt); + bool ok = (*c)->generator->generate(); + if (!ok) { + return llmErrorSetEOF(); } - bool ok = comp->generator->generate(); - if (ok) { - return LLM_TRUE; - } else { - return LLM_FALSE; - } - } catch (const lut::Error &e) { - if (comp) comp->error = e; - return LLM_FALSE; - } -} - -llmStatus_t llmCompletion_GetError(llmCompletion_t *comp) { - if (!comp) { - lut::Error err = lut::InvalidArgError("comp"); - setErrorCodeAndMessage(err); - return static_cast(err.getCode()); - } - - if (comp->error.getCode() == lut::ErrorCode::OK) { - return LLM_OK; - } else { - setErrorCodeAndMessage(comp->error); - return static_cast(comp->error.getCode()); + json chunkJson; + chunkJson["text"] = (*c)->generator->getToken(); + (*chunk)->jsonObject = chunkJson; + } catch (std::exception &e) { + return llmErrorSetAborted(e.what()); } -} -const char *llmCompletion_GetText(llmCompletion_t *comp) { - return runAndCatch( - [comp]() { - if (!comp) throw lut::InvalidArgError("comp"); - if (!comp->generator) throw lut::InvalidArgError("completion not started"); - - comp->chunkText = comp->generator->getToken(); - return comp->chunkText.c_str(); - }, - nullptr); -} - -const char *llmCompletion_GetToken(llmCompletion_t *comp) { - return runAndCatch( - [comp]() { - if (!comp) throw lut::InvalidArgError("comp"); - if (!comp->generator) throw lut::InvalidArgError("completion not started"); - - comp->chunkText = comp->generator->getTokenName(); - return comp->chunkText.c_str(); - }, - nullptr); + return 0; } int32_t llm_json_init(llm_json_t *j) { @@ -558,23 +376,11 @@ int32_t llm_asr_model_load(llm_asr_model_t *m, llm_json_t *options) { if (!options) return llmErrorSetInvalidArg("options"); try { - libllm::Device device; - std::shared_ptr package; json object = (*options)->jsonObject; - for (auto &[key, value] : object.items()) { - if (key == "filename") { - package = lut::ZipFile::fromFile(value); - } else if (key == "device") { - device = parseDevice(value); - } else { - return llmErrorSetAborted("invalid key in options: " + key); - } - } + checkJsonKeys(object, {{"filename", true}, {"device", true}}); - if (!package) return llmErrorSetAborted("options.filename undefined"); - if (device.getType() == libllm::Device::kUnknown) { - return llmErrorSetAborted("options.device undefined"); - } + std::shared_ptr package = lut::ZipFile::fromFile(object["filename"]); + libllm::Device device = parseDevice(object["device"]); Context ctx = Context().withName("whisper"); ctx.setDevice(device); diff --git a/src/libllm/llm.h b/src/libllm/llm.h index 935b1663..6cb40d2b 100644 --- a/src/libllm/llm.h +++ b/src/libllm/llm.h @@ -39,90 +39,41 @@ extern "C" { #endif // __cplusplus -#define LLM_DEVICE_CPU 0x0000 -#define LLM_DEVICE_CUDA 0x0100 -#define LLM_DEVICE_AUTO 0x1f00 -#define LLM_WAVE_FORMAT_PCM16KHZ16BITMONO 0x0001 -#define LLM_API_VERSION 20240101 -#define LLM_TRUE 1 -#define LLM_FALSE 0 -#define LLM_OK 0 - #define LLM_ERROR_INVALID_ARG 0x0100 #define LLM_ERROR_INSUFFICIENT_BUFFER 0x0101 #define LLM_ERROR_ABORTED 0x0102 #define LLM_ERROR_EOF 0x0103 -typedef int32_t llmStatus_t; -typedef struct llmModel_t llmModel_t; -typedef struct llmChunk_t llmChunk_t; -typedef struct llmPrompt_t llmPrompt_t; -typedef struct llmCompletion_t llmCompletion_t; -typedef struct llmLogitsFilter_t llmLogitsFilter_t; -typedef int32_t llmBool_t; -typedef int8_t llmByte_t; +typedef struct llm_model_impl_t *llm_model_t; +typedef struct llm_completion_impl_t *llm_completion_t; +typedef struct llm_json_impl_t *llm_json_t; +typedef struct llm_asr_recognition_impl_t *llm_asr_recognition_t; +typedef struct llm_asr_model_impl_t *llm_asr_model_t; // global state -LLMAPI llmStatus_t llmInit(int32_t apiVersion); -LLMAPI llmStatus_t llmDestroy(); -LLMAPI const char *llmGetLastErrorMessage(); - -// llmModel_t -LLMAPI llmModel_t *llmModel_New(); -LLMAPI llmStatus_t llmModel_Delete(llmModel_t *model); -LLMAPI llmStatus_t llmModel_SetFile(llmModel_t *model, const char *filename); -LLMAPI llmStatus_t llmModel_SetDevice(llmModel_t *model, int32_t device); -LLMAPI llmStatus_t llmModel_Load(llmModel_t *model); -LLMAPI const char *llmModel_GetName(llmModel_t *model); - -// llmPrompt_t -LLMAPI llmPrompt_t *llmPrompt_New(); -LLMAPI llmStatus_t llmPrompt_Delete(llmPrompt_t *prompt); -LLMAPI -llmStatus_t llmPrompt_AppendAudio( - llmPrompt_t *prompt, - const llmByte_t *audio, - int64_t size, - int32_t format); - -LLMAPI llmStatus_t llmPrompt_AppendText(llmPrompt_t *prompt, const char *text); -LLMAPI llmStatus_t llmPrompt_AppendControlToken(llmPrompt_t *prompt, const char *token); - -// llmCompletion_t -LLMAPI llmCompletion_t *llmCompletion_New(llmModel_t *model); -LLMAPI llmStatus_t llmCompletion_Delete(llmCompletion_t *comp); -LLMAPI llmStatus_t llmCompletion_SetPrompt(llmCompletion_t *comp, llmPrompt_t *prompt); -LLMAPI llmStatus_t llmCompletion_SetTopP(llmCompletion_t *comp, float topP); -LLMAPI llmStatus_t llmCompletion_SetTopK(llmCompletion_t *comp, int32_t topK); -LLMAPI -llmStatus_t llmCompletion_SetConfig(llmCompletion_t *comp, const char *key, const char *value); -LLMAPI llmStatus_t llmCompletion_SetTemperature(llmCompletion_t *comp, float temperature); -LLMAPI llmBool_t llmCompletion_Next(llmCompletion_t *comp); -LLMAPI llmStatus_t llmCompletion_GetError(llmCompletion_t *comp); -LLMAPI const char *llmCompletion_GetText(llmCompletion_t *comp); - -/// @brief Get the name of last generated token. -/// For a normal token the llmCompletion_GetText() will return its byte piece, for example, "foo "; -/// llmCompletion_GetToken() will return its name in model, for example, "hello_". -/// For a control token, llmCompletion_GetText() will return an empty string ("") and -/// llmCompletion_GetToken() will return its name, for example, "<|endoftext|>". -/// @param comp the llmCompletion_t. -/// @return name of the token. -LLMAPI const char *llmCompletion_GetToken(llmCompletion_t *comp); +LLMAPI void llm_init(); +LLMAPI const char *llm_get_last_error_message(); // JSON -typedef struct llm_json_impl_t *llm_json_t; - LLMAPI int32_t llm_json_init(llm_json_t *j); LLMAPI int32_t llm_json_destroy(llm_json_t *j); LLMAPI int32_t llm_json_parse(llm_json_t *j, const char *json_str); LLMAPI int32_t llm_json_dump(llm_json_t *j, char *buf, int64_t buf_size); -// ASR +// LLM -typedef struct llm_asr_recognition_impl_t *llm_asr_recognition_t; -typedef struct llm_asr_model_impl_t *llm_asr_model_t; +LLMAPI int32_t llm_model_init(llm_model_t *m); +LLMAPI int32_t llm_model_destroy(llm_model_t *m); +LLMAPI int32_t llm_model_load(llm_model_t *m, llm_json_t *kwargs); +LLMAPI int32_t llm_model_get_info(llm_model_t *m, llm_json_t *info); +LLMAPI int32_t llm_model_complete(llm_model_t *m, llm_json_t *kwargs, llm_completion_t *comp); + +LLMAPI int32_t llm_completion_init(llm_completion_t *c); +LLMAPI int32_t llm_completion_destroy(llm_completion_t *c); +LLMAPI int32_t llm_completion_get_next_chunk(llm_completion_t *c, llm_json_t *chunk); + +// ASR LLMAPI int32_t llm_asr_model_init(llm_asr_model_t *m); LLMAPI int32_t llm_asr_model_load(llm_asr_model_t *m, llm_json_t *options); diff --git a/src/libllm/model_for_generation.cc b/src/libllm/model_for_generation.cc index 5bd551af..08d862b3 100644 --- a/src/libllm/model_for_generation.cc +++ b/src/libllm/model_for_generation.cc @@ -19,6 +19,7 @@ #include "libllm/model_for_generation.h" +#include "libllm/bilibili_index.h" #include "libllm/constants.h" #include "libllm/llama.h" #include "libllm/qwen.h" @@ -48,7 +49,7 @@ std::shared_ptr ModelForGeneration::fromPackage( if (modelType == "llama") { model = llama::LlamaModelForGeneration::fromPackage(ctx, package); } else if (modelType == "index") { - model = llama::LlamaModelForGeneration::fromPackage(ctx, package); + model = index::IndexModelForGeneration::fromPackage(ctx, package); } else if (modelType == "qwen") { model = qwen::QwenModelForGeneration::fromPackage(ctx, package); } else { diff --git a/src/libllm/model_for_generation.h b/src/libllm/model_for_generation.h index a0371892..3b846b30 100644 --- a/src/libllm/model_for_generation.h +++ b/src/libllm/model_for_generation.h @@ -83,6 +83,11 @@ class ModelForGeneration { /// @return the output dimension of the model. virtual int getOutputDim() const = 0; + /// @brief build prompt from history messages. + /// @param history the history. + /// @return the prompt. + virtual Prompt buildPrompt(lut::Span history) const = 0; + /// @brief Get the vocabulary (tokenId to token string) of the model. /// @return The vocabulary. const Vocab *getVocab() const; diff --git a/src/libllm/prompt.cc b/src/libllm/prompt.cc index ebcf408d..e1f9ad1f 100644 --- a/src/libllm/prompt.cc +++ b/src/libllm/prompt.cc @@ -57,15 +57,6 @@ void Prompt::appendControlToken(const std::string &controlToken) { _blocks.emplace_back(std::move(block)); } -void Prompt::appendWave(lut::Span payload, WaveFormat format) { - PromptBlock block; - block.data = std::vector(payload.begin(), payload.end()); - block.waveFormat = format; - block.blockType = PromptBlock::Wave; - - _blocks.emplace_back(std::move(block)); -} - bool Prompt::empty() const { return _blocks.empty(); } diff --git a/src/libllm/prompt.h b/src/libllm/prompt.h index 37a9d465..b1c8bc4f 100644 --- a/src/libllm/prompt.h +++ b/src/libllm/prompt.h @@ -46,11 +46,15 @@ struct PromptBlock { static std::string typeToString(Type blockType); }; +struct Message { + std::string role; + std::string content; +}; + class Prompt { public: void appendText(const std::string &text); void appendControlToken(const std::string &controlToken); - void appendWave(lut::Span payload, WaveFormat format); bool empty() const; diff --git a/src/libllm/qwen.cc b/src/libllm/qwen.cc index d829672a..41817be5 100644 --- a/src/libllm/qwen.cc +++ b/src/libllm/qwen.cc @@ -64,5 +64,35 @@ bool QwenModelForGeneration::isStopToken(int tokenId) const { } } +Prompt QwenModelForGeneration::buildPrompt(lut::Span history) const { + CHECK(!history.empty()) << "history is empty"; + + Prompt prompt; + for (const Message &message : history.subspan(0, history.size() - 1)) { + prompt.appendControlToken("<|im_start|>"); + prompt.appendText(message.role + "\n" + message.content); + prompt.appendControlToken("<|im_end|>"); + prompt.appendText("\n"); + } + + const Message &message = history.back(); + if (message.role == "user") { + prompt.appendControlToken("<|im_start|>"); + prompt.appendText(message.role + "\n" + message.content); + prompt.appendControlToken("<|im_end|>"); + prompt.appendText("\n"); + prompt.appendControlToken("<|im_start|>"); + prompt.appendText("assistant\n"); + } else if (message.role == "assistant") { + prompt.appendControlToken("<|im_start|>"); + prompt.appendText(message.role + "\n" + message.content); + } else { + throw lut::AbortedError( + "invalid messages: role of last message should be either user or assistant"); + } + + return prompt; +} + } // namespace qwen } // namespace libllm diff --git a/src/libllm/qwen.h b/src/libllm/qwen.h index 3bdeee54..d9e9ff84 100644 --- a/src/libllm/qwen.h +++ b/src/libllm/qwen.h @@ -43,6 +43,7 @@ class QwenModelForGeneration : public llama::LlamaModelForGeneration { // override LlamaModelForGeneration bool isStopToken(int tokenId) const override; + Prompt buildPrompt(lut::Span history) const override; private: int _imStartId; diff --git a/go/skill/prompt_builder.go b/src/lutil/optional.h similarity index 70% rename from go/skill/prompt_builder.go rename to src/lutil/optional.h index 02e4986c..eb9cf601 100644 --- a/go/skill/prompt_builder.go +++ b/src/lutil/optional.h @@ -1,13 +1,13 @@ // The MIT License (MIT) // -// Copyright (c) 2024 Xiaoyang Chen +// Copyright (c) 2023 Xiaoyang Chen // // Permission is hereby granted, free of charge, to any person obtaining a copy of this software // and associated documentation files (the "Software"), to deal in the Software without // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -17,24 +17,26 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package skill +#pragma once + +namespace lut { + +struct nullopt_t { + nullopt_t() = default; +}; +constexpr nullopt_t nullopt{}; + +template +class optional { + public: + typedef T value_type; -import ( - "github.com/ling0322/libllm/go/llm" -) + constexpr optional() noexcept = default; + constexpr optional(nullopt_t) noexcept {} -type promptBuilder interface { - Build(history []Message) (llm.Prompt, error) -} + private: + bool _hasValue; + typename std::aligned_storage::type _storage; +}; -func newPromptBuilder(modelName string) (promptBuilder, error) { - if modelName == "llama" { - return &Llama{}, nil - } else if modelName == "index" { - return &BilibiliIndex{}, nil - } else if modelName == "qwen" { - return &Qwen{}, nil - } else { - return nil, ErrInvalidModelForChat - } -} +} // namespace lut