diff --git a/.clang-format b/.clang-format index 3d301cb3..6937be72 100644 --- a/.clang-format +++ b/.clang-format @@ -11,3 +11,5 @@ AllowShortFunctionsOnASingleLine: false PenaltyBreakAssignment: 2000 BreakConstructorInitializers: BeforeColon PackConstructorInitializers: Never +ReturnTypeBreakingStyle: Automatic +PenaltyReturnTypeOnItsOwnLine: 8000 \ No newline at end of file diff --git a/go/bin/go.mod b/go/bin/go.mod index 84e19cd6..2ec3aed3 100644 --- a/go/bin/go.mod +++ b/go/bin/go.mod @@ -1,13 +1,12 @@ module github.com/ling0322/libllm/go/bin -go 1.15 +go 1.21 replace github.com/ling0322/libllm/go/llm => ../llm -replace github.com/ling0322/libllm/go/i18n => ../i18n -replace github.com/ling0322/libllm/go/chat => ../chat + +replace github.com/ling0322/libllm/go/skill => ../skill require ( github.com/ling0322/libllm/go/llm v1.0.0 - github.com/ling0322/libllm/go/chat v1.0.0 - github.com/ling0322/libllm/go/i18n v1.0.0 + github.com/ling0322/libllm/go/skill v1.0.0 ) diff --git a/go/bin/go.sum b/go/bin/go.sum new file mode 100644 index 00000000..dc8e02bd --- /dev/null +++ b/go/bin/go.sum @@ -0,0 +1,2 @@ +github.com/ling0322/libllm/go/i18n v0.0.0-20240626100001-bf6e1d13e2be h1:W/7pkPYLRaHvWNCezcFNzv/WzKSl7wZpkgKndToLh94= +github.com/ling0322/libllm/go/i18n v0.0.0-20240626100001-bf6e1d13e2be/go.mod h1:+AS95R9P+RH73A5IROS+q0l6NoozpyfJGhrBurYiokM= diff --git a/go/i18n/i18n.go b/go/bin/i18n.go similarity index 98% rename from go/i18n/i18n.go rename to go/bin/i18n.go index c1161d7c..c3d19dee 100644 --- a/go/i18n/i18n.go +++ b/go/bin/i18n.go @@ -1,4 +1,4 @@ -package i18n +package bin import ( "errors" diff --git a/go/bin/llm/chat.go b/go/bin/llm/chat.go new file mode 100644 index 00000000..2eaf078c --- /dev/null +++ b/go/bin/llm/chat.go @@ -0,0 +1,105 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package main + +import ( + "bufio" + "errors" + "fmt" + "io" + "log" + "os" + "strings" + "time" + + "github.com/ling0322/libllm/go/llm" + "github.com/ling0322/libllm/go/skill" +) + +func chatMain(model llm.Model) { + llmChat, err := skill.NewChat(model) + if err != nil { + log.Fatal(err) + } + + fmt.Println(gLocalizer.Get(MsgInputQuestion)) + fmt.Println(gLocalizer.Get(MsgInputQuestionNew)) + fmt.Println(gLocalizer.Get(MsgInputQuestionSys)) + + history := []skill.Message{} + systemPrompt := "" + for { + reader := bufio.NewReader(os.Stdin) + + fmt.Print("> ") + question, err := reader.ReadString('\n') + if errors.Is(err, io.EOF) { + fmt.Println() + break + } else if err != nil { + log.Fatal(err) + } + question = strings.TrimSpace(question) + if len(question) > 5 && strings.ToLower(question)[0:5] == ":sys " { + systemPrompt = strings.TrimSpace(question[5:]) + history = []skill.Message{} + continue + } else if strings.ToLower(question) == ":new" { + fmt.Println(gLocalizer.Get(MsgNewSession)) + history = []skill.Message{} + continue + } else if question == "" { + continue + } + + if len(history) == 0 && systemPrompt != "" { + history = append(history, skill.Message{Role: "system", Content: systemPrompt}) + } + + history = append(history, skill.Message{Role: "user", Content: question}) + comp, err := llmChat.Chat(history) + if err != nil { + log.Fatal(err) + } + + t0 := time.Now() + answer := "" + numToken := 0 + for comp.Next() { + fmt.Print(comp.Text()) + answer += comp.Text() + numToken++ + } + if err := comp.Error(); err != nil { + log.Fatal(err) + } + + history = append(history, skill.Message{Role: "assistant", Content: answer}) + fmt.Println() + + dur := time.Since(t0) + fmt.Printf( + gLocalizer.Get(MsgStat), + numToken, + dur.Seconds(), + dur.Seconds()*1000/float64(numToken), + ) + } +} diff --git a/go/bin/llm/main.go b/go/bin/llm/main.go index 794cc34f..42ecf2bf 100644 --- a/go/bin/llm/main.go +++ b/go/bin/llm/main.go @@ -1,24 +1,19 @@ package main import ( - "bufio" - "errors" "flag" - "fmt" - "io" "log" - "os" - "path/filepath" "strings" - "time" - "github.com/ling0322/libllm/go/chat" - "github.com/ling0322/libllm/go/i18n" + "github.com/ling0322/libllm/go/bin" "github.com/ling0322/libllm/go/llm" ) var gModelPath string var gDevice string +var gTask string +var gAudioFile string +var gLocalizer *bin.Localizer const ( MsgInputQuestion = iota @@ -47,9 +42,12 @@ var gMsgEnUs = map[int]string{ func main() { flag.StringVar(&gModelPath, "model", "", "path of model file (.llmpkg)") flag.StringVar(&gDevice, "device", "audo", "inference device (cpu|cuda|audo)") + flag.StringVar(&gAudioFile, "audio", "", "the input audio file, only used in transcribe task") + flag.StringVar(&gTask, "task", "chat", "the task to run (chat|transcribe)") flag.Parse() - localizer, err := i18n.NewLocalizer(map[string]map[int]string{ + var err error + gLocalizer, err = bin.NewLocalizer(map[string]map[int]string{ "en_us": gMsgEnUs, "zh_cn": gMsgZhCn, }) @@ -68,19 +66,8 @@ func main() { log.Fatalf("unexpected device %s", gDevice) } - // if model is empty, automatically choose a *.llmpkg file in working directory. if gModelPath == "" { - llmpkgFiles, err := filepath.Glob("*.llmpkg") - if err != nil { - log.Fatal(err) - } - - if len(llmpkgFiles) != 1 { - fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) - flag.PrintDefaults() - } else { - gModelPath = llmpkgFiles[0] - } + log.Fatal("argument -model is required") } model, err := llm.NewModel(gModelPath, device) @@ -88,72 +75,9 @@ func main() { log.Fatal(err) } - llmChat, err := chat.NewChat(model) - if err != nil { - log.Fatal(err) - } - - fmt.Println(localizer.Get(MsgInputQuestion)) - fmt.Println(localizer.Get(MsgInputQuestionNew)) - fmt.Println(localizer.Get(MsgInputQuestionSys)) - - history := []chat.Message{} - systemPrompt := "" - for { - reader := bufio.NewReader(os.Stdin) - - fmt.Print("> ") - question, err := reader.ReadString('\n') - if errors.Is(err, io.EOF) { - fmt.Println() - break - } else if err != nil { - log.Fatal(err) - } - question = strings.TrimSpace(question) - if len(question) > 5 && strings.ToLower(question)[0:5] == ":sys " { - systemPrompt = strings.TrimSpace(question[5:]) - history = []chat.Message{} - continue - } else if strings.ToLower(question) == ":new" { - fmt.Println(localizer.Get(MsgNewSession)) - history = []chat.Message{} - continue - } else if question == "" { - continue - } - - if len(history) == 0 && systemPrompt != "" { - history = append(history, chat.Message{Role: "system", Content: systemPrompt}) - } - - history = append(history, chat.Message{Role: "user", Content: question}) - comp, err := llmChat.Chat(history) - if err != nil { - log.Fatal(err) - } - - t0 := time.Now() - answer := "" - numToken := 0 - for comp.IsActive() { - chunk, err := comp.GenerateNextChunk() - if err != nil { - log.Fatal(err) - } - fmt.Printf(chunk.Text) - answer += chunk.Text - numToken++ - } - history = append(history, chat.Message{Role: "assistant", Content: answer}) - fmt.Println() - - dur := time.Since(t0) - fmt.Printf( - localizer.Get(MsgStat), - numToken, - dur.Seconds(), - dur.Seconds()*1000/float64(numToken), - ) + if gTask == "chat" { + chatMain(model) + } else { + log.Fatalf("unexpected task: %s", gTask) } } diff --git a/go/bin/llm_transcribe/main.go b/go/bin/llm_transcribe/main.go new file mode 100644 index 00000000..f3a51a5a --- /dev/null +++ b/go/bin/llm_transcribe/main.go @@ -0,0 +1,78 @@ +package main + +import ( + "flag" + "log" + "strings" + + "github.com/ling0322/libllm/go/bin" + "github.com/ling0322/libllm/go/llm" +) + +var gModelPath string +var gDevice string +var gInputFile string +var gLocalizer *bin.Localizer +var gFfmpegBin string + +const ( + MsgInputQuestion = iota + MsgInputQuestionNew + MsgInputQuestionSys + MsgNewSession + MsgStat +) + +var gMsgZhCn = map[int]string{ + MsgInputQuestion: "请输入问题:", + MsgInputQuestionNew: " 输入 ':new' 重新开始一个新的对话 (清除历史).", + MsgInputQuestionSys: " 输入 ':sys <系统指令>' 设置对话的系统指令,并重新开始一个新的对话.", + MsgNewSession: "===== 新的对话 =====", + MsgStat: "(%d个Token, 总共耗时%.2f秒, 平均每个Token耗时%.2f毫秒)\n", +} + +var gMsgEnUs = map[int]string{ + MsgInputQuestion: "Please input your question.", + MsgInputQuestionNew: " Type ':new' to start a new session (clean history).", + MsgInputQuestionSys: " Type ':sys ' to set the system prompt and start a new session .", + MsgNewSession: "===== new session =====", + MsgStat: "(%d tokens, time=%.2fs, %.2fms per token)\n", +} + +func main() { + flag.StringVar(&gModelPath, "model", "", "path of model file (.llmpkg)") + flag.StringVar(&gDevice, "device", "audo", "inference device (cpu|cuda|audo)") + flag.StringVar(&gInputFile, "input", "", "the input audio file, only used in transcribe task") + flag.Parse() + + var err error + gLocalizer, err = bin.NewLocalizer(map[string]map[int]string{ + "en_us": gMsgEnUs, + "zh_cn": gMsgZhCn, + }) + if err != nil { + log.Fatal(err) + } + + var device llm.Device + if strings.ToLower(gDevice) == "cpu" { + device = llm.Cpu + } else if strings.ToLower(gDevice) == "cuda" { + device = llm.Cuda + } else if strings.ToLower(gDevice) == "audo" { + device = llm.Auto + } else { + log.Fatalf("unexpected device %s", gDevice) + } + + if gModelPath == "" { + log.Fatal("argument -model is required") + } + + model, err := llm.NewModel(gModelPath, device) + if err != nil { + log.Fatal(err) + } + + transcribeMain(model) +} diff --git a/go/llm/chunk.go b/go/bin/llm_transcribe/transcribe.go similarity index 65% rename from go/llm/chunk.go rename to go/bin/llm_transcribe/transcribe.go index 52fac61e..f808bb75 100644 --- a/go/llm/chunk.go +++ b/go/bin/llm_transcribe/transcribe.go @@ -17,42 +17,35 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package llm +package main -// #include -// #include "llm_api.h" -import "C" import ( - "errors" "fmt" + "log" "os" - "runtime" -) -// Generate by Compeltion. -type Chunk struct { - Text string -} + "github.com/ling0322/libllm/go/llm" + "github.com/ling0322/libllm/go/skill" +) -type chunkHandle struct { - handle *C.llmChunk_t -} +func transcribeMain(model llm.Model) { + if gInputFile == "" { + log.Fatal("argument -input is required for transcribe task") + } -func newChunkHandle() (*chunkHandle, error) { - cHandle := C.llmChunk_New() - if cHandle == nil { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) + fd, err := os.Open(gInputFile) + if err != nil { + log.Fatal(err) } + defer fd.Close() - handle := &chunkHandle{ - cHandle, + transcriber := skill.NewWhisperTranscriber(model, fd) + for transcriber.Transcribe() { + r := transcriber.Result() + fmt.Println(r.String()) } - runtime.SetFinalizer(handle, func(h *chunkHandle) { - status := C.llmChunk_Delete(h.handle) - if status != C.LLM_OK { - fmt.Fprintln(os.Stderr, "failed to call llmPrompt_Delete()") - } - }) - return handle, nil + if err = transcriber.Err(); err != nil { + log.Fatal(err) + } } diff --git a/go/i18n/posix.go b/go/bin/posix.go similarity index 95% rename from go/i18n/posix.go rename to go/bin/posix.go index ef1bd745..44d76239 100644 --- a/go/i18n/posix.go +++ b/go/bin/posix.go @@ -1,7 +1,7 @@ //go:build !windows // +build !windows -package i18n +package bin import ( "fmt" diff --git a/go/i18n/windows.go b/go/bin/windows.go similarity index 99% rename from go/i18n/windows.go rename to go/bin/windows.go index c96037d2..6b806192 100644 --- a/go/i18n/windows.go +++ b/go/bin/windows.go @@ -4,7 +4,7 @@ // Original source file: https://github.com/jeandeaual/go-locale/blob/master/locale_windows.go // LICENSE: MIT: https://github.com/jeandeaual/go-locale/blob/master/LICENSE -package i18n +package bin import ( "strings" diff --git a/go/chat/go.mod b/go/chat/go.mod deleted file mode 100644 index f5c2301a..00000000 --- a/go/chat/go.mod +++ /dev/null @@ -1,9 +0,0 @@ -module github.com/ling0322/libllm/go/chat - -go 1.15 - -replace github.com/ling0322/libllm/go/llm => ../llm - -require ( - github.com/ling0322/libllm/go/llm v1.0.0 -) diff --git a/go/i18n/go.mod b/go/i18n/go.mod deleted file mode 100644 index baa0a938..00000000 --- a/go/i18n/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module github.com/ling0322/libllm/go/i18n - -go 1.15 diff --git a/go/llm/completion.go b/go/llm/completion.go index f6e220de..5e4a2bbf 100644 --- a/go/llm/completion.go +++ b/go/llm/completion.go @@ -31,8 +31,16 @@ import ( // Config for LLM completion. type Completion interface { - IsActive() bool - GenerateNextChunk() (Chunk, error) + Next() bool + Error() error + Text() string + + // Get the name of current token. + // For a normal token Text() will return its byte piece, for example, "foo "; Token() will + // return its name in model, for example, "hello_". + // For a control token, Text() will return an empty string ("") and Token() will return its + // name, for example, "<|endoftext|>". + Token() string } type completionHandle struct { @@ -40,28 +48,53 @@ type completionHandle struct { } type completionImpl struct { - handle *completionHandle - chunkHandle *chunkHandle + handle *completionHandle + chunkText string + chunkToken string + err error } -func (c *completionImpl) IsActive() bool { - return C.llmCompletion_IsActive(c.handle.handle) != 0 +func (c *completionImpl) Next() bool { + ok := C.llmCompletion_Next(c.handle.handle) != 0 + if !ok { + return false + } + + cText := C.llmCompletion_GetText(c.handle.handle) + if cText == nil { + c.err = errors.New(C.GoString(C.llmGetLastErrorMessage())) + return false + } + c.chunkText = C.GoString(cText) + + cToken := C.llmCompletion_GetToken(c.handle.handle) + if cToken == nil { + c.err = errors.New(C.GoString(C.llmGetLastErrorMessage())) + return false + } + c.chunkToken = C.GoString(cToken) + + return true } -func (c *completionImpl) GenerateNextChunk() (Chunk, error) { - status := C.llmCompletion_GenerateNextChunk(c.handle.handle, c.chunkHandle.handle) - if status != C.LLM_OK { - return Chunk{}, errors.New(C.GoString(C.llmGetLastErrorMessage())) +func (c *completionImpl) Error() error { + if c.err != nil { + return c.err } - chunk := Chunk{} - cText := C.llmChunk_GetText(c.chunkHandle.handle) - if cText == nil { - return Chunk{}, errors.New(C.GoString(C.llmGetLastErrorMessage())) + if C.llmCompletion_GetError(c.handle.handle) != C.LLM_OK { + return errors.New(C.GoString(C.llmGetLastErrorMessage())) } - chunk.Text = C.GoString(cText) - return chunk, nil + return nil +} + +func (c *completionImpl) Text() string { + return c.chunkText +} + +func (c *completionImpl) Token() string { + return c.chunkToken } func newCompletionImpl(modelHandle *modelHandle) (*completionImpl, error) { @@ -70,14 +103,8 @@ func newCompletionImpl(modelHandle *modelHandle) (*completionImpl, error) { return nil, err } - chunkHandle, err := newChunkHandle() - if err != nil { - return nil, err - } - return &completionImpl{ - handle: handle, - chunkHandle: chunkHandle, + handle: handle, }, nil } diff --git a/go/llm/completion_config.go b/go/llm/completion_config.go index dbaecfff..86e47db8 100644 --- a/go/llm/completion_config.go +++ b/go/llm/completion_config.go @@ -22,7 +22,10 @@ package llm // #include // #include "llm_api.h" import "C" -import "errors" +import ( + "errors" + "unsafe" +) // Config for LLM completion. type CompletionConfig interface { @@ -35,6 +38,8 @@ type CompletionConfig interface { SetTemperature(temperature float32) GetTemperature() float32 + SetConfig(key, value string) + // update the llmCompletion_t according to the config. updateCompHandle(compHandle *completionHandle) error } @@ -43,6 +48,8 @@ type completionConfigImpl struct { topP float32 topK int temperature float32 + + kvConfig map[string]string } func NewCompletionConfig() CompletionConfig { @@ -50,6 +57,7 @@ func NewCompletionConfig() CompletionConfig { topP: 0.8, topK: 50, temperature: 1.0, + kvConfig: map[string]string{}, } } @@ -77,6 +85,10 @@ func (c *completionConfigImpl) GetTemperature() float32 { return c.temperature } +func (c *completionConfigImpl) SetConfig(key, value string) { + c.kvConfig[key] = value +} + func (c *completionConfigImpl) updateCompHandle(compHandle *completionHandle) error { if C.llmCompletion_SetTopP(compHandle.handle, C.float(c.topP)) != C.LLM_OK { return errors.New(C.GoString(C.llmGetLastErrorMessage())) @@ -90,5 +102,16 @@ func (c *completionConfigImpl) updateCompHandle(compHandle *completionHandle) er return errors.New(C.GoString(C.llmGetLastErrorMessage())) } + for key, value := range c.kvConfig { + cKey := C.CString(key) + cValue := C.CString(value) + ok := C.llmCompletion_SetConfig(compHandle.handle, cKey, cValue) + C.free(unsafe.Pointer(cKey)) + C.free(unsafe.Pointer(cValue)) + if ok != C.LLM_OK { + return errors.New(C.GoString(C.llmGetLastErrorMessage())) + } + } + return nil } diff --git a/go/llm/llm.go b/go/llm/llm.go index beafb9b9..2b57e3dc 100644 --- a/go/llm/llm.go +++ b/go/llm/llm.go @@ -35,11 +35,13 @@ import ( ) type Device int32 +type AudioFormat int32 const ( - Cpu = Device(0x0000) - Cuda = Device(0x0100) - Auto = Device(0x1f00) + Cpu = Device(0x0000) + Cuda = Device(0x0100) + Auto = Device(0x1f00) + Pcm16kHz16BitMono = AudioFormat(0x0001) ) var gInit atomic.Bool diff --git a/go/llm/llm_api.c b/go/llm/llm_api.c index fe6bfe79..d0497db0 100644 --- a/go/llm/llm_api.c +++ b/go/llm/llm_api.c @@ -17,15 +17,12 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -#ifndef LIBLLM_LLM_API_ -#define LIBLLM_LLM_API_ - #ifdef __APPLE__ #define LUT_PLATFORM_APPLE #elif defined(linux) || defined(__linux) || defined(__linux__) #define LUT_PLATFORM_LINUX -#elif defined(WIN32) || defined(__WIN32__) || defined(_MSC_VER) || \ - defined(_WIN32) || defined(__MINGW32__) +#elif defined(WIN32) || defined(__WIN32__) || defined(_MSC_VER) || defined(_WIN32) || \ + defined(__MINGW32__) #define LUT_PLATFORM_WINDOWS #else #error unknown platform @@ -33,7 +30,7 @@ #if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) #include -typedef void* LLM_HMODULE; +typedef void *LLM_HMODULE; #elif defined(LUT_PLATFORM_WINDOWS) #include typedef HMODULE LLM_HMODULE; @@ -42,19 +39,7 @@ typedef HMODULE LLM_HMODULE; #include #include -#define LLM_DEVICE_CPU 0x0000 -#define LLM_DEVICE_CUDA 0x0100 -#define LLM_DEVICE_AUTO 0x1f00 -#define LLM_API_VERSION 20240101 -#define LLM_OK 0 -#define LLM_ABORTED 1 - -typedef int32_t llmStatus_t; -typedef struct llmModel_t llmModel_t; -typedef struct llmChunk_t llmChunk_t; -typedef struct llmPrompt_t llmPrompt_t; -typedef struct llmCompletion_t llmCompletion_t; -typedef int32_t llmBool_t; +#include "llm_api.h" // global state llmStatus_t (*p_llmInit)(int32_t apiVersion); @@ -70,10 +55,15 @@ llmStatus_t (*p_llmModel_Load)(llmModel_t *model); const char *(*p_llmModel_GetName)(llmModel_t *model); // llmPrompt_t -llmPrompt_t *(*p_llmPrompt_New)(llmModel_t *model); +llmPrompt_t *(*p_llmPrompt_New)(); llmStatus_t (*p_llmPrompt_Delete)(llmPrompt_t *prompt); llmStatus_t (*p_llmPrompt_AppendText)(llmPrompt_t *prompt, const char *text); llmStatus_t (*p_llmPrompt_AppendControlToken)(llmPrompt_t *prompt, const char *token); +llmStatus_t (*p_llmPrompt_AppendAudio)( + llmPrompt_t *prompt, + const llmByte_t *audio, + int64_t size, + int32_t format); // llmCompletion_t llmCompletion_t *(*p_llmCompletion_New)(llmModel_t *model); @@ -82,24 +72,20 @@ llmStatus_t (*p_llmCompletion_SetPrompt)(llmCompletion_t *comp, llmPrompt_t *pro llmStatus_t (*p_llmCompletion_SetTopP)(llmCompletion_t *comp, float topP); llmStatus_t (*p_llmCompletion_SetTopK)(llmCompletion_t *comp, int32_t topK); llmStatus_t (*p_llmCompletion_SetTemperature)(llmCompletion_t *comp, float temperature); -llmStatus_t (*p_llmCompletion_Start)(llmCompletion_t *comp); -llmBool_t (*p_llmCompletion_IsActive)(llmCompletion_t *comp); -llmStatus_t (*p_llmCompletion_GenerateNextChunk)(llmCompletion_t *comp, llmChunk_t *chunk); - -// llmChunk_t -llmChunk_t *(*p_llmChunk_New)(); -llmStatus_t (*p_llmChunk_Delete)(llmChunk_t *chunk); -const char *(*p_llmChunk_GetText)(llmChunk_t *chunk); +llmStatus_t (*p_llmCompletion_SetConfig)(llmCompletion_t *comp, const char *key, const char *value); +llmBool_t (*p_llmCompletion_Next)(llmCompletion_t *comp); +llmStatus_t (*p_llmCompletion_GetError)(llmCompletion_t *comp); +const char *(*p_llmCompletion_GetText)(llmCompletion_t *comp); +const char *(*p_llmCompletion_GetToken)(llmCompletion_t *comp); // load the libllm shared library. -LLM_HMODULE llmLoadLibrary(const char *libraryPath) { +void *llmLoadLibrary(const char *libraryPath) { // first try to load the dll from same folder as current module. #if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) return dlopen(libraryPath, RTLD_NOW); #elif defined(LUT_PLATFORM_WINDOWS) return LoadLibraryA(libraryPath); #endif - } #if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) @@ -109,13 +95,15 @@ LLM_HMODULE llmLoadLibrary(const char *libraryPath) { #endif #define LOAD_SYMBOL(hDll, symbol) \ - p_##symbol = GET_PROC_ADDR(hDll, #symbol); \ + p_##symbol = GET_PROC_ADDR(hDll, #symbol); \ if (!p_##symbol) { \ fprintf(stderr, "llm.go: unable to load symbol: %s\n", #symbol); \ return LLM_ABORTED; \ } -llmStatus_t llmLoadSymbols(LLM_HMODULE hDll) { +llmStatus_t llmLoadSymbols(void *pDll) { + LLM_HMODULE hDll = (LLM_HMODULE)pDll; + LOAD_SYMBOL(hDll, llmInit); LOAD_SYMBOL(hDll, llmDestroy); LOAD_SYMBOL(hDll, llmGetLastErrorMessage); @@ -129,24 +117,24 @@ llmStatus_t llmLoadSymbols(LLM_HMODULE hDll) { LOAD_SYMBOL(hDll, llmPrompt_Delete); LOAD_SYMBOL(hDll, llmPrompt_AppendText); LOAD_SYMBOL(hDll, llmPrompt_AppendControlToken); + LOAD_SYMBOL(hDll, llmPrompt_AppendAudio); LOAD_SYMBOL(hDll, llmCompletion_New); LOAD_SYMBOL(hDll, llmCompletion_Delete); LOAD_SYMBOL(hDll, llmCompletion_SetPrompt); LOAD_SYMBOL(hDll, llmCompletion_SetTopP); LOAD_SYMBOL(hDll, llmCompletion_SetTopK); LOAD_SYMBOL(hDll, llmCompletion_SetTemperature); - LOAD_SYMBOL(hDll, llmCompletion_Start); - LOAD_SYMBOL(hDll, llmCompletion_IsActive); - LOAD_SYMBOL(hDll, llmCompletion_GenerateNextChunk); - LOAD_SYMBOL(hDll, llmChunk_New); - LOAD_SYMBOL(hDll, llmChunk_Delete); - LOAD_SYMBOL(hDll, llmChunk_GetText); + LOAD_SYMBOL(hDll, llmCompletion_SetConfig); + LOAD_SYMBOL(hDll, llmCompletion_Next); + LOAD_SYMBOL(hDll, llmCompletion_GetError); + LOAD_SYMBOL(hDll, llmCompletion_GetText); + LOAD_SYMBOL(hDll, llmCompletion_GetToken); return LLM_OK; } // load the libllm shared library. -llmStatus_t llmDestroyLibrary(LLM_HMODULE handle) { +llmStatus_t llmDestroyLibrary(void *handle) { p_llmInit = NULL; p_llmDestroy = NULL; p_llmGetLastErrorMessage = NULL; @@ -160,18 +148,18 @@ llmStatus_t llmDestroyLibrary(LLM_HMODULE handle) { p_llmPrompt_Delete = NULL; p_llmPrompt_AppendText = NULL; p_llmPrompt_AppendControlToken = NULL; + p_llmPrompt_AppendAudio = NULL; p_llmCompletion_New = NULL; p_llmCompletion_Delete = NULL; p_llmCompletion_SetPrompt = NULL; p_llmCompletion_SetTopP = NULL; p_llmCompletion_SetTopK = NULL; p_llmCompletion_SetTemperature = NULL; - p_llmCompletion_Start = NULL; - p_llmCompletion_IsActive = NULL; - p_llmCompletion_GenerateNextChunk = NULL; - p_llmChunk_New = NULL; - p_llmChunk_Delete = NULL; - p_llmChunk_GetText = NULL; + p_llmCompletion_SetConfig = NULL; + p_llmCompletion_Next = NULL; + p_llmCompletion_GetError = NULL; + p_llmCompletion_GetText = NULL; + p_llmCompletion_GetToken = NULL; // first try to load the dll from same folder as current module. #if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) @@ -180,7 +168,7 @@ llmStatus_t llmDestroyLibrary(LLM_HMODULE handle) { return LLM_ABORTED; } #elif defined(LUT_PLATFORM_WINDOWS) - BOOL success = FreeLibrary(handle); + BOOL success = FreeLibrary((LLM_HMODULE)handle); if (FALSE == success) { return LLM_ABORTED; } @@ -227,8 +215,8 @@ const char *llmModel_GetName(llmModel_t *model) { } // llmPrompt_t -llmPrompt_t *llmPrompt_New(llmModel_t *model) { - return p_llmPrompt_New(model); +llmPrompt_t *llmPrompt_New() { + return p_llmPrompt_New(); } llmStatus_t llmPrompt_Delete(llmPrompt_t *prompt) { @@ -243,6 +231,14 @@ llmStatus_t llmPrompt_AppendControlToken(llmPrompt_t *prompt, const char *token) return p_llmPrompt_AppendControlToken(prompt, token); } +llmStatus_t llmPrompt_AppendAudio( + llmPrompt_t *prompt, + const llmByte_t *audio, + int64_t size, + int32_t format) { + return p_llmPrompt_AppendAudio(prompt, audio, size, format); +} + // llmCompletion_t llmCompletion_t *llmCompletion_New(llmModel_t *model) { return p_llmCompletion_New(model); @@ -268,29 +264,22 @@ llmStatus_t llmCompletion_SetTemperature(llmCompletion_t *comp, float temperatur return p_llmCompletion_SetTemperature(comp, temperature); } -llmStatus_t llmCompletion_Start(llmCompletion_t *comp) { - return p_llmCompletion_Start(comp); -} - -llmBool_t llmCompletion_IsActive(llmCompletion_t *comp) { - return p_llmCompletion_IsActive(comp); -} - -llmStatus_t llmCompletion_GenerateNextChunk(llmCompletion_t *comp, llmChunk_t *chunk) { - return p_llmCompletion_GenerateNextChunk(comp, chunk); +llmStatus_t llmCompletion_SetConfig(llmCompletion_t *comp, const char *key, const char *value) { + return p_llmCompletion_SetConfig(comp, key, value); } -// llmChunk_t -llmChunk_t *llmChunk_New() { - return p_llmChunk_New(); +llmBool_t llmCompletion_Next(llmCompletion_t *comp) { + return p_llmCompletion_Next(comp); } -llmStatus_t llmChunk_Delete(llmChunk_t *chunk) { - return p_llmChunk_Delete(chunk); +llmStatus_t llmCompletion_GetError(llmCompletion_t *comp) { + return p_llmCompletion_GetError(comp); } -const char *llmChunk_GetText(llmChunk_t *chunk) { - return p_llmChunk_GetText(chunk); +const char *llmCompletion_GetText(llmCompletion_t *comp) { + return p_llmCompletion_GetText(comp); } -#endif // LIBLLM_LLM_API_ +const char *llmCompletion_GetToken(llmCompletion_t *comp) { + return p_llmCompletion_GetToken(comp); +} \ No newline at end of file diff --git a/go/llm/llm_api.h b/go/llm/llm_api.h index dd1971e5..d1317f2b 100644 --- a/go/llm/llm_api.h +++ b/go/llm/llm_api.h @@ -26,7 +26,9 @@ #define LLM_DEVICE_CUDA 0x0100 #define LLM_DEVICE_AUTO 0x1f00 #define LLM_API_VERSION 20240101 +#define LLM_WAVE_FORMAT_PCM16KHZ16BITMONO 0x0001 #define LLM_OK 0 +#define LLM_ABORTED 1 typedef int32_t llmStatus_t; typedef struct llmModel_t llmModel_t; @@ -34,6 +36,7 @@ typedef struct llmChunk_t llmChunk_t; typedef struct llmPrompt_t llmPrompt_t; typedef struct llmCompletion_t llmCompletion_t; typedef int32_t llmBool_t; +typedef int8_t llmByte_t; void *llmLoadLibrary(const char *libraryPath); llmStatus_t llmLoadSymbols(void *hDll); @@ -53,25 +56,28 @@ llmStatus_t llmModel_Load(llmModel_t *model); const char *llmModel_GetName(llmModel_t *model); // llmPrompt_t -llmPrompt_t *llmPrompt_New(llmModel_t *model); +llmPrompt_t *llmPrompt_New(); llmStatus_t llmPrompt_Delete(llmPrompt_t *prompt); llmStatus_t llmPrompt_AppendText(llmPrompt_t *prompt, const char *text); llmStatus_t llmPrompt_AppendControlToken(llmPrompt_t *prompt, const char *token); +llmStatus_t llmPrompt_AppendAudio( + llmPrompt_t *prompt, + const llmByte_t *audio, + int64_t size, + int32_t format); + // llmCompletion_t llmCompletion_t *llmCompletion_New(llmModel_t *model); llmStatus_t llmCompletion_Delete(llmCompletion_t *comp); +llmStatus_t llmCompletion_SetConfig(llmCompletion_t *comp, const char *key, const char *value); llmStatus_t llmCompletion_SetPrompt(llmCompletion_t *comp, llmPrompt_t *prompt); llmStatus_t llmCompletion_SetTopP(llmCompletion_t *comp, float topP); llmStatus_t llmCompletion_SetTopK(llmCompletion_t *comp, int32_t topK); llmStatus_t llmCompletion_SetTemperature(llmCompletion_t *comp, float temperature); -llmStatus_t llmCompletion_Start(llmCompletion_t *comp); -llmBool_t llmCompletion_IsActive(llmCompletion_t *comp); -llmStatus_t llmCompletion_GenerateNextChunk(llmCompletion_t *comp, llmChunk_t *chunk); - -// llmChunk_t -llmChunk_t *llmChunk_New(); -llmStatus_t llmChunk_Delete(llmChunk_t *chunk); -const char *llmChunk_GetText(llmChunk_t *chunk); +llmBool_t llmCompletion_Next(llmCompletion_t *comp); +llmStatus_t llmCompletion_GetError(llmCompletion_t *comp); +const char *llmCompletion_GetText(llmCompletion_t *comp); +const char *llmCompletion_GetToken(llmCompletion_t *comp); #endif // LIBLLM_LLM_API_ diff --git a/go/llm/model.go b/go/llm/model.go index c3c4dea2..d47d1da1 100644 --- a/go/llm/model.go +++ b/go/llm/model.go @@ -87,7 +87,7 @@ func (m *modelImpl) Complete(config CompletionConfig, prompt Prompt) (Completion return nil, err } - promptHandle, err := newPromptHandle(m.handle) + promptHandle, err := newPromptHandle() if err != nil { return nil, err } @@ -102,11 +102,6 @@ func (m *modelImpl) Complete(config CompletionConfig, prompt Prompt) (Completion return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) } - ok = C.llmCompletion_Start(comp.handle.handle) - if ok != C.LLM_OK { - return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) - } - return comp, nil } diff --git a/go/llm/prompt.go b/go/llm/prompt.go index f7f1132a..07708c55 100644 --- a/go/llm/prompt.go +++ b/go/llm/prompt.go @@ -34,6 +34,7 @@ import ( type Prompt interface { AppendText(text string) AppendControlToken(text string) + AppendAudio(data []byte, format AudioFormat) // Update the llmPrompt_t instance according to the current prompt. updatePromptHandle(handle *promptHandle) error @@ -55,6 +56,11 @@ type controlTokenPromptElem struct { token string } +type audioPromptElem struct { + payload []byte + format AudioFormat +} + type promptHandle struct { handle *C.llmPrompt_t } @@ -71,6 +77,10 @@ func (p *promptImpl) AppendControlToken(text string) { p.elements = append(p.elements, &controlTokenPromptElem{text}) } +func (p *promptImpl) AppendAudio(audio []byte, format AudioFormat) { + p.elements = append(p.elements, &audioPromptElem{audio, format}) +} + func (p *promptImpl) updatePromptHandle(handle *promptHandle) error { if len(p.elements) == 0 { return errors.New("prompt is empty") @@ -110,8 +120,24 @@ func (e *controlTokenPromptElem) AppendTo(handle *promptHandle) error { return nil } -func newPromptHandle(m *modelHandle) (*promptHandle, error) { - cHandle := C.llmPrompt_New(m.handle) +func (e *audioPromptElem) AppendTo(handle *promptHandle) error { + cPayload := C.CBytes(e.payload) + defer C.free(unsafe.Pointer(cPayload)) + + status := C.llmPrompt_AppendAudio( + handle.handle, + (*C.llmByte_t)(cPayload), + C.int64_t(len(e.payload)), + C.int32_t(Pcm16kHz16BitMono)) + if status != C.LLM_OK { + return errors.New(C.GoString(C.llmGetLastErrorMessage())) + } + + return nil +} + +func newPromptHandle() (*promptHandle, error) { + cHandle := C.llmPrompt_New() if cHandle == nil { return nil, errors.New(C.GoString(C.llmGetLastErrorMessage())) } diff --git a/go/chat/bilibili_index.go b/go/skill/bilibili_index.go similarity index 99% rename from go/chat/bilibili_index.go rename to go/skill/bilibili_index.go index e523a31b..23c49081 100644 --- a/go/chat/bilibili_index.go +++ b/go/skill/bilibili_index.go @@ -17,7 +17,7 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package chat +package skill import "github.com/ling0322/libllm/go/llm" diff --git a/go/chat/chat.go b/go/skill/chat.go similarity index 99% rename from go/chat/chat.go rename to go/skill/chat.go index 56ffb64c..68dbcf64 100644 --- a/go/chat/chat.go +++ b/go/skill/chat.go @@ -17,7 +17,7 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package chat +package skill import ( "github.com/ling0322/libllm/go/llm" diff --git a/go/skill/go.mod b/go/skill/go.mod new file mode 100644 index 00000000..83f4c9fd --- /dev/null +++ b/go/skill/go.mod @@ -0,0 +1,7 @@ +module github.com/ling0322/libllm/go/skill + +go 1.21 + +replace github.com/ling0322/libllm/go/llm => ../llm + +require github.com/ling0322/libllm/go/llm v1.0.0 diff --git a/go/chat/llama.go b/go/skill/llama.go similarity index 99% rename from go/chat/llama.go rename to go/skill/llama.go index aec05018..70744ab5 100644 --- a/go/chat/llama.go +++ b/go/skill/llama.go @@ -17,7 +17,7 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package chat +package skill import "github.com/ling0322/libllm/go/llm" diff --git a/go/chat/prompt_builder.go b/go/skill/prompt_builder.go similarity index 99% rename from go/chat/prompt_builder.go rename to go/skill/prompt_builder.go index 32643acd..ead49531 100644 --- a/go/chat/prompt_builder.go +++ b/go/skill/prompt_builder.go @@ -17,7 +17,7 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package chat +package skill import ( "fmt" diff --git a/go/skill/transcriber.go b/go/skill/transcriber.go new file mode 100644 index 00000000..c1fd84d7 --- /dev/null +++ b/go/skill/transcriber.go @@ -0,0 +1,42 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package skill + +import ( + "fmt" + "time" +) + +type TranscriptionResult struct { + Begin time.Duration + End time.Duration + Language string + Text string +} + +type Transcriber interface { + Transcribe() bool + Result() TranscriptionResult + Err() error +} + +func (r *TranscriptionResult) String() string { + return fmt.Sprintf("%8s - %8s: %s", r.Begin.String(), r.End.String(), r.Text) +} diff --git a/go/skill/whisper.go b/go/skill/whisper.go new file mode 100644 index 00000000..bfae92dd --- /dev/null +++ b/go/skill/whisper.go @@ -0,0 +1,375 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package skill + +import ( + "bytes" + "errors" + "io" + "log/slog" + "math" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strconv" + "sync" + "time" + + "github.com/ling0322/libllm/go/llm" +) + +var regexpLangToken = regexp.MustCompile(`^<\|([a-z][a-z][a-z]?)\|>$`) +var getFfmpegBin = sync.OnceValue[string](getFfmpegBinInternal) + +var ErrInvalidWhisperSequence = errors.New("invalid sequence for Whisper model") +var ErrStreamIsNil = errors.New("input stream is nil") +var ErrWhisperModelIsNil = errors.New("whisper model is nil") +var ErrNoMoreResults = errors.New("no more results") +var ErrAudioEndOfStream = errors.New("audio end of stream") + +const SampleRate = 16000 +const BytesPerSample = 2 + +const ( + whisperStateAudio = iota + whisperStateStartOfTranscription + whisperStateLanguage + whisperStateTranscribe + whisperStateBeginTime + whisperStateText + whisperStateEndTime +) + +// the transcriber with whisper model. implements the interface Transcriber. +type WhisperTranscriber struct { + // the whisper model. + WhisperModel llm.Model + + // the reader for input file. + InputStream io.Reader + + // current state in whisper sequence decoding. + state int + + // if any errors occured. + err error + + // internal completion object for llm API. + comp llm.Completion + + // the current transcription result. + result TranscriptionResult + + // the predicted language. + predictedLanguage string + + // the wave bytes for decoding. The format is 16khz 16bit mono-channel PCM without headfer. + wavePayload []byte + + // offset of the current segment in wavePayload. + waveOffset time.Duration +} + +// create a new instance of WhisperTranscriber from whisper model and stream of input file. +func NewWhisperTranscriber(whisperModel llm.Model, inputStream io.Reader) *WhisperTranscriber { + return &WhisperTranscriber{ + WhisperModel: whisperModel, + InputStream: inputStream, + state: whisperStateAudio, + } +} + +// find the path of ffmpeg. +func getFfmpegBinInternal() string { + ffmpegBin := "ffmpeg" + if runtime.GOOS == "windows" { + ffmpegBin += ".exe" + } + + cmd := exec.Command(ffmpegBin, "-version") + err := cmd.Run() + if err == nil { + // ffmpeg in $PATH + return ffmpegBin + } + + binPath, err := os.Executable() + if err != nil { + return "" + } + + binDir := filepath.Dir(binPath) + ffmpegPath := filepath.Join(binDir, ffmpegBin) + _, err = os.Stat(ffmpegPath) + if err != nil { + return "" + } + + // ffmpeg found in the dir as llm, check it. + cmd = exec.Command(ffmpegPath, "-version") + err = cmd.Run() + if err != nil { + return "" + } + + return ffmpegPath +} + +// convert the input file to pcm .wav file in OS temporary directory using ffmpeg. +func convertToPcm(inputStream io.Reader) ([]byte, error) { + ffmpegBin := getFfmpegBin() + if ffmpegBin == "" { + return nil, errors.New("unable to find ffmpeg") + } + + // ffmpeg found in the dir as llm, check it. + cmd := exec.Command( + ffmpegBin, "-hide_banner", "-nostdin", "-vn", "-threads", "0", "-i", "-", "-f", + "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", "16000", "-") + cmd.Stdin = inputStream + + var dataBuffer, errBuffer bytes.Buffer + cmd.Stdout = &dataBuffer + cmd.Stderr = &errBuffer + if err := cmd.Run(); err != nil { + slog.Error("ffmpeg failed", "stderr", errBuffer.String()) + return nil, err + } + + return dataBuffer.Bytes(), nil +} + +// parse a whisper timestamp token like <|3.22|>. On success. return the (parsed-time, true). On +// failed, return (time.Duration(0), false) +func (w *WhisperTranscriber) parseTimestampToken(token string) (time.Duration, bool) { + if token == "" { + return time.Duration(0), false + } + + if len(token) < 8 || token[:2] != "<|" || token[len(token)-2:] != "|>" { + return time.Duration(0), false + } + + offset, err := strconv.ParseFloat(token[2:len(token)-2], 64) + if err != nil { + return time.Duration(0), false + } + + return time.Duration(math.Round(offset*1000) * float64(time.Millisecond)), true +} + +// complete next token from whisper model. If completion ends, return ErrNoMoreResults. For other +// errors, return it directly. +func (w *WhisperTranscriber) completeNext() error { + ok := w.comp.Next() + slog.Debug("completeNext()", "token", w.comp.Token(), "piece", w.comp.Text()) + if w.comp.Error() != nil { + return w.comp.Error() + } else if !ok { + return ErrNoMoreResults + } + + return nil +} + +// decode one transcription from Whisper. Here transcription means a piece of text wrapped by begin +// and end timestamp tokens. On success, returns the result. If the whisper model generation ends +// in the begining or half of the transcription, return ErrNoMoreResults. +func (w *WhisperTranscriber) decodeTranscription() (TranscriptionResult, error) { + result := TranscriptionResult{Language: w.predictedLanguage} + + err := w.completeNext() + if err != nil { + return TranscriptionResult{}, err + } + + beginOffset, ok := w.parseTimestampToken(w.comp.Token()) + if !ok { + return TranscriptionResult{}, ErrInvalidWhisperSequence + } + result.Begin = w.waveOffset + beginOffset + + transcriptionDone := false + for w.comp.Next() { + token := w.comp.Token() + piece := w.comp.Text() + slog.Debug("comp.next()", "token", token, "piece", piece) + offset, isTimestampToken := w.parseTimestampToken(token) + if isTimestampToken { + result.End = w.waveOffset + offset + transcriptionDone = true + break + } + + result.Text += piece + } + if w.comp.Error() != nil { + return TranscriptionResult{}, err + } else if !transcriptionDone { + // generation stops at half of the transcription. + return TranscriptionResult{}, ErrNoMoreResults + } + + return result, nil +} + +// parse a language token and return the language name. +func (w *WhisperTranscriber) parseLanguageToken(token string) (lang string, ok bool) { + match := regexpLangToken.FindStringSubmatch(token) + if len(match) == 0 { + return "", false + } + + return match[1], true +} + +// prefill audio and prompt when in the begining of decoding or last audio segment finished. If no +// transcriotion result or <|nospeech|> got, return ErrNoMoreResults. +func (w *WhisperTranscriber) prefillNextAudioSegment() error { + slog.Debug("prefill segment", "offset", w.waveOffset) + nsPerSample := 1000000000 / SampleRate + sampleOffset := int(w.waveOffset.Nanoseconds() / int64(nsPerSample)) + byteOffset := sampleOffset * 2 + if len(w.wavePayload)-byteOffset < SampleRate/10 { + // ignore the last segment that less than 0.1s. + return ErrAudioEndOfStream + } + + nBytes := min(len(w.wavePayload)-byteOffset, 30*SampleRate*2) + audio := w.wavePayload[byteOffset : byteOffset+nBytes] + + prompt := llm.NewPrompt() + prompt.AppendAudio(audio, llm.Pcm16kHz16BitMono) + prompt.AppendControlToken("<|startoftranscript|>") + + compConfig := llm.NewCompletionConfig() + compConfig.SetTopK(1) + compConfig.SetTemperature(1.5) + compConfig.SetConfig("generator.type", "whisper") + comp, err := w.WhisperModel.Complete(compConfig, prompt) + if err != nil { + return err + } + + // first token would be language token or <|nospeech|> + w.comp = comp + err = w.completeNext() + if err != nil { + return err + } + + // exit once no speech for the whole audio segment. + if w.comp.Token() == "<|nospeech|>" { + return ErrNoMoreResults + } + + // language token. + lang, ok := w.parseLanguageToken(w.comp.Token()) + if !ok { + return ErrInvalidWhisperSequence + } + w.predictedLanguage = lang + + // transcribe token + err = w.completeNext() + if err != nil { + return err + } + + if w.comp.Token() != "<|transcribe|>" { + return ErrInvalidWhisperSequence + } + + // setup the compeltion done. + return nil +} + +// implements interface Transcriber. +func (w *WhisperTranscriber) Transcribe() bool { + if w.wavePayload == nil { + w.wavePayload, w.err = convertToPcm(w.InputStream) + } + if w.err != nil { + return false + } + + if w.WhisperModel == nil { + w.err = ErrWhisperModelIsNil + return false + } + + if w.InputStream == nil { + w.err = ErrStreamIsNil + return false + } + + // loop until we got a valid transcription. + for { + beginOfSegment := false + if w.comp == nil { + w.err = w.prefillNextAudioSegment() + beginOfSegment = true + } + + if errors.Is(w.err, ErrNoMoreResults) { + w.comp = nil + w.err = nil + w.waveOffset += 30 * time.Second + continue + } else if errors.Is(w.err, ErrAudioEndOfStream) { + w.comp = nil + w.err = nil + return false + } else if w.err != nil { + return false + } + + result, err := w.decodeTranscription() + if errors.Is(err, ErrNoMoreResults) && beginOfSegment { + // if no result for the whole audio segment, move forward to the next 30s segment. + w.comp = nil + w.waveOffset += 30 * time.Second + continue + } else if errors.Is(err, ErrNoMoreResults) && !beginOfSegment { + // move the wave offset to the end of last completed transcription. + w.comp = nil + w.waveOffset = w.result.End + continue + } else if err != nil { + w.err = err + return false + } + + w.result = result + return true + } +} + +// implements interface Transcriber. +func (w *WhisperTranscriber) Result() TranscriptionResult { + return w.result +} + +// implements interface Transcriber. +func (w *WhisperTranscriber) Err() error { + return w.err +} diff --git a/src/libllm/CMakeLists.txt b/src/libllm/CMakeLists.txt index c85b5118..704af33d 100644 --- a/src/libllm/CMakeLists.txt +++ b/src/libllm/CMakeLists.txt @@ -7,6 +7,7 @@ endif() set(lut_SOURCES "lut/internal/log.cc" "lut/internal/sprintf.cc" + "lut/base64.cc" "lut/error.cc" "lut/flags.cc" "lut/half.cc" @@ -32,47 +33,52 @@ set(libllm_SOURCES "cpu/copy.cc" "cpu/cpu_operators.cc" "cpu/cpu_tensor_data.cc" + "cpu/fill.cc" "cpu/fingerprint.cc" + "cpu/gelu.cc" + "cpu/log_mel_spectrogram.cc" "cpu/lookup.cc" "cpu/matmul.cc" + "cpu/normalizations.cc" "cpu/print.cc" "cpu/rand.cc" - "cpu/rms_norm.cc" + "cpu/reduce.cc" "cpu/softmax.cc" "cpu/swiglu.cc" "cpu/tensor.cc" "cpu/transform.cc" + "cpu/unfold.cc" "cpu/view.cc" "bpe_config.cc" "bpe_encoder.cc" "bpe_model.cc" - "chatglm.cc" - "c_api.cc" "context.cc" "device.cc" "dtype.cc" "functional.cc" "generator.cc" "llama.cc" + "llm.cc" "model_for_generation.cc" "module.cc" "mp.cc" "operators.cc" + "prompt.cc" "qwen.cc" - "sampler.cc" "state_map.cc" "tensor.cc" "tokenizer.cc" + "wave.cc" + "whisper.cc" "../../third_party/ruapu/ruapu.cc") set(unittest_SOURCES "cpu/kernel/benchmark.cc" "cpu/kernel/interface_test.cc" - # "cpu/kernel/unittest_kernel.cc" + "cpu/log_mel_spectrogram_test.cc" "cpu/test.cc" "lut/path_test.cc" "lut/strings_test.cc" - "chatglm_test.cc" "llama_test.cc" "module_test.cc" "operator_tester.cc" @@ -255,4 +261,12 @@ add_custom_target(llmbin ALL DEPENDS libllm COMMAND go build -o $>/llm${CMAKE_EXECUTABLE_SUFFIX} ${CMAKE_SOURCE_DIR}/go/bin/llm + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/go/bin) + +add_custom_target(llmtranscribe + ALL + DEPENDS libllm + COMMAND go build + -o $>/llm-transcribe${CMAKE_EXECUTABLE_SUFFIX} + ${CMAKE_SOURCE_DIR}/go/bin/llm_transcribe WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/go/bin) \ No newline at end of file diff --git a/src/libllm/benchmark_main.cc b/src/libllm/benchmark_main.cc index 17c0a680..f7a0a3e0 100644 --- a/src/libllm/benchmark_main.cc +++ b/src/libllm/benchmark_main.cc @@ -30,8 +30,8 @@ #include "libllm/lut/flags.h" #include "libllm/lut/random.h" #include "libllm/lut/time.h" -#include "libllm/operators.h" #include "libllm/model_for_generation.h" +#include "libllm/operators.h" constexpr int MagicNumber = 0x55aa; constexpr double MaxWait = 10; @@ -96,14 +96,14 @@ float benchmarkTokenGeneration( inputToken = F::to(model->getCtx().getDevice(), inputToken); x = model->forward(pastClone, inputToken); - x = model->forwardHidden(x); + x = model->forwardLmHead(x); double t0 = lut::now(); int nLoop = 0; while (lut::now() - t0 < MaxWait) { StateMap pastClone = past.clone(); x = model->forward(pastClone, inputToken); - x = model->forwardHidden(x); + x = model->forwardLmHead(x); ++nLoop; } double t1 = lut::now(); @@ -132,8 +132,11 @@ llama::LlamaConfig getLlamaConfig(LlamaType type) { NOT_IMPL(); } -std::shared_ptr -getLlamaModel(lut::Random *r, LlamaType type, Device device, DType weightType) { +std::shared_ptr getLlamaModel( + lut::Random *r, + LlamaType type, + Device device, + DType weightType) { Context ctx; ctx.setDevice(device); ctx.setFloatDType(F::getDefaultFloatType(device)); @@ -183,7 +186,7 @@ int benchmarkMain(Device device) { libllm::benchmarkLlama(model, 512, libllm::DType::kQInt4x32); printf("----------------------------------------------------------\n"); - + libllm::destroyOperators(); return 0; } diff --git a/src/libllm/chatglm.cc b/src/libllm/chatglm.cc deleted file mode 100644 index 4edf22ad..00000000 --- a/src/libllm/chatglm.cc +++ /dev/null @@ -1,410 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2023 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#include "libllm/chatglm.h" - -#include "libllm/constants.h" -#include "libllm/lut/strings.h" - -namespace libllm { -namespace chatglm { - -// -----------------------------------------------------------------------------------------------+ -// class ChatGlmConfig | -// -----------------------------------------------------------------------------------------------+ - -constexpr char ChatGlmConfig::kSection[]; - -ChatGlmConfig::ChatGlmConfig() - : hiddenSize(0), - vocabSize(0), - kvChannels(0), - seqLength(0), - hiddenSizePerAttentionHead(0), - multiQueryGroupNum(0), - normEps(0.0f), - numLayers(0), - symbolGMask(0), - symbolSOP(0), - symbolEOS(0) { -} - -ChatGlmConfig ChatGlmConfig::loadConfig(const lut::IniConfig &ini) { - const lut::IniSection §ion = ini.getSection(kSection); - - ChatGlmConfig config; - config.hiddenSize = section.getInt("hidden_size"); - config.vocabSize = section.getInt("vocab_size"); - config.kvChannels = section.getInt("kv_channels"); - config.seqLength = section.getInt("seq_length"); - config.hiddenSizePerAttentionHead = section.getInt("hidden_size_per_attention_head"); - config.multiQueryGroupNum = section.getInt("multi_query_group_num"); - config.normEps = section.getFloat("norm_eps"); - config.ffnHiddenSize = section.getInt("ffn_hidden_size"); - config.numLayers = section.getInt("num_layers"); - config.symbolGMask = section.getInt("symbol_gmask"); - config.symbolSOP = section.getInt("symbol_sop"); - config.symbolEOS = section.getInt("symbol_eos"); - - return config; -} - -// -----------------------------------------------------------------------------------------------+ -// class MLP | -// -----------------------------------------------------------------------------------------------+ - -std::unique_ptr MLP::create(const Context &ctx, ChatGlmConfig config) { - std::unique_ptr layer{new MLP()}; - layer->setCtx(ctx); - - layer->_ffnHiddenSize = config.ffnHiddenSize; - layer->_hiddenSize = config.hiddenSize; - - int dh = config.hiddenSize; - int df = config.ffnHiddenSize; - layer->_dense1 = Linear::create(ctx.withName("dense1"), dh, df * 2, false); - layer->_dense2 = Linear::create(ctx.withName("dense2"), df, dh, false); - - return layer; -} - -void MLP::initParameters(const StateMap &stateDict) { - const Context &ctx = getCtx(); - - _dense1->initParameters(stateDict); - _dense2->initParameters(stateDict); -} - -void MLP::initParameters(lut::Random *generator, DType weightType) { - _dense1->initParameters(generator, weightType); - _dense2->initParameters(generator, weightType); -} - -Tensor MLP::forward(const Tensor &input) const { - Tensor x = _dense1->forward(input); - x = F::swiglu(x); - x = _dense2->forward(x); - - return x; -} - -// -----------------------------------------------------------------------------------------------+ -// class SelfAttention | -// -----------------------------------------------------------------------------------------------+ - -std::unique_ptr SelfAttention::create(const Context &ctx, ChatGlmConfig config) { - std::unique_ptr layer{new SelfAttention()}; - layer->setCtx(ctx); - - layer->_kvProjDim = config.hiddenSizePerAttentionHead * config.multiQueryGroupNum; - layer->_qProjDim = config.hiddenSize; - layer->_hiddenSizePerHead = config.hiddenSizePerAttentionHead; - layer->_namePastK = ctx.name("k"); - layer->_namePastV = ctx.name("v"); - layer->_namePastLength = ctx.name("len"); - - if (config.hiddenSize % config.hiddenSizePerAttentionHead != 0) { - throw lut::AbortedError("invalid hidden_size and hidden_size_per_head"); - } - - int qkvProjOutDim = layer->_qProjDim + 2 * layer->_kvProjDim; - int dh = config.hiddenSize; - layer->_qkvProj = Linear::create(ctx.withName("qkv_proj"), dh, qkvProjOutDim, true); - layer->_outProj = Linear::create(ctx.withName("out_proj"), dh, dh, false); - return layer; -} - -void SelfAttention::initParameters(const StateMap &stateDict) { - _qkvProj->initParameters(stateDict); - _outProj->initParameters(stateDict); -} - -void SelfAttention::initParameters(lut::Random *g, DType weightType) { - _qkvProj->initParameters(g, weightType); - _outProj->initParameters(g, weightType); -} - -int SelfAttention::getCtxLength(StateMap *past) const { - if (past && past->hasValue(_namePastLength)) { - return past->getValue(_namePastLength); - } else { - return 0; - } -} - -Tensor SelfAttention::forward(StateMap &past, Tensor input, Tensor roPE) const { - Tensor qkvProj = _qkvProj->forward(input); - - CHECK(qkvProj.getDim() == 3 && qkvProj.getShape(-1) == _kvProjDim * 2 + _qProjDim); - Tensor qProj = qkvProj.slice(-1, {0, _qProjDim}); - Tensor kProj = qkvProj.slice(-1, {_qProjDim, _qProjDim + _kvProjDim}); - Tensor vProj = qkvProj.slice(-1, {_qProjDim + _kvProjDim, _qProjDim + 2 * _kvProjDim}); - - int N = input.getShape(0); // batch size - int qL = input.getShape(1); // sequence length - int qNH = _qProjDim / _hiddenSizePerHead; // query num-heads - int kvNH = _kvProjDim / _hiddenSizePerHead; // key and value num-heads - int D = _hiddenSizePerHead; - - Tensor q = qProj.view({N, qL, qNH, D}); - Tensor k = kProj.view({N, qL, kvNH, D}); - Tensor v = vProj.view({N, qL, kvNH, D}); - - // apply roPE to [..., :_hiddenSizePerHead / 2] of QKV - Tensor qe = q.slice(-1, {0, D / 2}); - Tensor ke = k.slice(-1, {0, D / 2}); - Tensor ve = v.slice(-1, {0, D / 2}); - - // fetch and update past length. - // TODO: check kvL length oveflow. - int kvL = qL; - if (past.hasValue(_namePastLength)) { - kvL += past.getValue(_namePastLength); - } - past.putValue(_namePastLength, kvL); - - // apply rope. - Tensor qkRoPE = roPE.slice({kvL - qL, kvL}); - F::copy(F::applyRotaryPosEmb(qe, qkRoPE), qe); - F::copy(F::applyRotaryPosEmb(ke, qkRoPE), ke); - - // fetch and update past k. - if (past.hasTensor(_namePastK) && past.hasTensor(_namePastV)) { - const Tensor &pastK = past.getTensor(_namePastK); - const Tensor &pastV = past.getTensor(_namePastV); - - k = F::cat(pastK, k, 1); - v = F::cat(pastV, v, 1); - - CHECK(k.getShape(1) == v.getShape(1) && k.getShape(1) == kvL); - } - - // update kv_cache in past. - past.putTensor(_namePastK, k); - past.putTensor(_namePastV, v); - - // expand KV - CHECK(qNH % kvNH == 0); - std::initializer_list expandShape = {N, kvL, kvNH, qNH / kvNH, D}; - std::initializer_list qShape = {N, kvL, qNH, D}; - k = F::contiguous(k.unsqueeze(3).expand(expandShape)).view(qShape); - v = F::contiguous(v.unsqueeze(3).expand(expandShape)).view(qShape); - - // apply attention. - // TODO: streaming mode support. - q = q.transpose(1, 2); - k = k.transpose(1, 2); - v = v.transpose(1, 2); - Tensor x = qL == 1 ? F::attention(q, k, v) - : F::attention(q, k, v, F::causalMask(q.getShape(2), getCtx().getDevice())); - - x = F::contiguous(x.transpose(1, 2)).view({N, qL, qNH * D}); - x = _outProj->forward(x); - - return x; -} - -// -----------------------------------------------------------------------------------------------+ -// class GLMBlock | -// -----------------------------------------------------------------------------------------------+ - -std::unique_ptr GLMBlock::create(const Context &ctx, ChatGlmConfig config) { - std::unique_ptr layer{new GLMBlock()}; - layer->setCtx(ctx); - - int hiddenSize = config.hiddenSize; - float normEps = config.normEps; - - layer->_inputNorm = RMSNorm::create(ctx.withName("norm"), hiddenSize, normEps); - layer->_attnNorm = RMSNorm::create(ctx.withName("attn_norm"), hiddenSize, normEps); - layer->_attn = SelfAttention::create(ctx.withName("attn"), config); - layer->_mlp = MLP::create(ctx.withName("mlp"), config); - - return layer; -} - -void GLMBlock::initParameters(const StateMap &stateMap) { - _attn->initParameters(stateMap); - _inputNorm->initParameters(stateMap); - _attnNorm->initParameters(stateMap); - _mlp->initParameters(stateMap); -} - -void GLMBlock::initParameters(lut::Random *generator, DType weightType) { - _attn->initParameters(generator, weightType); - _inputNorm->initParameters(generator, weightType); - _attnNorm->initParameters(generator, weightType); - _mlp->initParameters(generator, weightType); -} - -Tensor GLMBlock::forward(StateMap &past, Tensor input, Tensor roPE) const { - Tensor residual = input; - - // norm+attention - Tensor x = _inputNorm->forward(input); - x = _attn->forward(past, x, roPE); - - // residual - x = F::add(x, residual); - residual = x; - - // norm+mlp - x = _attnNorm->forward(x); - x = _mlp->forward(x); - - // residual - x = F::add(x, residual); - - return x; -} - -// -----------------------------------------------------------------------------------------------+ -// class ChatGlmModel | -// -----------------------------------------------------------------------------------------------+ - -ChatGlmModel::ChatGlmModel() { -} - -std::unique_ptr ChatGlmModel::create(const Context &ctx, ChatGlmConfig c) { - std::unique_ptr model{new ChatGlmModel()}; - model->setCtx(ctx); - - int dh = c.hiddenSize; - model->_config = c; - model->_embedding = Embedding::create(ctx.withName("embd"), dh, c.vocabSize); - model->_finalNorm = RMSNorm::create(ctx.withName("final_norm"), dh, c.normEps); - model->_outProj = Linear::create(ctx.withName("out_proj"), dh, c.vocabSize, false); - for (int i = 0; i < c.numLayers; ++i) { - model->_blocks.emplace_back( - GLMBlock::create(ctx.withName(lut::sprintf("%s%d", "block", i)), c)); - } - - if (c.kvChannels % 4 != 0) { - throw lut::AbortedError("invalid kv_channels"); - } - - return model; -} - -void ChatGlmModel::initParameters(const StateMap &stateDict) { - const Context &ctx = getCtx(); - - _embedding->initParameters(stateDict); - _finalNorm->initParameters(stateDict); - _outProj->initParameters(stateDict); - - for (int i = 0; i < _config.numLayers; ++i) { - _blocks[i]->initParameters(stateDict); - } - - _rope = stateDict.getTensor(ctx.name("rope")); - _rope.throwIfInvalidShape({_config.seqLength, _config.kvChannels / 4, 2}, ctx.name("rope")); - _rope = _rope.view({_config.seqLength, 1, _config.kvChannels / 2}); - _rope = moveAndCastFloat(_rope, ctx); -} - -void ChatGlmModel::initParameters(lut::Random *generator, DType weightType) { - Context ctx = getCtx(); - - _embedding->initParameters(generator, weightType); - _finalNorm->initParameters(generator, weightType); - _outProj->initParameters(generator, weightType); - - _rope = F::rand( - {_config.seqLength, 1, _config.kvChannels / 2}, - DType::kFloat, // roPE must be float - Device::getCpu(), - generator, - -0.2f, - 0.2f); - _rope = moveAndCastFloat(_rope, ctx); - - for (int i = 0; i < _config.numLayers; ++i) { - _blocks[i]->initParameters(generator, weightType); - } -} - -Tensor ChatGlmModel::forwardHidden(Tensor hiddenState) const { - return _outProj->forward(hiddenState); -} - -Tensor ChatGlmModel::forward(StateMap &past, Tensor input) const { - Tensor x = _embedding->forward(input); - for (int i = 0; i < _config.numLayers; ++i) { - x = _blocks[i]->forward(past, x, _rope); - } - x = _finalNorm->forward(x); - - return x; -} - -// -----------------------------------------------------------------------------------------------+ -// class ChatGlmModelForGeneration | -// -----------------------------------------------------------------------------------------------+ - -std::shared_ptr ChatGlmModelForGeneration::fromConfig( - const Context &ctx, - const lut::IniConfig &config) { - std::shared_ptr model{new ChatGlmModelForGeneration()}; - - ChatGlmConfig ChatGlmConfig = ChatGlmConfig::loadConfig(config); - model->_model = ChatGlmModel::create(ctx, ChatGlmConfig); - model->_config = ChatGlmConfig; - model->_modelName = config.getSection(ModelSection).getString(ModelTypeField); - - return model; -} - -void ChatGlmModelForGeneration::initParameters(const StateMap &stateMap) { - _model->initParameters(stateMap); -} - -Tensor ChatGlmModelForGeneration::buildInput(const std::vector &prompt) const { - std::vector inputData{_config.symbolGMask, _config.symbolSOP}; - inputData.insert(inputData.end(), prompt.begin(), prompt.end()); - - int len = inputData.size(); - Tensor inputs = Tensor::create({1, len}, inputData); - inputs = F::to(_model->getCtx().getDevice(), inputs); - return inputs; -} - -Tensor ChatGlmModelForGeneration::forward(StateMap &past, Tensor input) const { - Tensor x = _model->forward(past, input); - return x; -} - -Tensor ChatGlmModelForGeneration::forwardHidden(Tensor hidden) const { - return _model->forwardHidden(hidden); -} - -bool ChatGlmModelForGeneration::isStopToken(int tokenId) const { - return tokenId == _config.symbolEOS; -} - -const char *ChatGlmModelForGeneration::getName() const { - return _modelName.c_str(); -} - -Device ChatGlmModelForGeneration::getDevice() const { - return _model->getCtx().getDevice(); -} - -} // namespace chatglm -} // namespace libllm diff --git a/src/libllm/chatglm.h b/src/libllm/chatglm.h deleted file mode 100644 index 53ab49dd..00000000 --- a/src/libllm/chatglm.h +++ /dev/null @@ -1,175 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2023 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#pragma once - -#include - -#include - -#include "libllm/functional.h" -#include "libllm/lut/error.h" -#include "libllm/lut/ini_config.h" -#include "libllm/model_for_generation.h" -#include "libllm/module.h" - -namespace libllm { -namespace chatglm { - -struct ChatGlmConfig { - // config section in ini - static constexpr char kSection[] = "chatglm"; - - int hiddenSize; - int vocabSize; - int kvChannels; - int seqLength; - int hiddenSizePerAttentionHead; - int multiQueryGroupNum; - float normEps; - int ffnHiddenSize; - int numLayers; - - int symbolGMask; - int symbolSOP; - int symbolEOS; - - ChatGlmConfig(); - static ChatGlmConfig loadConfig(const lut::IniConfig &ini); -}; - -class SelfAttention : public Module { - public: - static std::unique_ptr create(const Context &ctx, ChatGlmConfig config); - - // implement interface Module - void initParameters(const StateMap &state_dict) override; - void initParameters(lut::Random *generator, DType weightType) override; - - Tensor forward(StateMap &past, Tensor input, Tensor roPE) const; - - private: - std::shared_ptr _qkvProj; - std::shared_ptr _outProj; - - int _kvProjDim; - int _qProjDim; - int _hiddenSizePerHead; - - std::string _namePastK; - std::string _namePastV; - std::string _namePastLength; - - SelfAttention() = default; - - int getCtxLength(StateMap *past) const; -}; - -class MLP : public Module { - public: - static std::unique_ptr create(const Context &ctx, ChatGlmConfig config); - - // implement interface Module - void initParameters(const StateMap &state_dict) override; - void initParameters(lut::Random *generator, DType weightType) override; - - Tensor forward(const Tensor &input) const; - - private: - std::shared_ptr _dense1; - std::shared_ptr _dense2; - - int _hiddenSize; - int _ffnHiddenSize; - - MLP() = default; -}; - -class GLMBlock : public Module { - public: - static std::unique_ptr create(const Context &ctx, ChatGlmConfig config); - - // implement interface Module - void initParameters(const StateMap &state_dict) override; - void initParameters(lut::Random *generator, DType weightType) override; - - Tensor forward(StateMap &past, Tensor input, Tensor roPE) const; - - private: - std::unique_ptr _inputNorm; - std::unique_ptr _attnNorm; - std::unique_ptr _attn; - std::unique_ptr _mlp; - - GLMBlock() = default; -}; - -// The ChatGLM2 model. -class ChatGlmModel : public Module { - public: - // create ChatGLM2 Model. - static std::unique_ptr create(const Context &ctx, ChatGlmConfig config); - - // implement interface Module - void initParameters(const StateMap &state_dict) override; - void initParameters(lut::Random *generator, DType weightType) override; - - Tensor forward(StateMap &past, Tensor input) const; - Tensor forwardHidden(Tensor hiddenState) const; - - private: - ChatGlmConfig _config; - - std::unique_ptr _embedding; - std::vector> _blocks; - std::unique_ptr _finalNorm; - Tensor _rope; - - std::shared_ptr _outProj; - - ChatGlmModel(); -}; - -class ChatGlmModelForGeneration : public ModelForGeneration { - public: - static std::shared_ptr fromConfig( - const Context &ctx, - const lut::IniConfig &config); - - // implements interface ModelForGeneration - void initParameters(const StateMap &state_dict) override; - - Tensor forward(StateMap &past, Tensor input) const override; - Tensor forwardHidden(Tensor hidden) const override; - Tensor buildInput(const std::vector &prompt) const override; - bool isStopToken(int tokenId) const override; - const char *getName() const override; - Device getDevice() const override; - - private: - std::string _modelName; - - std::shared_ptr _model; - ChatGlmConfig _config; - - ChatGlmModelForGeneration() = default; -}; - -} // namespace chatglm -} // namespace libllm diff --git a/src/libllm/chatglm_test.cc b/src/libllm/chatglm_test.cc deleted file mode 100644 index c26b1854..00000000 --- a/src/libllm/chatglm_test.cc +++ /dev/null @@ -1,235 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2023 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#include "libllm/chatglm.h" - -#include - -#include "catch2/catch_amalgamated.hpp" -#include "libllm/cpu/fingerprint.h" -#include "libllm/lut/random.h" -#include "libllm/lut/span.h" -#include "libllm/tensor.h" -#include "libllm/test_helper.h" - -namespace libllm { -namespace chatglm { - -class TestCommon { - public: - static ChatGlmConfig getConfig() { - ChatGlmConfig config; - config.ffnHiddenSize = 512; - config.hiddenSize = 256; - config.hiddenSizePerAttentionHead = 64; - config.kvChannels = 64; - config.multiQueryGroupNum = 2; - config.normEps = 1e-5; - config.numLayers = 2; - config.seqLength = 8192; - config.symbolEOS = 2; - config.symbolGMask = 98; - config.symbolSOP = 99; - config.vocabSize = 100; - - return config; - } -}; - -class ChatGlmTester : public ModuleTester { - public: - ChatGlmTester(Device device, DType weightType) - : ModuleTester(device, weightType) { - } - - void run() { - ChatGlmConfig config = TestCommon::getConfig(); - std::shared_ptr layer = ChatGlmModel::create(getCtx(), config); - randomInit(layer); - - Tensor x = Tensor::create({1, 7}, {3, 6, 7, 99, 23, 1, 2}); - x = toTargetDevice(x); - - StateMap past; - x = layer->forward(past, x); - - std::vector xr0, xr1; - if (getWeightType() == DType::kQInt4x32) { - xr0 = {0.1738, 1.0190, -0.4752, 0.2405, 0.0359, 0.0821, -0.3922, 0.7598}; - xr1 = {-0.0869, 0.3474, 0.3587, 0.3405, -0.1254, 0.4693, 1.6388, 0.4857}; - } else { - xr0 = {0.1209, 1.0635, -0.3218, 0.2810, 0.0438, 0.0445, -0.4287, 0.7803}; - xr1 = {-0.1908, 0.2644, 0.4468, 0.3596, -0.1265, 0.4871, 1.4609, 0.5425}; - } - CATCH_REQUIRE(allClose(op::cpu::fingerprint(toCpu(x)), xr0)); - - // forward next token. - x = Tensor::create({1, 1}, {5}); - x = toTargetDevice(x); - - x = layer->forward(past, x); - CATCH_REQUIRE(allClose(op::cpu::fingerprint(toCpu(x)), xr1)); - } -}; - -class GlmBlockTester : public ModuleTester { - public: - GlmBlockTester(Device device, DType weightType) - : ModuleTester(device, weightType) { - } - - void run() { - ChatGlmConfig config = TestCommon::getConfig(); - std::shared_ptr layer = GLMBlock::create(getCtx(), config); - randomInit(layer); - - Tensor x = generateTensor({1, 20, config.hiddenSize}); - Tensor roPE = generateTensor({256, 1, config.hiddenSizePerAttentionHead / 2}); - - StateMap past; - x = layer->forward(past, x, roPE); - - std::vector xr0, xr1; - if (getWeightType() == DType::kQInt4x32) { - xr0 = {-0.6995, -1.7179, 0.6030, 0.4212, 0.1758, -0.0991, 0.1561, 0.3834}; - xr1 = {0.0515, -0.5976, 0.9120, -0.8255, 0.5175, 0.3463, 0.1092, 0.0433}; - } else { - xr0 = {-0.6816, -1.6572, 0.6406, 0.4470, 0.1726, -0.1067, 0.1279, 0.3716}; - xr1 = {0.0298, -0.5806, 0.9023, -0.8027, 0.5244, 0.2893, 0.0976, 0.0643}; - } - CATCH_REQUIRE(allClose(op::cpu::fingerprint(toCpu(x)), xr0)); - - // forward next token. - x = generateTensor({1, 1, config.hiddenSize}); - x = layer->forward(past, x, roPE); - CATCH_REQUIRE(allClose(op::cpu::fingerprint(toCpu(x)), xr1)); - } -}; - -class MlpTester : public ModuleTester { - public: - MlpTester(Device device, DType weightType) - : ModuleTester(device, weightType) { - } - - void run() { - ChatGlmConfig config = TestCommon::getConfig(); - std::shared_ptr layer = MLP::create(getCtx(), config); - randomInit(layer); - - Tensor x = generateTensor({1, 20, config.hiddenSize}); - x = layer->forward(x); - - std::vector xr; - if (getWeightType() == DType::kQInt4x32) { - xr = {6.4475e-03, -0.1200, -0.1880, 0.1787, -0.0284, -0.3479, 0.0240, 0.1778}; - } else { - xr = {-4.1466e-03, -0.1296, -0.1753, 0.1750, -0.0217, -0.3584, -5.1689e-03, 0.1554}; - } - CATCH_REQUIRE(allClose(op::cpu::fingerprint(toCpu(x)), xr)); - } -}; - -class SelfAttnetionTester : public ModuleTester { - public: - SelfAttnetionTester(Device device, DType weightType) - : ModuleTester(device, weightType) { - } - - float getRtol() const override { - DType defaultFloatType = F::getDefaultFloatType(getDevice()); - if (getDevice().getType() == Device::kCpu && defaultFloatType == DType::kFloat16) { - return 3.5e-2; - } else { - return 5e-3; - } - } - - void run() { - ChatGlmConfig config = TestCommon::getConfig(); - std::shared_ptr layer = SelfAttention::create(getCtx(), config); - randomInit(layer); - - Tensor x = generateTensor({1, 20, config.hiddenSize}); - Tensor roPE = generateTensor({256, 1, config.hiddenSizePerAttentionHead / 2}); - - StateMap past; - x = layer->forward(past, x, roPE); - - std::vector xr0, xr1; - if (getWeightType() == DType::kQInt4x32) { - xr0 = {-0.4076, 0.4038, -0.1918, -9.9697e-03, -6.4159e-03, 0.0496, -0.0568, 0.0421}; - xr1 = {0.3660, -0.1488, -0.0495, -0.0875, 0.0862, -0.0287, -9.2791e-03, -0.0864}; - } else { - xr0 = {-0.3489, 0.4617, -0.1710, -0.0709, 0.0242, 0.0579, -0.0675, 0.0418}; - xr1 = {0.3424, -0.1663, -0.0594, -0.0924, 0.0969, -0.0248, -5.9996e-03, -0.0793}; - } - CATCH_REQUIRE(allClose(op::cpu::fingerprint(toCpu(x)), xr0)); - - // forward next token. - x = generateTensor({1, 1, config.hiddenSize}); - x = layer->forward(past, x, roPE); - CATCH_REQUIRE(allClose(op::cpu::fingerprint(toCpu(x)), xr1)); - } -}; - -CATCH_TEST_CASE("test chatglm::ChatGlmModel", "[llm][chatglm]") { - ChatGlmTester(Device::getCpu(), DType::kFloat).run(); - ChatGlmTester(Device::getCpu(), DType::kQInt4x32).run(); -} - -CATCH_TEST_CASE("test chatglm::GLMBlock", "[llm][chatglm]") { - GlmBlockTester(Device::getCpu(), DType::kFloat).run(); - GlmBlockTester(Device::getCpu(), DType::kQInt4x32).run(); -} - -CATCH_TEST_CASE("test chatglm::MLP", "[llm][chatglm]") { - MlpTester(Device::getCpu(), DType::kFloat).run(); - MlpTester(Device::getCpu(), DType::kQInt4x32).run(); -} - -CATCH_TEST_CASE("test chatglm::SelfAttnetion", "[llm][chatglm]") { - SelfAttnetionTester(Device::getCpu(), DType::kFloat).run(); - SelfAttnetionTester(Device::getCpu(), DType::kQInt4x32).run(); -} - -#ifdef LIBLLM_CUDA_ENABLED -CATCH_TEST_CASE("test chatglm::ChatGlmModel (cuda)", "[llm][chatglm][cuda]") { - ChatGlmTester(Device::getCuda(), DType::kFloat).run(); - ChatGlmTester(Device::getCuda(), DType::kQInt4x32).run(); -} - -CATCH_TEST_CASE("test chatglm::GLMBlock (cuda)", "[llm][chatglm][cuda]") { - GlmBlockTester(Device::getCuda(), DType::kFloat).run(); - GlmBlockTester(Device::getCuda(), DType::kQInt4x32).run(); -} - -CATCH_TEST_CASE("test chatglm::MLP (cuda)", "[llm][chatglm][cuda]") { - MlpTester(Device::getCuda(), DType::kFloat).run(); - MlpTester(Device::getCuda(), DType::kQInt4x32).run(); -} - -CATCH_TEST_CASE("test chatglm::SelfAttnetion (cuda)", "[llm][chatglm][cuda]") { - SelfAttnetionTester(Device::getCuda(), DType::kFloat).run(); - SelfAttnetionTester(Device::getCuda(), DType::kQInt4x32).run(); -} -#endif - -} // namespace chatglm -} // namespace libllm diff --git a/src/libllm/cpu/common.h b/src/libllm/cpu/common.h index b401185f..e39f4369 100644 --- a/src/libllm/cpu/common.h +++ b/src/libllm/cpu/common.h @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -19,10 +19,10 @@ #pragma once -#include "libllm/tensor.h" #include "libllm/cpu/accessor.h" #include "libllm/cpu/kernel/interface.h" #include "libllm/lut/span.h" +#include "libllm/tensor.h" namespace libllm { namespace op { @@ -65,6 +65,6 @@ inline void applyDequant(int64_t offset, int n, const TensorData *data, Float16 kernel::CpuMathBackend::DEFAULT); } -} // cpu -} // op -} // ly +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/cpu_operators.cc b/src/libllm/cpu/cpu_operators.cc index f72e1d3c..58ee1748 100644 --- a/src/libllm/cpu/cpu_operators.cc +++ b/src/libllm/cpu/cpu_operators.cc @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -20,26 +20,33 @@ #include "libllm/cpu/cpu_operators.h" #include + #include #include #include -#include "libllm/cpu/kernel/interface.h" + #include "libllm/cpu/all_close.h" #include "libllm/cpu/apply_rotary_pos_emb.h" #include "libllm/cpu/binary_op.h" #include "libllm/cpu/cast.h" #include "libllm/cpu/common.h" #include "libllm/cpu/copy.h" +#include "libllm/cpu/cpu_tensor_data.h" +#include "libllm/cpu/fill.h" +#include "libllm/cpu/gelu.h" +#include "libllm/cpu/kernel/interface.h" +#include "libllm/cpu/log_mel_spectrogram.h" #include "libllm/cpu/lookup.h" #include "libllm/cpu/matmul.h" +#include "libllm/cpu/normalizations.h" #include "libllm/cpu/print.h" #include "libllm/cpu/rand.h" -#include "libllm/cpu/rms_norm.h" +#include "libllm/cpu/reduce.h" #include "libllm/cpu/softmax.h" #include "libllm/cpu/swiglu.h" #include "libllm/cpu/tensor.h" #include "libllm/cpu/transform.h" -#include "libllm/cpu/cpu_tensor_data.h" +#include "libllm/cpu/unfold.h" #include "libllm/operators.h" #include "libllm/tensor.h" @@ -47,8 +54,8 @@ namespace libllm { namespace op { namespace cpu { -CPUOperators::CPUOperators() {} - +CPUOperators::CPUOperators() { +} Tensor CPUOperators::tensor(lut::Span shape, DType dtype) { return op::cpu::tensor(shape, dtype); @@ -60,8 +67,12 @@ Tensor CPUOperators::tensorLike(Tensor input) { // -- class CPUOperators ---------- -Tensor CPUOperators::rand(lut::Span shape, DType dtype, lut::Random *generator, - float min, float max) { +Tensor CPUOperators::rand( + lut::Span shape, + DType dtype, + lut::Random *generator, + float min, + float max) { return op::cpu::rand(shape, dtype, generator, min, max); } @@ -101,6 +112,22 @@ Tensor CPUOperators::lookup(Tensor table, Tensor indices) { return cpu::lookup(table, indices); } +Tensor CPUOperators::gelu(Tensor input) { + return cpu::gelu(input); +} + +void CPUOperators::fill(Tensor input, float value) { + return cpu::fill(input, value); +} + +Tensor CPUOperators::sum(Tensor inputs) { + return cpu::reduce(inputs, MapReduceType::SUM); +} + +Tensor CPUOperators::max(Tensor inputs) { + return cpu::reduce(inputs, MapReduceType::MAX); +} + Tensor CPUOperators::rmsNorm(Tensor input, Tensor weight, float eps) { CHECK(input.getDType() == weight.getDType()); @@ -115,6 +142,10 @@ Tensor CPUOperators::applyRotaryPosEmb(Tensor A, Tensor roPE) { return cpu::applyRotaryPosEmb(A, roPE); } +Tensor CPUOperators::layerNorm(Tensor input, Tensor weight, Tensor bias, float eps) { + return cpu::layerNorm(input, weight, bias, eps); +} + void CPUOperators::copy(Tensor src, Tensor dest) { return cpu::copy(src, dest); } @@ -129,6 +160,14 @@ Tensor CPUOperators::to(Device device, Tensor tensor) { NOT_IMPL(); } +Tensor CPUOperators::logMelSpectrogram(Tensor wave) { + return cpu::logMelSpectrogram(wave); +} + +Tensor CPUOperators::unfold(Tensor input, int kernelSize, int stride) { + return cpu::unfold(input, kernelSize, stride); +} + Tensor CPUOperators::cast(Tensor tensor, DType dtype) { return cpu::cast(tensor, dtype); } @@ -137,6 +176,6 @@ DType CPUOperators::getDefaultFloatType() { return DType::getType(); } -} // cpu -} // op -} // ly +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/cpu_operators.h b/src/libllm/cpu/cpu_operators.h index f2e61254..6ccd9d2b 100644 --- a/src/libllm/cpu/cpu_operators.h +++ b/src/libllm/cpu/cpu_operators.h @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -20,7 +20,9 @@ #pragma once #include + #include + #include "libllm/operators.h" #include "libllm/tensor.h" @@ -42,10 +44,15 @@ class CPUOperators : public Operators { // implement interface Operators Tensor lookup(Tensor table, Tensor indices) override; Tensor matmul(Tensor a, Tensor b) override; + Tensor layerNorm(Tensor input, Tensor weight, Tensor bias, float eps) override; Tensor mul(Tensor input, float other) override; Tensor mul(Tensor input, Tensor other) override; Tensor softmax(Tensor input) override; + Tensor gelu(Tensor input) override; + void fill(Tensor input, float value) override; Tensor add(Tensor a, Tensor b) override; + Tensor sum(Tensor inputs) override; + Tensor max(Tensor inputs) override; Tensor tensor(lut::Span shape, DType dtype) override; Tensor tensorLike(Tensor input) override; Tensor zeros(lut::Span shape, DType dtype) override; @@ -58,8 +65,10 @@ class CPUOperators : public Operators { Tensor swiglu(Tensor A) override; Tensor to(Device device, Tensor tensor) override; Tensor cast(Tensor tensor, DType dtype) override; - Tensor rand(lut::Span shape, DType dtype, lut::Random *generator, float min, - float max) override; + Tensor logMelSpectrogram(Tensor wave) override; + Tensor unfold(Tensor input, int kernelSize, int stride) override; + Tensor rand(lut::Span shape, DType dtype, lut::Random *generator, float min, float max) + override; DType getDefaultFloatType() override; @@ -67,7 +76,6 @@ class CPUOperators : public Operators { typedef TensorShape::Elem Shape; }; -} // cpu -} // op -} // ly - +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/fill.cc b/src/libllm/cpu/fill.cc new file mode 100644 index 00000000..1af87a7b --- /dev/null +++ b/src/libllm/cpu/fill.cc @@ -0,0 +1,71 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/cpu/fill.h" + +#include "libllm/cpu/accessor.h" +#include "libllm/cpu/common.h" +#include "libllm/cpu/tensor.h" +#include "libllm/mp.h" +#include "libllm/tensor.h" + +namespace libllm { +namespace op { +namespace cpu { + +template +void fillKernel(Tensor A, float value) { + TensorList vC = TensorList::fromTensor(A); + MP::parallelFor({vC.getLength()}, [&vC, value](MP::Partition partition) { + for (int j : partition.getRange()) { + TensorAccessor c = vC.getTensor(j); + + for (int i = 0; i < c.getShape(0); ++i) { + c[i] = value; + } + } + }); +} + +void fill(Tensor src, float value) { + if (src.getDType() == DType::kFloat) { + if (src.getNumEl() == 1) { + *src.getData() = value; + } else { + fillKernel(src, value); + } + return; + } +#if LUT_CPU_ARCH == LUT_AARCH64 + if (src.getDType() == DType::kFloat16) { + if (src.getNumEl() == 1) { + *src.getData() = value; + } else { + fillKernel(src, value); + } + return; + } +#endif + + NOT_IMPL(); +} + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/go/chat/context.go b/src/libllm/cpu/fill.h similarity index 83% rename from go/chat/context.go rename to src/libllm/cpu/fill.h index 42fa29e4..a95efc83 100644 --- a/go/chat/context.go +++ b/src/libllm/cpu/fill.h @@ -17,15 +17,17 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package chat +#pragma once -type QA struct { - Question string - Answer string -} +#include "libllm/tensor.h" -type Context struct { - System string - History []QA - Question string -} +namespace libllm { +namespace op { +namespace cpu { + +// fill tensor with value +void fill(Tensor tensor, float value); + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/rms_norm.cc b/src/libllm/cpu/gelu.cc similarity index 61% rename from src/libllm/cpu/rms_norm.cc rename to src/libllm/cpu/gelu.cc index a0d101de..7f552717 100644 --- a/src/libllm/cpu/rms_norm.cc +++ b/src/libllm/cpu/gelu.cc @@ -17,51 +17,39 @@ // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -#include "libllm/cpu/rms_norm.h" +#include "libllm/cpu/gelu.h" -#include +#include #include "libllm/cpu/accessor.h" -#include "libllm/cpu/common.h" #include "libllm/cpu/tensor.h" +#include "libllm/lut/thread_pool.h" #include "libllm/mp.h" -#include "libllm/tensor.h" namespace libllm { namespace op { namespace cpu { -template -Tensor rmsNormKernel(const Tensor &tensor, const Tensor &weight, float eps) { - CHECK(weight.getDim() == 1); - CHECK(tensor.getShape(-1) == weight.getShape(0)); +constexpr float Sqrt2 = 1.4142136f; - Tensor C = tensorLike(tensor); +template +Tensor geluKernel(const Tensor &A) { + Tensor C = tensor(A.getShape(), DType::getType()); - TensorList vA = TensorList::fromTensor(tensor); + TensorList vA = TensorList::fromTensor(A); TensorList vC = TensorList::fromTensor(C); CHECK(vA.getLength() == vC.getLength()); - TensorAccessor w = weight; - - MP::parallelFor({vA.getLength()}, [&vA, &vC, w, eps](MP::Partition partition) { + MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Partition partition) { for (int j : partition.getRange()) { TensorAccessor a = vA.getTensor(j); TensorAccessor c = vC.getTensor(j); - double sum = 0.0; - for (int i = 0; i < a.getShape(0); ++i) { - double va = a[i]; - sum += va * va; - } - double mean = sum / a.getShape(0); - double rms = std::sqrt(mean + eps); - - // compute rms-norm - for (int i = 0; i < a.getShape(0); ++i) { - double va = a[i]; - double vw = w[i]; - c[i] = static_cast(a[i] * w[i] / rms); + int n = c.getShape(0); + for (int i = 0; i < n; ++i) { + float x = a[i]; + x = x * 0.5f * (1.0f + erf(x / Sqrt2)); + c[i] = T(x); } } }); @@ -69,10 +57,12 @@ Tensor rmsNormKernel(const Tensor &tensor, const Tensor &weight, float eps) { return C; } -Tensor rmsNorm(const Tensor &tensor, const Tensor &weight, float eps) { - if (tensor.getDType() == DType::kFloat) return rmsNormKernel(tensor, weight, eps); +Tensor gelu(const Tensor &A) { + CHECK(A.getShape(-1) % 2 == 0); + + if (A.getDType() == DType::kFloat) return geluKernel(A); #if LUT_CPU_ARCH == LUT_AARCH64 - if (tensor.getDType() == DType::kFloat16) return rmsNormKernel(tensor, weight, eps); + if (A.getDType() == DType::kFloat16) return geluKernel(A); #endif NOT_IMPL(); diff --git a/src/libllm/cpu/rms_norm.h b/src/libllm/cpu/gelu.h similarity index 91% rename from src/libllm/cpu/rms_norm.h rename to src/libllm/cpu/gelu.h index 7c47fa6a..ca74a83d 100644 --- a/src/libllm/cpu/rms_norm.h +++ b/src/libllm/cpu/gelu.h @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -25,8 +25,8 @@ namespace libllm { namespace op { namespace cpu { -Tensor rmsNorm(const Tensor &tensor, const Tensor &weight, float eps); +Tensor gelu(const Tensor &A); -} // cpu -} // op -} // libllm +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/log_mel_spectrogram.cc b/src/libllm/cpu/log_mel_spectrogram.cc new file mode 100644 index 00000000..e006ec87 --- /dev/null +++ b/src/libllm/cpu/log_mel_spectrogram.cc @@ -0,0 +1,268 @@ +// The MIT License (MIT) +// +// Copyright (c) 2023 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#define POCKETFFT_NO_MULTITHREADING + +#include + +#include +#include +#include + +#include "libllm/cpu/accessor.h" +#include "libllm/cpu/tensor.h" +#include "libllm/lut/thread_pool.h" +#include "libllm/mp.h" +#include "pocketfft/pocketfft_hdronly.h" + +namespace libllm { +namespace op { +namespace cpu { + +constexpr float PI = 3.1415926; +constexpr int NumFft = 400; +constexpr int NumPad = NumFft / 2; +constexpr int HopLength = 160; + +constexpr int kMel_InputDim = 201; +constexpr int kMel_OutputDim = 128; +constexpr int kMel_Offsets[] = { + 1, 1, 2, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 9, 10, 10, 11, + 12, 12, 13, 13, 14, 15, 15, 16, 16, 17, 17, 18, 19, 19, 20, 20, 21, 22, 22, + 23, 23, 24, 24, 25, 26, 26, 27, 28, 28, 29, 30, 30, 31, 32, 32, 33, 34, 35, + 36, 37, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 51, 52, 54, 55, + 56, 58, 59, 60, 62, 63, 65, 66, 68, 70, 71, 73, 75, 77, 79, 80, 82, 84, 86, + 89, 91, 93, 95, 98, 100, 102, 105, 107, 110, 113, 115, 118, 121, 124, 127, 130, 133, 136, + 140, 143, 147, 150, 154, 158, 161, 165, 169, 174, 178, 182, 187, 191}; +constexpr int kMel_Lengths[] = { + 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, + 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, + 2, 2, 2, 3, 3, 2, 2, 2, 3, 3, 2, 3, 3, 2, 3, 3, 3, 3, 3, 4, 3, 3, 4, 4, 4, 3, 3, 4, 4, 5, 5, 4, + 4, 5, 5, 4, 5, 5, 5, 6, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 7, 7, 8, 9, 9, 8, 9, 9, 9, 9}; +constexpr float kMel_Weights[] = { + 1.23740e-02f, 3.03926e-02f, 2.47480e-02f, 1.80186e-02f, 3.71220e-02f, 5.64459e-03f, + 6.72939e-03f, 3.60372e-02f, 1.91034e-02f, 2.36632e-02f, 3.14774e-02f, 1.12892e-02f, + 1.08480e-03f, 4.16818e-02f, 1.34588e-02f, 2.93078e-02f, 2.58328e-02f, 1.69338e-02f, + 3.82068e-02f, 4.55979e-03f, 7.81420e-03f, 3.49524e-02f, 2.01882e-02f, 2.25784e-02f, + 3.25622e-02f, 1.02044e-02f, 2.16960e-03f, 4.05969e-02f, 1.45436e-02f, 2.82230e-02f, + 2.69176e-02f, 1.58490e-02f, 3.92916e-02f, 3.47499e-03f, 8.89900e-03f, 3.38676e-02f, + 2.12730e-02f, 2.14936e-02f, 3.36470e-02f, 9.11958e-03f, 3.25441e-03f, 3.95121e-02f, + 1.56284e-02f, 2.71382e-02f, 2.80024e-02f, 1.47642e-02f, 4.03764e-02f, 2.38069e-03f, + 1.02026e-02f, 3.16115e-02f, 2.45470e-02f, 1.53292e-02f, 1.66584e-03f, 3.67291e-02f, + 2.00971e-02f, 1.69310e-02f, 2.90266e-03f, 3.28450e-02f, 2.35200e-02f, 1.10389e-02f, + 1.07258e-02f, 2.27183e-02f, 3.22787e-02f, 1.16268e-04f, 2.28535e-02f, 8.56344e-03f, + 1.49798e-02f, 1.55140e-02f, 8.51491e-03f, 2.11068e-02f, 3.32652e-03f, 2.54706e-02f, + 2.73591e-02f, 6.58536e-04f, 2.38381e-02f, 3.44359e-03f, 2.12246e-02f, 5.35842e-03f, + 1.94256e-02f, 6.49325e-03f, 1.83554e-02f, 6.93138e-03f, 1.79350e-02f, 6.74968e-03f, + 1.80915e-02f, 6.01899e-03f, 1.87577e-02f, 4.80453e-03f, 1.98717e-02f, 3.16628e-03f, + 2.13769e-02f, 1.25317e-03f, 1.15934e-03f, 2.08036e-02f, 4.04487e-03f, 1.75536e-02f, + 7.08320e-03f, 1.40754e-02f, 1.03266e-02f, 1.04092e-02f, 1.37370e-02f, 6.59188e-03f, + 1.72799e-02f, 1.46804e-03f, 2.65682e-03f, 1.80919e-02f, 5.85656e-03f, 1.33428e-02f, + 1.02827e-02f, 8.56800e-03f, 1.47223e-02f, 1.04040e-03f, 3.79086e-03f, 1.71468e-02f, + 6.11609e-03f, 1.17593e-02f, 1.11339e-02f, 6.43858e-03f, 1.60781e-02f, 4.23917e-03f, + 1.19989e-03f, 1.27567e-02f, 9.65299e-03f, 7.06935e-03f, 1.49405e-02f, 4.19025e-03f, + 1.51483e-03f, 1.20090e-02f, 9.84823e-03f, 6.10224e-03f, 1.53386e-02f, 5.57677e-03f, + 3.68273e-04f, 9.89749e-03f, 1.13534e-02f, 2.05122e-03f, 3.89297e-03f, 1.29735e-02f, + 8.06632e-03f, 6.74493e-03f, 1.38587e-02f, 5.41191e-03f, 7.42202e-04f, 8.98779e-03f, + 1.13787e-02f, 3.32958e-03f, 2.82314e-03f, 1.06805e-02f, 9.43341e-03f, 1.76326e-03f, + 4.39019e-03f, 1.18776e-02f, 7.97006e-03f, 6.61047e-04f, 5.49467e-03f, 1.26295e-02f, + 6.93988e-03f, 6.18402e-03f, 1.29347e-02f, 6.29779e-03f, 2.32521e-05f, 6.50207e-03f, + 1.23266e-02f, 6.00217e-03f, 3.15488e-04f, 6.48926e-03f, 1.20413e-02f, 6.01463e-03f, + 2.99796e-04f, 6.18288e-03f, 1.20427e-02f, 6.29981e-03f, 5.56896e-04f, 1.12047e-05f, + 5.61729e-03f, 1.12234e-02f, 6.82516e-03f, 1.35264e-03f, 4.82410e-03f, 1.01662e-02f, + 7.56076e-03f, 2.34590e-03f, 3.83236e-03f, 8.92296e-03f, 8.47910e-03f, 3.50979e-03f, + 2.66873e-03f, 7.51965e-03f, 9.55501e-03f, 4.81966e-03f, 8.43175e-05f, 1.35767e-03f, + 5.98020e-03f, 1.06027e-02f, 6.25298e-03f, 1.74060e-03f, 4.32644e-03f, 8.73132e-03f, + 7.78917e-03f, 3.48924e-03f, 2.57835e-03f, 6.77583e-03f, 9.40942e-03f, 5.31195e-03f, + 1.21448e-03f, 7.54112e-04f, 4.75396e-03f, 8.75380e-03f, 7.19209e-03f, 3.28754e-03f, + 2.68180e-03f, 6.49331e-03f, 9.11458e-03f, 5.39387e-03f, 1.67317e-03f, 5.73943e-04f, + 4.20600e-03f, 7.83806e-03f, 7.52023e-03f, 3.97471e-03f, 4.29187e-04f, 1.90464e-03f, + 5.36569e-03f, 8.82674e-03f, 6.27609e-03f, 2.89751e-03f, 2.89885e-03f, 6.19694e-03f, + 8.56699e-03f, 5.34748e-03f, 2.12797e-03f, 4.47502e-04f, 3.59030e-03f, 6.73311e-03f, + 7.77024e-03f, 4.70231e-03f, 1.63439e-03f, 1.01536e-03f, 4.01019e-03f, 7.00501e-03f, + 7.23443e-03f, 4.31096e-03f, 1.38748e-03f, 1.33349e-03f, 4.18731e-03f, 7.04113e-03f, + 6.93188e-03f, 4.14606e-03f, 1.36023e-03f, 1.42880e-03f, 4.14825e-03f, 6.86770e-03f, + 6.83705e-03f, 4.18239e-03f, 1.52774e-03f, 1.32610e-03f, 3.91751e-03f, 6.50892e-03f, + 6.92640e-03f, 4.39673e-03f, 1.86706e-03f, 1.04828e-03f, 3.51767e-03f, 5.98707e-03f, + 7.17824e-03f, 4.76768e-03f, 2.35712e-03f, 6.16364e-04f, 2.96949e-03f, 5.32262e-03f, + 7.57265e-03f, 5.27559e-03f, 2.97852e-03f, 6.81461e-04f, 4.97140e-05f, 2.29205e-03f, + 4.53438e-03f, 6.77672e-03f, 5.90241e-03f, 3.71350e-03f, 1.52459e-03f, 1.50285e-03f, + 3.63961e-03f, 5.77637e-03f, 6.63159e-03f, 4.54574e-03f, 2.45990e-03f, 3.74049e-04f, + 6.17959e-04f, 2.65411e-03f, 4.69026e-03f, 6.72641e-03f, 5.46035e-03f, 3.47271e-03f, + 1.48507e-03f, 1.59234e-03f, 3.53262e-03f, 5.47290e-03f, 6.44368e-03f, 4.54963e-03f, + 2.65558e-03f, 7.61525e-04f, 4.67494e-04f, 2.31642e-03f, 4.16534e-03f, 6.01427e-03f, + 5.67845e-03f, 3.87357e-03f, 2.06870e-03f, 2.63827e-04f, 1.05349e-03f, 2.81536e-03f, + 4.57723e-03f, 6.33910e-03f, 5.12816e-03f, 3.40826e-03f, 1.68837e-03f, 1.43350e-03f, + 3.11242e-03f, 4.79133e-03f, 6.40944e-03f, 4.77052e-03f, 3.13161e-03f, 1.49269e-03f, + 2.93236e-05f, 1.62919e-03f, 3.22906e-03f, 4.82892e-03f, 6.14671e-03f, 4.58497e-03f, + 3.02322e-03f, 1.46147e-03f, 1.36017e-04f, 1.66056e-03f, 3.18509e-03f, 4.70963e-03f, + 6.04072e-03f, 4.55251e-03f, 3.06429e-03f, 1.57608e-03f, 8.78619e-05f, 9.32810e-05f, + 1.54604e-03f, 2.99880e-03f, 4.45155e-03f, 5.90431e-03f, 4.65566e-03f, 3.23752e-03f, + 1.81937e-03f, 4.01226e-04f, 1.30263e-03f, 2.68698e-03f, 4.07134e-03f, 5.45570e-03f, + 4.87832e-03f, 3.52695e-03f, 2.17558e-03f, 8.24205e-04f, 9.45950e-04f, 2.26513e-03f, + 3.58430e-03f, 4.90348e-03f, 5.20570e-03f, 3.91795e-03f, 2.63021e-03f, 1.34246e-03f, + 5.47149e-05f, 4.90379e-04f, 1.74744e-03f, 3.00451e-03f, 4.26157e-03f, 5.51864e-03f, + 4.39707e-03f, 3.16996e-03f, 1.94284e-03f, 7.15731e-04f, 1.14698e-03f, 2.34486e-03f, + 3.54273e-03f, 4.74061e-03f, 4.95198e-03f, 3.78265e-03f, 2.61331e-03f, 1.44397e-03f, + 2.74637e-04f, 4.75695e-04f, 1.61717e-03f, 2.75865e-03f, 3.90013e-03f, 5.04160e-03f, + 4.45712e-03f, 3.34284e-03f, 2.22856e-03f, 1.11428e-03f}; + +std::vector hannWindow(int windowSize) { + std::vector window(windowSize); + for (int i = 0; i < windowSize; ++i) { + window[i] = static_cast(0.5 - cosf(static_cast(2 * PI * i / windowSize)) / 2); + } + + return window; +} + +std::vector> fft(lut::Span input) { + CHECK(input.size() % 2 == 0); + + std::vector> output(input.size() / 2 + 1); + + pocketfft::shape_t shapeIn = {input.size()}; + pocketfft::stride_t strideIn = {sizeof(float)}; + pocketfft::stride_t strideOut = {sizeof(std::complex)}; + pocketfft::r2c( + shapeIn, + strideIn, + strideOut, + 0, + pocketfft::FORWARD, + input.data(), + output.data(), + 1.0f); + + return output; +} + +std::vector mel128FilterBank(lut::Span input) { + CHECK(input.size() == kMel_InputDim); + + std::vector output(kMel_OutputDim); + + const float *weight = kMel_Weights; + for (int mel_bin = 0; mel_bin < kMel_OutputDim; ++mel_bin) { + float v = 0; + + int begin = kMel_Offsets[mel_bin]; + int end = kMel_Offsets[mel_bin] + kMel_Lengths[mel_bin]; + for (int i = begin; i < end; ++i) { + v += input[i] * (*weight); + ++weight; + } + + output[mel_bin] = v; + } + + CHECK(weight - kMel_Weights == sizeof(kMel_Weights) / sizeof(float)); + return output; +} + +std::vector applyLogMelSpectrogramWindow( + lut::Span data, + lut::Span window) { + CHECK(data.size() == window.size()); + + // apply window. + std::vector windowData(data.size()); + for (int i = 0; i < data.size(); ++i) { + windowData[i] = data[i] * window[i]; + } + + // apply fft. + std::vector> fftResult = fft(windowData); + + // compute magnitudes. + std::vector magnitudes(fftResult.size()); + for (int i = 0; i < fftResult.size(); ++i) { + float v = std::abs(fftResult[i]); + magnitudes[i] = v * v; + } + + // apply mel filter-bank. + std::vector melFbank = mel128FilterBank(magnitudes); + + // apply log10 to mel filter-bank. + for (int i = 0; i < melFbank.size(); ++i) { + float v = std::max(1e-10f, melFbank[i]); + melFbank[i] = log10f(v); + } + + return melFbank; +} + +Tensor logMelSpectrogram(Tensor inputs) { + CHECK(inputs.getDim() == 1); + CHECK(inputs.getShape(0) > NumPad); + + int numFrames = inputs.getShape(0) / HopLength; + Tensor outputs = op::cpu::tensor({numFrames, kMel_OutputDim}, DType::kFloat); + + TensorAccessor inputAccessor(inputs); + lut::Span inputSpan(inputAccessor.getData(), inputAccessor.getShape(0)); + + // padding. + std::vector paddedInputs(NumPad); + std::copy(inputSpan.begin() + 1, inputSpan.begin() + 1 + NumPad, paddedInputs.rbegin()); + paddedInputs.insert(paddedInputs.end(), inputSpan.begin(), inputSpan.end()); + paddedInputs.insert(paddedInputs.end(), inputSpan.rbegin() + 1, inputSpan.rbegin() + 1 + NumPad); + CHECK(paddedInputs.size() == inputSpan.size() + 2 * NumPad); + + // hanning window function. + std::vector window = hannWindow(NumFft); + + // for each window. + TensorAccessor outputAccessor(outputs); + lut::Span paddedSpan(paddedInputs); + for (int i = 0; i < numFrames; ++i) { + lut::Span windowSpan = paddedSpan.subspan(i * HopLength, NumFft); + std::vector feature = applyLogMelSpectrogramWindow(windowSpan, window); + CHECK(feature.size() == outputAccessor.getShape(1)); + + for (int j = 0; j < feature.size(); ++j) { + outputAccessor[i][j] = feature[j]; + } + } + + // whisper feature normalize. + float maxVal = -std::numeric_limits::infinity(); + for (int i = 0; i < numFrames; ++i) { + for (int j = 0; j < kMel_OutputDim; ++j) { + float val = outputAccessor[i][j]; + if (val > maxVal) maxVal = val; + } + } + + float featureMinVal = maxVal - 8.0f; + for (int i = 0; i < numFrames; ++i) { + for (int j = 0; j < kMel_OutputDim; ++j) { + float val = outputAccessor[i][j]; + val = std::max(val, featureMinVal); + outputAccessor[i][j] = (val + 4.0f) / 4.0f; + } + } + + return outputs; +} + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/log_mel_spectrogram.h b/src/libllm/cpu/log_mel_spectrogram.h new file mode 100644 index 00000000..8706e7a8 --- /dev/null +++ b/src/libllm/cpu/log_mel_spectrogram.h @@ -0,0 +1,30 @@ +// The MIT License (MIT) +// +// Copyright (c) 2023 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/cpu/tensor.h" + +namespace libllm { +namespace op { +namespace cpu { + +Tensor logMelSpectrogram(Tensor inputs); + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/log_mel_spectrogram_test.cc b/src/libllm/cpu/log_mel_spectrogram_test.cc new file mode 100644 index 00000000..1ee3a378 --- /dev/null +++ b/src/libllm/cpu/log_mel_spectrogram_test.cc @@ -0,0 +1,158 @@ +// The MIT License (MIT) +// +// Copyright (c) 2023 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/cpu/log_mel_spectrogram.h" + +#include "catch2/catch_amalgamated.hpp" +#include "libllm/cpu/fingerprint.h" +#include "libllm/functional.h" +#include "libllm/lut/base64.h" +#include "libllm/wave.h" + +namespace libllm { +namespace op { +namespace cpu { + +const char *gDataEnUsHelloBase64 = + "UklGRnwWAABXQVZFZm10IBAAAAABAAEAgD4AAAB9AAACABAAZGF0YVgWAACH/0L/gv+W/5D/dP+K/zX/MP5m/c/9DP+a/" + "1X/TP8dAIwACwA7/6f+fP63/nj/y/8K/1z+6f4v/wD+p/wD/Wr+Zv+F/7z+Av6o/T/+7P7i/cf8D/7R/38AGwADAGH/" + "Av7V/aL+Gv60/J78zv2V/j7+3/3i/Tj+p/49/lD98f0J/4n/SADmAAQAj/4p/f/8wP0M/s791/0Q/nr9rPwT/f39S/" + "23+3T83P/JAG3/UP/L/2D/2v4M/8b/3P7t/VD+ev29+9j5KPiB93X5Dv1l/8L/FgGtAtYCLQFS/279v/u2+3T8e/" + "zp+6z87f36/f79GP/b/ygBrQLZAjIC5QAG/6T92vtJ+5v8Vf2g/TD+jP/IAaoCoQG4ATYDPgS3AgAA6v5L/o38mvss/" + "KD9G/9R//L+0P4k/pD9B/6+/vj+5/4R/0P/Vv9r/p3+NgA4AQ4BUwE+AmgCDwEEAKkA5ABY/4b+Rv+E/13/" + "MgCPAVgBgQDEALICAwPeAREC8wIvA/YB5QBtAX0CmAKbAmMCaQHa/9H/fwCA/9b+dwCIAuABu//K/mL+Mf2a/E/" + "+AQCoAI8BsgKCAtb/Lv5t/1EAXf/4/s//kwB6AFP/Fv5L/" + "Sf+ZgCQAfsBewJFA3EEGgTTAfj+Jf6tAIQCxQL8AoUD8wMIBPMDSARgA3EE2QZ6CDsIogVxBKgCNAC7/Y39v/7T/" + "wkBqgHCAtwB7QCEAfMBRQEQAA8AAQDZ/pr8kfs//KX7zPtu/YT+f/7E/RH9wft4+kf76/ub+6D88/3D/h3+T/1g/" + "YT9jf5jAHgCUwRVBsEHygh8CbwJUwkZCAUIcQiCCGIJkgvlDAQNhw00DoEMJArxCA4IVwclBt4FOAQYAjYCnAFz/" + "hj6WPje90r2N/Un9rb3n/ib+dn59PhC+MP4G/p/+tT5Zfkh+f/4dfhX9/P2ZPfn93H4Rvm2+g/" + "9BgC9A8sGtgmzDjgSihDeCtUGAge4CH0LIQ6SEdwV/xgwGu8X2xRRExoScBDpDgsMWwT++o/" + "ypuo45cnje+bc7MH0kvxpAEEBHALKATcBUwCJ/x4AxAC5/" + "7n7ZvVD8ALt7uox6ifr+u6g9Oz63v+QAZ8BjgKxAycF3gbICVUO0BIkFgoV9g+" + "GC2AK4woDDLwOyBT5Gj8eGx8aG5kVxBDnDUMJ+/8R9jvvZO0V69fobugg7Fv0+Pq7/" + "iEAtgHTA0sEhwNEAk4C2AMyA5wAfPuT9WTyou9W7WXrWevI7g/" + "0HPgJ+6H8rP22AIwCXAOABFUG4glEDeoOWhDfEc0RNg87CzgJ6QnsDIUQrRRdGGYaVxtrGR8VUBDxCw0HH/" + "8o9WTvau6q7UHtg+0c8Av1fPi3+uX7jfwL/8IB0wLqAoEC+QJTAuH+YfrE9VrzNvJc8Abv8u6p8BfzBvUD96D4RPoJ/YH/" + "9ADOAt0FPwlhCxsMzAxSDroPFBCkDtoMcw0wD0MRPxOQFbwXohgxGCMWWxKXCxADJvkv8HLr4+h36Lrqiu6J8i/" + "2NPiH+HT4LPkD+xz8pf1vAJsD/wQ6BC0BVvxj+MD0m/Fk7z/unO/v8o31p/" + "Z29mX2sffq+P75GPyAAB8GpwrEDMQNCw89EM4Paw1uC9ELcQ4IEVMT8BXYF18ZQRkTFogSHA4nB439o/" + "NY7hvs9+pZ6xrt7u8U82z1uPYE91T3Nfnh+0H+OADeArEF3galBdsCQgCX/m78t/hx9Qf0s/" + "Mw84XyofGw8aDzCvYv+Ef6m/" + "0zAgAGkghiCgEMBQ5MELAREBGpD4kPkBAqERkRcBILFdwXphleGcYXFhXpEFMKJQAm9ofwDO5K7ELsP+7W8RD2tPj2+" + "Xb5O/nM+kb87/0yAG8D4wY+CXEIqQWOAiv/jfun97z0NfMT87LzP/Xy9Sf2t/" + "do+oP8uv1BAJQDUwelCs4MiQ4UEH8RShKBETEQ/Q+HEOMRPBQFF/" + "oZqxzcHQQdTRltE9YKfv+29ZHuReqO6KboV+ta773yjPX49uv2Vfdj9034lvqf/f0BxwWbCFwJugdtBaoC+f2o+F/" + "1yPPt8mPyEvP89OD28/ck+e75xfq0/O/" + "+GwJ0BdIIaQxPD9IQ+xBhEGYPcA4PDgIPXxHZFP4XYBphGycafReqEVIJbv9M9XXtjOi75mznTOov7h/yYfRA9d70k/" + "TN9Af1bvYZ+cP9UgJLBYwGTwbVBFgC3v59+oP25fOM8ofxe/AB8VbzP/" + "Vs9kj3Jfjx+en7+f2vALcDbQfzCl0Nqw59D3MP4g43DrcNwA4iEf8TTBYPGPoYcBj/" + "FFQOwQQy+qPxlesE53vl+uZg6lHu1vCD8jPz8PLf8sHyJ/OD9V75Ov6YAg0F4QW7BRcECwAx+wz3RvTG8r3xCfK288j15/" + "fc+Nf47fht+fv6ovzM/lUCkAZqCtcM/g1jDt4NRA3HDPoMEQ5RENMSgBUHGAkZ2BhsFvQPmQZt/" + "J7zPO1P6JDmueet6nHukfFw88rzlvNN87Dy8/Kr9Ab43vwrAY0EsQZsBwUHpgT9/9v6aven9MjygPIm84f1A/iE+Yr6m/" + "oU+wH81/xR/vYASwQgCGALKw2ZDUgO/g71Du4O1Q7pD9oR4BN9FsQXXRfDFRAR3QmeADn4PPKH7K7pduku68buR/" + "HG83H1KvVI9V30iPNe9Dn2uPkE/q8BZAQeBvgGDgb3AtD+QPs6+B/2PfWM9Sv39fjQ+lH8qfwq/ar9HP5L/" + "8oAiwLqBLAHzQnZCsULSwzfDE8NlA3/Dj8QJBKDFKwVPBbTFdITnA+lCEgBYvsA9svxTO/A7sfv6/B88tfzO/So9L30D/" + "RP9L/0T/Yx+c37H/6CANsCYwQMBGEB3P6m/Hr5V/cT9i/2ZPiG+gv9sv/EAOgBfQLXAaUBLgJnA/" + "kFKwgICooMVg6gD6cQkhFWEnoTORVtFqIWTBalE1AObwfR//L4lvKm7b7qFOpI63ftpe8A8XXyV/" + "Sl9DD1dfak93H67P3SAEgDoAWyBwYIwAT+AAT79/OH8TPvI+7l7xzz1PnY/" + "gYBhQMlBX4GAgaLBWwHcQmoDIsQ1RLDFbIY5xlZG1kbuhpmG3YaBRkdFOgK1QEg96ftNudg4ETdxt9x5H3p6e1S8ov26Pk" + "P+5z75/w9/zMDwwbzCIcKPgvVC2kIRwAj+Vfymep45Y7j+eQj6orvsfaZ/" + "WcBNQUfB8YHjQjhCK0Lmg9VE4QW1Rn9HFsdNh0bHlYdmRuiGiMZ3hbvD68ChPa57DDi7tjK1NjWo9vV4sTrHPRs+oD/" + "0wRABigG1AbfB2sKgwu9C1ALqAhPBb3+SfVM7iflVtxX3F/e9eKX6uLxa/xDBKAHMwvdC+oLrQ3gD/0SdRaCGbscSB/" + "hHUgdYBsUGnsa2xifGDcVzA96A+Hwk+XU2w7SE8+b0SzbYuet8tH8KAPtB5sKJAr5B9IHWwljC80NHg3DCQQFX/" + "8N97Dr9eHb3NLZPtk83yLmWO77+KwB8AcSCgQLnA17DooNBA8GE70W9Bl4G9YaShpAGTwXdxTDEzAVAhbnFU0Skwh6+" + "MnpseDC1orQGdMO20LnVfNv/uIFrgmvDAsNMwplB68HIAl+CpEKEQjgA23+7/bQ7FHjRtyC1wfW1dmt4r7sM/" + "cqAnsKzA54EBIR0xCfD38PzRHLFJIX1Rm4GhoaihhCFuITsBJHEvQS5RN+" + "EE8HqfmN6wPjzdpQ0zvU99x66gD2vP5YB94LDw3eCxYIkQUWBgUICgkvCKMFjgH9+" + "1H0MOoQ4Y7bq9rL3EzfBOVb71v6AgShCjwNkw+" + "8EY0S8xGzD7sPHBTaFoIW0RV8FREWLhZDFEwSNxJNFDEWOBNqB2H4jOxj4dbYZtMQ1AHdSukG960BHQfxC84OcAwjCG8Fq" + "wUQB+YI8AhgBOz+L/" + "oV8kfnK92e17rWsdhG4IPrPPY4AQALVRABEtMRThC4DiwOkQ8OE6UWSRk1GTUXexUMFBUTRRALD9ESjBY0F5sNNf5O8ZTj" + "Gdkl0XXPA9cf5Bn1NQIaCeQOzxHMD1ULVAbaAyQF7gcNCNYEKADS+070KeoS4eHZd9jC2TffTOhO8Rr9+" + "we4DkkRjhGyETwRhg8TDtsOLRLLFdoXChj8FSAV4RN+Et0SqhImE2YVCxJtBvL12uQZ3GLVPtFJ1cPf3e/T/" + "lgJNA7DDl0PtA0CCBMEmQPrBQkIFQesAin7dfRA7QHkxtrH1nDanuAI6bzzFP45B4sOohHREEoOAQ3FDTsOfQ73EN0Vahi" + "/GLMX7RXzFKsT+xGDEq8UJReNEyYFofQ95jrcAtMRzmDTFN/" + "M8K8AlgmbDgESVhOGDzAJcgWOBJMGjwiZBpMBsfoG9EbsZOEn2BjVy9mQ4pvrqvQ6/2gJyRCCE3kROQ/" + "FDoIPRA8nD3URwBXCGKQXcRVuE1gTvxM0Eu4R4BOMF3gTPAWs8+" + "zjyNoS1afSAtaA4HPyVQMKDHkPWREJErgPbApqBcYDyQUWCMwETP3J9fnve+h53pjWpdWx3GfmQvHB+" + "wAGGg9ZFaQVZBFVDi8OYw+eDw8RUBTYF9EZfxk9FscSLhL3EdwRfRLiFI0TOAcw92/" + "m6trT00fQFtVp3zzwzAGbDgEUGRQSE4YQ9gp3BRoDeAP5BPcDeP8599nuVufE3Z/VKdQZ3KTnFvNZ/" + "9MJGRGDFKQUnBAJDCkLPg0PDzIRYhX9F+AYaBj9FasSshC+EEgQuBB3E3cTlgiE9gHnV92/" + "1SLRdtTo3qvujf98DF0STRRkFIkR8Av6BeICFgKgAv4BO/0J9S7thuab3b/" + "VwdRd3NPnufNOAH8L6RP+F3AXXxOkDmEMewz2DKUOyRLMFhAYVxdRFYMTuBFuDzkOew7NEc8O1gA28XTjT9ql1G/" + "T0dff4Xn0KwaIDzUURhZoFdwQQAtTBhsCeAHyAYr/d/" + "nB8FLoZN9a1m7RpNJ22wDqXfqzCKcT4BtOH2IcExWkDwsNGQybDd0RCxZXF/QXYBabErcOlAyADJEMWA7JDW8D8vM/" + "5tHbQNVh04/YCOQl9HgFeBG+FtEXpRXEEPUJvQPQAAUAVf8m/Qr5RPJ/6ZbgnNgm0x7UG9/" + "57jL9sgrwFk0eKB+9GiwTpgsUCWAL2A4eEi0WDhraGmsYbxPaDnkMmguZDAMOKAjQ++7u5+EN2eTSEdQy3J/" + "oqfpECvUUvxm6Gb0WixC0CYkEtACT/lb8/PjM81LsO+Tf2hvU89Km2Uzn/fW7BFcSORxdIKYd9xayD/" + "YKpApJDYYRuhasGeUaHxoAFjsRaAyjCscJQAqpCZYAnvJO5XzdYdiB1ZTZV+SL8xIEIhG9F+UYLBheFRMObQbeAcf+/" + "vwo+ln1Du7Y5b3dbdZa003WauI/8tgAXQ4HGrEg7x86GksTKA0JCjULXw/" + "7E1sXwhm4GrkYxxMqD4IMPgudCgQJgf9N8o7mdd251xfWK9wv5r3zygPkEOIWlRcxFvkSDg2pBQQAuPyH+" + "rD3PPPH7ODkMd5R2nnZ6N1Q6X34tgZ1EroakB35GuoUVQ7XCMMG5QgMDuMT8hZNGDQYqBaIE84Phg30DMANFA3YBbT5/" + "urY3CHURtIe1gTd9OoT/qkOVhhAHEEbJBVFDQkHuwEh/N/" + "42vix+NjzFuyO5KDcVtfL2PjgHexx+TgJMRe5Hg8f3xnXEb8J0QJ5/" + "20AtQTaCrkRQBdxGMEWOxP6D30M8QlPChcMYg7WD+YPGQoH/" + "fTptNla0kjRzdN42yPtOAJrEy0esyBBGyES9QnUAtL75Pb09sL6gv3l+hH1tuwq407aUNXt15bgxe/" + "wAUMTih9rJDgjPxxVEB0DUvkc9ab2wfz8BTIPShYcGgAaRBdvEqINtQqaCkMNoxHwFnUXzg/b/" + "Bnl5NVw0N3OmNDM3OnxGgc4GBshzB5ZFtENswVy/Ib00fNi+SkAPAIT/" + "yv51O8m5czchte11jTeju4HAkcSMx3mIHUe9xa/C/v/h/" + "bn8tb1if0ZBl0O9hRcGLYXARSBDz4LAQlFC3cR1RjXHQEgkR+SFeH8It2XydDE1sSxyM3X6/" + "DOBzoZ3SM6IiEWBAqKBD0Bxfs7+bT+4gdJDC4IG/7G7wjh2NZ20lbRKNUz41b5CQ6lGR0dWByFFuAMwQJY+0X3C/ht/" + "jsGTgkNCLIJ5w4FEb0MYAiiB3kM6xJfGesc4h19IN0iWyGaFGL7B9zfyRjHksi0ymvVPOzmAUoSLRwfHRcUugkhBi0FRQK" + "lAFgFaQtbDKsHUQCx9JzmntxS19rVedcC4A/umP1jCFsNnRBmEVAQggycB6UBDf55/6QCPQKl/jn94v/" + "PA98FgQm9DPQPJBTfGQ0eIx2XHEQcSxy3G1gaYRaLBgjpINJ+zXzOyMo4ypXayfBVA74RZBnlGP8TMRPYE/" + "kOFwfRBBsHbQePA77+NfeH7MDk6uD33TPbHd3d5Mbu1fVM/VQG6g3UET8SABIWDrQHlAO6AQv+z/" + "fi9Lv2UPkX+73+AQRVCnUQrBWtGSIaExs2HdgdbBwWHK0eaSEAIFkZ+AFm3jvI/" + "sOFxQzCgcdK3Qv4tg8FH0oiuhwlFhoTKA9jBaT/zgK+CsMNbwolBpr8r+6M4t3aD9Vx0NbT2OHg8vH/" + "fQjGDTwRCRGGDrsK/AQ+ADT/OwLwAy0C4v6D+6z4O/ZJ9Vb1zfex/zIMcxcuHWsefh+xIDEgOxxcF2QXQRz1If4i/" + "xpmAAPeqcvmybbGub2ewWPYPPUZDbgboR8xHNUX3RWjEl0KcgPAA5oKSw8GDmoHZ/r57Yzm3+HY2c3Ra9QC4w/zl/" + "xVAqIJ/xB+EgsQZwqaAxz+tvxS/3cA3v6V+yP6lPls9znzKvAX8dP2Vf9rCBwRJxk5H4Yh7x/" + "vGnMXyxVIFB0TsBT9G5AhmiFXFbb5RtzdyobF/cBfv9nJ8N/F+gAUACK7IuQdIBokFqsMLAQMAo4FnQr/DVwP4wpp/" + "3vx1+aa3+LW8c+00qDfOu8e+xEEtAyvEfIQvgwUCK0Dkf7a+239RgE/AzoCzf8d/" + "bj5JvRi7EDn0OnT8SH73AMFDk0ZfSE0IrkdvhjwE1kQBQ8ZEhcYnx17IqglgiUgHNEDpuWtzvfDBsAMwPDI7dqg8pYKkhs" + "wIfwdfRjvE7wN8QWIAIIBCglGEfAVDhQJCyb/8fJl5yfa8M4dzTfV/+O58/UBIw2nE2AV/hP8DVwERfuk99X59PyP/" + "2gBewP5A5MAavoT8yrtUup46vntgvT//LYFWgylEAITQxRrFT0VoBRlFMwVchjtGqocdR3iHc4bqxaZDdsCh/" + "Vi5avWw86+z3TV095M6kH3aAQ0DwAV1xS+EQwPkw3NC1UKJQoZC4oLDwpWB8ABE/" + "kV71bnw+IK4SfiVOUC7Pn0lf1IBMQH3wn2Cs0K1gmDBxsFpwMMA3YCbADd/Cv5WvXi8ubw4O/m7+fw1/" + "QC+pH+ZAEFBPkGAAkNCvsKGAxrDpcRtRQUFzoYZBn5GX4ZzhgCF/" + "wUwhGoDCsHDgBQ9pLqaOCJ3BzdCeAZ5WPs1PZ7AXkJRA0YDnQOtA6xDXwLyQh3B60HgwdBBt4DVwCl+5z2ufK+" + "75LsMuu27Kvx+/XO+ET96wGgBigJZwm6CO0GxgW2BMECBgAU/bH7GvoW92L0kfLd8ebxvvJy9Wb49for/g=="; + +CATCH_TEST_CASE("test logMelSpectrogram", "[op][cpu][logmelspectrogram]") { + std::vector pcmData = lut::decodeBase64(gDataEnUsHelloBase64); + CATCH_REQUIRE(std::string(reinterpret_cast(pcmData.data()) + 36, 4) == "data"); + + lut::Span pcmSpan(pcmData); + pcmSpan = pcmSpan.subspan(44); // 44: pcm header size. + + Tensor wave = Wave::read(pcmSpan, WaveFormat::Wave16kHz16bitMonoPCM); + Tensor features = logMelSpectrogram(wave); + + F::print(op::cpu::fingerprint(features)); + + CATCH_REQUIRE(F::allClose( + op::cpu::fingerprint(features), + Tensor::create( + {8}, + {0.5365, 0.6787, 0.1886, 0.4008, -0.2633, -0.3035, -0.4268, -0.6635}))); +} + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/matmul.cc b/src/libllm/cpu/matmul.cc index 6aa07ae7..c0671400 100644 --- a/src/libllm/cpu/matmul.cc +++ b/src/libllm/cpu/matmul.cc @@ -26,10 +26,6 @@ #include "libllm/lut/strings.h" #include "libllm/mp.h" -#ifndef _OPENMP -#error OpenMP required -#endif - namespace libllm { namespace op { namespace cpu { diff --git a/src/libllm/cpu/normalizations.cc b/src/libllm/cpu/normalizations.cc new file mode 100644 index 00000000..0344737f --- /dev/null +++ b/src/libllm/cpu/normalizations.cc @@ -0,0 +1,138 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/cpu/normalizations.h" + +#include + +#include "libllm/cpu/accessor.h" +#include "libllm/cpu/common.h" +#include "libllm/cpu/tensor.h" +#include "libllm/mp.h" +#include "libllm/tensor.h" + +namespace libllm { +namespace op { +namespace cpu { + +template +Tensor rmsNormKernel(const Tensor &tensor, const Tensor &weight, float eps) { + CHECK(weight.getDim() == 1); + CHECK(tensor.getShape(-1) == weight.getShape(0)); + + Tensor C = tensorLike(tensor); + + TensorList vA = TensorList::fromTensor(tensor); + TensorList vC = TensorList::fromTensor(C); + CHECK(vA.getLength() == vC.getLength()); + + TensorAccessor w = weight; + + MP::parallelFor({vA.getLength()}, [&vA, &vC, w, eps](MP::Partition partition) { + for (int j : partition.getRange()) { + TensorAccessor a = vA.getTensor(j); + TensorAccessor c = vC.getTensor(j); + + double sum = 0.0; + for (int i = 0; i < a.getShape(0); ++i) { + double va = a[i]; + sum += va * va; + } + double mean = sum / a.getShape(0); + double rms = std::sqrt(mean + eps); + + // compute rms-norm + for (int i = 0; i < a.getShape(0); ++i) { + double va = a[i]; + double vw = w[i]; + c[i] = static_cast(a[i] * w[i] / rms); + } + } + }); + + return C; +} + +template +Tensor layerNormKernel(const Tensor &tensor, const Tensor &weight, const Tensor &bias, float eps) { + CHECK(weight.getDim() == 1); + CHECK(tensor.getShape(-1) == weight.getShape(0)); + + Tensor C = tensorLike(tensor); + + TensorList vA = TensorList::fromTensor(tensor); + TensorList vC = TensorList::fromTensor(C); + CHECK(vA.getLength() == vC.getLength()); + + TensorAccessor w = weight; + TensorAccessor b = bias; + + MP::parallelFor({vA.getLength()}, [&vA, &vC, w, b, eps](MP::Partition partition) { + for (int j : partition.getRange()) { + TensorAccessor a = vA.getTensor(j); + TensorAccessor c = vC.getTensor(j); + + double sum = 0.0f; + for (int i = 0; i < a.getShape(0); ++i) { + sum += a[i]; + } + double mean = sum / a.getShape(0); + + // var (unbiased) + sum = 0.0; + for (int i = 0; i < a.getShape(0); ++i) { + double d = a[i] - mean; + sum += d * d; + } + double var = sum / a.getShape(0); + double sd = sqrt(var + eps); + + // compute layer-norm + for (int i = 0; i < a.getShape(0); ++i) { + float elem = static_cast((a[i] - mean) / sd); + c[i] = elem * w[i] + b[i]; + } + } + }); + + return C; +} + +Tensor rmsNorm(Tensor tensor, Tensor weight, float eps) { + if (tensor.getDType() == DType::kFloat) return rmsNormKernel(tensor, weight, eps); +#if LUT_CPU_ARCH == LUT_AARCH64 + if (tensor.getDType() == DType::kFloat16) return rmsNormKernel(tensor, weight, eps); +#endif + + NOT_IMPL(); +} + +Tensor layerNorm(Tensor tensor, Tensor weight, Tensor bias, float eps) { + if (tensor.getDType() == DType::kFloat) return layerNormKernel(tensor, weight, bias, eps); +#if LUT_CPU_ARCH == LUT_AARCH64 + if (tensor.getDType() == DType::kFloat16) + return layerNormKernel(tensor, weight, bias, eps); +#endif + + NOT_IMPL(); +} + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/normalizations.h b/src/libllm/cpu/normalizations.h new file mode 100644 index 00000000..8606b02f --- /dev/null +++ b/src/libllm/cpu/normalizations.h @@ -0,0 +1,33 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#pragma once + +#include "libllm/tensor.h" + +namespace libllm { +namespace op { +namespace cpu { + +Tensor rmsNorm(Tensor tensor, Tensor weight, float eps); +Tensor layerNorm(Tensor tensor, Tensor weight, Tensor bias, float eps); + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/reduce.cc b/src/libllm/cpu/reduce.cc new file mode 100644 index 00000000..4a5c405e --- /dev/null +++ b/src/libllm/cpu/reduce.cc @@ -0,0 +1,117 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/cpu/reduce.h" + +#include "libllm/cpu/accessor.h" +#include "libllm/cpu/tensor.h" +#include "libllm/mp.h" +#include "libllm/tensor.h" + +namespace libllm { +namespace op { +namespace cpu { + +enum class MapType { EXP_FP16_FP32, SQUARE_FP16_FP32, IDENTITY, UNKNOWN }; +enum class ReduceType { SUM, MAX, UNKNOWN }; + +constexpr MapType getMapType(MapReduceType mapReduceType) { + switch (mapReduceType) { + case MapReduceType::SUM: + return MapType::IDENTITY; + case MapReduceType::MAX: + return MapType::IDENTITY; + default: + return MapType::UNKNOWN; + } +} + +constexpr ReduceType getReduceType(MapReduceType mapReduceType) { + switch (mapReduceType) { + case MapReduceType::SUM: + return ReduceType::SUM; + case MapReduceType::MAX: + return ReduceType::MAX; + default: + return ReduceType::UNKNOWN; + } +} + +template +T getReduceInitial() { + switch (REDUCE_TYPE) { + case ReduceType::SUM: + return T(0); + case ReduceType::MAX: + return -std::numeric_limits::infinity(); + default: + NOT_IMPL(); + } +} + +template +Tensor reduceKernel(Tensor A) { + std::vector shape = A.getShape(); + Tensor C = tensor(shape, A.getDType()); + + TensorList vA = TensorList::fromTensor(A); + TensorList vC = TensorList::fromTensor(C); + CHECK(vA.getLength() == vC.getLength()); + + MP::parallelFor({vA.getLength()}, [&vA, &vC](MP::Partition partition) { + for (int j : partition.getRange()) { + TensorAccessor a = vA.getTensor(j); + TensorAccessor c = vC.getTensor(j); + + float accumulator = getReduceInitial(); + for (int i = 0; i < a.getShape(0); i++) { + if (REDUCE_TYPE == ReduceType::SUM) { + accumulator += a[i]; + } else if (REDUCE_TYPE == ReduceType::MAX) { + if (a[i] > accumulator) accumulator = a[i]; + } else { + NOT_IMPL(); + } + } + + c[0] = accumulator; + } + }); + + return C; +} + +Tensor reduce(const Tensor &A, MapReduceType reduceType) { + if (A.getDType() == DType::kFloat && reduceType == MapReduceType::SUM) + return reduceKernel(A); + if (A.getDType() == DType::kFloat && reduceType == MapReduceType::MAX) + return reduceKernel(A); +#if LUT_CPU_ARCH == LUT_AARCH64 + if (A.getDType() == DType::kFloat16 && reduceType == MapReduceType::SUM) + return reduceKernel(A); + if (A.getDType() == DType::kFloat16 && reduceType == MapReduceType::MAX) + return reduceKernel(A); +#endif + + NOT_IMPL(); +} + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/reduce.h b/src/libllm/cpu/reduce.h new file mode 100644 index 00000000..8388d350 --- /dev/null +++ b/src/libllm/cpu/reduce.h @@ -0,0 +1,34 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#pragma once + +#include "libllm/tensor.h" + +namespace libllm { +namespace op { +namespace cpu { + +enum class MapReduceType { SUM, MAX }; + +Tensor reduce(const Tensor &A, MapReduceType reduceType); + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/swiglu.h b/src/libllm/cpu/swiglu.h index 0f9eab1c..06a03f51 100644 --- a/src/libllm/cpu/swiglu.h +++ b/src/libllm/cpu/swiglu.h @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -28,6 +28,6 @@ namespace cpu { Tensor swiglu(const Tensor &A); Tensor swigluFp32(const Tensor &A); -} // cpu -} // op -} // ly +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/unfold.cc b/src/libllm/cpu/unfold.cc new file mode 100644 index 00000000..64610384 --- /dev/null +++ b/src/libllm/cpu/unfold.cc @@ -0,0 +1,95 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/cpu/unfold.h" + +#include "libllm/cpu/accessor.h" +#include "libllm/cpu/common.h" +#include "libllm/cpu/copy.h" +#include "libllm/cpu/tensor.h" +#include "libllm/mp.h" + +namespace libllm { +namespace op { +namespace cpu { + +template +void unfold1DKernel(const Tensor &src, Tensor &dest, int kernelSize, int stride) { + TensorList mA = TensorList::fromTensor(src); + TensorList mC = TensorList::fromTensor(dest); + CHECK(mA.getLength() == mC.getLength()); + + for (int i = 0; i < mA.getLength(); ++i) { + TensorAccessor vA = mA.getTensor(i); + TensorAccessor vC = mC.getTensor(i); + CHECK(vA.getShape(0) / stride == vC.getShape(0)); + + MP::parallelFor({vC.getShape(0)}, [&vA, &vC, kernelSize, stride](MP::Partition partition) { + int kernekIdxBegin = -(kernelSize / 2); + int kernekIdxEnd = (kernelSize - 1) / 2; + + for (int j : partition.getRange()) { + int numChannels = vA.getShape(1); + int numInFrames = vA.getShape(0); + + for (int d = 0; d < numChannels; ++d) { + for (int k = kernekIdxBegin; k <= kernekIdxEnd; ++k) { + int srcIdx = j * stride + k; + int offset = k - kernekIdxBegin; + if (srcIdx < 0 || srcIdx >= numInFrames) { + // padding. + vC[j][d * kernelSize + offset] = 0.0f; + } else { + vC[j][d * kernelSize + offset] = vA[srcIdx][d]; + } + } + } + } + }); + } +} + +Tensor unfold(const Tensor &src, int kernelSize, int stride) { + CHECK(src.getDim() >= 2); + CHECK(src.getShape(-1) >= kernelSize); + + std::vector shape = src.getShape(); + shape.back() *= kernelSize; + shape[shape.size() - 2] /= stride; + + Tensor dest = op::cpu::tensor(shape, src.getDType()); + + if (src.getDType() == DType::kFloat) { + unfold1DKernel(src, dest, kernelSize, stride); + } else if (src.getDType() == DType::kFloat16) { +#if LUT_CPU_ARCH == LUT_AARCH64 + unfold1DKernel(src, dest, kernelSize, stride); +#else + NOT_IMPL(); +#endif + } else { + NOT_IMPL(); + } + + return dest; +} + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/cpu/unfold.h b/src/libllm/cpu/unfold.h new file mode 100644 index 00000000..36246c6c --- /dev/null +++ b/src/libllm/cpu/unfold.h @@ -0,0 +1,33 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#pragma once + +#include "libllm/lut/span.h" +#include "libllm/tensor.h" + +namespace libllm { +namespace op { +namespace cpu { + +Tensor unfold(const Tensor &src, int kernelSize, int stride); + +} // namespace cpu +} // namespace op +} // namespace libllm diff --git a/src/libllm/dialog_manager.cc b/src/libllm/dialog_manager.cc deleted file mode 100644 index 7ad81551..00000000 --- a/src/libllm/dialog_manager.cc +++ /dev/null @@ -1,244 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2023 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#include "libllm/dialog_manager.h" - -#include "libllm/lut/error.h" -#include "libllm/lut/strings.h" -#include "libllm/lut/time.h" - -namespace libllm { - -ChatOutput::ChatOutput() - : numAnswerTokens(0), - promptDuration(0.0), - answerDuration(0.0) { -} - -// -----------------------------------------------------------------------------------------------+ -// class ChatGLM2PromptBuilder | -// -----------------------------------------------------------------------------------------------+ - -std::shared_ptr ChatGLM2PromptBuilder::buildPrompt( - std::shared_ptr model, - lut::Span history, - const std::string &question) { - std::string prompt; - - int round = 1; - for (const QA &qa : history) { - prompt += lut::sprintf("[Round %d]\n\n问:%s\n\n答:%s\n\n", round, qa.question, qa.answer); - ++round; - } - prompt += lut::sprintf("[Round %d]\n\n问:%s\n\n答:", round, question); - - std::shared_ptr pPrompt = llm::Prompt::fromModel(model); - pPrompt->appendText(prompt); - return pPrompt; -} - -std::string ChatGLM2PromptBuilder::getStopSeq() { - return ""; -} - -// -----------------------------------------------------------------------------------------------+ -// class ChatGLM3PromptBuilder | -// -----------------------------------------------------------------------------------------------+ - -std::shared_ptr ChatGLM3PromptBuilder::buildPrompt( - std::shared_ptr model, - lut::Span history, - const std::string &question) { - std::shared_ptr prompt = llm::Prompt::fromModel(model); - for (const QA &qa : history) { - prompt->appendControlToken("<|user|>"); - prompt->appendText("\n"); - prompt->appendText(lut::sprintf("%s", qa.question)); - prompt->appendControlToken("<|assistant|>"); - prompt->appendText("\n"); - prompt->appendText(lut::sprintf("%s", qa.answer)); - } - - prompt->appendControlToken("<|user|>"); - prompt->appendText("\n"); - prompt->appendText(lut::sprintf("%s", question)); - prompt->appendControlToken("<|assistant|>"); - - return prompt; -} - -std::string ChatGLM3PromptBuilder::getStopSeq() { - return ""; -} - -// -----------------------------------------------------------------------------------------------+ -// class LlamaPromptBuilder | -// -----------------------------------------------------------------------------------------------+ - -std::shared_ptr LlamaPromptBuilder::buildPrompt( - std::shared_ptr model, - lut::Span history, - const std::string &question) { - std::shared_ptr prompt = llm::Prompt::fromModel(model); - bool systemPromptAdded = false; - prompt->appendControlToken("<|begin_of_text|>"); - if (!systemPromptAdded) { - prompt->appendControlToken("<|start_header_id|>"); - prompt->appendText("system"); - prompt->appendControlToken("<|end_header_id|>"); - prompt->appendText("\n\nYou are a helpful assistant."); - prompt->appendControlToken("<|eot_id|>"); - prompt->appendText("\n"); - systemPromptAdded = true; - } - - for (const QA &qa : history) { - prompt->appendControlToken("<|start_header_id|>"); - prompt->appendText("user"); - prompt->appendControlToken("<|end_header_id|>"); - prompt->appendText("\n\n" + qa.question); - prompt->appendControlToken("<|eot_id|>"); - prompt->appendControlToken("<|start_header_id|>"); - prompt->appendText("assistant"); - prompt->appendControlToken("<|end_header_id|>"); - prompt->appendText("\n\n" + qa.answer); - prompt->appendControlToken("<|eot_id|>"); - prompt->appendText("\n"); - } - - prompt->appendControlToken("<|start_header_id|>"); - prompt->appendText("user"); - prompt->appendControlToken("<|end_header_id|>"); - prompt->appendText("\n\n" + question); - prompt->appendControlToken("<|eot_id|>"); - prompt->appendControlToken("<|start_header_id|>"); - prompt->appendText("assistant"); - prompt->appendControlToken("<|end_header_id|>"); - prompt->appendText("\n\n"); - return prompt; -} - -std::string LlamaPromptBuilder::getStopSeq() { - return ""; -} - -// -----------------------------------------------------------------------------------------------+ -// class QwenPromptBuilder | -// -----------------------------------------------------------------------------------------------+ - -std::shared_ptr QwenPromptBuilder::buildPrompt( - std::shared_ptr model, - lut::Span history, - const std::string &question) { - std::shared_ptr prompt = llm::Prompt::fromModel(model); - prompt->appendControlToken("<|im_start|>"); - prompt->appendText("system\nYou are a helpful assistant."); - prompt->appendControlToken("<|im_end|>"); - for (const QA &qa : history) { - prompt->appendText("\n"); - prompt->appendControlToken("<|im_start|>"); - prompt->appendText(lut::sprintf("user\n%s", qa.question)); - prompt->appendControlToken("<|im_end|>"); - prompt->appendText("\n"); - prompt->appendControlToken("<|im_start|>"); - prompt->appendText(lut::sprintf("assistant\n%s", qa.answer)); - prompt->appendControlToken("<|im_end|>"); - } - prompt->appendText("\n"); - prompt->appendControlToken("<|im_start|>"); - prompt->appendText(lut::sprintf("user\n%s", question)); - prompt->appendControlToken("<|im_end|>"); - prompt->appendText("\n"); - prompt->appendControlToken("<|im_start|>"); - prompt->appendText("assistant\n"); - - return prompt; -} - -std::string QwenPromptBuilder::getStopSeq() { - return ""; -} - -// -----------------------------------------------------------------------------------------------+ -// class PromptBuilder | -// -----------------------------------------------------------------------------------------------+ - -std::shared_ptr PromptBulder::create(const std::string &modelName) { - if (modelName == "llama") return std::make_shared(); - if (modelName == "qwen") return std::make_shared(); - if (modelName == "chatglm2") return std::make_shared(); - if (modelName == "chatglm3") return std::make_shared(); - - THROW(Aborted, "unexpected model name: " + modelName); - return nullptr; -} - -// -----------------------------------------------------------------------------------------------+ -// class DialogManager | -// -----------------------------------------------------------------------------------------------+ - -DialogManager::DialogManager( - std::shared_ptr model, - std::shared_ptr promptBuilder) - : _model(model), - _promptBuilder(promptBuilder) { -} - -ChatOutput DialogManager::chat( - const std::string &question, - std::function onTokenCallback) { - ChatOutput output; - - std::shared_ptr prompt = _promptBuilder->buildPrompt(_model, _history, question); - llm::CompletionConfig config; - config.setTopK(1); - - double t0 = lut::now(); - std::shared_ptr comp = _model->complete(prompt); - output.promptDuration = lut::now() - t0; - - std::string answer; - std::string stopSeq = _promptBuilder->getStopSeq(); - t0 = lut::now(); - int numToken = 0; - while (comp->isActive()) { - llm::Chunk chunk = comp->nextChunk(); - std::string nextToken = chunk.getText(); - - if (onTokenCallback) onTokenCallback(nextToken); - answer += nextToken; - ++numToken; - - if ((!stopSeq.empty()) && (answer.find(stopSeq) != std::string::npos)) break; - } - output.numAnswerTokens = numToken; - output.answerDuration = lut::now() - t0; - - answer = lut::trim(answer); - output.answer = answer; - - QA qa; - qa.question = question; - qa.answer = answer; - _history.emplace_back(qa); - - return output; -} - -} // namespace libllm diff --git a/src/libllm/dialog_manager.h b/src/libllm/dialog_manager.h deleted file mode 100644 index 747aaafb..00000000 --- a/src/libllm/dialog_manager.h +++ /dev/null @@ -1,122 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2023 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#pragma once - -#include -#include -#include - -#include "libllm/llm_cpp.h" -#include "libllm/lut/span.h" - -namespace libllm { - -struct QA { - std::string question; - std::string answer; -}; - -// statistics for a LLM completion in dialog. -struct ChatOutput { - int numAnswerTokens; - std::string answer; - - double promptDuration; - double answerDuration; - - ChatOutput(); -}; - -/// @brief Interface for dialog prompt builder. -class PromptBulder { - public: - /// @brief Create an instance of PromptBulder according to the LLM model name. - /// @param modelName LLM model name. - /// @return An instance of PromptBulder. - static std::shared_ptr create(const std::string &modelName); - - virtual ~PromptBulder() = default; - - // build the prompt from history and current question. - virtual std::shared_ptr buildPrompt( - std::shared_ptr model, - lut::Span history, - const std::string &question) = 0; - - /// @brief Get the step sequence for generation. - /// @return the stop sequence in string. - virtual std::string getStopSeq() = 0; -}; - -class ChatGLM2PromptBuilder : public PromptBulder { - public: - std::shared_ptr buildPrompt( - std::shared_ptr model, - lut::Span history, - const std::string &question) override; - std::string getStopSeq() override; -}; - -class ChatGLM3PromptBuilder : public PromptBulder { - public: - std::shared_ptr buildPrompt( - std::shared_ptr model, - lut::Span history, - const std::string &question) override; - std::string getStopSeq() override; -}; - -class LlamaPromptBuilder : public PromptBulder { - public: - std::shared_ptr buildPrompt( - std::shared_ptr model, - lut::Span history, - const std::string &question) override; - std::string getStopSeq() override; -}; - -class QwenPromptBuilder : public PromptBulder { - public: - std::shared_ptr buildPrompt( - std::shared_ptr model, - lut::Span history, - const std::string &question) override; - std::string getStopSeq() override; -}; - -// llm dialog manager. -class DialogManager { - public: - DialogManager(std::shared_ptr model, std::shared_ptr promptBuilder); - - // chat with LLM. Once a token was generated by LLM, onTokenCallback will be invoked by the - // token (if onTokenCallback is not empty). Once answering was compeleted, returns ChatOutput. - ChatOutput chat( - const std::string &question, - std::function onTokenCallback = {}); - - private: - std::string _stop; - std::shared_ptr _model; - std::shared_ptr _promptBuilder; - std::vector _history; -}; - -} // namespace libllm diff --git a/src/libllm/functional.cc b/src/libllm/functional.cc index b5a7043a..b3d7908e 100644 --- a/src/libllm/functional.cc +++ b/src/libllm/functional.cc @@ -24,6 +24,7 @@ #include "libllm/lut/error.h" #include "libllm/lut/strings.h" #include "libllm/operators.h" +#include "libllm/tensor.h" namespace libllm { namespace F { @@ -150,6 +151,20 @@ void copy(Tensor src, Tensor dest) { } } +Tensor sum(Tensor tensor, int dim) { + CHECK(dim == -1 || dim == tensor.getDim() - 1); + return getOperators(tensor.getDevice().getType())->sum(tensor); +} + +Tensor max(Tensor tensor, int dim) { + CHECK(dim == -1 || dim == tensor.getDim() - 1); + return getOperators(tensor.getDevice().getType())->max(tensor); +} + +void fill(Tensor tensor, float value) { + getOperators(tensor.getDevice().getType())->fill(tensor, value); +} + Tensor attention(Tensor q, Tensor k, Tensor v, Tensor mask) { float dK = 1.0f / sqrtf(1.0f * q.getShape(-1)); q = F::mul(q, sqrtf(dK)); @@ -166,8 +181,16 @@ Tensor attention(Tensor q, Tensor k, Tensor v, Tensor mask) { return outputs; } -Tensor swiglu(Tensor input) { - return getOperators(input.getDevice().getType())->swiglu(input); +Tensor swiglu(Tensor inputs) { + return getOperators(inputs.getDevice().getType())->swiglu(inputs); +} + +Tensor logMelSpectrogram(Tensor wave) { + return getOperators(wave.getDevice().getType())->logMelSpectrogram(wave); +} + +Tensor unfold(Tensor input, int kernelSize, int stride) { + return getOperators(input.getDevice().getType())->unfold(input, kernelSize, stride); } Tensor to(Device device, Tensor tensor) { diff --git a/src/libllm/functional.h b/src/libllm/functional.h index 7ec0af21..6ecc9235 100644 --- a/src/libllm/functional.h +++ b/src/libllm/functional.h @@ -1,13 +1,13 @@ // The MIT License (MIT) // -// Copyright (c) 2023 Xiaoyang Chen +// Copyright (c) 2023-2024 Xiaoyang Chen // // Permission is hereby granted, free of charge, to any person obtaining a copy of this software // and associated documentation files (the "Software"), to deal in the Software without // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -19,9 +19,9 @@ #pragma once -#include "libllm/tensor.h" -#include "libllm/lut/span.h" #include "libllm/lut/random.h" +#include "libllm/lut/span.h" +#include "libllm/tensor.h" namespace libllm { namespace F { @@ -72,7 +72,7 @@ Tensor mul(Tensor input, Tensor other); Tensor softmax(Tensor input); // return input + other. -Tensor add(Tensor input, Tensor other) ; +Tensor add(Tensor input, Tensor other); // Applies the Gaussian Error Linear Units function for `input`. Here it use the approximate // version of GELU: @@ -99,8 +99,13 @@ Tensor tensor(lut::Span shape, DType dtype, Device device = Device::g /// @param min minimal value for random number generator. /// @param max maximum value for random number generator. /// @return Generated random tensor. -Tensor rand(lut::Span shape, DType dtype, Device device = Device::getCpu(), - lut::Random *generator = nullptr, float min = -1.0f, float max = 1.0f); +Tensor rand( + lut::Span shape, + DType dtype, + Device device = Device::getCpu(), + lut::Random *generator = nullptr, + float min = -1.0f, + float max = 1.0f); // returns a uninitialized tensor with the same shape and dtype as input Tensor tensorLike(Tensor input); @@ -167,6 +172,42 @@ Tensor attention(Tensor q, Tensor k, Tensor v, Tensor mask = Tensor()); // (..., D / 2): the output tensor. Tensor swiglu(Tensor input); +/// @brief Apply Gaussian error linear unit (GELU) activation to the inputs. it applies +/// element-wise the function GELU(x) = x * Phi(x) where Phi(x) is the Cumulative Distribution. In +/// the implementation, it did not use the approximate version. Function for Gaussian Distribution. +/// @param inputs: (..., D): the input tensor. +/// @return (..., D): the output tensor. +Tensor gelu(Tensor inputs); + +/// @brief fill tensor with value. +/// @param tensor the tensor to fill. +/// @param value the value. +void fill(Tensor tensor, float value); + +/// @brief Returns the sum of each row of the input tensor in the given dimension dim. +/// @param tensor (d1, d2, ..., dn) the input tensor. +/// @return (d1, d2, ..., dn-1): the output tensor. +Tensor sum(Tensor tensor, int dim = -1); + +/// @brief Returns the maximum value of each row of the input tensor in the given dimension dim. +/// @param tensor (d1, d2, ..., dn) the input tensor. +/// @return (d1, d2, ..., dn-1): the output tensor. +Tensor max(Tensor tensor, int dim = -1); + +/// @brief (im2col) Extracts sliding local blocks from the input tensor. To make +/// sure the input and output shape are the same after Conv, it will also pad the input tensor with +/// zero. +/// @param input (N, L, C): the input tensor. +/// @param kernelSize: the kernel size. +/// @param stride: the stride. +/// @return (N, L / stride, D * kernelSize): the output tensor. +Tensor unfold(Tensor input, int kernelSize, int stride); + +/// @brief Extract the log mel spectrogram feature from input wave. +/// @param wave (wave_len, ): the input wave. +/// @return (feature_len, FeatDim=80): the logMelSpectrogram feature. +Tensor logMelSpectrogram(Tensor wave); + /// @brief Copy the tensor to target device. If `castFloat` is true and the tensor type is float, // it will cast the data type to default float type of that device. /// @param tensor the source tensor. @@ -186,5 +227,5 @@ Tensor cast(Tensor tensor, DType dtype); /// @return float type as DType. DType getDefaultFloatType(Device device); -} // F -} // libllm +} // namespace F +} // namespace libllm diff --git a/src/libllm/generator.cc b/src/libllm/generator.cc index 491173bf..3db213a6 100644 --- a/src/libllm/generator.cc +++ b/src/libllm/generator.cc @@ -1,6 +1,6 @@ // The MIT License (MIT) // -// Copyright (c) 2023 Xiaoyang Chen +// Copyright (c) 2023-2024 Xiaoyang Chen // // Permission is hereby granted, free of charge, to any person obtaining a copy of this software // and associated documentation files (the "Software"), to deal in the Software without @@ -19,8 +19,14 @@ #include "libllm/generator.h" +#include + +#include + #include "libllm/functional.h" +#include "libllm/lut/error.h" #include "libllm/lut/strings.h" +#include "libllm/whisper.h" namespace libllm { @@ -30,60 +36,165 @@ GenerationConfig::GenerationConfig() temperature(1.0f) { } -Generator::Generator( - GenerationConfig config, - std::shared_ptr model, - std::shared_ptr tokenizer) - : _config(config), - _sampler(config.topK, config.topP), - _tokenizer(tokenizer), - _model(model), +// -----------------------------------------------------------------------------------------------+ +// class Sampler | +// -----------------------------------------------------------------------------------------------+ + +Sampler::Sampler(int topK, float topP) + : _topK(topK), + _topP(topP) { +} + +std::vector Sampler::getTopP(const Tensor &distribution, lut::Span topK) { + CHECK(distribution.getDim() == 1 && distribution.getDType() == DType::kFloat); + float sumP = 0.0f; + + std::vector topP; + const float *d = distribution.getData(); + for (int label : topK) { + float p = d[label]; + topP.push_back(label); + + sumP += p; + if (sumP >= _topP) { + break; + } + } + + return topP; +} + +std::vector Sampler::getTopK(const Tensor &distribution) { + CHECK(_topK <= distribution.getShape(0) && distribution.getStride(0) == 1); + if (_topBuffer.size() != distribution.getShape(0)) _topBuffer.resize(distribution.getShape(0)); + + const float *d = distribution.getData(); + for (int32_t i = 0; i < distribution.getShape(0); ++i) { + _topBuffer[i] = std::make_pair(i, d[i]); + } + + std::partial_sort( + _topBuffer.begin(), + _topBuffer.begin() + _topK, + _topBuffer.end(), + [](const std::pair &a, const std::pair &b) { + return a.second > b.second; + }); + + std::vector topK; + LOG(DEBUG) << "Sampler TopK (K=" << _topK << ")"; + for (int i = 0; i < _topK; ++i) { + topK.push_back(_topBuffer[i].first); + LOG(DEBUG) << i << ": " << _topBuffer[i].first << ", " << _topBuffer[i].second; + } + + return topK; +} + +int Sampler::sampleTopP(const Tensor &distribution, lut::Span topP) { + CHECK(distribution.getDim() == 1 && distribution.getDType() == DType::kFloat); + std::vector probAcc; + + float sumP = 0.0f; + const float *probData = distribution.getData(); + for (int label : topP) { + float p = probData[label]; + sumP += p; + probAcc.push_back(sumP); + } + + float r = _random.nextFloat() * sumP; + for (int i = 0; i < topP.size(); ++i) { + if (r < probAcc[i]) { + return topP[i]; + } + } + return topP.back(); +} + +int Sampler::sample(const Tensor &distribution) { + CHECK(distribution.getDim() == 1 && distribution.getDType() == DType::kFloat); + + std::vector topK = getTopK(distribution); // topK is sorted by its prob in x + std::vector topP = getTopP(distribution, topK); + + return sampleTopP(distribution, topP); +} + +// -----------------------------------------------------------------------------------------------+ +// class BaseGenerator | +// -----------------------------------------------------------------------------------------------+ + +BaseGenerator::BaseGenerator(std::shared_ptr model) + : _model(model), _currentToken(-1) { } -void Generator::forwardPrompt(const std::vector &prompt) { - for (LongType tokenId : prompt) { - LOG(DEBUG) << tokenId << " -> " << _tokenizer->getVocab()->getTokenString(tokenId); +bool BaseGenerator::generate() { + if (_model->isStopToken(_currentToken)) return false; + + if (_currentToken >= 0) { + _currentToken = searchToken(_model->decode(_past, _currentToken)); + } else { + _currentToken = searchToken(_model->prefill(_past, _prompt)); } - Tensor inputs = _model->buildInput(prompt); - Tensor hiddenState = _model->forward(_past, inputs); + LOG(DEBUG) << lut::sprintf( + "%d -> \"%s\"", + _currentToken, + _model->getVocab()->getTokenString(_currentToken)); + if (_model->isStopToken(_currentToken)) return false; - CHECK(hiddenState.getDim() == 3); - Tensor x = hiddenState.slice(1, {-1, None}); - Tensor logits = _model->forwardHidden(x); - _currentToken = sampleToken(logits); + return true; } -const char *Generator::nextToken() { - if (stopped()) return nullptr; +void BaseGenerator::setPrompt(const Prompt &prompt) { + _prompt = prompt; +} - const Vocab *vocab = _tokenizer->getVocab(); - const char *token = vocab->getTokenPiece(_currentToken).c_str(); - LOG(DEBUG) << lut::sprintf("%d -> \"%s\"", _currentToken, vocab->getTokenString(_currentToken)); +std::string BaseGenerator::getToken() { + if (_currentToken < 0) return ""; - std::array inputData{_currentToken}; - Tensor inputs = Tensor::create({1, 1}, inputData); - inputs = F::to(_model->getDevice(), inputs); + const Vocab *vocab = _model->getVocab(); + const char *token = vocab->getTokenPiece(_currentToken).c_str(); + return token; +} - Tensor x = _model->forward(_past, inputs); - Tensor logits = _model->forwardHidden(x); - _currentToken = sampleToken(logits); +std::string BaseGenerator::getTokenName() { + if (_currentToken < 0) return ""; + const Vocab *vocab = _model->getVocab(); + const char *token = vocab->getTokenString(_currentToken).c_str(); return token; } -bool Generator::stopped() const { - return _model->isStopToken(_currentToken) || _currentToken < 0; +// -----------------------------------------------------------------------------------------------+ +// class SamplingGenerator | +// -----------------------------------------------------------------------------------------------+ + +SamplingGenerator::SamplingGenerator( + const GenerationConfig &config, + std::shared_ptr model) + : BaseGenerator(model), + _sampler(config.topK, config.topP), + _temperature(config.temperature) { +} + +std::shared_ptr SamplingGenerator::newGenerator( + const GenerationConfig &config, + std::shared_ptr model) { + std::shared_ptr generator{new SamplingGenerator(config, model)}; + return generator; } -int Generator::sampleToken(const Tensor &logits) { +int SamplingGenerator::searchToken(const Tensor &logits) { CHECK(logits.getDim() == 3 && logits.getShape(0) == 1 && logits.getShape(1) == 1); Tensor x = logits.subtensor(0).subtensor(0); - if (_config.temperature != 1.0f) { - x = F::mul(x, 1.0f / _config.temperature); + if (_temperature != 1.0f) { + x = F::mul(x, 1.0f / _temperature); } + x = F::softmax(x); if (x.getDType() == DType::kFloat16) { x = F::cast(x, DType::kFloat); @@ -95,4 +206,71 @@ int Generator::sampleToken(const Tensor &logits) { return _sampler.sample(x); } +// -----------------------------------------------------------------------------------------------+ +// class WhisperGreedyGenerator | +// -----------------------------------------------------------------------------------------------+ + +WhisperGreedyGenerator::WhisperGreedyGenerator( + const GenerationConfig &config, + std::shared_ptr model) + : BaseGenerator(model), + _temperature(config.temperature) { +} + +std::shared_ptr WhisperGreedyGenerator::newGenerator( + const GenerationConfig &config, + std::shared_ptr model) { + std::shared_ptr generator{new WhisperGreedyGenerator(config, model)}; + std::string modelName = model->getName(); + if (modelName.find("whisper") == std::string::npos) { + throw lut::AbortedError("use WhisperGreedyGenerator for a non-whipser model"); + } + + generator->_whisperLogitsProcessor = whisper::WhisperLogitsProcessor::newProcessor( + model->getVocab()); + + return generator; +} + +void WhisperGreedyGenerator::setPrompt(const Prompt &prompt) { + CHECK(!prompt.empty()); + const PromptBlock &lastBlock = prompt.getBlocks().back(); + if (lastBlock.blockType != PromptBlock::ControlToken || + (lastBlock.text != "<|startoftranscript|>" && lastBlock.text != "<|transcript|>" && + lastBlock.text != "<|translate|>" && lastBlock.text != "<|notimestamps|>")) { + throw lut::AbortedError( + "last token of prompt for whisper should be one of <|startoftranscript|>, <|transcript|>, " + "<|translate|> or <|notimestamps|>"); + } + + _prompt = prompt; +} + +int WhisperGreedyGenerator::searchToken(const Tensor &logits) { + CHECK(logits.getDim() == 3 && logits.getShape(0) == 1 && logits.getShape(1) == 1); + + Tensor x = logits.subtensor(0).subtensor(0); + if (_temperature != 1.0f) { + x = F::mul(x, 1.0f / _temperature); + } + + _whisperLogitsProcessor->processLogits(x); + + x = F::softmax(x); + if (x.getDType() == DType::kFloat16) { + x = F::cast(x, DType::kFloat); + } + if (x.getDevice().getType() == Device::kCuda) { + x = F::to(Device::kCpu, x); + } + + CHECK(x.getDim() == 1 && x.getStride(0) == 1); + const float *data = x.getData(); + const float *best = std::max_element(data, data + x.getShape(0)); + + int tokenId = static_cast(best - data); + _whisperLogitsProcessor->notifyToken(tokenId); + return tokenId; +} + } // namespace libllm diff --git a/src/libllm/generator.h b/src/libllm/generator.h index 67b33813..b92addf0 100644 --- a/src/libllm/generator.h +++ b/src/libllm/generator.h @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -20,8 +20,12 @@ #pragma once #include +#include +#include + +#include "libllm/lut/random.h" #include "libllm/model_for_generation.h" -#include "libllm/sampler.h" +#include "libllm/prompt.h" namespace libllm { @@ -29,33 +33,106 @@ struct GenerationConfig { int topK; float topP; float temperature; + std::unordered_map kvConfig; GenerationConfig(); }; -// LLM text generator +/// @brief Given model and the generation config, generate tokens. class Generator { public: - Generator(GenerationConfig config, - std::shared_ptr model, - std::shared_ptr tokenizer); + enum { Sampling, Whisper }; - void forwardPrompt(const std::vector &prompt); + virtual ~Generator() = default; - // generate the next word (token). Returns nullptr if the generation is finished. - const char *nextToken(); - - bool stopped() const; + /// @brief set the prompt to prefill. + /// @param prompt the prompt; + virtual void setPrompt(const Prompt &prompt) = 0; - private: - GenerationConfig _config; - Sampler _sampler; + /// @brief generate next token. Return false if generation is finished. + /// @return if generation is finished. + virtual bool generate() = 0; + + /// @brief get the piece of current token. + /// @return piece of current token. + virtual std::string getToken() = 0; + + /// @brief get the display name of current token. + /// @return name of current token. + virtual std::string getTokenName() = 0; +}; + +class BaseGenerator : public Generator { + public: + BaseGenerator(std::shared_ptr model); + ~BaseGenerator() = default; + + bool generate() override; + std::string getToken() override; + std::string getTokenName() override; + void setPrompt(const Prompt &prompt) override; + + protected: + Prompt _prompt; StateMap _past; - std::shared_ptr _tokenizer; std::shared_ptr _model; int _currentToken; - int sampleToken(const Tensor &logits); + virtual int searchToken(const Tensor &logits) = 0; +}; + +class Sampler { + public: + Sampler(int topK, float topP); + + int sample(const Tensor &distribution); + + private: + lut::Random _random; + int _topK; + float _topP; + std::vector> _topBuffer; + + std::vector getTopK(const Tensor &distribution); + std::vector getTopP(const Tensor &distribution, lut::Span topK); + int sampleTopP(const Tensor &distribution, lut::Span topP); +}; + +// generator by sampling. +class SamplingGenerator : public BaseGenerator { + public: + static std::shared_ptr newGenerator( + const GenerationConfig &config, + std::shared_ptr model); + ~SamplingGenerator() = default; + + protected: + int searchToken(const Tensor &logits) override; + + private: + Sampler _sampler; + float _temperature; + + SamplingGenerator(const GenerationConfig &config, std::shared_ptr model); +}; + +class WhisperGreedyGenerator : public BaseGenerator { + public: + static std::shared_ptr newGenerator( + const GenerationConfig &config, + std::shared_ptr model); + ~WhisperGreedyGenerator() = default; + + void setPrompt(const Prompt &prompt) override; + + protected: + int searchToken(const Tensor &logits) override; + + private: + float _temperature; + std::shared_ptr _whisperLogitsProcessor; + + WhisperGreedyGenerator(const GenerationConfig &config, std::shared_ptr model); }; } // namespace libllm diff --git a/src/libllm/llama.cc b/src/libllm/llama.cc index bf2d1689..32904efe 100644 --- a/src/libllm/llama.cc +++ b/src/libllm/llama.cc @@ -385,11 +385,15 @@ Tensor LlamaModel::forward(StateMap &past, Tensor input) const { return x; } -Tensor LlamaModel::forwardHidden(Tensor hidden) const { +Tensor LlamaModel::forwardLmHead(Tensor hidden) const { Tensor logits = _outProj->forward(hidden); return logits; } +int LlamaModel::getOutputDim() const { + return _config.vocabSize; +} + // -----------------------------------------------------------------------------------------------+ // class LlamaModelForGeneration | // -----------------------------------------------------------------------------------------------+ @@ -398,43 +402,65 @@ LlamaModelForGeneration::LlamaModelForGeneration() : _eotId(0) { } -std::shared_ptr LlamaModelForGeneration::fromConfig( +std::shared_ptr LlamaModelForGeneration::fromPackage( const Context &ctx, - const lut::IniConfig &config) { + lut::ZipFile *package) { + std::shared_ptr reader = package->open(ModelConfig); + std::shared_ptr ini = lut::IniConfig::fromStream(reader.get()); + + std::string modelFile = ini->getSection(ModelSection).getString(ModelFileField); + std::string modelType = ini->getSection(ModelSection).getString(ModelTypeField); + + const lut::IniSection &llamaIni = ini->getSection(modelType); + std::shared_ptr model{new LlamaModelForGeneration()}; - model->init(ctx, config); + LlamaConfig llamaConfig = LlamaConfig::loadConfig(llamaIni); + + StateMap stateMap; + stateMap.read(package->open(modelFile).get()); + model->_model = LlamaModel::create(ctx, llamaConfig); + model->_model->initParameters(stateMap); + model->_eotId = llamaIni.getInt("eot_token_id"); + model->_modelName = modelType; + + model->initTokenizer(package); return model; } -void LlamaModelForGeneration::init(const Context &ctx, const lut::IniConfig &config) { - std::string modelType = config.getSection(ModelSection).getString(ModelTypeField); - const lut::IniSection &llamaSection = config.getSection(modelType); +Tensor LlamaModelForGeneration::prefill(StateMap &past, const Prompt &prompt) const { + Tensor x = _model->forward(past, buildInput(prompt)); + CHECK(x.getDim() == 3); - // create model - LlamaConfig llamaConfig = LlamaConfig::loadConfig(llamaSection); - _model = LlamaModel::create(ctx, llamaConfig); + x = x.slice(1, {-1, None}); + x = _model->forwardLmHead(x); - _eotId = llamaSection.getInt("eot_token_id"); - _modelName = modelType; + return x; } -void LlamaModelForGeneration::initParameters(const StateMap &stateDict) { - _model->initParameters(stateDict); -} +Tensor LlamaModelForGeneration::decode(StateMap &past, LongType inputToken) const { + std::array inputData{inputToken}; + Tensor inputs = Tensor::create({1, 1}, inputData); + inputs = F::to(getDevice(), inputs); -Tensor LlamaModelForGeneration::forward(StateMap &past, Tensor input) const { - Tensor x = _model->forward(past, input); - return x; -} + Tensor x = _model->forward(past, inputs); + x = _model->forwardLmHead(x); -Tensor LlamaModelForGeneration::forwardHidden(Tensor hidden) const { - return _model->forwardHidden(hidden); + return x; } -Tensor LlamaModelForGeneration::buildInput(const std::vector &prompt) const { +Tensor LlamaModelForGeneration::buildInput(const Prompt &prompt) const { std::vector inputData{}; - inputData.insert(inputData.end(), prompt.begin(), prompt.end()); + for (const PromptBlock &block : prompt.getBlocks()) { + if (block.blockType == PromptBlock::ControlToken || block.blockType == PromptBlock::Text) { + encodePromptBlock(block, inputData); + } else { + throw lut::AbortedError(lut::sprintf( + "unexpected prompt type %s for model %s", + PromptBlock::typeToString(block.blockType), + _modelName)); + } + } int len = inputData.size(); Tensor inputs = Tensor::create({1, len}, inputData); @@ -454,5 +480,9 @@ Device LlamaModelForGeneration::getDevice() const { return _model->getCtx().getDevice(); } +int LlamaModelForGeneration::getOutputDim() const { + return _model->getOutputDim(); +} + } // namespace llama } // namespace libllm diff --git a/src/libllm/llama.h b/src/libllm/llama.h index 2f657acb..488ed616 100644 --- a/src/libllm/llama.h +++ b/src/libllm/llama.h @@ -123,7 +123,8 @@ class LlamaModel : public Module { void initParameters(lut::Random *generator, DType weightType) override; Tensor forward(StateMap &past, Tensor input) const; - Tensor forwardHidden(Tensor hidden) const; + Tensor forwardLmHead(Tensor hidden) const; + int getOutputDim() const; private: LlamaConfig _config; @@ -137,19 +138,17 @@ class LlamaModel : public Module { class LlamaModelForGeneration : public ModelForGeneration { public: - static std::shared_ptr fromConfig( + static std::shared_ptr fromPackage( const Context &ctx, - const lut::IniConfig &config); + lut::ZipFile *package); - // implements interface ModelForGeneration - void initParameters(const StateMap &stateDict) override; + Tensor prefill(StateMap &past, const Prompt &prompt) const override; + Tensor decode(StateMap &past, LongType inputToken) const override; - Tensor forward(StateMap &past, Tensor input) const override; - Tensor forwardHidden(Tensor hidden) const override; - Tensor buildInput(const std::vector &prompt) const override; bool isStopToken(int tokenId) const override; const char *getName() const override; Device getDevice() const override; + int getOutputDim() const override; protected: std::shared_ptr _model; @@ -157,7 +156,7 @@ class LlamaModelForGeneration : public ModelForGeneration { int _eotId; LlamaModelForGeneration(); - void init(const Context &ctx, const lut::IniConfig &config); + Tensor buildInput(const Prompt &prompt) const; }; } // namespace llama diff --git a/src/libllm/c_api.cc b/src/libllm/llm.cc similarity index 65% rename from src/libllm/c_api.cc rename to src/libllm/llm.cc index 6389a1a5..afef38d3 100644 --- a/src/libllm/c_api.cc +++ b/src/libllm/llm.cc @@ -1,22 +1,24 @@ +#include "libllm/llm.h" + #include #include #include #include #include +#include -#include "libllm/chatglm.h" #include "libllm/context.h" #include "libllm/dtype.h" #include "libllm/functional.h" #include "libllm/generator.h" -#include "libllm/llm.h" #include "libllm/lut/error.h" #include "libllm/lut/ini_config.h" #include "libllm/lut/log.h" #include "libllm/lut/zip_file.h" #include "libllm/model_for_generation.h" #include "libllm/operators.h" +#include "libllm/prompt.h" #include "libllm/tokenizer.h" using libllm::Context; @@ -24,11 +26,14 @@ using libllm::GenerationConfig; using libllm::Generator; using libllm::LongType; using libllm::ModelForGeneration; +using libllm::Prompt; using libllm::Tokenizer; -using libllm::chatglm::ChatGlmConfig; -using libllm::chatglm::ChatGlmModel; using lut::IniConfig; +constexpr char LlmConfigKey_GeneratorType[] = "generator.type"; +constexpr char LlmConfigValue_Sampler[] = "sampler"; +constexpr char LlmConfigValue_Whisper[] = "whisper"; + struct llmModel_t { Context ctx; std::shared_ptr model_for_generation; @@ -41,10 +46,12 @@ struct llmCompletion_t { int top_k; float top_p; float temperature; - std::vector prompt; + std::shared_ptr prompt; std::weak_ptr model_for_generation; - std::weak_ptr tokenizer; std::shared_ptr generator; + lut::Error error; + std::string chunkText; + std::unordered_map kvConfig; }; struct llmChunk_t { @@ -52,8 +59,7 @@ struct llmChunk_t { }; struct llmPrompt_t { - std::weak_ptr tokenizer; - std::vector inputs; + std::shared_ptr prompt; }; namespace libllm { @@ -111,6 +117,16 @@ Device getDeviceFromApi(int apiDevice) { } } +int parseGeneratorType(const std::string &name) { + if (name == LlmConfigValue_Sampler) { + return Generator::Sampling; + } else if (name == LlmConfigValue_Whisper) { + return Generator::Whisper; + } else { + throw lut::AbortedError("invalid generator type: " + name); + } +} + } // namespace api } // namespace libllm @@ -208,14 +224,11 @@ const char *llmModel_GetName(llmModel_t *model) { nullptr); } -llmPrompt_t *llmPrompt_New(llmModel_t *model) { +llmPrompt_t *llmPrompt_New() { return runAndCatch( - [model]() { - if (!model) throw lut::InvalidArgError("model"); - if (!model->tokenizer) throw lut::InvalidArgError("model not initialized"); - + []() { llmPrompt_t *prompt = new llmPrompt_t(); - prompt->tokenizer = model->tokenizer; + prompt->prompt = std::make_shared(); return prompt; }, nullptr); @@ -231,14 +244,7 @@ llmStatus_t llmPrompt_AppendText(llmPrompt_t *prompt, const char *text) { if (!prompt) throw lut::InvalidArgError("prompt"); if (!text) throw lut::InvalidArgError("text"); - std::shared_ptr tokenizer = prompt->tokenizer.lock(); - if (!tokenizer) throw lut::AbortedError("tokenizer expired."); - - std::vector inputIds = tokenizer->encode(text); - for (int tokenId : inputIds) { - prompt->inputs.push_back(tokenId); - } - + prompt->prompt->appendText(text); return LLM_OK; }); } @@ -248,13 +254,26 @@ llmStatus_t llmPrompt_AppendControlToken(llmPrompt_t *prompt, const char *name) if (!prompt) throw lut::InvalidArgError("prompt"); if (!name) throw lut::InvalidArgError("name"); - std::shared_ptr tokenizer = prompt->tokenizer.lock(); - if (!tokenizer) throw lut::AbortedError("tokenizer expired."); - - int tokenId = tokenizer->getVocab()->findControlToken(name); - prompt->inputs.push_back(tokenId); - LOG(DEBUG) << "control token " << name << " -> " << tokenId; + prompt->prompt->appendControlToken(name); + return LLM_OK; + }); +} +llmStatus_t llmPrompt_AppendAudio( + llmPrompt_t *prompt, + const llmByte_t *audio, + int64_t size, + int32_t format) { + return runAndCatch([prompt, audio, size, format]() { + if (!prompt) throw lut::InvalidArgError("prompt"); + if (!audio) throw lut::InvalidArgError("audio"); + if (size <= 0 || size > 1024 * 1024 * 1024) + throw lut::AbortedError("invalid size, [1, 1G) expected"); + if (format != LLM_WAVE_FORMAT_PCM16KHZ16BITMONO) throw lut::AbortedError("invalid format"); + + prompt->prompt->appendWave( + lut::Span(reinterpret_cast(audio), size), + WaveFormat::Wave16kHz16bitMonoPCM); return LLM_OK; }); } @@ -267,7 +286,6 @@ llmCompletion_t *llmCompletion_New(llmModel_t *model) { std::unique_ptr comp = std::make_unique(); comp->model_for_generation = model->model_for_generation; - comp->tokenizer = model->tokenizer; comp->temperature = 1.0f; comp->top_k = 50; comp->top_p = 0.8f; @@ -282,14 +300,25 @@ llmStatus_t llmCompletion_Delete(llmCompletion_t *comp) { return LLM_OK; } +llmStatus_t llmCompletion_SetConfig(llmCompletion_t *comp, const char *key, const char *value) { + return runAndCatch([comp, key, value]() { + if (!comp) throw lut::InvalidArgError("comp"); + if (!key) throw lut::InvalidArgError("key"); + if (!value) throw lut::InvalidArgError("value"); + + comp->kvConfig[key] = value; + return LLM_OK; + }); +} + llmStatus_t llmCompletion_SetPrompt(llmCompletion_t *comp, llmPrompt_t *prompt) { return runAndCatch([comp, prompt]() { if (!comp) throw lut::InvalidArgError("comp"); if (!prompt) throw lut::InvalidArgError("prompt"); if (comp->generator) throw lut::InvalidArgError("completion already started"); - if (prompt->inputs.empty()) throw lut::InvalidArgError("prompt is empty"); + if (prompt->prompt->empty()) throw lut::InvalidArgError("prompt is empty"); - comp->prompt = prompt->inputs; + comp->prompt = prompt->prompt; return LLM_OK; }); } @@ -324,68 +353,93 @@ llmStatus_t llmCompletion_SetTemperature(llmCompletion_t *comp, float temperatur }); } -llmStatus_t llmCompletion_Start(llmCompletion_t *comp) { - return runAndCatch([comp]() { +llmBool_t llmCompletion_Next(llmCompletion_t *comp) { + try { if (!comp) throw lut::InvalidArgError("comp"); - if (comp->generator) throw lut::InvalidArgError("completion already started"); - if (comp->prompt.empty()) throw lut::InvalidArgError("prompt is empty"); + if (comp->prompt->empty()) throw lut::InvalidArgError("prompt is empty"); - std::shared_ptr model = comp->model_for_generation.lock(); - std::shared_ptr tokenizer = comp->tokenizer.lock(); + if (comp->error.getCode() != lut::ErrorCode::OK) { + return LLM_FALSE; + } - if (!model) throw lut::InvalidArgError("model had been destroyed"); - if (!tokenizer) throw lut::InvalidArgError("tokenizer had been destroyed"); + if (!comp->generator) { + // prefill + std::shared_ptr model = comp->model_for_generation.lock(); + if (!model) throw lut::InvalidArgError("model had been destroyed"); + + GenerationConfig config; + config.temperature = comp->temperature; + config.topK = comp->top_k; + config.topP = comp->top_p; + + int generatorType = Generator::Sampling; + for (const auto &kv : comp->kvConfig) { + if (kv.first == LlmConfigKey_GeneratorType) { + generatorType = parseGeneratorType(kv.second); + } else { + throw lut::AbortedError("invalid configuration key: " + kv.first); + } + } - GenerationConfig config; - config.temperature = comp->temperature; - config.topK = comp->top_k; - config.topP = comp->top_p; + if (generatorType == Generator::Sampling) { + comp->generator = SamplingGenerator::newGenerator(config, model); + } else if (generatorType == Generator::Whisper) { + comp->generator = WhisperGreedyGenerator::newGenerator(config, model); + } else { + NOT_IMPL(); + } - comp->generator = std::make_shared(config, model, tokenizer); - comp->generator->forwardPrompt(comp->prompt); - return LLM_OK; - }); -} + comp->generator->setPrompt(*comp->prompt); + } -llmBool_t llmCompletion_IsActive(llmCompletion_t *comp) { - return runAndCatch( - [comp]() { - if (!comp) throw lut::InvalidArgError("comp"); - if (!comp->generator) throw lut::InvalidArgError("completion not started"); + bool ok = comp->generator->generate(); + if (ok) { + return LLM_TRUE; + } else { + return LLM_FALSE; + } - return !comp->generator->stopped(); - }, - false); + } catch (const lut::Error &e) { + if (comp) comp->error = e; + return LLM_FALSE; + } } -llmStatus_t llmCompletion_GenerateNextChunk(llmCompletion_t *comp, llmChunk_t *chunk) { - return runAndCatch([comp, chunk]() { - if (!comp) throw lut::InvalidArgError("comp"); - if (!comp->generator) throw lut::InvalidArgError("completion not started"); - if (comp->generator->stopped()) throw lut::AbortedError("completion stopped"); - - const char *token = comp->generator->nextToken(); - if (!token) throw lut::AbortedError("unexpected empty token"); +llmStatus_t llmCompletion_GetError(llmCompletion_t *comp) { + if (!comp) { + lut::Error err = lut::InvalidArgError("comp"); + setErrorCodeAndMessage(err); + return static_cast(err.getCode()); + } - chunk->text = token; + if (comp->error.getCode() == lut::ErrorCode::OK) { return LLM_OK; - }); + } else { + setErrorCodeAndMessage(comp->error); + return static_cast(comp->error.getCode()); + } } -llmChunk_t *llmChunk_New() { - return new llmChunk_t(); -} +const char *llmCompletion_GetText(llmCompletion_t *comp) { + return runAndCatch( + [comp]() { + if (!comp) throw lut::InvalidArgError("comp"); + if (!comp->generator) throw lut::InvalidArgError("completion not started"); -llmStatus_t llmChunk_Delete(llmChunk_t *chunk) { - delete chunk; - return LLM_OK; + comp->chunkText = comp->generator->getToken(); + return comp->chunkText.c_str(); + }, + nullptr); } -const char *llmChunk_GetText(llmChunk_t *chunk) { +const char *llmCompletion_GetToken(llmCompletion_t *comp) { return runAndCatch( - [chunk]() { - if (!chunk) throw lut::InvalidArgError("chunk"); - return chunk->text.c_str(); + [comp]() { + if (!comp) throw lut::InvalidArgError("comp"); + if (!comp->generator) throw lut::InvalidArgError("completion not started"); + + comp->chunkText = comp->generator->getTokenName(); + return comp->chunkText.c_str(); }, nullptr); } diff --git a/src/libllm/llm.h b/src/libllm/llm.h index 6a7d1772..af5752d0 100644 --- a/src/libllm/llm.h +++ b/src/libllm/llm.h @@ -42,7 +42,10 @@ extern "C" { #define LLM_DEVICE_CPU 0x0000 #define LLM_DEVICE_CUDA 0x0100 #define LLM_DEVICE_AUTO 0x1f00 +#define LLM_WAVE_FORMAT_PCM16KHZ16BITMONO 0x0001 #define LLM_API_VERSION 20240101 +#define LLM_TRUE 1 +#define LLM_FALSE 0 #define LLM_OK 0 typedef int32_t llmStatus_t; @@ -50,7 +53,9 @@ typedef struct llmModel_t llmModel_t; typedef struct llmChunk_t llmChunk_t; typedef struct llmPrompt_t llmPrompt_t; typedef struct llmCompletion_t llmCompletion_t; +typedef struct llmLogitsFilter_t llmLogitsFilter_t; typedef int32_t llmBool_t; +typedef int8_t llmByte_t; // global state LLMAPI llmStatus_t llmInit(int32_t apiVersion); @@ -66,8 +71,15 @@ LLMAPI llmStatus_t llmModel_Load(llmModel_t *model); LLMAPI const char *llmModel_GetName(llmModel_t *model); // llmPrompt_t -LLMAPI llmPrompt_t *llmPrompt_New(llmModel_t *model); +LLMAPI llmPrompt_t *llmPrompt_New(); LLMAPI llmStatus_t llmPrompt_Delete(llmPrompt_t *prompt); +LLMAPI +llmStatus_t llmPrompt_AppendAudio( + llmPrompt_t *prompt, + const llmByte_t *audio, + int64_t size, + int32_t format); + LLMAPI llmStatus_t llmPrompt_AppendText(llmPrompt_t *prompt, const char *text); LLMAPI llmStatus_t llmPrompt_AppendControlToken(llmPrompt_t *prompt, const char *token); @@ -77,15 +89,21 @@ LLMAPI llmStatus_t llmCompletion_Delete(llmCompletion_t *comp); LLMAPI llmStatus_t llmCompletion_SetPrompt(llmCompletion_t *comp, llmPrompt_t *prompt); LLMAPI llmStatus_t llmCompletion_SetTopP(llmCompletion_t *comp, float topP); LLMAPI llmStatus_t llmCompletion_SetTopK(llmCompletion_t *comp, int32_t topK); +LLMAPI +llmStatus_t llmCompletion_SetConfig(llmCompletion_t *comp, const char *key, const char *value); LLMAPI llmStatus_t llmCompletion_SetTemperature(llmCompletion_t *comp, float temperature); -LLMAPI llmStatus_t llmCompletion_Start(llmCompletion_t *comp); -LLMAPI llmBool_t llmCompletion_IsActive(llmCompletion_t *comp); -LLMAPI llmStatus_t llmCompletion_GenerateNextChunk(llmCompletion_t *comp, llmChunk_t *chunk); +LLMAPI llmBool_t llmCompletion_Next(llmCompletion_t *comp); +LLMAPI llmStatus_t llmCompletion_GetError(llmCompletion_t *comp); +LLMAPI const char *llmCompletion_GetText(llmCompletion_t *comp); -// llmChunk_t -LLMAPI llmChunk_t *llmChunk_New(); -LLMAPI llmStatus_t llmChunk_Delete(llmChunk_t *chunk); -LLMAPI const char *llmChunk_GetText(llmChunk_t *chunk); +/// @brief Get the name of last generated token. +/// For a normal token the llmCompletion_GetText() will return its byte piece, for example, "foo "; +/// llmCompletion_GetToken() will return its name in model, for example, "hello_". +/// For a control token, llmCompletion_GetText() will return an empty string ("") and +/// llmCompletion_GetToken() will return its name, for example, "<|endoftext|>". +/// @param comp the llmCompletion_t. +/// @return name of the token. +LLMAPI const char *llmCompletion_GetToken(llmCompletion_t *comp); #ifdef __cplusplus } // extern "C" diff --git a/src/libllm/llm_cpp.h b/src/libllm/llm_cpp.h deleted file mode 100644 index 4a19f52a..00000000 --- a/src/libllm/llm_cpp.h +++ /dev/null @@ -1,280 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2023 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -// C++ wrapper for libllm C API. - -#pragma once - -#include -#include -#include -#include - -#include "llm.h" - -namespace llm { - -class Model; -class ModelFactory; -class Completion; - -enum class DeviceType { CPU = LLM_DEVICE_CPU, CUDA = LLM_DEVICE_CUDA, AUTO = LLM_DEVICE_AUTO }; - -// configuration for LLM completion task. -class CompletionConfig { - public: - CompletionConfig() - : _topP(0.8f), - _topK(50), - _temperature(1.0f) { - } - - // setters for the config. - void setTopP(float topP) { - _topP = topP; - } - void setTopK(int topK) { - _topK = topK; - } - void setTemperature(float temperature) { - _temperature = temperature; - } - - // getters for the config. - float getTopP() const { - return _topP; - } - int getTopK() const { - return _topK; - } - float getTemperature() const { - return _temperature; - } - - private: - float _topP; - int _topK; - float _temperature; -}; - -class Chunk { - public: - friend class Completion; - - std::string getText() const { - return _text; - } - - private: - std::string _text; -}; - -/// @brief Store the state of ongoing completion task. -class Completion { - public: - friend class Model; - - /// @brief If completion is ongoing (active) returns true, if stopped returns false. - /// @return If completion is active. - bool isActive(); - - /// @brief Get the next chunk of tokens generated by the model. - /// @return instance of Chunk. - Chunk nextChunk(); - - private: - std::shared_ptr _handle; - std::shared_ptr _chunkHandle; - std::shared_ptr _model; - - Completion() { - } -}; - -/// @brief Input prompt for Model::comeplete(). -class Prompt { - public: - Prompt(Prompt &) = delete; - Prompt &operator=(Prompt &) = delete; - - /// @brief Create a prompt from model. - static std::shared_ptr fromModel(std::shared_ptr model); - - /// @brief Append text to the prompt. - /// @param text text to append. - void appendText(const std::string &text); - - /// @brief Append a control token to the prompt. - /// @param text name of the control token. - void appendControlToken(const std::string &text); - - /// @brief Get internal handle for the prompt. - /// @return A pointer of llmPrompt_t. - llmPrompt_t *getHandle() { - return _handle.get(); - } - - private: - std::shared_ptr _handle; - std::shared_ptr _model; - - Prompt() = default; -}; - -/// @brief Stores an instance of LLM Model. -class Model : public std::enable_shared_from_this { - public: - Model(Model &) = delete; - Model &operator=(Model &) = delete; - - /// @brief Create an instance of Model from the package file; - /// @param configFile config file of the model. - /// @param device device of the model storage and computation device. Use DeviceType::AUTO to - /// let libllm determine the best one. - /// @return A shared pointer of the Model instance. - static std::shared_ptr fromFile( - const std::string &filename, - DeviceType device = DeviceType::AUTO); - - /// @brief Get the name of model, for example, "llama". - /// @return name of the model. - std::string getName(); - - /// @brief Complete the string version of given `prompt` with LLM. - /// @param prompt The prompt to complete. - /// @param config The config for completion. - /// @return A `Completion` object. - std::shared_ptr complete( - std::shared_ptr prompt, - CompletionConfig config = CompletionConfig()); - - /// @brief Get internal handle for the model. - /// @return A pointer of llmModel_t. - llmModel_t *getHandle() { - return _handle.get(); - } - - private: - std::shared_ptr _handle; - - Model() = default; -}; - -// -- Implementation of libLLM C++ API (wrapper for C api) ---------------------------------------- - -namespace internal { - -inline void throwLastError() { - std::string lastError = llmGetLastErrorMessage(); - throw std::runtime_error(lastError); -} - -} // namespace internal - -// -- Completion ---------- - -inline bool Completion::isActive() { - return llmCompletion_IsActive(_handle.get()) != 0; -} - -inline Chunk Completion::nextChunk() { - if (LLM_OK != llmCompletion_GenerateNextChunk(_handle.get(), _chunkHandle.get())) { - internal::throwLastError(); - } - - const char *text = llmChunk_GetText(_chunkHandle.get()); - if (!text) internal::throwLastError(); - - Chunk c; - c._text = text; - return c; -} - -// -- Prompt ---------- - -inline std::shared_ptr Prompt::fromModel(std::shared_ptr model) { - std::shared_ptr pPrompt(llmPrompt_New(model->getHandle()), llmPrompt_Delete); - if (!pPrompt) internal::throwLastError(); - - std::shared_ptr prompt{new Prompt()}; - prompt->_model = model; - prompt->_handle = pPrompt; - - return prompt; -} - -inline void Prompt::appendText(const std::string &text) { - if (LLM_OK != llmPrompt_AppendText(_handle.get(), text.c_str())) { - internal::throwLastError(); - } -} - -inline void Prompt::appendControlToken(const std::string &name) { - if (LLM_OK != llmPrompt_AppendControlToken(_handle.get(), name.c_str())) { - internal::throwLastError(); - } -} - -// -- Model ---------- - -inline std::shared_ptr Model::fromFile(const std::string &config, DeviceType device) { - std::shared_ptr pModel(llmModel_New(), llmModel_Delete); - int32_t dwDevice = static_cast(device); - if (LLM_OK != llmModel_SetFile(pModel.get(), config.c_str())) internal::throwLastError(); - if (LLM_OK != llmModel_SetDevice(pModel.get(), dwDevice)) internal::throwLastError(); - if (LLM_OK != llmModel_Load(pModel.get())) internal::throwLastError(); - - std::shared_ptr model{new Model()}; - model->_handle = pModel; - - return model; -} - -inline std::string Model::getName() { - const char *name = llmModel_GetName(_handle.get()); - if (!name) internal::throwLastError(); - - return name; -} - -inline std::shared_ptr Model::complete( - std::shared_ptr prompt, - CompletionConfig config) { - std::shared_ptr pComp(llmCompletion_New(_handle.get()), llmCompletion_Delete); - if (!pComp) internal::throwLastError(); - - if (LLM_OK != llmCompletion_SetPrompt(pComp.get(), prompt->getHandle())) - internal::throwLastError(); - if (LLM_OK != llmCompletion_SetTopK(pComp.get(), config.getTopK())) internal::throwLastError(); - if (LLM_OK != llmCompletion_SetTopP(pComp.get(), config.getTopP())) internal::throwLastError(); - if (LLM_OK != llmCompletion_SetTemperature(pComp.get(), config.getTemperature())) { - internal::throwLastError(); - } - if (LLM_OK != llmCompletion_Start(pComp.get())) internal::throwLastError(); - - std::shared_ptr pChunk(llmChunk_New(), llmChunk_Delete); - std::shared_ptr comp{new Completion()}; - comp->_model = shared_from_this(); - comp->_chunkHandle = pChunk; - comp->_handle = pComp; - - return comp; -} - -} // namespace llm \ No newline at end of file diff --git a/src/libllm/llm_main.cc b/src/libllm/llm_main.cc deleted file mode 100644 index b95dd0d6..00000000 --- a/src/libllm/llm_main.cc +++ /dev/null @@ -1,104 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2023 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#include - -#include -#include - -#include "libllm/dialog_manager.h" -#include "libllm/llm.h" -#include "libllm/lut/error.h" -#include "libllm/lut/flags.h" -#include "libllm/lut/strings.h" -#include "libllm/lut/time.h" - -using libllm::ChatOutput; -using libllm::DialogManager; -using libllm::PromptBulder; - -/// @brief Print the chat statistical data. -/// @param chatOutput Output from chat. -void printChatStat(const ChatOutput &chatOutput) { - double msPerToken = 1000 * chatOutput.answerDuration / chatOutput.numAnswerTokens; - std::cout << std::endl - << lut::sprintf( - "(%d token, time=%.2fs, %.2fms per token)", - chatOutput.numAnswerTokens, - chatOutput.answerDuration, - msPerToken) - << std::endl; -} - -int main(int argc, char **argv) { - std::string configPath; - std::string deviceType = "auto"; - - const char *usage = - "Command line interface for libllm.\n" - "Usage: llm -m [-d (cpu|gpu|cuda)]"; - - lut::Flags flags(usage); - flags.define("-m", &configPath, "filename of libllm config file."); - flags.define("-d", &deviceType, "device of the model. (cpu|cuda|auto)"); - flags.parse(argc, argv); - - if (configPath.empty()) { - flags.printUsage(); - return 1; - } - - llm::DeviceType device = llm::DeviceType::AUTO; - if (deviceType == "auto") - device = llm::DeviceType::AUTO; - else if (deviceType == "cuda") - device = llm::DeviceType::CUDA; - else if (deviceType == "cpu") - device = llm::DeviceType::CPU; - else { - printf("invali·d device"); - return 1; - } - - if (llmInit(LLM_API_VERSION) != LLM_OK) { - printf("init libllm failed: %s\n", llmGetLastErrorMessage()); - return 1; - } - - std::shared_ptr model = llm::Model::fromFile(configPath, device); - std::shared_ptr promptBuilder = PromptBulder::create(model->getName()); - DialogManager dialogManager(model, promptBuilder); - - std::cout << "> "; - - std::string query; - while (std::getline(std::cin, query)) { - if (lut::trim(query) == "") continue; - - ChatOutput chatOutput = dialogManager.chat(query, [](const std::string &token) { - std::cout << token; - std::cout.flush(); - }); - - printChatStat(chatOutput); - std::cout << "> "; - } - - return 0; -} diff --git a/src/libllm/lut/base64.cc b/src/libllm/lut/base64.cc new file mode 100644 index 00000000..846fee4b --- /dev/null +++ b/src/libllm/lut/base64.cc @@ -0,0 +1,169 @@ +/* + * Base64 encoding/decoding (RFC1341) + * Copyright (c) 2005, Jouni Malinen + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation. + * + * Alternatively, this software may be distributed under the terms of BSD + * license. + * + * See README and COPYING for more details. + */ + +// original file: +// https://android.googlesource.com/platform/external/wpa_supplicant/+/4d8c3c1ca334d1319decf3e2c5d2be0cf472e3f9/base64.c + +#include +#include +#include + +#include +#include +#include + +#include "libllm/lut/span.h" + +static const unsigned char + base64_table[65] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +namespace lut { +namespace internal { + +/** + * base64_encode - Base64 encode + * @src: Data to be encoded + * @len: Length of the data to be encoded + * @out_len: Pointer to output length variable, or %NULL if not used + * Returns: Allocated buffer of out_len bytes of encoded data, + * or %NULL on failure + * + * Caller is responsible for freeing the returned buffer. Returned buffer is + * nul terminated to make it easier to use as a C string. The nul terminator is + * not included in out_len. + */ +unsigned char *base64_encode(const unsigned char *src, size_t len, size_t *out_len) { + unsigned char *out, *pos; + const unsigned char *end, *in; + size_t olen; + int line_len; + olen = len * 4 / 3 + 4; /* 3-byte blocks to 4-byte */ + olen += olen / 72; /* line feeds */ + olen++; /* nul termination */ + out = (unsigned char *)malloc(olen); + if (out == nullptr) return nullptr; + end = src + len; + in = src; + pos = out; + line_len = 0; + while (end - in >= 3) { + *pos++ = base64_table[in[0] >> 2]; + *pos++ = base64_table[((in[0] & 0x03) << 4) | (in[1] >> 4)]; + *pos++ = base64_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)]; + *pos++ = base64_table[in[2] & 0x3f]; + in += 3; + line_len += 4; + if (line_len >= 72) { + *pos++ = '\n'; + line_len = 0; + } + } + if (end - in) { + *pos++ = base64_table[in[0] >> 2]; + if (end - in == 1) { + *pos++ = base64_table[(in[0] & 0x03) << 4]; + *pos++ = '='; + } else { + *pos++ = base64_table[((in[0] & 0x03) << 4) | (in[1] >> 4)]; + *pos++ = base64_table[(in[1] & 0x0f) << 2]; + } + *pos++ = '='; + line_len += 4; + } + if (line_len) *pos++ = '\n'; + *pos = '\0'; + if (out_len) *out_len = pos - out; + return out; +} +/** + * base64_decode - Base64 decode + * @src: Data to be decoded + * @len: Length of the data to be decoded + * @out_len: Pointer to output length variable + * Returns: Allocated buffer of out_len bytes of decoded data, + * or %NULL on failure + * + * Caller is responsible for freeing the returned buffer. + */ +unsigned char *base64_decode(const unsigned char *src, size_t len, size_t *out_len) { + unsigned char dtable[256], *out, *pos, in[4], block[4], tmp; + size_t i, count, olen; + memset(dtable, 0x80, 256); + for (i = 0; i < sizeof(base64_table) - 1; i++) dtable[base64_table[i]] = (unsigned char)i; + dtable['='] = 0; + count = 0; + for (i = 0; i < len; i++) { + if (dtable[src[i]] != 0x80) count++; + } + if (count == 0 || count % 4) return nullptr; + olen = count / 4 * 3; + pos = out = (unsigned char *)malloc(olen); + if (out == nullptr) return nullptr; + count = 0; + for (i = 0; i < len; i++) { + tmp = dtable[src[i]]; + if (tmp == 0x80) continue; + in[count] = src[i]; + block[count] = tmp; + count++; + if (count == 4) { + *pos++ = (block[0] << 2) | (block[1] >> 4); + *pos++ = (block[1] << 4) | (block[2] >> 2); + *pos++ = (block[2] << 6) | block[3]; + count = 0; + } + } + if (pos > out) { + if (in[2] == '=') + pos -= 2; + else if (in[3] == '=') + pos--; + } + *out_len = pos - out; + return out; +} + +} // namespace internal +} // namespace lut + +namespace lut { + +std::vector decodeBase64(const std::string &base64String) { + size_t outlen; + int8_t *pdata = reinterpret_cast(internal::base64_decode( + reinterpret_cast(base64String.c_str()), + base64String.size(), + &outlen)); + + std::vector output(outlen); + std::copy(pdata, pdata + outlen, output.begin()); + + free(pdata); + return output; +} + +std::string encodeBase64(lut::Span data) { + size_t outlen; + char *pstring = reinterpret_cast(internal::base64_encode( + reinterpret_cast(data.data()), + data.size(), + &outlen)); + + std::string s(pstring, outlen); + + free(pstring); + return s; +} + +} // namespace lut \ No newline at end of file diff --git a/src/libllm/lut/base64.h b/src/libllm/lut/base64.h new file mode 100644 index 00000000..4cec405b --- /dev/null +++ b/src/libllm/lut/base64.h @@ -0,0 +1,34 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#pragma once + +#include + +#include +#include + +#include "libllm/lut/span.h" + +namespace lut { + +std::vector decodeBase64(const std::string &base64String); +std::string encodeBase64(lut::Span data); + +} // namespace lut diff --git a/src/libllm/lut/error.cc b/src/libllm/lut/error.cc index 6e328d04..22baa452 100644 --- a/src/libllm/lut/error.cc +++ b/src/libllm/lut/error.cc @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -39,9 +39,15 @@ std::string getErrorCodeName(ErrorCode code) { std::string buildErrorMsg(ErrorCode code, const std::string &what) { return getErrorCodeName(code) + ": " + what; } - -Error::Error(ErrorCode code, const std::string &what) : _code(code), _what(what) {} -Error::~Error() {} +Error::Error() + : _code(ErrorCode::OK) { +} +Error::Error(ErrorCode code, const std::string &what) + : _code(code), + _what(what) { +} +Error::~Error() { +} ErrorCode Error::getCode() const { return _code; @@ -52,15 +58,19 @@ const char *Error::what() const noexcept { } AbortedError::AbortedError(const std::string &what) - : Error(ErrorCode::Aborted, buildErrorMsg(ErrorCode::Aborted, what)){} + : Error(ErrorCode::Aborted, buildErrorMsg(ErrorCode::Aborted, what)) { +} OutOfRangeError::OutOfRangeError(const std::string &what) - : Error(ErrorCode::OutOfRange, buildErrorMsg(ErrorCode::OutOfRange, what)) {} + : Error(ErrorCode::OutOfRange, buildErrorMsg(ErrorCode::OutOfRange, what)) { +} InvalidArgError::InvalidArgError(const std::string &what) - : Error(ErrorCode::InvalidArg, buildErrorMsg(ErrorCode::InvalidArg, what)) {} + : Error(ErrorCode::InvalidArg, buildErrorMsg(ErrorCode::InvalidArg, what)) { +} NotImplementedError::NotImplementedError(const std::string &what) - : Error(ErrorCode::InvalidArg, buildErrorMsg(ErrorCode::NotImplemented, what)) {} + : Error(ErrorCode::InvalidArg, buildErrorMsg(ErrorCode::NotImplemented, what)) { +} -} // namespace lut +} // namespace lut diff --git a/src/libllm/lut/error.h b/src/libllm/lut/error.h index 05085475..0029497f 100644 --- a/src/libllm/lut/error.h +++ b/src/libllm/lut/error.h @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -22,11 +22,12 @@ #include #include -#define THROW(exType, msg) { \ - std::string msg_ = msg; \ - LOG(ERROR) << "an " << #exType << " exception was thrown: " << msg_; \ - throw lut::exType ## Error(msg_); \ -} +#define THROW(exType, msg) \ + { \ + std::string msg_ = msg; \ + LOG(ERROR) << "an " << #exType << " exception was thrown: " << msg_; \ + throw lut::exType##Error(msg_); \ + } namespace lut { @@ -40,14 +41,15 @@ enum class ErrorCode : int { class Error : public std::exception { public: + Error(); Error(ErrorCode code, const std::string &what); ~Error(); - + // get error code. ErrorCode getCode() const; // implement std::exception. - const char* what() const noexcept override; + const char *what() const noexcept override; private: ErrorCode _code; @@ -74,4 +76,4 @@ class NotImplementedError : public Error { NotImplementedError(const std::string &what); }; -} // namespace lut +} // namespace lut diff --git a/src/libllm/model_for_generation.cc b/src/libllm/model_for_generation.cc index c14c3c37..74b56b7e 100644 --- a/src/libllm/model_for_generation.cc +++ b/src/libllm/model_for_generation.cc @@ -19,7 +19,6 @@ #include "libllm/model_for_generation.h" -#include "libllm/chatglm.h" #include "libllm/constants.h" #include "libllm/llama.h" #include "libllm/lut/error.h" @@ -27,6 +26,7 @@ #include "libllm/lut/strings.h" #include "libllm/lut/zip_file.h" #include "libllm/qwen.h" +#include "libllm/whisper.h" namespace libllm { @@ -45,27 +45,47 @@ std::shared_ptr ModelForGeneration::fromPackage( Context ctx = fromCtx.withName(modelType); std::shared_ptr model; - if (modelType == "chatglm2" || modelType == "chatglm3") { - model = chatglm::ChatGlmModelForGeneration::fromConfig(ctx, *ini); - } else if (modelType == "llama") { - model = llama::LlamaModelForGeneration::fromConfig(ctx, *ini); + if (modelType == "llama") { + model = llama::LlamaModelForGeneration::fromPackage(ctx, package); + } else if (modelType == "whisper") { + model = whisper::WhisperModelForGeneration::fromPackage(ctx, package); } else if (modelType == "index") { - model = llama::LlamaModelForGeneration::fromConfig(ctx, *ini); + model = llama::LlamaModelForGeneration::fromPackage(ctx, package); } else if (modelType == "qwen") { - model = qwen::QwenModelForGeneration::fromConfig(ctx, *ini); + model = qwen::QwenModelForGeneration::fromPackage(ctx, package); } else { throw lut::AbortedError(lut::sprintf("unexpected model type: %s", modelType)); } - // read state map. - std::string modelFile = ini->getSection(ModelSection).getString(ModelFileField); + return model; +} - StateMap stateMap; - stateMap.read(package->open(modelFile).get()); +void ModelForGeneration::initTokenizer(lut::ZipFile *package) { + _tokenizer = Tokenizer::fromPackage(package); +} - // initialize parameters. - model->initParameters(stateMap); - return model; +const Vocab *ModelForGeneration::getVocab() const { + return _tokenizer->getVocab(); +} + +void ModelForGeneration::encodePromptBlock( + const PromptBlock &block, + std::vector &tokenIds) const { + int tokenId; + switch (block.blockType) { + case PromptBlock::ControlToken: + tokenId = _tokenizer->getVocab()->findControlToken(block.text); + tokenIds.push_back(tokenId); + LOG(DEBUG) << "control token " << block.text << " -> " << tokenId; + break; + case PromptBlock::Text: + for (int tokenId : _tokenizer->encode(block.text)) { + tokenIds.push_back(tokenId); + } + break; + default: + NOT_IMPL(); + } } } // namespace libllm diff --git a/src/libllm/model_for_generation.h b/src/libllm/model_for_generation.h index 61827fd5..74c3dc3c 100644 --- a/src/libllm/model_for_generation.h +++ b/src/libllm/model_for_generation.h @@ -21,12 +21,27 @@ #include "libllm/context.h" #include "libllm/lut/zip_file.h" +#include "libllm/prompt.h" #include "libllm/state_map.h" #include "libllm/tensor.h" #include "libllm/tokenizer.h" namespace libllm { +/// @brief logits processor used in the generator. +class LogitsProcessor { + public: + virtual ~LogitsProcessor() = default; + + /// @brief tells the logits processor that an token is emitted by input prompt or generator. + /// @param tokenId the id of token. + virtual void notifyToken(int tokenId) = 0; + + /// @brief process the logits tensor. + /// @param logits the logits tensor to process. + virtual void processLogits(Tensor logits) = 0; +}; + // base class for language model. class ModelForGeneration { public: @@ -37,28 +52,19 @@ class ModelForGeneration { virtual ~ModelForGeneration() = default; - // initialize the parameters from stateMap. - virtual void initParameters(const StateMap &stateMap) = 0; - - // Forward input token ids through this language model, update the `past` state and return the - // hidden state of last layer. - // Args: - // past (StateMap): key-value cache. - // inputs (N, L): prompt token ids. - // Returns: - // (N, L, D): hidden state from last layer. - virtual Tensor forward(StateMap &past, Tensor input) const = 0; - - // Forward the hidden state from last layer and get the logits. hiddenState is usually the - // return value of forward(). - // Args: - // hidden_state (N, L, D): hidden state from last layer. - // Returns: - // (N, L, V): logits. V is vocabulary size. - virtual Tensor forwardHidden(Tensor hiddenState) const = 0; - - // build model input from the prompt token-ids. - virtual Tensor buildInput(const std::vector &prompt) const = 0; + /// @brief Used in the prefill phase. Forward the input prompt through this language model, update + /// the `past` state and return the logits for the next token. + /// @param past (StateMap): key-value cache. + /// @param prompt (Prompt): the input prompt for prefill. + /// @return (N, 1, V): hidden state from last layer. + virtual Tensor prefill(StateMap &past, const Prompt &prompt) const = 0; + + /// @brief Used in the decodeing phase. Forward input token ids through this language model, + /// update the `past` state and return the logits for the next token. + /// @param past (StateMap): key-value cache. + /// @param inputToken (LongType): the input token. + /// @return (N, 1, V): hidden state from last layer. + virtual Tensor decode(StateMap &past, LongType inputToken) const = 0; /// @brief Return true if tokenId is a stop token. (stop generating texts) /// @param tokenId the token id. @@ -71,6 +77,30 @@ class ModelForGeneration { /// @brief Get device of the model. /// @return the device. virtual Device getDevice() const = 0; + + /// @brief get the output dimension of model. This dimention is usually the same as vocabulary + /// size. But for some specific models, they are different. + /// @return the output dimension of the model. + virtual int getOutputDim() const = 0; + + /// @brief Get the vocabulary (tokenId to token string) of the model. + /// @return The vocabulary. + const Vocab *getVocab() const; + + protected: + std::shared_ptr _tokenizer; + + ModelForGeneration() = default; + + /// @brief Initialize the tokenizer. + /// @param package The model package. + void initTokenizer(lut::ZipFile *package); + + /// @brief Encode a prompt block and append the tokens into `tokenIds`. It will ONLY process two + /// types of promptBlock: text and controlToken. Once other type occured, it will fatal directly. + /// @param block The block to process. + /// @param tokenIds The vector to append processed tokens. + void encodePromptBlock(const PromptBlock &block, std::vector &tokenIds) const; }; } // namespace libllm diff --git a/src/libllm/module.cc b/src/libllm/module.cc index 37fce498..544e7825 100644 --- a/src/libllm/module.cc +++ b/src/libllm/module.cc @@ -211,6 +211,14 @@ std::unique_ptr LayerNorm::create(const Context &ctx, int dModel, flo return layer; } +void LayerNorm::initParameters(lut::Random *generator, DType _) { + _w = F::rand({_dModel}, DType::kFloat, Device::getCpu(), generator); + _w = moveAndCastFloat(_w, getCtx()); + + _b = F::rand({_dModel}, DType::kFloat, Device::getCpu(), generator); + _b = moveAndCastFloat(_b, getCtx()); +} + void LayerNorm::initParameters(const StateMap &stateDict) { const Context &ctx = getCtx(); @@ -231,4 +239,96 @@ Tensor LayerNorm::forward(const Tensor &input) const { return F::layerNorm(input, _w, _b, _eps); } +// -----------------------------------------------------------------------------------------------+ +// Conv1D | +// -----------------------------------------------------------------------------------------------+ + +constexpr char Conv1D::kWeight[]; +constexpr char Conv1D::kBias[]; + +Conv1D::Conv1D() + : _inChannels(0), + _outChannels(0), + _kernelSize(0), + _hasBias(false) { +} + +std::shared_ptr Conv1D::create( + const Context &ctx, + int inChannels, + int outChannels, + int kernelSize, + int stride, + bool bias) { + std::shared_ptr layer{new Conv1D()}; + layer->setCtx(ctx); + + if (kernelSize == 0 || kernelSize >= 16) { + throw lut::AbortedError("invalid kernelSize"); + } + + layer->_hasBias = bias; + layer->_inChannels = inChannels; + layer->_outChannels = outChannels; + layer->_kernelSize = kernelSize; + layer->_stride = stride; + return layer; +} + +void Conv1D::initParameters(const StateMap &stateDict) { + const Context &ctx = getCtx(); + + std::string nameW = getCtx().name(kWeight); + std::string nameB = ctx.name(kBias); + + _w = stateDict.getTensor(nameW); + _w.throwIfInvalidShape({_outChannels, _inChannels, _kernelSize}, nameW); + _w = moveAndCastFloat(_w, ctx); + _w = _w.view({_outChannels, -1}); + + if (_hasBias) { + _b = stateDict.getTensor(nameB); + _b.throwIfInvalidShape({_outChannels}, nameB); + _b = moveAndCastFloat(_b, ctx); + } else { + if (stateDict.hasTensor(nameB)) { + throw lut::AbortedError(lut::sprintf( + "In module %s: hasBias=false but bias weight found in state_map.", + ctx.name())); + } + } +} + +void Conv1D::initParameters(lut::Random *generator, DType weightType) { + float xs = sqrtf(3.0f / (_inChannels * _kernelSize)); + _w = F::rand( + {_outChannels, _inChannels * _kernelSize}, + weightType, + Device::getCpu(), + generator, + -xs, + xs); + _w = moveAndCastFloat(_w, getCtx()); + + if (_hasBias) { + _b = F::rand({_outChannels}, DType::kFloat, Device::getCpu(), generator, -0.2f, 0.2f); + _b = moveAndCastFloat(_b, getCtx()); + } +} + +Tensor Conv1D::forward(const Tensor &input) const { + Tensor x = F::unfold(input, _kernelSize, _stride); + if (input.getDim() >= 2) { + x = F::matmul(x, _w.transpose(0, 1)); + } else { + NOT_IMPL(); + } + + if (_hasBias) { + x = F::add(x, _b); + } + + return x; +} + } // namespace libllm diff --git a/src/libllm/module.h b/src/libllm/module.h index e1f1e593..72e6c258 100644 --- a/src/libllm/module.h +++ b/src/libllm/module.h @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -20,10 +20,11 @@ #pragma once #include + #include "libllm/context.h" +#include "libllm/lut/random.h" #include "libllm/state_map.h" #include "libllm/tensor.h" -#include "libllm/lut/random.h" #include "tensor.h" namespace libllm { @@ -50,11 +51,15 @@ class Module { /// @brief Get context of current module. /// @return reference of Context. - const Context &getCtx() const { return _ctx; } + const Context &getCtx() const { + return _ctx; + } /// @brief Set the context of current module. /// @param ctx reference of Context. - void setCtx(const Context &ctx) { _ctx = ctx; } + void setCtx(const Context &ctx) { + _ctx = ctx; + } private: Context _ctx; @@ -64,13 +69,14 @@ class Module { class LayerNorm : public Module { public: static std::unique_ptr create(const Context &ctx, int d_model, float eps = 1e-5); - + // implement interface Module void initParameters(const StateMap &state_dict) override; + void initParameters(lut::Random *generator, DType weightType) override; // forward input and return the output. Tensor forward(const Tensor &input) const; - + private: // tensor names. static constexpr char kWeight[] = "weight"; @@ -135,9 +141,12 @@ class Embedding : public Module { class Linear : public Module { public: - // create Linear module from context. + // create Linear module from context. static std::unique_ptr create( - const Context &ctx, int inDim, int outDim, bool hasBias = true); + const Context &ctx, + int inDim, + int outDim, + bool hasBias = true); // implement interface Module void initParameters(const StateMap &state_dict) override; @@ -161,4 +170,41 @@ class Linear : public Module { Linear(); }; +/// @brief Apply 1D Convolution to the input tensor. Unlike Conv1D in pytorch, it use (N, L, C) as +/// input format. +class Conv1D : public Module { + public: + // create Linear module from context. + static std::shared_ptr create( + const Context &ctx, + int inChennels, + int outChannels, + int kernelSize, + int stride = 1, + bool bias = true); + + // implement interface Module + void initParameters(const StateMap &stateDict) override; + void initParameters(lut::Random *generator, DType weightType) override; + + // forward input and return the output. + Tensor forward(const Tensor &input) const; + + private: + // tensor names. + static constexpr char kWeight[] = "weight"; + static constexpr char kBias[] = "bias"; + + Tensor _w; + Tensor _b; + + int _inChannels; + int _outChannels; + int _kernelSize; + int _stride; + bool _hasBias; + + Conv1D(); +}; + } // namespace libllm diff --git a/src/libllm/operator_tester.h b/src/libllm/operator_tester.h index f82de51d..8ce21fca 100644 --- a/src/libllm/operator_tester.h +++ b/src/libllm/operator_tester.h @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -20,22 +20,17 @@ #pragma once #include "libllm/device.h" -#include "libllm/operators.h" #include "libllm/lut/attributes.h" +#include "libllm/operators.h" namespace libllm { class OperatorTester { public: using ShapeType = std::initializer_list; - + static constexpr uint32_t MagicNumber = 0x33; - enum class OperatorType { - Add, - Mul, - Softmax, - Swiglu - }; + enum class OperatorType { Add, Mul, Softmax, Swiglu }; OperatorTester(); @@ -68,7 +63,7 @@ class OperatorTester { float _rtol; Operators *_op; Operators *_referenceOp; - + Device _testDevice; DType _testFloatType; }; diff --git a/src/libllm/operators.cc b/src/libllm/operators.cc index 5ee7d3dc..111c4b57 100644 --- a/src/libllm/operators.cc +++ b/src/libllm/operators.cc @@ -53,10 +53,22 @@ Tensor Operators::softmax(Tensor input) { NOT_IMPL(); } +Tensor Operators::sum(Tensor input) { + NOT_IMPL(); +} + +Tensor Operators::max(Tensor input) { + NOT_IMPL(); +} + Tensor Operators::gelu(Tensor input) { NOT_IMPL(); } +void Operators::fill(Tensor input, float value) { + NOT_IMPL(); +} + Tensor Operators::add(Tensor a, Tensor b) { NOT_IMPL(); } @@ -105,10 +117,18 @@ Tensor Operators::swiglu(Tensor A) { NOT_IMPL(); } +Tensor Operators::melFbank(Tensor A) { + NOT_IMPL(); +} + Tensor Operators::to(Device device, Tensor tensor) { NOT_IMPL(); } +Tensor Operators::unfold(Tensor input, int kernelSize, int stride) { + NOT_IMPL(); +} + Tensor Operators::cast(Tensor tensor, DType dtype) { NOT_IMPL(); } @@ -117,6 +137,10 @@ DType Operators::getDefaultFloatType() { NOT_IMPL(); } +Tensor Operators::logMelSpectrogram(Tensor wave) { + NOT_IMPL(); +} + Tensor Operators::rand( lut::Span shape, DType dtype, diff --git a/src/libllm/operators.h b/src/libllm/operators.h index 9ba5665d..8594989f 100644 --- a/src/libllm/operators.h +++ b/src/libllm/operators.h @@ -41,7 +41,11 @@ class Operators { virtual Tensor mul(Tensor input, Tensor other); virtual Tensor softmax(Tensor input); virtual Tensor add(Tensor input, Tensor other); + virtual Tensor sum(Tensor input); + virtual Tensor max(Tensor input); + virtual Tensor melFbank(Tensor input); virtual Tensor gelu(Tensor input); + virtual void fill(Tensor input, float value); virtual Tensor tensor(lut::Span shape, DType dtype); virtual Tensor tensorLike(Tensor input); virtual Tensor zeros(lut::Span shape, DType dtype); @@ -52,9 +56,15 @@ class Operators { virtual void copy(Tensor src, Tensor dest); virtual Tensor swiglu(Tensor A); virtual Tensor to(Device device, Tensor tensor); + virtual Tensor unfold(Tensor input, int kernelSize, int stride); virtual Tensor cast(Tensor tensor, DType dtype); - virtual Tensor - rand(lut::Span shape, DType dtype, lut::Random *generator, float min, float max); + virtual Tensor logMelSpectrogram(Tensor wave); + virtual Tensor rand( + lut::Span shape, + DType dtype, + lut::Random *generator, + float min, + float max); virtual DType getDefaultFloatType(); }; diff --git a/src/libllm/prompt.cc b/src/libllm/prompt.cc new file mode 100644 index 00000000..ebcf408d --- /dev/null +++ b/src/libllm/prompt.cc @@ -0,0 +1,77 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/prompt.h" + +namespace libllm { + +PromptBlock::PromptBlock() + : waveFormat(WaveFormat::Unknown), + blockType(Type::Unknown) { +} + +std::string PromptBlock::typeToString(Type blockType) { + switch (blockType) { + case Type::ControlToken: + return "controlToken"; + case Type::Text: + return "text"; + case Type::Wave: + return "wave"; + case Type::Unknown: + return "unknown"; + default: + NOT_IMPL(); + } +} + +void Prompt::appendText(const std::string &text) { + PromptBlock block; + block.text = text; + block.blockType = PromptBlock::Text; + + _blocks.emplace_back(std::move(block)); +} + +void Prompt::appendControlToken(const std::string &controlToken) { + PromptBlock block; + block.text = controlToken; + block.blockType = PromptBlock::ControlToken; + + _blocks.emplace_back(std::move(block)); +} + +void Prompt::appendWave(lut::Span payload, WaveFormat format) { + PromptBlock block; + block.data = std::vector(payload.begin(), payload.end()); + block.waveFormat = format; + block.blockType = PromptBlock::Wave; + + _blocks.emplace_back(std::move(block)); +} + +bool Prompt::empty() const { + return _blocks.empty(); +} + +lut::Span Prompt::getBlocks() const { + return _blocks; +} + +} // namespace libllm \ No newline at end of file diff --git a/src/libllm/prompt.h b/src/libllm/prompt.h new file mode 100644 index 00000000..37a9d465 --- /dev/null +++ b/src/libllm/prompt.h @@ -0,0 +1,63 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#pragma once + +#include + +#include +#include + +#include "libllm/dtype.h" +#include "libllm/wave.h" + +namespace libllm { + +struct PromptBlock { + enum Type { + Text, + ControlToken, + Wave, + Unknown, + }; + + std::string text; + std::vector data; + WaveFormat waveFormat; + Type blockType; + + PromptBlock(); + static std::string typeToString(Type blockType); +}; + +class Prompt { + public: + void appendText(const std::string &text); + void appendControlToken(const std::string &controlToken); + void appendWave(lut::Span payload, WaveFormat format); + + bool empty() const; + + lut::Span getBlocks() const; + + private: + std::vector _blocks; +}; + +} // namespace libllm diff --git a/src/libllm/qwen.cc b/src/libllm/qwen.cc index 49e92f9d..d829672a 100644 --- a/src/libllm/qwen.cc +++ b/src/libllm/qwen.cc @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -22,26 +22,41 @@ namespace libllm { namespace qwen { -QwenModelForGeneration::QwenModelForGeneration() : _imStartId(-1), _imEndId(-1) {} +QwenModelForGeneration::QwenModelForGeneration() + : _imStartId(-1), + _imEndId(-1) { +} -std::shared_ptr QwenModelForGeneration::fromConfig( +std::shared_ptr QwenModelForGeneration::fromPackage( const Context &ctx, - const lut::IniConfig &config) { + lut::ZipFile *package) { + std::shared_ptr reader = package->open(ModelConfig); + std::shared_ptr ini = lut::IniConfig::fromStream(reader.get()); + + std::string modelFile = ini->getSection(ModelSection).getString(ModelFileField); + std::string modelType = ini->getSection(ModelSection).getString(ModelTypeField); + CHECK(modelType == "qwen"); + + const lut::IniSection &qwenIni = ini->getSection(modelType); + std::shared_ptr model{new QwenModelForGeneration()}; - model->init(ctx, config); + llama::LlamaConfig llamaConfig = llama::LlamaConfig::loadConfig(qwenIni); - std::string modelType = config.getSection(ModelSection).getString(ModelTypeField); - const lut::IniSection &qwenSection = config.getSection(modelType); + StateMap stateMap; + stateMap.read(package->open(modelFile).get()); - model->_imStartId = qwenSection.getInt("im_start_token_id"); - model->_imEndId = qwenSection.getInt("im_end_token_id"); + model->_model = llama::LlamaModel::create(ctx, llamaConfig); + model->_model->initParameters(stateMap); + model->_imStartId = qwenIni.getInt("im_start_token_id"); + model->_imEndId = qwenIni.getInt("im_end_token_id"); + model->_modelName = modelType; + model->initTokenizer(package); return model; } bool QwenModelForGeneration::isStopToken(int tokenId) const { - if (llama::LlamaModelForGeneration::isStopToken(tokenId) || - tokenId == _imEndId || + if (llama::LlamaModelForGeneration::isStopToken(tokenId) || tokenId == _imEndId || tokenId == _imStartId) { return true; } else { diff --git a/src/libllm/qwen.h b/src/libllm/qwen.h index 71f8f9f3..dcdc53e8 100644 --- a/src/libllm/qwen.h +++ b/src/libllm/qwen.h @@ -7,7 +7,7 @@ // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -20,8 +20,9 @@ #pragma once #include -#include "libllm/lut/ini_config.h" + #include "libllm/llama.h" +#include "libllm/lut/ini_config.h" #include "libllm/model_for_generation.h" namespace libllm { @@ -32,9 +33,9 @@ namespace qwen { /// here. class QwenModelForGeneration : public llama::LlamaModelForGeneration { public: - static std::shared_ptr fromConfig( + static std::shared_ptr fromPackage( const Context &ctx, - const lut::IniConfig &config); + lut::ZipFile *package); // noncopyable QwenModelForGeneration(QwenModelForGeneration &) = delete; diff --git a/src/libllm/sampler.cc b/src/libllm/sampler.cc deleted file mode 100644 index 1423ab9b..00000000 --- a/src/libllm/sampler.cc +++ /dev/null @@ -1,105 +0,0 @@ -// The MIT License (MIT) -// -// Copyright (c) 2023 Xiaoyang Chen -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software -// and associated documentation files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING -// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#include "libllm/generator.h" - -#include -#include -#include "libllm/tensor.h" -#include "libllm/lut/log.h" - -namespace libllm { - -Sampler::Sampler(int topK, float topP) : _topK(topK), _topP(topP) {} - -std::vector Sampler::getTopP(const Tensor &distribution, lut::Span topK) { - CHECK(distribution.getDim() == 1 && distribution.getDType() == DType::kFloat); - float sumP = 0.0f; - - std::vector topP; - const float *d = distribution.getData(); - for (int label : topK) { - float p = d[label]; - topP.push_back(label); - - sumP += p; - if (sumP >= _topP) { - break; - } - } - - return topP; -} - -std::vector Sampler::getTopK(const Tensor &distribution) { - CHECK(_topK <= distribution.getShape(0) && distribution.getStride(0) == 1); - if (_topBuffer.size() != distribution.getShape(0)) - _topBuffer.resize(distribution.getShape(0)); - - const float *d = distribution.getData(); - for (int32_t i = 0; i < distribution.getShape(0); ++i) { - _topBuffer[i] = std::make_pair(i, d[i]); - } - - std::partial_sort(_topBuffer.begin(), _topBuffer.begin() + _topK, _topBuffer.end(), - [](const std::pair &a, const std::pair &b) { - return a.second > b.second; - }); - - std::vector topK; - LOG(DEBUG) << "Sampler TopK (K=" << _topK << ")"; - for (int i = 0; i < _topK; ++i) { - topK.push_back(_topBuffer[i].first); - LOG(DEBUG) << i << ": " <<_topBuffer[i].first << ", " << _topBuffer[i].second; - } - - return topK; -} - -int Sampler::sampleTopP(const Tensor &distribution, lut::Span topP) { - CHECK(distribution.getDim() == 1 && distribution.getDType() == DType::kFloat); - std::vector probAcc; - - float sumP = 0.0f; - const float *probData = distribution.getData(); - for (int label : topP) { - float p = probData[label]; - sumP += p; - probAcc.push_back(sumP); - } - - float r = _random.nextFloat() * sumP; - for (int i = 0; i < topP.size(); ++i) { - if (r < probAcc[i]) { - return topP[i]; - } - } - return topP.back(); -} - -int Sampler::sample(const Tensor &distribution) { - CHECK(distribution.getDim() == 1 && distribution.getDType() == DType::kFloat); - - std::vector topK = getTopK(distribution); // topK is sorted by its prob in x - std::vector topP = getTopP(distribution, topK); - - return sampleTopP(distribution, topP); -} - -} // namespace libllm diff --git a/src/libllm/state_map.cc b/src/libllm/state_map.cc index ba068744..fcddc5e3 100644 --- a/src/libllm/state_map.cc +++ b/src/libllm/state_map.cc @@ -54,6 +54,7 @@ void StateMap::read(lut::Reader *fp) { if (tag != " ") throw lut::AbortedError("invalid tensor map file"); std::pair kv = readTensor(fp); + LOG(DEBUG) << "Load tensor: " << kv.first; _dict[kv.first] = kv.second; tag = fp->readString(4); diff --git a/src/libllm/wave.cc b/src/libllm/wave.cc new file mode 100644 index 00000000..b05a8363 --- /dev/null +++ b/src/libllm/wave.cc @@ -0,0 +1,46 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/wave.h" + +#include "libllm/functional.h" +#include "libllm/lut/error.h" + +namespace libllm { + +Tensor Wave::read(lut::Span data, WaveFormat format) { + if (format == WaveFormat::Wave16kHz16bitMonoPCM) { + int numSamples = static_cast(data.size() / 2); + if (data.size() % 2 != 0) { + throw lut::AbortedError("Wave::read: invalid size of data"); + } + + std::vector wave(numSamples); + const int16_t *phData = reinterpret_cast(data.data()); + for (int i = 0; i < numSamples; ++i) { + wave[i] = static_cast(phData[i]) / 32768.0f; + } + + return Tensor::create({numSamples}, lut::makeConstSpan(wave)); + } else { + NOT_IMPL(); + } +} + +} // namespace libllm diff --git a/src/libllm/sampler.h b/src/libllm/wave.h similarity index 70% rename from src/libllm/sampler.h rename to src/libllm/wave.h index c18df3f5..5eb1b044 100644 --- a/src/libllm/sampler.h +++ b/src/libllm/wave.h @@ -1,13 +1,13 @@ // The MIT License (MIT) // -// Copyright (c) 2023 Xiaoyang Chen +// Copyright (c) 2024 Xiaoyang Chen // // Permission is hereby granted, free of charge, to any person obtaining a copy of this software // and associated documentation files (the "Software"), to deal in the Software without // restriction, including without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // @@ -19,28 +19,22 @@ #pragma once -#include -#include "libllm/lut/random.h" +#include + #include "libllm/lut/span.h" #include "libllm/tensor.h" namespace libllm { -class Sampler { - public: - Sampler(int topK, float topP); - - int sample(const Tensor &distribution); - - private: - lut::Random _random; - int _topK; - float _topP; - std::vector> _topBuffer; +enum class WaveFormat { + Wave16kHz16bitMonoPCM, + Unknown, +}; - std::vector getTopK(const Tensor &distribution); - std::vector getTopP(const Tensor &distribution, lut::Span topK); - int sampleTopP(const Tensor &distribution, lut::Span topP); +// interface for Tokenizer. +class Wave { + public: + static Tensor read(lut::Span data, WaveFormat format); }; } // namespace libllm diff --git a/src/libllm/whisper.cc b/src/libllm/whisper.cc new file mode 100644 index 00000000..dc4a76fb --- /dev/null +++ b/src/libllm/whisper.cc @@ -0,0 +1,824 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "libllm/whisper.h" + +#include + +#include "libllm/constants.h" +#include "libllm/functional.h" +#include "libllm/lut/error.h" +#include "libllm/lut/strings.h" + +namespace libllm { +namespace whisper { + +WhisperConfig::WhisperConfig() + : hiddenSize(0), + encoderNumHeads(0), + encoderFfnDim(0), + encoderNumLayers(0), + decoderNumLayers(0), + decoderFfnDim(0), + vocabSize(0), + maxTgtLength(0) { +} + +WhisperConfig WhisperConfig::loadConfig(const lut::IniSection §ion) { + WhisperConfig config; + + config.hiddenSize = section.getInt("hidden_size"); + config.encoderNumHeads = section.getInt("encoder_num_heads"); + config.encoderFfnDim = section.getInt("encoder_ffn_dim"); + config.encoderNumLayers = section.getInt("encoder_num_layers"); + config.decoderNumLayers = section.getInt("decoder_num_layers"); + config.decoderFfnDim = section.getInt("decoder_ffn_dim"); + config.vocabSize = section.getInt("vocab_size"); + config.maxTgtLength = section.getInt("max_tgt_length"); + return config; +} + +// -----------------------------------------------------------------------------------------------+ +// class EncoderAttention | +// -----------------------------------------------------------------------------------------------+ + +EncoderAttention::EncoderAttention() + : _numHeads(0), + _hiddenSize(0) { +} + +EncoderAttention::~EncoderAttention() { +} + +std::shared_ptr EncoderAttention::fromConfig( + const Context &ctx, + WhisperConfig config) { + std::shared_ptr model{new EncoderAttention()}; + model->setCtx(ctx); + + if (config.hiddenSize % config.encoderNumHeads != 0) { + throw lut::AbortedError("invalid hiddenSize and numHeads"); + } + + model->_qkvProj = Linear::create( + ctx.withName("qkv_proj"), + config.hiddenSize, + config.hiddenSize * 3); + model->_outProj = Linear::create(ctx.withName("out_proj"), config.hiddenSize, config.hiddenSize); + model->_hiddenSize = config.hiddenSize; + model->_numHeads = config.encoderNumHeads; + return model; +} + +void EncoderAttention::initParameters(const StateMap &stateDict) { + _qkvProj->initParameters(stateDict); + _outProj->initParameters(stateDict); +} + +void EncoderAttention::initParameters(lut::Random *generator, DType weightType) { + _qkvProj->initParameters(generator, weightType); + _outProj->initParameters(generator, weightType); +} + +Tensor EncoderAttention::forward(Tensor inputs) { + CHECK(inputs.getDim() == 3); + Tensor qkv = _qkvProj->forward(inputs); + + Tensor q = qkv.slice(-1, {0, _hiddenSize}); + Tensor k = qkv.slice(-1, {_hiddenSize, _hiddenSize * 2}); + Tensor v = qkv.slice(-1, {_hiddenSize * 2, _hiddenSize * 3}); + + int bsz = inputs.getShape(0); + int len = inputs.getShape(1); + int headDim = _hiddenSize / _numHeads; + q = q.view({bsz, len, _numHeads, headDim}); + k = k.view({bsz, len, _numHeads, headDim}); + v = v.view({bsz, len, _numHeads, headDim}); + + q = q.transpose(1, 2); + k = k.transpose(1, 2); + v = v.transpose(1, 2); + Tensor x = F::attention(q, k, v); + + x = F::contiguous(x.transpose(1, 2)).view({bsz, len, _hiddenSize}); + x = _outProj->forward(x); + + return x; +} + +// -----------------------------------------------------------------------------------------------+ +// class EncoderLayer | +// -----------------------------------------------------------------------------------------------+ + +EncoderLayer::EncoderLayer() { +} + +EncoderLayer::~EncoderLayer() { +} + +std::shared_ptr EncoderLayer::fromConfig(const Context &ctx, WhisperConfig config) { + std::shared_ptr model{new EncoderLayer()}; + model->setCtx(ctx); + + model->_norm1 = LayerNorm::create(ctx.withName("norm1"), config.hiddenSize); + model->_norm2 = LayerNorm::create(ctx.withName("norm2"), config.hiddenSize); + model->_attn = EncoderAttention::fromConfig(ctx.withName("attn"), config); + model->_fc1 = Linear::create(ctx.withName("fc1"), config.hiddenSize, config.encoderFfnDim); + model->_fc2 = Linear::create(ctx.withName("fc2"), config.encoderFfnDim, config.hiddenSize); + return model; +} + +void EncoderLayer::initParameters(const StateMap &stateDict) { + _norm1->initParameters(stateDict); + _norm2->initParameters(stateDict); + _attn->initParameters(stateDict); + _fc1->initParameters(stateDict); + _fc2->initParameters(stateDict); +} + +void EncoderLayer::initParameters(lut::Random *generator, DType weightType) { + _norm1->initParameters(generator, weightType); + _norm2->initParameters(generator, weightType); + _attn->initParameters(generator, weightType); + _fc1->initParameters(generator, weightType); + _fc2->initParameters(generator, weightType); +} + +Tensor EncoderLayer::forward(Tensor inputs) { + Tensor residual = inputs; + + Tensor x = _norm1->forward(inputs); + x = _attn->forward(x); + x = F::add(x, residual); + + residual = x; + x = _norm2->forward(x); + + x = _fc1->forward(x); + x = F::gelu(x); + + x = _fc2->forward(x); + x = F::add(x, residual); + return x; +} + +// -----------------------------------------------------------------------------------------------+ +// class EncoderModel | +// -----------------------------------------------------------------------------------------------+ + +EncoderModel::EncoderModel() + : _hiddenSize(0) { +} + +EncoderModel::~EncoderModel() { +} + +std::shared_ptr EncoderModel::fromConfig(const Context &ctx, WhisperConfig config) { + std::shared_ptr model{new EncoderModel()}; + model->setCtx(ctx); + + model->_conv1 = Conv1D::create(ctx.withName("conv1"), FeatDim, config.hiddenSize, 3); + model->_conv2 = Conv1D::create(ctx.withName("conv2"), config.hiddenSize, config.hiddenSize, 3, 2); + model->_hiddenSize = config.hiddenSize; + for (int i = 0; i < config.encoderNumLayers; ++i) { + model->_layers.emplace_back( + EncoderLayer::fromConfig(ctx.withName(lut::sprintf("layer%d", i)), config)); + } + model->_norm = LayerNorm::create(ctx.withName("norm"), config.hiddenSize); + return model; +} + +void EncoderModel::initParameters(const StateMap &stateDict) { + Context ctx = getCtx(); + + _conv1->initParameters(stateDict); + _conv2->initParameters(stateDict); + + _posEmbd = stateDict.getTensor(ctx.name("pos_embd")); + _posEmbd = moveAndCastFloat(_posEmbd, ctx); + + for (std::shared_ptr &layer : _layers) { + layer->initParameters(stateDict); + } + + _norm->initParameters(stateDict); +} + +void EncoderModel::initParameters(lut::Random *generator, DType weightType) { + _conv1->initParameters(generator, weightType); + _conv2->initParameters(generator, weightType); + + float r = 0.2f; + Device dCpu = Device::getCpu(); + _posEmbd = F::rand({NumFrames, _hiddenSize}, DType::kFloat, dCpu, generator, -r, r); + _posEmbd = moveAndCastFloat(_posEmbd, getCtx()); + + for (std::shared_ptr &layer : _layers) { + layer->initParameters(generator, weightType); + } + + _norm->initParameters(generator, weightType); +} + +Tensor EncoderModel::forward(Tensor wave) { + CHECK(wave.getDim() == 1 && wave.getShape(-1) <= InputSamples); + + // pad wave. + if (wave.getShape(-1) < InputSamples) { + Tensor pad = F::zeros({InputSamples}, wave.getDType(), wave.getDevice()); + F::copy(wave, pad.slice({0, wave.getShape(-1)})); + wave = pad; + } + + Tensor features = F::logMelSpectrogram(wave); + + CHECK(features.getDim() == 2); + features = features.unsqueeze(0); + + Tensor x = _conv1->forward(features); + x = F::gelu(x); + + x = _conv2->forward(x); + x = F::gelu(x); + x = F::add(x, _posEmbd); + + for (const std::shared_ptr &layer : _layers) { + x = layer->forward(x); + } + + x = _norm->forward(x); + return x; +} + +// -----------------------------------------------------------------------------------------------+ +// class DecoderInitModel | +// -----------------------------------------------------------------------------------------------+ + +DecoderInitModel::DecoderInitModel() + : _dModel(0) { +} + +DecoderInitModel::~DecoderInitModel() { +} + +std::shared_ptr DecoderInitModel::fromConfig( + const Context &ctx, + WhisperConfig config) { + std::shared_ptr model{new DecoderInitModel()}; + model->setCtx(ctx); + + int dModel = config.hiddenSize; + for (int i = 0; i < config.encoderNumLayers; ++i) { + Context ctxLayer = ctx.withName(lut::sprintf("layer%d", i)).withName(DecoderLayer::CrossAttn); + model->_kvProjs.emplace_back(Linear::create(ctxLayer.withName("kv_proj"), dModel, dModel * 2)); + } + model->_dModel = dModel; + return model; +} + +void DecoderInitModel::initParameters(const StateMap &stateDict) { + for (std::shared_ptr &layer : _kvProjs) { + layer->initParameters(stateDict); + } +} + +void DecoderInitModel::initParameters(lut::Random *generator, DType weightType) { + for (std::shared_ptr &layer : _kvProjs) { + layer->initParameters(generator, weightType); + } +} + +void DecoderInitModel::forward(StateMap &past, Tensor encoderHidden) { + CHECK(encoderHidden.getDim() == 3); + + for (int i = 0; i < _kvProjs.size(); ++i) { + Context ctxLayer = getCtx().withName(lut::sprintf("layer%d", i)); + Context ctxAttn = ctxLayer.withName(DecoderLayer::CrossAttn); + + Tensor x = _kvProjs[i]->forward(encoderHidden); + Tensor cacheK = x.slice(2, {0, _dModel}); + Tensor cacheV = x.slice(2, {_dModel, 2 * _dModel}); + + past.putTensor(ctxAttn.name("k"), cacheK); + past.putTensor(ctxAttn.name("v"), cacheV); + } +} + +// -----------------------------------------------------------------------------------------------+ +// class Attention | +// -----------------------------------------------------------------------------------------------+ + +Attention::Attention() + : _numHeads(0), + _hiddenSize(0) { +} + +Attention::~Attention() { +} + +std::shared_ptr Attention::selfAttn(const Context &ctx, WhisperConfig config) { + std::shared_ptr model{new Attention()}; + model->setCtx(ctx); + model->initCommon(config); + + model->_proj = Linear::create(ctx.withName("qkv_proj"), config.hiddenSize, config.hiddenSize * 3); + model->_selfAttn = true; + return model; +} + +std::shared_ptr Attention::crossAttn(const Context &ctx, WhisperConfig config) { + std::shared_ptr model{new Attention()}; + model->setCtx(ctx); + model->initCommon(config); + + model->_proj = Linear::create(ctx.withName("q_proj"), config.hiddenSize, config.hiddenSize); + model->_selfAttn = false; + return model; +} + +int Attention::getCtxLength(const StateMap &past) const { + if (past.hasValue(_namePastLen)) { + return past.getValue(_namePastLen); + } else { + return 0; + } +} + +void Attention::initCommon(WhisperConfig config) { + if (config.hiddenSize % config.encoderNumHeads != 0) { + throw lut::AbortedError("invalid hiddenSize and numHeads"); + } + + _outProj = Linear::create(getCtx().withName("out_proj"), config.hiddenSize, config.hiddenSize); + _hiddenSize = config.hiddenSize; + _numHeads = config.encoderNumHeads; + + _namePastK = getCtx().name("k"); + _namePastV = getCtx().name("v"); + _namePastLen = getCtx().name("len"); +} + +void Attention::initParameters(const StateMap &stateDict) { + _proj->initParameters(stateDict); + _outProj->initParameters(stateDict); +} + +void Attention::initParameters(lut::Random *generator, DType weightType) { + _proj->initParameters(generator, weightType); + _outProj->initParameters(generator, weightType); +} + +std::pair Attention::getPresentKV(StateMap &past, Tensor k, Tensor v) { + Tensor pastK, pastV; + + int pastLen = getCtxLength(past); + int presentLen = pastLen + k.getShape(1); + + int cacheLen = 0; + if (pastLen > 0) { + pastK = past.getTensor(_namePastK); + pastV = past.getTensor(_namePastV); + cacheLen = pastK.getShape(1); + CHECK(pastK.getDim() == 3 && pastV.getDim() == 3 && pastK.getShape(1) == pastV.getShape(1)); + } + + if (cacheLen < presentLen) { + LOG(DEBUG) << lut::sprintf( + "update kv cache cacheLen=%d pastLen=%d presentLen=%d", + cacheLen, + pastLen, + presentLen); + + // to reduce memory allocation, we extend the kv cache block by block. + int nextNumBlocks = (presentLen + PastBlockSize - 1) / PastBlockSize; + int nextLen = PastBlockSize * nextNumBlocks; + + int d0, d2; + if (pastLen) { + d0 = pastK.getShape(0); + d2 = pastK.getShape(2); + } else { + d0 = k.getShape(0); + d2 = k.getShape(2); + } + Tensor nextK = F::zeros({d0, nextLen, d2}, k.getDType(), k.getDevice()); + Tensor nextV = F::zeros({d0, nextLen, d2}, v.getDType(), v.getDevice()); + + if (pastLen) { + F::copy(pastK.slice(1, {0, pastLen}), nextK.slice(1, {0, pastLen})); + F::copy(pastV.slice(1, {0, pastLen}), nextV.slice(1, {0, pastLen})); + } + + past.putTensor(_namePastK, nextK); + past.putTensor(_namePastV, nextV); + + pastK = nextK; + pastV = nextV; + } + + F::copy(k, pastK.slice(1, {pastLen, presentLen})); + F::copy(v, pastV.slice(1, {pastLen, presentLen})); + + Tensor presentK = pastK.slice(1, {0, presentLen}); + Tensor presentV = pastV.slice(1, {0, presentLen}); + past.putValue(_namePastLen, presentLen); + + return std::make_pair(presentK, presentV); +} + +Tensor Attention::forward(StateMap &past, Tensor inputs) { + CHECK(inputs.getDim() == 3); + + Tensor q, k, v; + if (_selfAttn) { + Tensor qkv = _proj->forward(inputs); + q = qkv.slice(-1, {0, _hiddenSize}); + k = qkv.slice(-1, {_hiddenSize, _hiddenSize * 2}); + v = qkv.slice(-1, {_hiddenSize * 2, _hiddenSize * 3}); + + std::tie(k, v) = getPresentKV(past, k, v); + } else { + q = _proj->forward(inputs); + + // initialized in the DecoderInitModel. + k = past.getTensor(_namePastK); + v = past.getTensor(_namePastV); + } + + int bsz = inputs.getShape(0); + int len = inputs.getShape(1); + int headDim = _hiddenSize / _numHeads; + q = q.view({bsz, len, _numHeads, headDim}); + k = k.view({bsz, k.getShape(1), _numHeads, headDim}); + v = v.view({bsz, v.getShape(1), _numHeads, headDim}); + + q = q.transpose(1, 2); + k = k.transpose(1, 2); + v = v.transpose(1, 2); + + Tensor x; + if (_selfAttn && inputs.getShape(1) == 1) { + x = F::attention(q, k, v, F::causalMask(q.getShape(2), getCtx().getDevice())); + } else { + x = F::attention(q, k, v); + } + + x = F::contiguous(x.transpose(1, 2)).view({bsz, len, _hiddenSize}); + x = _outProj->forward(x); + return x; +} + +// -----------------------------------------------------------------------------------------------+ +// class DecoderLayer | +// -----------------------------------------------------------------------------------------------+ + +constexpr char DecoderLayer::CrossAttn[]; +constexpr char DecoderLayer::SelfAttn[]; + +DecoderLayer::DecoderLayer() { +} + +DecoderLayer::~DecoderLayer() { +} + +std::shared_ptr DecoderLayer::fromConfig(const Context &ctx, WhisperConfig config) { + std::shared_ptr model{new DecoderLayer()}; + model->setCtx(ctx); + + model->_norm1 = LayerNorm::create(ctx.withName("norm1"), config.hiddenSize); + model->_norm2 = LayerNorm::create(ctx.withName("norm2"), config.hiddenSize); + model->_norm3 = LayerNorm::create(ctx.withName("norm3"), config.hiddenSize); + model->_selfAttn = Attention::selfAttn(ctx.withName(SelfAttn), config); + model->_crossAttn = Attention::crossAttn(ctx.withName(CrossAttn), config); + model->_fc1 = Linear::create(ctx.withName("fc1"), config.hiddenSize, config.decoderFfnDim); + model->_fc2 = Linear::create(ctx.withName("fc2"), config.decoderFfnDim, config.hiddenSize); + return model; +} + +void DecoderLayer::initParameters(const StateMap &stateDict) { + _norm1->initParameters(stateDict); + _norm2->initParameters(stateDict); + _norm3->initParameters(stateDict); + _selfAttn->initParameters(stateDict); + _crossAttn->initParameters(stateDict); + _fc1->initParameters(stateDict); + _fc2->initParameters(stateDict); +} + +void DecoderLayer::initParameters(lut::Random *generator, DType weightType) { + _norm1->initParameters(generator, weightType); + _norm2->initParameters(generator, weightType); + _norm3->initParameters(generator, weightType); + _selfAttn->initParameters(generator, weightType); + _crossAttn->initParameters(generator, weightType); + _fc1->initParameters(generator, weightType); + _fc2->initParameters(generator, weightType); +} + +Tensor DecoderLayer::forward(StateMap &past, Tensor inputs) { + Tensor residual = inputs; + + Tensor x = _norm1->forward(inputs); + x = _selfAttn->forward(past, x); + x = F::add(x, residual); + + residual = x; + x = _norm2->forward(x); + x = _crossAttn->forward(past, x); + x = F::add(x, residual); + + residual = x; + x = _norm3->forward(x); + x = _fc1->forward(x); + x = F::gelu(x); + + x = _fc2->forward(x); + x = F::add(x, residual); + return x; +} + +// -----------------------------------------------------------------------------------------------+ +// class DecoderModel | +// -----------------------------------------------------------------------------------------------+ + +DecoderModel::DecoderModel() + : _dModel(0), + _maxTgtLength(0), + _outputDim(0) { +} + +DecoderModel::~DecoderModel() { +} + +std::shared_ptr DecoderModel::fromConfig(const Context &ctx, WhisperConfig config) { + std::shared_ptr model{new DecoderModel()}; + model->setCtx(ctx); + + model->_embd = Embedding::create(ctx.withName("embd"), config.hiddenSize, config.vocabSize); + for (int i = 0; i < config.decoderNumLayers; ++i) { + model->_layers.emplace_back( + DecoderLayer::fromConfig(ctx.withName(lut::sprintf("layer%d", i)), config)); + } + model->_norm = LayerNorm::create(ctx.withName("norm"), config.hiddenSize); + model->_outProj = Linear::create( + ctx.withName("out_proj"), + config.hiddenSize, + config.vocabSize, + false); + model->_maxTgtLength = config.maxTgtLength; + model->_dModel = config.hiddenSize; + model->_namePastLen = ctx.name("len"); + model->_outputDim = config.vocabSize; + return model; +} + +void DecoderModel::initParameters(const StateMap &stateDict) { + Context ctx = getCtx(); + + _embd->initParameters(stateDict); + _norm->initParameters(stateDict); + _outProj->initParameters(stateDict); + + _posEmbd = stateDict.getTensor(ctx.name("pos_embd")); + _posEmbd.throwIfInvalidShape({_maxTgtLength, _dModel}, ctx.name("pos_embd")); + _posEmbd = moveAndCastFloat(_posEmbd, ctx); + + for (std::shared_ptr &layer : _layers) { + layer->initParameters(stateDict); + } +} + +void DecoderModel::initParameters(lut::Random *generator, DType weightType) { + _embd->initParameters(generator, weightType); + _norm->initParameters(generator, weightType); + _outProj->initParameters(generator, weightType); + + float r = 0.2f; + Device dCpu = Device::getCpu(); + _posEmbd = F::rand({_maxTgtLength, _dModel}, DType::kFloat, dCpu, generator, -r, r); + _posEmbd = moveAndCastFloat(_posEmbd, getCtx()); + + for (std::shared_ptr &layer : _layers) { + layer->initParameters(generator, weightType); + } +} + +int DecoderModel::getCtxLength(const StateMap &past) const { + if (past.hasValue(_namePastLen)) { + return past.getValue(_namePastLen); + } else { + return 0; + } +} + +Tensor DecoderModel::forward(StateMap &past, Tensor inputs) { + Tensor x = _embd->forward(inputs); + + // positional embedding. + int pastLen = getCtxLength(past); + int presentLen = pastLen + inputs.getShape(1); + x = F::add(x, _posEmbd.slice({pastLen, presentLen})); + past.putValue(_namePastLen, presentLen); + + for (const std::shared_ptr &layer : _layers) { + x = layer->forward(past, x); + } + + x = _norm->forward(x); + return x; +} + +Tensor DecoderModel::forwardLmHead(Tensor inputs) { + return _outProj->forward(inputs); +} + +int DecoderModel::getOutputDim() const { + return _outputDim; +} + +// -----------------------------------------------------------------------------------------------+ +// class WhisperLogitsProcessor | +// -----------------------------------------------------------------------------------------------+ + +WhisperLogitsProcessor::WhisperLogitsProcessor() + : _lastTimeToken(-1), + _beginTimeToken(-1), + _endTimeToken(-1), + _eotToken(-1) { +} + +std::shared_ptr WhisperLogitsProcessor::newProcessor(const Vocab *vocab) { + std::shared_ptr processor{new WhisperLogitsProcessor()}; + processor->_lastTimeToken = -1; + processor->_beginTimeToken = vocab->findControlToken("<|0.00|>"); + processor->_endTimeToken = vocab->findControlToken("<|30.00|>"); + processor->_eotToken = vocab->findControlToken("<|endoftext|>"); + processor->_transcribeToken = vocab->findControlToken("<|transcribe|>"); + processor->_translateToken = vocab->findControlToken("<|translate|>"); + processor->_noTimestampToken = vocab->findControlToken("<|notimestamps|>"); + + return processor; +} + +void WhisperLogitsProcessor::notifyToken(int tokenId) { + _history.push_back(tokenId); + if (tokenId >= _beginTimeToken && tokenId <= _endTimeToken) { + _lastTimeToken = tokenId; + } +} + +void WhisperLogitsProcessor::processLogits(Tensor logits) { + bool lastWasTimestamp = _history.size() >= 1 && _history.back() >= _beginTimeToken; + bool lastWasTranscribe = _history.size() >= 1 && _history.back() == _transcribeToken; + bool penultimateWasTimestamp = _history.size() < 2 || + _history[_history.size() - 2] >= _beginTimeToken || + _history[_history.size() - 2] == _transcribeToken || + _history[_history.size() - 2] == _translateToken; + + if (lastWasTranscribe) { + F::fill(logits.slice(-1, {_noTimestampToken, _noTimestampToken + 1}), -Inf); + } + + if (lastWasTimestamp) { + if (penultimateWasTimestamp) { + F::fill(logits.slice(-1, {_beginTimeToken, _endTimeToken + 1}), -Inf); + } else { + F::fill(logits.slice(-1, {0, _eotToken}), -Inf); + } + } + + if (_lastTimeToken > 0) { + F::fill(logits.slice(-1, {_beginTimeToken, _lastTimeToken + 1}), -Inf); + } + + Tensor probs = F::softmax(logits); + Tensor maxText = F::max(probs.slice(-1, {0, _eotToken + 1})); + Tensor sumTimestamp = F::sum(probs.slice(-1, {_beginTimeToken, _endTimeToken + 1})); + + float maxTextVal = *maxText.getData(); + float sumTimestampVal = *sumTimestamp.getData(); + if (sumTimestampVal >= maxTextVal) { + F::fill(logits.slice(-1, {0, _eotToken}), -Inf); + } +} + +// -----------------------------------------------------------------------------------------------+ +// class WhisperModelForGeneration | +// -----------------------------------------------------------------------------------------------+ + +WhisperModelForGeneration::WhisperModelForGeneration() + : _eotId(0) { +} + +std::shared_ptr WhisperModelForGeneration::fromPackage( + const Context &ctx, + lut::ZipFile *package) { + std::shared_ptr reader = package->open(ModelConfig); + std::shared_ptr ini = lut::IniConfig::fromStream(reader.get()); + + std::string modelFile = ini->getSection(ModelSection).getString(ModelFileField); + std::string modelType = ini->getSection(ModelSection).getString(ModelTypeField); + + const lut::IniSection &llamaIni = ini->getSection(modelType); + + std::shared_ptr model{new WhisperModelForGeneration()}; + WhisperConfig llamaConfig = WhisperConfig::loadConfig(llamaIni); + + StateMap stateMap; + + stateMap.read(package->open(modelFile).get()); + model->_encoder = EncoderModel::fromConfig(ctx.withName("encoder"), llamaConfig); + model->_decoderInit = DecoderInitModel::fromConfig(ctx.withName("decoder"), llamaConfig); + model->_decoder = DecoderModel::fromConfig(ctx.withName("decoder"), llamaConfig); + + model->_encoder->initParameters(stateMap); + model->_decoderInit->initParameters(stateMap); + model->_decoder->initParameters(stateMap); + model->_eotId = llamaIni.getInt("eot_token_id"); + model->_modelName = modelType; + + model->initTokenizer(package); + return model; +} + +Tensor WhisperModelForGeneration::buildDecoderInput(lut::Span prompt) const { + std::vector inputData{}; + for (const PromptBlock &block : prompt) { + if (block.blockType == PromptBlock::ControlToken || block.blockType == PromptBlock::Text) { + encodePromptBlock(block, inputData); + } else { + throw lut::AbortedError("in whisper prompt, only one audio input is supported"); + } + } + + int len = inputData.size(); + Tensor inputs = Tensor::create({1, len}, inputData); + inputs = F::to(_decoder->getCtx().getDevice(), inputs); + return inputs; +} + +Tensor WhisperModelForGeneration::prefill(StateMap &past, const Prompt &prompt) const { + if (prompt.empty()) throw lut::AbortedError("prompt is empty"); + + const PromptBlock &audioBlock = prompt.getBlocks()[0]; + if (audioBlock.blockType != PromptBlock::Wave) { + throw lut::AbortedError("in whisper model, the first element in prompt should be the audio"); + } + if (prompt.getBlocks().size() <= 1) throw lut::AbortedError("decoder prompt is empty"); + + Tensor wave = Wave::read(audioBlock.data, audioBlock.waveFormat); + Tensor encoderHidden = _encoder->forward(wave); + _decoderInit->forward(past, encoderHidden); + + Tensor inputs = buildDecoderInput(prompt.getBlocks().subspan(1)); + Tensor x = _decoder->forward(past, inputs); + + x = x.slice(1, {-1, None}); + x = _decoder->forwardLmHead(x); + return x; +} + +Tensor WhisperModelForGeneration::decode(StateMap &past, LongType inputToken) const { + std::array inputData{inputToken}; + Tensor inputs = Tensor::create({1, 1}, inputData); + inputs = F::to(getDevice(), inputs); + + Tensor x = _decoder->forward(past, inputs); + x = _decoder->forwardLmHead(x); + return x; +} + +bool WhisperModelForGeneration::isStopToken(int tokenId) const { + return tokenId == _eotId; +} + +const char *WhisperModelForGeneration::getName() const { + return _modelName.c_str(); +} + +Device WhisperModelForGeneration::getDevice() const { + return _decoder->getCtx().getDevice(); +} + +int WhisperModelForGeneration::getOutputDim() const { + return _decoder->getOutputDim(); +} + +} // namespace whisper +} // namespace libllm diff --git a/src/libllm/whisper.h b/src/libllm/whisper.h new file mode 100644 index 00000000..263d0465 --- /dev/null +++ b/src/libllm/whisper.h @@ -0,0 +1,279 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Xiaoyang Chen +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software +// and associated documentation files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#pragma once + +#include + +#include "libllm/lut/ini_config.h" +#include "libllm/model_for_generation.h" +#include "libllm/module.h" + +namespace libllm { +namespace whisper { + +struct WhisperConfig { + int hiddenSize; + int encoderNumHeads; + int encoderFfnDim; + int encoderNumLayers; + int decoderNumLayers; + int decoderFfnDim; + int vocabSize; + int maxTgtLength; + + WhisperConfig(); + + static WhisperConfig loadConfig(const lut::IniSection §ion); +}; + +class EncoderAttention : public Module { + public: + static std::shared_ptr fromConfig(const Context &ctx, WhisperConfig config); + ~EncoderAttention(); + + void initParameters(const StateMap &stateDict) override; + void initParameters(lut::Random *generator, DType weightType) override; + Tensor forward(Tensor inputs); + + private: + std::shared_ptr _qkvProj; + std::shared_ptr _outProj; + int _numHeads; + int _hiddenSize; + + EncoderAttention(); +}; + +class EncoderLayer : public Module { + public: + static std::shared_ptr fromConfig(const Context &ctx, WhisperConfig config); + ~EncoderLayer(); + + void initParameters(const StateMap &stateDict) override; + void initParameters(lut::Random *generator, DType weightType) override; + Tensor forward(Tensor inputs); + + private: + std::shared_ptr _norm1; + std::shared_ptr _norm2; + std::shared_ptr _attn; + std::shared_ptr _fc1; + std::shared_ptr _fc2; + + EncoderLayer(); +}; + +class EncoderModel : public Module { + public: + static std::shared_ptr fromConfig(const Context &ctx, WhisperConfig config); + ~EncoderModel(); + + void initParameters(const StateMap &stateDict) override; + void initParameters(lut::Random *generator, DType weightType) override; + + /// @brief Forward the wave through the whisper encoder model and update the key-value cache in + /// `past`. + /// @param wave the input wave. + Tensor forward(Tensor wave); + + private: + static constexpr int FeatDim = 128; + static constexpr int NumFrames = 30; + static constexpr int InputSamples = 16000 * NumFrames; + std::shared_ptr _conv1; + std::shared_ptr _conv2; + std::vector> _layers; + std::shared_ptr _norm; + Tensor _posEmbd; + int _hiddenSize; + + EncoderModel(); +}; + +class DecoderInitModel : public Module { + public: + static std::shared_ptr fromConfig(const Context &ctx, WhisperConfig config); + ~DecoderInitModel(); + + void initParameters(const StateMap &stateDict) override; + void initParameters(lut::Random *generator, DType weightType) override; + + /// @brief Forward the encoderHidden through the cross attention kv-projection layers and update + /// the key-value cache for cross attention in `past`. + /// @param past the kv_cache to update. + /// @param wave the hidden output from encoder model. + void forward(StateMap &past, Tensor encoderHidden); + + private: + std::vector> _kvProjs; + int _dModel; + + DecoderInitModel(); +}; + +class Attention : public Module { + public: + static std::shared_ptr selfAttn(const Context &ctx, WhisperConfig config); + static std::shared_ptr crossAttn(const Context &ctx, WhisperConfig config); + ~Attention(); + + void initParameters(const StateMap &stateDict) override; + void initParameters(lut::Random *generator, DType weightType) override; + Tensor forward(StateMap &past, Tensor inputs); + + private: + static constexpr int PastBlockSize = 2; + + std::shared_ptr _proj; + std::shared_ptr _outProj; + int _numHeads; + int _hiddenSize; + bool _selfAttn; + + std::string _namePastK; + std::string _namePastV; + std::string _namePastLen; + + Attention(); + + /// @brief Common part of initialization for cross attention and self attention. + /// @param config + void initCommon(WhisperConfig config); + + /// @brief Get the present kv tensor from input kv and past kv tensors. NOTE: do not modify the + /// content of returned tensors since they were the kv cache in next iteration. + /// @param past the kv cache. + /// @param k the input k. + /// @param v the input v. + std::pair getPresentKV(StateMap &past, Tensor k, Tensor v); + + /// @brief Get context (history) length for the self attention. + /// @param past the kv_cache. + /// @return the context length. + int getCtxLength(const StateMap &past) const; +}; + +class DecoderLayer : public Module { + public: + static constexpr char CrossAttn[] = "cross_attn"; + static constexpr char SelfAttn[] = "self_attn"; + + static std::shared_ptr fromConfig(const Context &ctx, WhisperConfig config); + ~DecoderLayer(); + + void initParameters(const StateMap &stateDict) override; + void initParameters(lut::Random *generator, DType weightType) override; + + Tensor forward(StateMap &past, Tensor inputs); + + private: + std::shared_ptr _norm1; + std::shared_ptr _norm2; + std::shared_ptr _norm3; + std::shared_ptr _selfAttn; + std::shared_ptr _crossAttn; + std::shared_ptr _fc1; + std::shared_ptr _fc2; + + DecoderLayer(); +}; + +class DecoderModel : public Module { + public: + static std::shared_ptr fromConfig(const Context &ctx, WhisperConfig config); + ~DecoderModel(); + + void initParameters(const StateMap &stateDict) override; + void initParameters(lut::Random *generator, DType weightType) override; + + Tensor forward(StateMap &past, Tensor inputs); + Tensor forwardLmHead(Tensor inputs); + int getOutputDim() const; + + private: + std::vector> _layers; + std::shared_ptr _embd; + std::shared_ptr _norm; + std::shared_ptr _outProj; + Tensor _posEmbd; + std::string _namePastLen; + int _dModel; + int _maxTgtLength; + int _outputDim; + + DecoderModel(); + + /// @brief Get context (history) length for the positional embedding. + /// @param past the kv_cache. + /// @return the context length. + int getCtxLength(const StateMap &past) const; +}; + +class WhisperLogitsProcessor : public LogitsProcessor { + public: + static std::shared_ptr newProcessor(const Vocab *vocab); + + void notifyToken(int tokenId) override; + void processLogits(Tensor logits) override; + + private: + static constexpr float Inf = std::numeric_limits::infinity(); + + std::vector _history; + + int _lastTimeToken; + int _beginTimeToken; + int _endTimeToken; + int _eotToken; + int _transcribeToken; + int _translateToken; + int _noTimestampToken; + + WhisperLogitsProcessor(); +}; + +class WhisperModelForGeneration : public ModelForGeneration { + public: + static std::shared_ptr fromPackage( + const Context &ctx, + lut::ZipFile *package); + + Tensor prefill(StateMap &past, const Prompt &prompt) const override; + Tensor decode(StateMap &past, LongType inputToken) const override; + + bool isStopToken(int tokenId) const override; + const char *getName() const override; + Device getDevice() const override; + int getOutputDim() const override; + + protected: + std::shared_ptr _encoder; + std::shared_ptr _decoderInit; + std::shared_ptr _decoder; + std::string _modelName; + int _eotId; + + WhisperModelForGeneration(); + void init(const Context &ctx, const lut::IniConfig &config); + Tensor buildDecoderInput(lut::Span prompt) const; +}; + +} // namespace whisper +} // namespace libllm diff --git a/third_party/pocketfft/pocketfft_hdronly.h b/third_party/pocketfft/pocketfft_hdronly.h new file mode 100644 index 00000000..6fd7288f --- /dev/null +++ b/third_party/pocketfft/pocketfft_hdronly.h @@ -0,0 +1,3994 @@ +/* +This file is part of pocketfft. + +Copyright (C) 2010-2022 Max-Planck-Society +Copyright (C) 2019-2020 Peter Bell + +For the odd-sized DCT-IV transforms: + Copyright (C) 2003, 2007-14 Matteo Frigo + Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology + +For the prev_good_size search: + Copyright (C) 2024 Tan Ping Liang, Peter Bell + +Authors: Martin Reinecke, Peter Bell + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. +* Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef POCKETFFT_HDRONLY_H +#define POCKETFFT_HDRONLY_H + +#ifndef __cplusplus +#error This file is C++ and requires a C++ compiler. +#endif + +#if !(__cplusplus >= 201103L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201103L)) +#error This file requires at least C++11 support. +#endif + +#ifndef POCKETFFT_CACHE_SIZE +#define POCKETFFT_CACHE_SIZE 0 +#endif + +#include +#include +#include +#include +#include +#include +#include +#if POCKETFFT_CACHE_SIZE != 0 +#include +#include +#endif + +#ifndef POCKETFFT_NO_MULTITHREADING +#include +#include +#include +#include +#include +#include +#include + +#ifdef POCKETFFT_PTHREADS +#include +#endif +#endif + +#if defined(__GNUC__) +#define POCKETFFT_NOINLINE __attribute__((noinline)) +#define POCKETFFT_RESTRICT __restrict__ +#elif defined(_MSC_VER) +#define POCKETFFT_NOINLINE __declspec(noinline) +#define POCKETFFT_RESTRICT __restrict +#else +#define POCKETFFT_NOINLINE +#define POCKETFFT_RESTRICT +#endif + +namespace pocketfft { + +namespace detail { +using std::ptrdiff_t; +using std::size_t; + +// Always use std:: for functions +template +T cos(T) = delete; +template +T sin(T) = delete; +template +T sqrt(T) = delete; + +using shape_t = std::vector; +using stride_t = std::vector; + +constexpr bool FORWARD = true, BACKWARD = false; + +// only enable vector support for gcc>=5.0 and clang>=5.0 +#ifndef POCKETFFT_NO_VECTORS +#define POCKETFFT_NO_VECTORS +#if defined(__INTEL_COMPILER) +// do nothing. This is necessary because this compiler also sets __GNUC__. +#elif defined(__clang__) +// AppleClang has their own version numbering +#ifdef __apple_build_version__ +#if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1) +#undef POCKETFFT_NO_VECTORS +#endif +#elif __clang_major__ >= 5 +#undef POCKETFFT_NO_VECTORS +#endif +#elif defined(__GNUC__) +#if __GNUC__ >= 5 +#undef POCKETFFT_NO_VECTORS +#endif +#endif +#endif + +template +struct VLEN { + static constexpr size_t val = 1; +}; + +#ifndef POCKETFFT_NO_VECTORS +#if (defined(__AVX512F__)) +template<> +struct VLEN { + static constexpr size_t val = 16; +}; +template<> +struct VLEN { + static constexpr size_t val = 8; +}; +#elif (defined(__AVX__)) +template<> +struct VLEN { + static constexpr size_t val = 8; +}; +template<> +struct VLEN { + static constexpr size_t val = 4; +}; +#elif (defined(__SSE2__)) +template<> +struct VLEN { + static constexpr size_t val = 4; +}; +template<> +struct VLEN { + static constexpr size_t val = 2; +}; +#elif (defined(__VSX__)) +template<> +struct VLEN { + static constexpr size_t val = 4; +}; +template<> +struct VLEN { + static constexpr size_t val = 2; +}; +#elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) +template<> +struct VLEN { + static constexpr size_t val = 4; +}; +template<> +struct VLEN { + static constexpr size_t val = 2; +}; +#else +#define POCKETFFT_NO_VECTORS +#endif +#endif + +// std::aligned_alloc is a bit cursed ... it doesn't exist on MacOS < 10.15 +// and in musl, and other OSes seem to have even more peculiarities. +// Let's unconditionally work around it for now. +#if 0 +//#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER)) && (__MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_15) +inline void *aligned_alloc(size_t align, size_t size) + { + // aligned_alloc() requires that the requested size is a multiple of "align" + void *ptr = ::aligned_alloc(align,(size+align-1)&(~(align-1))); + if (!ptr) throw std::bad_alloc(); + return ptr; + } +inline void aligned_dealloc(void *ptr) + { free(ptr); } +#else // portable emulation +inline void *aligned_alloc(size_t align, size_t size) { + align = std::max(align, alignof(max_align_t)); + void *ptr = malloc(size + align); + if (!ptr) throw std::bad_alloc(); + void *res = reinterpret_cast( + (reinterpret_cast(ptr) & ~(uintptr_t(align - 1))) + uintptr_t(align)); + (reinterpret_cast(res))[-1] = ptr; + return res; +} +inline void aligned_dealloc(void *ptr) { + if (ptr) free((reinterpret_cast(ptr))[-1]); +} +#endif + +template +class arr { + private: + T *p; + size_t sz; + +#if defined(POCKETFFT_NO_VECTORS) + static T *ralloc(size_t num) { + if (num == 0) return nullptr; + void *res = malloc(num * sizeof(T)); + if (!res) throw std::bad_alloc(); + return reinterpret_cast(res); + } + static void dealloc(T *ptr) { + free(ptr); + } +#else + static T *ralloc(size_t num) { + if (num == 0) return nullptr; + void *ptr = aligned_alloc(64, num * sizeof(T)); + return static_cast(ptr); + } + static void dealloc(T *ptr) { + aligned_dealloc(ptr); + } +#endif + + public: + arr() + : p(0), + sz(0) { + } + arr(size_t n) + : p(ralloc(n)), + sz(n) { + } + arr(arr &&other) + : p(other.p), + sz(other.sz) { + other.p = nullptr; + other.sz = 0; + } + ~arr() { + dealloc(p); + } + + void resize(size_t n) { + if (n == sz) return; + dealloc(p); + p = ralloc(n); + sz = n; + } + + T &operator[](size_t idx) { + return p[idx]; + } + const T &operator[](size_t idx) const { + return p[idx]; + } + + T *data() { + return p; + } + const T *data() const { + return p; + } + + size_t size() const { + return sz; + } +}; + +template +struct cmplx { + T r, i; + cmplx() { + } + cmplx(T r_, T i_) + : r(r_), + i(i_) { + } + void Set(T r_, T i_) { + r = r_; + i = i_; + } + void Set(T r_) { + r = r_; + i = T(0); + } + cmplx &operator+=(const cmplx &other) { + r += other.r; + i += other.i; + return *this; + } + template + cmplx &operator*=(T2 other) { + r *= other; + i *= other; + return *this; + } + template + cmplx &operator*=(const cmplx &other) { + T tmp = r * other.r - i * other.i; + i = r * other.i + i * other.r; + r = tmp; + return *this; + } + template + cmplx &operator+=(const cmplx &other) { + r += other.r; + i += other.i; + return *this; + } + template + cmplx &operator-=(const cmplx &other) { + r -= other.r; + i -= other.i; + return *this; + } + template + auto operator*(const T2 &other) const -> cmplx { + return {r * other, i * other}; + } + template + auto operator+(const cmplx &other) const -> cmplx { + return {r + other.r, i + other.i}; + } + template + auto operator-(const cmplx &other) const -> cmplx { + return {r - other.r, i - other.i}; + } + template + auto operator*(const cmplx &other) const -> cmplx { + return {r * other.r - i * other.i, r * other.i + i * other.r}; + } + template + auto special_mul(const cmplx &other) const -> cmplx { + using Tres = cmplx; + return fwd ? Tres(r * other.r + i * other.i, i * other.r - r * other.i) + : Tres(r * other.r - i * other.i, r * other.i + i * other.r); + } +}; +template +inline void PM(T &a, T &b, T c, T d) { + a = c + d; + b = c - d; +} +template +inline void PMINPLACE(T &a, T &b) { + T t = a; + a += b; + b = t - b; +} +template +inline void MPINPLACE(T &a, T &b) { + T t = a; + a -= b; + b = t + b; +} +template +cmplx conj(const cmplx &a) { + return {a.r, -a.i}; +} +template +void special_mul(const cmplx &v1, const cmplx &v2, cmplx &res) { + res = fwd ? cmplx(v1.r * v2.r + v1.i * v2.i, v1.i * v2.r - v1.r * v2.i) + : cmplx(v1.r * v2.r - v1.i * v2.i, v1.r * v2.i + v1.i * v2.r); +} + +template +void ROT90(cmplx &a) { + auto tmp_ = a.r; + a.r = -a.i; + a.i = tmp_; +} +template +void ROTX90(cmplx &a) { + auto tmp_ = fwd ? -a.r : a.r; + a.r = fwd ? a.i : -a.i; + a.i = tmp_; +} + +// +// twiddle factor section +// +template +class sincos_2pibyn { + private: + using Thigh = typename std::conditional<(sizeof(T) > sizeof(double)), T, double>::type; + size_t N, mask, shift; + arr> v1, v2; + + static cmplx calc(size_t x, size_t n, Thigh ang) { + x <<= 3; + if (x < 4 * n) // first half + { + if (x < 2 * n) // first quadrant + { + if (x < n) return cmplx(std::cos(Thigh(x) * ang), std::sin(Thigh(x) * ang)); + return cmplx(std::sin(Thigh(2 * n - x) * ang), std::cos(Thigh(2 * n - x) * ang)); + } else // second quadrant + { + x -= 2 * n; + if (x < n) return cmplx(-std::sin(Thigh(x) * ang), std::cos(Thigh(x) * ang)); + return cmplx(-std::cos(Thigh(2 * n - x) * ang), std::sin(Thigh(2 * n - x) * ang)); + } + } else { + x = 8 * n - x; + if (x < 2 * n) // third quadrant + { + if (x < n) return cmplx(std::cos(Thigh(x) * ang), -std::sin(Thigh(x) * ang)); + return cmplx(std::sin(Thigh(2 * n - x) * ang), -std::cos(Thigh(2 * n - x) * ang)); + } else // fourth quadrant + { + x -= 2 * n; + if (x < n) return cmplx(-std::sin(Thigh(x) * ang), -std::cos(Thigh(x) * ang)); + return cmplx(-std::cos(Thigh(2 * n - x) * ang), -std::sin(Thigh(2 * n - x) * ang)); + } + } + } + + public: + POCKETFFT_NOINLINE sincos_2pibyn(size_t n) + : N(n) { + constexpr auto pi = 3.141592653589793238462643383279502884197L; + Thigh ang = Thigh(0.25L * pi / n); + size_t nval = (n + 2) / 2; + shift = 1; + while ((size_t(1) << shift) * (size_t(1) << shift) < nval) ++shift; + mask = (size_t(1) << shift) - 1; + v1.resize(mask + 1); + v1[0].Set(Thigh(1), Thigh(0)); + for (size_t i = 1; i < v1.size(); ++i) v1[i] = calc(i, n, ang); + v2.resize((nval + mask) / (mask + 1)); + v2[0].Set(Thigh(1), Thigh(0)); + for (size_t i = 1; i < v2.size(); ++i) v2[i] = calc(i * (mask + 1), n, ang); + } + + cmplx operator[](size_t idx) const { + if (2 * idx <= N) { + auto x1 = v1[idx & mask], x2 = v2[idx >> shift]; + return cmplx(T(x1.r * x2.r - x1.i * x2.i), T(x1.r * x2.i + x1.i * x2.r)); + } + idx = N - idx; + auto x1 = v1[idx & mask], x2 = v2[idx >> shift]; + return cmplx(T(x1.r * x2.r - x1.i * x2.i), -T(x1.r * x2.i + x1.i * x2.r)); + } +}; + +struct util // hack to avoid duplicate symbols +{ + static POCKETFFT_NOINLINE size_t largest_prime_factor(size_t n) { + size_t res = 1; + while ((n & 1) == 0) { + res = 2; + n >>= 1; + } + for (size_t x = 3; x * x <= n; x += 2) + while ((n % x) == 0) { + res = x; + n /= x; + } + if (n > 1) res = n; + return res; + } + + static POCKETFFT_NOINLINE double cost_guess(size_t n) { + constexpr double lfp = 1.1; // penalty for non-hardcoded larger factors + size_t ni = n; + double result = 0.; + while ((n & 1) == 0) { + result += 2; + n >>= 1; + } + for (size_t x = 3; x * x <= n; x += 2) + while ((n % x) == 0) { + result += (x <= 5) ? double(x) : lfp * double(x); // penalize larger prime factors + n /= x; + } + if (n > 1) result += (n <= 5) ? double(n) : lfp * double(n); + return result * double(ni); + } + + /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) { + if (n <= 12) return n; + + size_t bestfac = 2 * n; + for (size_t f11 = 1; f11 < bestfac; f11 *= 11) + for (size_t f117 = f11; f117 < bestfac; f117 *= 7) + for (size_t f1175 = f117; f1175 < bestfac; f1175 *= 5) { + size_t x = f1175; + while (x < n) x *= 2; + for (;;) { + if (x < n) + x *= 3; + else if (x > n) { + if (x < bestfac) bestfac = x; + if (x & 1) break; + x >>= 1; + } else + return n; + } + } + return bestfac; + } + /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n + and a multiple of required_factor. */ + static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n, size_t required_factor) { + if (required_factor < 1) throw std::runtime_error("required factor must not be 0"); + return good_size_cmplx((n + required_factor - 1) / required_factor) * required_factor; + } + + /* returns the smallest composite of 2, 3, 5 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_real(size_t n) { + if (n <= 6) return n; + + size_t bestfac = 2 * n; + for (size_t f5 = 1; f5 < bestfac; f5 *= 5) { + size_t x = f5; + while (x < n) x *= 2; + for (;;) { + if (x < n) + x *= 3; + else if (x > n) { + if (x < bestfac) bestfac = x; + if (x & 1) break; + x >>= 1; + } else + return n; + } + } + return bestfac; + } + /* returns the smallest composite of 2, 3, 5 which is >= n + and a multiple of required_factor. */ + static POCKETFFT_NOINLINE size_t good_size_real(size_t n, size_t required_factor) { + if (required_factor < 1) throw std::runtime_error("required factor must not be 0"); + return good_size_real((n + required_factor - 1) / required_factor) * required_factor; + } + + /* returns the largest composite of 2, 3, 5, 7 and 11 which is <= n */ + static POCKETFFT_NOINLINE size_t prev_good_size_cmplx(size_t n) { + if (n <= 12) return n; + + size_t bestfound = 1; + for (size_t f11 = 1; f11 <= n; f11 *= 11) + for (size_t f117 = f11; f117 <= n; f117 *= 7) + for (size_t f1175 = f117; f1175 <= n; f1175 *= 5) { + size_t x = f1175; + while (x * 2 <= n) x *= 2; + if (x > bestfound) bestfound = x; + while (true) { + if (x * 3 <= n) + x *= 3; + else if (x % 2 == 0) + x /= 2; + else + break; + + if (x > bestfound) bestfound = x; + } + } + return bestfound; + } + + /* returns the largest composite of 2, 3, 5 which is <= n */ + static POCKETFFT_NOINLINE size_t prev_good_size_real(size_t n) { + if (n <= 6) return n; + + size_t bestfound = 1; + for (size_t f5 = 1; f5 <= n; f5 *= 5) { + size_t x = f5; + while (x * 2 <= n) x *= 2; + if (x > bestfound) bestfound = x; + while (true) { + if (x * 3 <= n) + x *= 3; + else if (x % 2 == 0) + x /= 2; + else + break; + + if (x > bestfound) bestfound = x; + } + } + return bestfound; + } + + static size_t prod(const shape_t &shape) { + size_t res = 1; + for (auto sz : shape) res *= sz; + return res; + } + + static POCKETFFT_NOINLINE void sanity_check( + const shape_t &shape, + const stride_t &stride_in, + const stride_t &stride_out, + bool inplace) { + auto ndim = shape.size(); + if (ndim < 1) throw std::runtime_error("ndim must be >= 1"); + if ((stride_in.size() != ndim) || (stride_out.size() != ndim)) + throw std::runtime_error("stride dimension mismatch"); + if (inplace && (stride_in != stride_out)) throw std::runtime_error("stride mismatch"); + } + + static POCKETFFT_NOINLINE void sanity_check( + const shape_t &shape, + const stride_t &stride_in, + const stride_t &stride_out, + bool inplace, + const shape_t &axes) { + sanity_check(shape, stride_in, stride_out, inplace); + auto ndim = shape.size(); + shape_t tmp(ndim, 0); + for (auto ax : axes) { + if (ax >= ndim) throw std::invalid_argument("bad axis number"); + if (++tmp[ax] > 1) throw std::invalid_argument("axis specified repeatedly"); + } + } + + static POCKETFFT_NOINLINE void sanity_check( + const shape_t &shape, + const stride_t &stride_in, + const stride_t &stride_out, + bool inplace, + size_t axis) { + sanity_check(shape, stride_in, stride_out, inplace); + if (axis >= shape.size()) throw std::invalid_argument("bad axis number"); + } + +#ifdef POCKETFFT_NO_MULTITHREADING + static size_t + thread_count(size_t /*nthreads*/, const shape_t & /*shape*/, size_t /*axis*/, size_t /*vlen*/) { + return 1; + } +#else + static size_t thread_count(size_t nthreads, const shape_t &shape, size_t axis, size_t vlen) { + if (nthreads == 1) return 1; + size_t size = prod(shape); + size_t parallel = size / (shape[axis] * vlen); + if (shape[axis] < 1000) parallel /= 4; + size_t max_threads = nthreads == 0 ? std::thread::hardware_concurrency() : nthreads; + return std::max(size_t(1), std::min(parallel, max_threads)); + } +#endif +}; + +namespace threading { + +#ifdef POCKETFFT_NO_MULTITHREADING + +constexpr inline size_t thread_id() { + return 0; +} +constexpr inline size_t num_threads() { + return 1; +} + +template +void thread_map(size_t /* nthreads */, Func f) { + f(); +} + +#else + +inline size_t &thread_id() { + static thread_local size_t thread_id_ = 0; + return thread_id_; +} +inline size_t &num_threads() { + static thread_local size_t num_threads_ = 1; + return num_threads_; +} +static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); + +class latch { + std::atomic num_left_; + std::mutex mut_; + std::condition_variable completed_; + using lock_t = std::unique_lock; + + public: + latch(size_t n) + : num_left_(n) { + } + + void count_down() { + lock_t lock(mut_); + if (--num_left_) return; + completed_.notify_all(); + } + + void wait() { + lock_t lock(mut_); + completed_.wait(lock, [this] { return is_ready(); }); + } + bool is_ready() { + return num_left_ == 0; + } +}; + +template +class concurrent_queue { + std::queue q_; + std::mutex mut_; + std::atomic size_; + using lock_t = std::lock_guard; + + public: + void push(T val) { + lock_t lock(mut_); + ++size_; + q_.push(std::move(val)); + } + + bool try_pop(T &val) { + if (size_ == 0) return false; + lock_t lock(mut_); + // Queue might have been emptied while we acquired the lock + if (q_.empty()) return false; + + val = std::move(q_.front()); + --size_; + q_.pop(); + return true; + } + + bool empty() const { + return size_ == 0; + } +}; + +// C++ allocator with support for over-aligned types +template +struct aligned_allocator { + using value_type = T; + template + aligned_allocator(const aligned_allocator &) { + } + aligned_allocator() = default; + + T *allocate(size_t n) { + void *mem = aligned_alloc(alignof(T), n * sizeof(T)); + return static_cast(mem); + } + + void deallocate(T *p, size_t /*n*/) { + aligned_dealloc(p); + } +}; + +class thread_pool { + // A reasonable guess, probably close enough for most hardware + static constexpr size_t cache_line_size = 64; + struct alignas(cache_line_size) worker { + std::thread thread; + std::condition_variable work_ready; + std::mutex mut; + std::atomic_flag busy_flag = ATOMIC_FLAG_INIT; + std::function work; + + void worker_main( + std::atomic &shutdown_flag, + std::atomic &unscheduled_tasks, + concurrent_queue> &overflow_work) { + using lock_t = std::unique_lock; + bool expect_work = true; + while (!shutdown_flag || expect_work) { + std::function local_work; + if (expect_work || unscheduled_tasks == 0) { + lock_t lock(mut); + // Wait until there is work to be executed + work_ready.wait(lock, [&] { return (work || shutdown_flag); }); + local_work.swap(work); + expect_work = false; + } + + bool marked_busy = false; + if (local_work) { + marked_busy = true; + local_work(); + } + + if (!overflow_work.empty()) { + if (!marked_busy && busy_flag.test_and_set()) { + expect_work = true; + continue; + } + marked_busy = true; + + while (overflow_work.try_pop(local_work)) { + --unscheduled_tasks; + local_work(); + } + } + + if (marked_busy) busy_flag.clear(); + } + } + }; + + concurrent_queue> overflow_work_; + std::mutex mut_; + std::vector> workers_; + std::atomic shutdown_; + std::atomic unscheduled_tasks_; + using lock_t = std::lock_guard; + + void create_threads() { + lock_t lock(mut_); + size_t nthreads = workers_.size(); + for (size_t i = 0; i < nthreads; ++i) { + try { + auto *worker = &workers_[i]; + worker->busy_flag.clear(); + worker->work = nullptr; + worker->thread = std::thread( + [worker, this] { worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_); }); + } catch (...) { + shutdown_locked(); + throw; + } + } + } + + void shutdown_locked() { + shutdown_ = true; + for (auto &worker : workers_) worker.work_ready.notify_all(); + + for (auto &worker : workers_) + if (worker.thread.joinable()) worker.thread.join(); + } + + public: + explicit thread_pool(size_t nthreads) + : workers_(nthreads) { + create_threads(); + } + + thread_pool() + : thread_pool(max_threads) { + } + + ~thread_pool() { + shutdown(); + } + + void submit(std::function work) { + lock_t lock(mut_); + if (shutdown_) throw std::runtime_error("Work item submitted after shutdown"); + + ++unscheduled_tasks_; + + // First check for any idle workers and wake those + for (auto &worker : workers_) + if (!worker.busy_flag.test_and_set()) { + --unscheduled_tasks_; + { + lock_t lock(worker.mut); + worker.work = std::move(work); + } + worker.work_ready.notify_one(); + return; + } + + // If no workers were idle, push onto the overflow queue for later + overflow_work_.push(std::move(work)); + } + + void shutdown() { + lock_t lock(mut_); + shutdown_locked(); + } + + void restart() { + shutdown_ = false; + create_threads(); + } +}; + +inline thread_pool &get_pool() { + static thread_pool pool; +#ifdef POCKETFFT_PTHREADS + static std::once_flag f; + std::call_once(f, [] { + pthread_atfork( + +[] { get_pool().shutdown(); }, // prepare + +[] { get_pool().restart(); }, // parent + +[] { get_pool().restart(); } // child + ); + }); +#endif + + return pool; +} + +/** Map a function f over nthreads */ +template +void thread_map(size_t nthreads, Func f) { + if (nthreads == 0) nthreads = max_threads; + + if (nthreads == 1) { + f(); + return; + } + + auto &pool = get_pool(); + latch counter(nthreads); + std::exception_ptr ex; + std::mutex ex_mut; + for (size_t i = 0; i < nthreads; ++i) { + pool.submit([&f, &counter, &ex, &ex_mut, i, nthreads] { + thread_id() = i; + num_threads() = nthreads; + try { + f(); + } catch (...) { + std::lock_guard lock(ex_mut); + ex = std::current_exception(); + } + counter.count_down(); + }); + } + counter.wait(); + if (ex) std::rethrow_exception(ex); +} + +#endif + +} // namespace threading + +// +// complex FFTPACK transforms +// + +template +class cfftp { + private: + struct fctdata { + size_t fct; + cmplx *tw, *tws; + }; + + size_t length; + arr> mem; + std::vector fact; + + void add_factor(size_t factor) { + fact.push_back({factor, nullptr, nullptr}); + } + + template + void pass2( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const cmplx *POCKETFFT_RESTRICT wa) const { + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 2 * c)]; + }; + auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; + + if (ido == 1) + for (size_t k = 0; k < l1; ++k) { + CH(0, k, 0) = CC(0, 0, k) + CC(0, 1, k); + CH(0, k, 1) = CC(0, 0, k) - CC(0, 1, k); + } + else + for (size_t k = 0; k < l1; ++k) { + CH(0, k, 0) = CC(0, 0, k) + CC(0, 1, k); + CH(0, k, 1) = CC(0, 0, k) - CC(0, 1, k); + for (size_t i = 1; i < ido; ++i) { + CH(i, k, 0) = CC(i, 0, k) + CC(i, 1, k); + special_mul(CC(i, 0, k) - CC(i, 1, k), WA(0, i), CH(i, k, 1)); + } + } + } + +#define POCKETFFT_PREP3(idx) \ + T t0 = CC(idx, 0, k), t1, t2; \ + PM(t1, t2, CC(idx, 1, k), CC(idx, 2, k)); \ + CH(idx, k, 0) = t0 + t1; +#define POCKETFFT_PARTSTEP3a(u1, u2, twr, twi) \ + { \ + T ca = t0 + t1 * twr; \ + T cb{-t2.i * twi, t2.r * twi}; \ + PM(CH(0, k, u1), CH(0, k, u2), ca, cb); \ + } +#define POCKETFFT_PARTSTEP3b(u1, u2, twr, twi) \ + { \ + T ca = t0 + t1 * twr; \ + T cb{-t2.i * twi, t2.r * twi}; \ + special_mul(ca + cb, WA(u1 - 1, i), CH(i, k, u1)); \ + special_mul(ca - cb, WA(u2 - 1, i), CH(i, k, u2)); \ + } + template + void pass3( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const cmplx *POCKETFFT_RESTRICT wa) const { + constexpr T0 tw1r = -0.5, tw1i = (fwd ? -1 : 1) * T0(0.8660254037844386467637231707529362L); + + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 3 * c)]; + }; + auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; + + if (ido == 1) + for (size_t k = 0; k < l1; ++k) { + POCKETFFT_PREP3(0) + POCKETFFT_PARTSTEP3a(1, 2, tw1r, tw1i) + } + else + for (size_t k = 0; k < l1; ++k) { + { + POCKETFFT_PREP3(0) + POCKETFFT_PARTSTEP3a(1, 2, tw1r, tw1i) + } + for (size_t i = 1; i < ido; ++i) { + POCKETFFT_PREP3(i) + POCKETFFT_PARTSTEP3b(1, 2, tw1r, tw1i) + } + } + } + +#undef POCKETFFT_PARTSTEP3b +#undef POCKETFFT_PARTSTEP3a +#undef POCKETFFT_PREP3 + + template + void pass4( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const cmplx *POCKETFFT_RESTRICT wa) const { + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 4 * c)]; + }; + auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; + + if (ido == 1) + for (size_t k = 0; k < l1; ++k) { + T t1, t2, t3, t4; + PM(t2, t1, CC(0, 0, k), CC(0, 2, k)); + PM(t3, t4, CC(0, 1, k), CC(0, 3, k)); + ROTX90(t4); + PM(CH(0, k, 0), CH(0, k, 2), t2, t3); + PM(CH(0, k, 1), CH(0, k, 3), t1, t4); + } + else + for (size_t k = 0; k < l1; ++k) { + { + T t1, t2, t3, t4; + PM(t2, t1, CC(0, 0, k), CC(0, 2, k)); + PM(t3, t4, CC(0, 1, k), CC(0, 3, k)); + ROTX90(t4); + PM(CH(0, k, 0), CH(0, k, 2), t2, t3); + PM(CH(0, k, 1), CH(0, k, 3), t1, t4); + } + for (size_t i = 1; i < ido; ++i) { + T t1, t2, t3, t4; + T cc0 = CC(i, 0, k), cc1 = CC(i, 1, k), cc2 = CC(i, 2, k), cc3 = CC(i, 3, k); + PM(t2, t1, cc0, cc2); + PM(t3, t4, cc1, cc3); + ROTX90(t4); + CH(i, k, 0) = t2 + t3; + special_mul(t1 + t4, WA(0, i), CH(i, k, 1)); + special_mul(t2 - t3, WA(1, i), CH(i, k, 2)); + special_mul(t1 - t4, WA(2, i), CH(i, k, 3)); + } + } + } + +#define POCKETFFT_PREP5(idx) \ + T t0 = CC(idx, 0, k), t1, t2, t3, t4; \ + PM(t1, t4, CC(idx, 1, k), CC(idx, 4, k)); \ + PM(t2, t3, CC(idx, 2, k), CC(idx, 3, k)); \ + CH(idx, k, 0).r = t0.r + t1.r + t2.r; \ + CH(idx, k, 0).i = t0.i + t1.i + t2.i; + +#define POCKETFFT_PARTSTEP5a(u1, u2, twar, twbr, twai, twbi) \ + { \ + T ca, cb; \ + ca.r = t0.r + twar * t1.r + twbr * t2.r; \ + ca.i = t0.i + twar * t1.i + twbr * t2.i; \ + cb.i = twai * t4.r twbi * t3.r; \ + cb.r = -(twai * t4.i twbi * t3.i); \ + PM(CH(0, k, u1), CH(0, k, u2), ca, cb); \ + } + +#define POCKETFFT_PARTSTEP5b(u1, u2, twar, twbr, twai, twbi) \ + { \ + T ca, cb, da, db; \ + ca.r = t0.r + twar * t1.r + twbr * t2.r; \ + ca.i = t0.i + twar * t1.i + twbr * t2.i; \ + cb.i = twai * t4.r twbi * t3.r; \ + cb.r = -(twai * t4.i twbi * t3.i); \ + special_mul(ca + cb, WA(u1 - 1, i), CH(i, k, u1)); \ + special_mul(ca - cb, WA(u2 - 1, i), CH(i, k, u2)); \ + } + template + void pass5( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const cmplx *POCKETFFT_RESTRICT wa) const { + constexpr T0 tw1r = T0(0.3090169943749474241022934171828191L), + tw1i = (fwd ? -1 : 1) * T0(0.9510565162951535721164393333793821L), + tw2r = T0(-0.8090169943749474241022934171828191L), + tw2i = (fwd ? -1 : 1) * T0(0.5877852522924731291687059546390728L); + + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 5 * c)]; + }; + auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; + + if (ido == 1) + for (size_t k = 0; k < l1; ++k) { + POCKETFFT_PREP5(0) + POCKETFFT_PARTSTEP5a(1, 4, tw1r, tw2r, +tw1i, +tw2i) + POCKETFFT_PARTSTEP5a(2, 3, tw2r, tw1r, +tw2i, -tw1i) + } + else + for (size_t k = 0; k < l1; ++k) { + { + POCKETFFT_PREP5(0) + POCKETFFT_PARTSTEP5a(1, 4, tw1r, tw2r, +tw1i, +tw2i) + POCKETFFT_PARTSTEP5a(2, 3, tw2r, tw1r, +tw2i, -tw1i) + } + for (size_t i = 1; i < ido; ++i) { + POCKETFFT_PREP5(i) + POCKETFFT_PARTSTEP5b(1, 4, tw1r, tw2r, +tw1i, +tw2i) + POCKETFFT_PARTSTEP5b(2, 3, tw2r, tw1r, +tw2i, -tw1i) + } + } + } + +#undef POCKETFFT_PARTSTEP5b +#undef POCKETFFT_PARTSTEP5a +#undef POCKETFFT_PREP5 + +#define POCKETFFT_PREP7(idx) \ + T t1 = CC(idx, 0, k), t2, t3, t4, t5, t6, t7; \ + PM(t2, t7, CC(idx, 1, k), CC(idx, 6, k)); \ + PM(t3, t6, CC(idx, 2, k), CC(idx, 5, k)); \ + PM(t4, t5, CC(idx, 3, k), CC(idx, 4, k)); \ + CH(idx, k, 0).r = t1.r + t2.r + t3.r + t4.r; \ + CH(idx, k, 0).i = t1.i + t2.i + t3.i + t4.i; + +#define POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, out1, out2) \ + { \ + T ca, cb; \ + ca.r = t1.r + x1 * t2.r + x2 * t3.r + x3 * t4.r; \ + ca.i = t1.i + x1 * t2.i + x2 * t3.i + x3 * t4.i; \ + cb.i = y1 * t7.r y2 * t6.r y3 * t5.r; \ + cb.r = -(y1 * t7.i y2 * t6.i y3 * t5.i); \ + PM(out1, out2, ca, cb); \ + } +#define POCKETFFT_PARTSTEP7a(u1, u2, x1, x2, x3, y1, y2, y3) \ + POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, CH(0, k, u1), CH(0, k, u2)) +#define POCKETFFT_PARTSTEP7(u1, u2, x1, x2, x3, y1, y2, y3) \ + { \ + T da, db; \ + POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, da, db) \ + special_mul(da, WA(u1 - 1, i), CH(i, k, u1)); \ + special_mul(db, WA(u2 - 1, i), CH(i, k, u2)); \ + } + + template + void pass7( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const cmplx *POCKETFFT_RESTRICT wa) const { + constexpr T0 tw1r = T0(0.6234898018587335305250048840042398L), + tw1i = (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L), + tw2r = T0(-0.2225209339563144042889025644967948L), + tw2i = (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L), + tw3r = T0(-0.9009688679024191262361023195074451L), + tw3i = (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L); + + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 7 * c)]; + }; + auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; + + if (ido == 1) + for (size_t k = 0; k < l1; ++k) { + POCKETFFT_PREP7(0) + POCKETFFT_PARTSTEP7a(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i) + POCKETFFT_PARTSTEP7a(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i) + POCKETFFT_PARTSTEP7a(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i) + } + else + for (size_t k = 0; k < l1; ++k) { + { + POCKETFFT_PREP7(0) + POCKETFFT_PARTSTEP7a(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i) + POCKETFFT_PARTSTEP7a(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i) + POCKETFFT_PARTSTEP7a(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i) + } + for (size_t i = 1; i < ido; ++i) { + POCKETFFT_PREP7(i) + POCKETFFT_PARTSTEP7(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i) + POCKETFFT_PARTSTEP7(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i) + POCKETFFT_PARTSTEP7(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i) + } + } + } + +#undef POCKETFFT_PARTSTEP7 +#undef POCKETFFT_PARTSTEP7a0 +#undef POCKETFFT_PARTSTEP7a +#undef POCKETFFT_PREP7 + + template + void ROTX45(T &a) const { + constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L); + if (fwd) { + auto tmp_ = a.r; + a.r = hsqt2 * (a.r + a.i); + a.i = hsqt2 * (a.i - tmp_); + } else { + auto tmp_ = a.r; + a.r = hsqt2 * (a.r - a.i); + a.i = hsqt2 * (a.i + tmp_); + } + } + template + void ROTX135(T &a) const { + constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L); + if (fwd) { + auto tmp_ = a.r; + a.r = hsqt2 * (a.i - a.r); + a.i = hsqt2 * (-tmp_ - a.i); + } else { + auto tmp_ = a.r; + a.r = hsqt2 * (-a.r - a.i); + a.i = hsqt2 * (tmp_ - a.i); + } + } + + template + void pass8( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const cmplx *POCKETFFT_RESTRICT wa) const { + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 8 * c)]; + }; + auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; + + if (ido == 1) + for (size_t k = 0; k < l1; ++k) { + T a0, a1, a2, a3, a4, a5, a6, a7; + PM(a1, a5, CC(0, 1, k), CC(0, 5, k)); + PM(a3, a7, CC(0, 3, k), CC(0, 7, k)); + PMINPLACE(a1, a3); + ROTX90(a3); + + ROTX90(a7); + PMINPLACE(a5, a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0, a4, CC(0, 0, k), CC(0, 4, k)); + PM(a2, a6, CC(0, 2, k), CC(0, 6, k)); + PM(CH(0, k, 0), CH(0, k, 4), a0 + a2, a1); + PM(CH(0, k, 2), CH(0, k, 6), a0 - a2, a3); + ROTX90(a6); + PM(CH(0, k, 1), CH(0, k, 5), a4 + a6, a5); + PM(CH(0, k, 3), CH(0, k, 7), a4 - a6, a7); + } + else + for (size_t k = 0; k < l1; ++k) { + { + T a0, a1, a2, a3, a4, a5, a6, a7; + PM(a1, a5, CC(0, 1, k), CC(0, 5, k)); + PM(a3, a7, CC(0, 3, k), CC(0, 7, k)); + PMINPLACE(a1, a3); + ROTX90(a3); + + ROTX90(a7); + PMINPLACE(a5, a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0, a4, CC(0, 0, k), CC(0, 4, k)); + PM(a2, a6, CC(0, 2, k), CC(0, 6, k)); + PM(CH(0, k, 0), CH(0, k, 4), a0 + a2, a1); + PM(CH(0, k, 2), CH(0, k, 6), a0 - a2, a3); + ROTX90(a6); + PM(CH(0, k, 1), CH(0, k, 5), a4 + a6, a5); + PM(CH(0, k, 3), CH(0, k, 7), a4 - a6, a7); + } + for (size_t i = 1; i < ido; ++i) { + T a0, a1, a2, a3, a4, a5, a6, a7; + PM(a1, a5, CC(i, 1, k), CC(i, 5, k)); + PM(a3, a7, CC(i, 3, k), CC(i, 7, k)); + ROTX90(a7); + PMINPLACE(a1, a3); + ROTX90(a3); + PMINPLACE(a5, a7); + ROTX45(a5); + ROTX135(a7); + PM(a0, a4, CC(i, 0, k), CC(i, 4, k)); + PM(a2, a6, CC(i, 2, k), CC(i, 6, k)); + PMINPLACE(a0, a2); + CH(i, k, 0) = a0 + a1; + special_mul(a0 - a1, WA(3, i), CH(i, k, 4)); + special_mul(a2 + a3, WA(1, i), CH(i, k, 2)); + special_mul(a2 - a3, WA(5, i), CH(i, k, 6)); + ROTX90(a6); + PMINPLACE(a4, a6); + special_mul(a4 + a5, WA(0, i), CH(i, k, 1)); + special_mul(a4 - a5, WA(4, i), CH(i, k, 5)); + special_mul(a6 + a7, WA(2, i), CH(i, k, 3)); + special_mul(a6 - a7, WA(6, i), CH(i, k, 7)); + } + } + } + +#define POCKETFFT_PREP11(idx) \ + T t1 = CC(idx, 0, k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \ + PM(t2, t11, CC(idx, 1, k), CC(idx, 10, k)); \ + PM(t3, t10, CC(idx, 2, k), CC(idx, 9, k)); \ + PM(t4, t9, CC(idx, 3, k), CC(idx, 8, k)); \ + PM(t5, t8, CC(idx, 4, k), CC(idx, 7, k)); \ + PM(t6, t7, CC(idx, 5, k), CC(idx, 6, k)); \ + CH(idx, k, 0).r = t1.r + t2.r + t3.r + t4.r + t5.r + t6.r; \ + CH(idx, k, 0).i = t1.i + t2.i + t3.i + t4.i + t5.i + t6.i; + +#define POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, out1, out2) \ + { \ + T ca = t1 + t2 * x1 + t3 * x2 + t4 * x3 + t5 * x4 + t6 * x5, cb; \ + cb.i = y1 * t11.r y2 * t10.r y3 * t9.r y4 * t8.r y5 * t7.r; \ + cb.r = -(y1 * t11.i y2 * t10.i y3 * t9.i y4 * t8.i y5 * t7.i); \ + PM(out1, out2, ca, cb); \ + } +#define POCKETFFT_PARTSTEP11a(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5) \ + POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, CH(0, k, u1), CH(0, k, u2)) +#define POCKETFFT_PARTSTEP11(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5) \ + { \ + T da, db; \ + POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, da, db) \ + special_mul(da, WA(u1 - 1, i), CH(i, k, u1)); \ + special_mul(db, WA(u2 - 1, i), CH(i, k, u2)); \ + } + + template + void pass11( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const cmplx *POCKETFFT_RESTRICT wa) const { + constexpr T0 tw1r = T0(0.8412535328311811688618116489193677L), + tw1i = (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L), + tw2r = T0(0.4154150130018864255292741492296232L), + tw2i = (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L), + tw3r = T0(-0.1423148382732851404437926686163697L), + tw3i = (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L), + tw4r = T0(-0.6548607339452850640569250724662936L), + tw4i = (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L), + tw5r = T0(-0.9594929736144973898903680570663277L), + tw5i = (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L); + + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 11 * c)]; + }; + auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; }; + + if (ido == 1) + for (size_t k = 0; k < l1; ++k) { + POCKETFFT_PREP11(0) + POCKETFFT_PARTSTEP11a( + 1, + 10, + tw1r, + tw2r, + tw3r, + tw4r, + tw5r, + +tw1i, + +tw2i, + +tw3i, + +tw4i, + +tw5i) + POCKETFFT_PARTSTEP11a( + 2, + 9, + tw2r, + tw4r, + tw5r, + tw3r, + tw1r, + +tw2i, + +tw4i, + -tw5i, + -tw3i, + -tw1i) + POCKETFFT_PARTSTEP11a( + 3, + 8, + tw3r, + tw5r, + tw2r, + tw1r, + tw4r, + +tw3i, + -tw5i, + -tw2i, + +tw1i, + +tw4i) + POCKETFFT_PARTSTEP11a( + 4, + 7, + tw4r, + tw3r, + tw1r, + tw5r, + tw2r, + +tw4i, + -tw3i, + +tw1i, + +tw5i, + -tw2i) + POCKETFFT_PARTSTEP11a( + 5, + 6, + tw5r, + tw1r, + tw4r, + tw2r, + tw3r, + +tw5i, + -tw1i, + +tw4i, + -tw2i, + +tw3i) + } + else + for (size_t k = 0; k < l1; ++k) { + { + POCKETFFT_PREP11(0) + POCKETFFT_PARTSTEP11a( + 1, + 10, + tw1r, + tw2r, + tw3r, + tw4r, + tw5r, + +tw1i, + +tw2i, + +tw3i, + +tw4i, + +tw5i) + POCKETFFT_PARTSTEP11a( + 2, + 9, + tw2r, + tw4r, + tw5r, + tw3r, + tw1r, + +tw2i, + +tw4i, + -tw5i, + -tw3i, + -tw1i) + POCKETFFT_PARTSTEP11a( + 3, + 8, + tw3r, + tw5r, + tw2r, + tw1r, + tw4r, + +tw3i, + -tw5i, + -tw2i, + +tw1i, + +tw4i) + POCKETFFT_PARTSTEP11a( + 4, + 7, + tw4r, + tw3r, + tw1r, + tw5r, + tw2r, + +tw4i, + -tw3i, + +tw1i, + +tw5i, + -tw2i) + POCKETFFT_PARTSTEP11a( + 5, + 6, + tw5r, + tw1r, + tw4r, + tw2r, + tw3r, + +tw5i, + -tw1i, + +tw4i, + -tw2i, + +tw3i) + } + for (size_t i = 1; i < ido; ++i) { + POCKETFFT_PREP11(i) + POCKETFFT_PARTSTEP11( + 1, + 10, + tw1r, + tw2r, + tw3r, + tw4r, + tw5r, + +tw1i, + +tw2i, + +tw3i, + +tw4i, + +tw5i) + POCKETFFT_PARTSTEP11( + 2, + 9, + tw2r, + tw4r, + tw5r, + tw3r, + tw1r, + +tw2i, + +tw4i, + -tw5i, + -tw3i, + -tw1i) + POCKETFFT_PARTSTEP11( + 3, + 8, + tw3r, + tw5r, + tw2r, + tw1r, + tw4r, + +tw3i, + -tw5i, + -tw2i, + +tw1i, + +tw4i) + POCKETFFT_PARTSTEP11( + 4, + 7, + tw4r, + tw3r, + tw1r, + tw5r, + tw2r, + +tw4i, + -tw3i, + +tw1i, + +tw5i, + -tw2i) + POCKETFFT_PARTSTEP11( + 5, + 6, + tw5r, + tw1r, + tw4r, + tw2r, + tw3r, + +tw5i, + -tw1i, + +tw4i, + -tw2i, + +tw3i) + } + } + } + +#undef POCKETFFT_PARTSTEP11 +#undef POCKETFFT_PARTSTEP11a0 +#undef POCKETFFT_PARTSTEP11a +#undef POCKETFFT_PREP11 + + template + void passg( + size_t ido, + size_t ip, + size_t l1, + T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const cmplx *POCKETFFT_RESTRICT wa, + const cmplx *POCKETFFT_RESTRICT csarr) const { + const size_t cdim = ip; + size_t ipph = (ip + 1) / 2; + size_t idl1 = ido * l1; + + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + cdim * c)]; + }; + auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T & { + return cc[a + ido * (b + l1 * c)]; + }; + auto CX2 = [cc, idl1](size_t a, size_t b) -> T & { return cc[a + idl1 * b]; }; + auto CH2 = [ch, idl1](size_t a, size_t b) -> const T & { return ch[a + idl1 * b]; }; + + arr> wal(ip); + wal[0] = cmplx(1., 0.); + for (size_t i = 1; i < ip; ++i) wal[i] = cmplx(csarr[i].r, fwd ? -csarr[i].i : csarr[i].i); + + for (size_t k = 0; k < l1; ++k) + for (size_t i = 0; i < ido; ++i) CH(i, k, 0) = CC(i, 0, k); + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) + for (size_t k = 0; k < l1; ++k) + for (size_t i = 0; i < ido; ++i) PM(CH(i, k, j), CH(i, k, jc), CC(i, j, k), CC(i, jc, k)); + for (size_t k = 0; k < l1; ++k) + for (size_t i = 0; i < ido; ++i) { + T tmp = CH(i, k, 0); + for (size_t j = 1; j < ipph; ++j) tmp += CH(i, k, j); + CX(i, k, 0) = tmp; + } + for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) { + // j=0 + for (size_t ik = 0; ik < idl1; ++ik) { + CX2(ik, l).r = CH2(ik, 0).r + wal[l].r * CH2(ik, 1).r + wal[2 * l].r * CH2(ik, 2).r; + CX2(ik, l).i = CH2(ik, 0).i + wal[l].r * CH2(ik, 1).i + wal[2 * l].r * CH2(ik, 2).i; + CX2(ik, lc).r = -wal[l].i * CH2(ik, ip - 1).i - wal[2 * l].i * CH2(ik, ip - 2).i; + CX2(ik, lc).i = wal[l].i * CH2(ik, ip - 1).r + wal[2 * l].i * CH2(ik, ip - 2).r; + } + + size_t iwal = 2 * l; + size_t j = 3, jc = ip - 3; + for (; j < ipph - 1; j += 2, jc -= 2) { + iwal += l; + if (iwal > ip) iwal -= ip; + cmplx xwal = wal[iwal]; + iwal += l; + if (iwal > ip) iwal -= ip; + cmplx xwal2 = wal[iwal]; + for (size_t ik = 0; ik < idl1; ++ik) { + CX2(ik, l).r += CH2(ik, j).r * xwal.r + CH2(ik, j + 1).r * xwal2.r; + CX2(ik, l).i += CH2(ik, j).i * xwal.r + CH2(ik, j + 1).i * xwal2.r; + CX2(ik, lc).r -= CH2(ik, jc).i * xwal.i + CH2(ik, jc - 1).i * xwal2.i; + CX2(ik, lc).i += CH2(ik, jc).r * xwal.i + CH2(ik, jc - 1).r * xwal2.i; + } + } + for (; j < ipph; ++j, --jc) { + iwal += l; + if (iwal > ip) iwal -= ip; + cmplx xwal = wal[iwal]; + for (size_t ik = 0; ik < idl1; ++ik) { + CX2(ik, l).r += CH2(ik, j).r * xwal.r; + CX2(ik, l).i += CH2(ik, j).i * xwal.r; + CX2(ik, lc).r -= CH2(ik, jc).i * xwal.i; + CX2(ik, lc).i += CH2(ik, jc).r * xwal.i; + } + } + } + + // shuffling and twiddling + if (ido == 1) + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) + for (size_t ik = 0; ik < idl1; ++ik) { + T t1 = CX2(ik, j), t2 = CX2(ik, jc); + PM(CX2(ik, j), CX2(ik, jc), t1, t2); + } + else { + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) + for (size_t k = 0; k < l1; ++k) { + T t1 = CX(0, k, j), t2 = CX(0, k, jc); + PM(CX(0, k, j), CX(0, k, jc), t1, t2); + for (size_t i = 1; i < ido; ++i) { + T x1, x2; + PM(x1, x2, CX(i, k, j), CX(i, k, jc)); + size_t idij = (j - 1) * (ido - 1) + i - 1; + special_mul(x1, wa[idij], CX(i, k, j)); + idij = (jc - 1) * (ido - 1) + i - 1; + special_mul(x2, wa[idij], CX(i, k, jc)); + } + } + } + } + + template + void pass_all(T c[], T0 fct) const { + if (length == 1) { + c[0] *= fct; + return; + } + size_t l1 = 1; + arr ch(length); + T *p1 = c, *p2 = ch.data(); + + for (size_t k1 = 0; k1 < fact.size(); k1++) { + size_t ip = fact[k1].fct; + size_t l2 = ip * l1; + size_t ido = length / l2; + if (ip == 4) + pass4(ido, l1, p1, p2, fact[k1].tw); + else if (ip == 8) + pass8(ido, l1, p1, p2, fact[k1].tw); + else if (ip == 2) + pass2(ido, l1, p1, p2, fact[k1].tw); + else if (ip == 3) + pass3(ido, l1, p1, p2, fact[k1].tw); + else if (ip == 5) + pass5(ido, l1, p1, p2, fact[k1].tw); + else if (ip == 7) + pass7(ido, l1, p1, p2, fact[k1].tw); + else if (ip == 11) + pass11(ido, l1, p1, p2, fact[k1].tw); + else { + passg(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws); + std::swap(p1, p2); + } + std::swap(p1, p2); + l1 = l2; + } + if (p1 != c) { + if (fct != 1.) + for (size_t i = 0; i < length; ++i) c[i] = ch[i] * fct; + else + std::copy_n(p1, length, c); + } else if (fct != 1.) + for (size_t i = 0; i < length; ++i) c[i] *= fct; + } + + public: + template + void exec(T c[], T0 fct, bool fwd) const { + fwd ? pass_all(c, fct) : pass_all(c, fct); + } + + private: + POCKETFFT_NOINLINE void factorize() { + size_t len = length; + while ((len & 7) == 0) { + add_factor(8); + len >>= 3; + } + while ((len & 3) == 0) { + add_factor(4); + len >>= 2; + } + if ((len & 1) == 0) { + len >>= 1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor = 3; divisor * divisor <= len; divisor += 2) + while ((len % divisor) == 0) { + add_factor(divisor); + len /= divisor; + } + if (len > 1) add_factor(len); + } + + size_t twsize() const { + size_t twsize = 0, l1 = 1; + for (size_t k = 0; k < fact.size(); ++k) { + size_t ip = fact[k].fct, ido = length / (l1 * ip); + twsize += (ip - 1) * (ido - 1); + if (ip > 11) twsize += ip; + l1 *= ip; + } + return twsize; + } + + void comp_twiddle() { + sincos_2pibyn twiddle(length); + size_t l1 = 1; + size_t memofs = 0; + for (size_t k = 0; k < fact.size(); ++k) { + size_t ip = fact[k].fct, ido = length / (l1 * ip); + fact[k].tw = mem.data() + memofs; + memofs += (ip - 1) * (ido - 1); + for (size_t j = 1; j < ip; ++j) + for (size_t i = 1; i < ido; ++i) + fact[k].tw[(j - 1) * (ido - 1) + i - 1] = twiddle[j * l1 * i]; + if (ip > 11) { + fact[k].tws = mem.data() + memofs; + memofs += ip; + for (size_t j = 0; j < ip; ++j) fact[k].tws[j] = twiddle[j * l1 * ido]; + } + l1 *= ip; + } + } + + public: + POCKETFFT_NOINLINE cfftp(size_t length_) + : length(length_) { + if (length == 0) throw std::runtime_error("zero-length FFT requested"); + if (length == 1) return; + factorize(); + mem.resize(twsize()); + comp_twiddle(); + } +}; + +// +// real-valued FFTPACK transforms +// + +template +class rfftp { + private: + struct fctdata { + size_t fct; + T0 *tw, *tws; + }; + + size_t length; + arr mem; + std::vector fact; + + void add_factor(size_t factor) { + fact.push_back({factor, nullptr, nullptr}); + } + + /* (a+ib) = conj(c+id) * (e+if) */ + template + inline void MULPM(T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const { + a = c * e + d * f; + b = c * f - d * e; + } + + template + void radf2( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa) const { + auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; + auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + l1 * c)]; + }; + auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T & { return ch[a + ido * (b + 2 * c)]; }; + + for (size_t k = 0; k < l1; k++) PM(CH(0, 0, k), CH(ido - 1, 1, k), CC(0, k, 0), CC(0, k, 1)); + if ((ido & 1) == 0) + for (size_t k = 0; k < l1; k++) { + CH(0, 1, k) = -CC(ido - 1, k, 1); + CH(ido - 1, 0, k) = CC(ido - 1, k, 0); + } + if (ido <= 2) return; + for (size_t k = 0; k < l1; k++) + for (size_t i = 2; i < ido; i += 2) { + size_t ic = ido - i; + T tr2, ti2; + MULPM(tr2, ti2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1)); + PM(CH(i - 1, 0, k), CH(ic - 1, 1, k), CC(i - 1, k, 0), tr2); + PM(CH(i, 0, k), CH(ic, 1, k), ti2, CC(i, k, 0)); + } + } + +// a2=a+b; b2=i*(b-a); +#define POCKETFFT_REARRANGE(rx, ix, ry, iy) \ + { \ + auto t1 = rx + ry, t2 = ry - rx, t3 = ix + iy, t4 = ix - iy; \ + rx = t1; \ + ix = t3; \ + ry = t4; \ + iy = t2; \ + } + + template + void radf3( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa) const { + constexpr T0 taur = -0.5, taui = T0(0.8660254037844386467637231707529362L); + + auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; + auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + l1 * c)]; + }; + auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T & { return ch[a + ido * (b + 3 * c)]; }; + + for (size_t k = 0; k < l1; k++) { + T cr2 = CC(0, k, 1) + CC(0, k, 2); + CH(0, 0, k) = CC(0, k, 0) + cr2; + CH(0, 2, k) = taui * (CC(0, k, 2) - CC(0, k, 1)); + CH(ido - 1, 1, k) = CC(0, k, 0) + taur * cr2; + } + if (ido == 1) return; + for (size_t k = 0; k < l1; k++) + for (size_t i = 2; i < ido; i += 2) { + size_t ic = ido - i; + T di2, di3, dr2, dr3; + MULPM( + dr2, + di2, + WA(0, i - 2), + WA(0, i - 1), + CC(i - 1, k, 1), + CC(i, k, 1)); // d2=conj(WA0)*CC1 + MULPM( + dr3, + di3, + WA(1, i - 2), + WA(1, i - 1), + CC(i - 1, k, 2), + CC(i, k, 2)); // d3=conj(WA1)*CC2 + POCKETFFT_REARRANGE(dr2, di2, dr3, di3); + CH(i - 1, 0, k) = CC(i - 1, k, 0) + dr2; // c add + CH(i, 0, k) = CC(i, k, 0) + di2; + T tr2 = CC(i - 1, k, 0) + taur * dr2; // c add + T ti2 = CC(i, k, 0) + taur * di2; + T tr3 = taui * dr3; // t3 = taui*i*(d3-d2)? + T ti3 = taui * di3; + PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr2, tr3); // PM(i) = t2+t3 + PM(CH(i, 2, k), CH(ic, 1, k), ti3, ti2); // PM(ic) = conj(t2-t3) + } + } + + template + void radf4( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa) const { + constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L); + + auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; + auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + l1 * c)]; + }; + auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T & { return ch[a + ido * (b + 4 * c)]; }; + + for (size_t k = 0; k < l1; k++) { + T tr1, tr2; + PM(tr1, CH(0, 2, k), CC(0, k, 3), CC(0, k, 1)); + PM(tr2, CH(ido - 1, 1, k), CC(0, k, 0), CC(0, k, 2)); + PM(CH(0, 0, k), CH(ido - 1, 3, k), tr2, tr1); + } + if ((ido & 1) == 0) + for (size_t k = 0; k < l1; k++) { + T ti1 = -hsqt2 * (CC(ido - 1, k, 1) + CC(ido - 1, k, 3)); + T tr1 = hsqt2 * (CC(ido - 1, k, 1) - CC(ido - 1, k, 3)); + PM(CH(ido - 1, 0, k), CH(ido - 1, 2, k), CC(ido - 1, k, 0), tr1); + PM(CH(0, 3, k), CH(0, 1, k), ti1, CC(ido - 1, k, 2)); + } + if (ido <= 2) return; + for (size_t k = 0; k < l1; k++) + for (size_t i = 2; i < ido; i += 2) { + size_t ic = ido - i; + T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4; + MULPM(cr2, ci2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1)); + MULPM(cr3, ci3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2), CC(i, k, 2)); + MULPM(cr4, ci4, WA(2, i - 2), WA(2, i - 1), CC(i - 1, k, 3), CC(i, k, 3)); + PM(tr1, tr4, cr4, cr2); + PM(ti1, ti4, ci2, ci4); + PM(tr2, tr3, CC(i - 1, k, 0), cr3); + PM(ti2, ti3, CC(i, k, 0), ci3); + PM(CH(i - 1, 0, k), CH(ic - 1, 3, k), tr2, tr1); + PM(CH(i, 0, k), CH(ic, 3, k), ti1, ti2); + PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr3, ti4); + PM(CH(i, 2, k), CH(ic, 1, k), tr4, ti3); + } + } + + template + void radf5( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa) const { + constexpr T0 tr11 = T0(0.3090169943749474241022934171828191L), + ti11 = T0(0.9510565162951535721164393333793821L), + tr12 = T0(-0.8090169943749474241022934171828191L), + ti12 = T0(0.5877852522924731291687059546390728L); + + auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; + auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + l1 * c)]; + }; + auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T & { return ch[a + ido * (b + 5 * c)]; }; + + for (size_t k = 0; k < l1; k++) { + T cr2, cr3, ci4, ci5; + PM(cr2, ci5, CC(0, k, 4), CC(0, k, 1)); + PM(cr3, ci4, CC(0, k, 3), CC(0, k, 2)); + CH(0, 0, k) = CC(0, k, 0) + cr2 + cr3; + CH(ido - 1, 1, k) = CC(0, k, 0) + tr11 * cr2 + tr12 * cr3; + CH(0, 2, k) = ti11 * ci5 + ti12 * ci4; + CH(ido - 1, 3, k) = CC(0, k, 0) + tr12 * cr2 + tr11 * cr3; + CH(0, 4, k) = ti12 * ci5 - ti11 * ci4; + } + if (ido == 1) return; + for (size_t k = 0; k < l1; ++k) + for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) { + T di2, di3, di4, di5, dr2, dr3, dr4, dr5; + MULPM(dr2, di2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1)); + MULPM(dr3, di3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2), CC(i, k, 2)); + MULPM(dr4, di4, WA(2, i - 2), WA(2, i - 1), CC(i - 1, k, 3), CC(i, k, 3)); + MULPM(dr5, di5, WA(3, i - 2), WA(3, i - 1), CC(i - 1, k, 4), CC(i, k, 4)); + POCKETFFT_REARRANGE(dr2, di2, dr5, di5); + POCKETFFT_REARRANGE(dr3, di3, dr4, di4); + CH(i - 1, 0, k) = CC(i - 1, k, 0) + dr2 + dr3; + CH(i, 0, k) = CC(i, k, 0) + di2 + di3; + T tr2 = CC(i - 1, k, 0) + tr11 * dr2 + tr12 * dr3; + T ti2 = CC(i, k, 0) + tr11 * di2 + tr12 * di3; + T tr3 = CC(i - 1, k, 0) + tr12 * dr2 + tr11 * dr3; + T ti3 = CC(i, k, 0) + tr12 * di2 + tr11 * di3; + T tr5 = ti11 * dr5 + ti12 * dr4; + T ti5 = ti11 * di5 + ti12 * di4; + T tr4 = ti12 * dr5 - ti11 * dr4; + T ti4 = ti12 * di5 - ti11 * di4; + PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr2, tr5); + PM(CH(i, 2, k), CH(ic, 1, k), ti5, ti2); + PM(CH(i - 1, 4, k), CH(ic - 1, 3, k), tr3, tr4); + PM(CH(i, 4, k), CH(ic, 3, k), ti4, ti3); + } + } + +#undef POCKETFFT_REARRANGE + + template + void radfg( + size_t ido, + size_t ip, + size_t l1, + T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa, + const T0 *POCKETFFT_RESTRICT csarr) const { + const size_t cdim = ip; + size_t ipph = (ip + 1) / 2; + size_t idl1 = ido * l1; + + auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c) -> T & { + return cc[a + ido * (b + cdim * c)]; + }; + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> const T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto C1 = [cc, ido, l1](size_t a, size_t b, size_t c) -> T & { + return cc[a + ido * (b + l1 * c)]; + }; + auto C2 = [cc, idl1](size_t a, size_t b) -> T & { return cc[a + idl1 * b]; }; + auto CH2 = [ch, idl1](size_t a, size_t b) -> T & { return ch[a + idl1 * b]; }; + + if (ido > 1) { + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 114 + { + size_t is = (j - 1) * (ido - 1), is2 = (jc - 1) * (ido - 1); + for (size_t k = 0; k < l1; ++k) // 113 + { + size_t idij = is; + size_t idij2 = is2; + for (size_t i = 1; i <= ido - 2; i += 2) // 112 + { + T t1 = C1(i, k, j), t2 = C1(i + 1, k, j), t3 = C1(i, k, jc), t4 = C1(i + 1, k, jc); + T x1 = wa[idij] * t1 + wa[idij + 1] * t2, x2 = wa[idij] * t2 - wa[idij + 1] * t1, + x3 = wa[idij2] * t3 + wa[idij2 + 1] * t4, x4 = wa[idij2] * t4 - wa[idij2 + 1] * t3; + PM(C1(i, k, j), C1(i + 1, k, jc), x3, x1); + PM(C1(i + 1, k, j), C1(i, k, jc), x2, x4); + idij += 2; + idij2 += 2; + } + } + } + } + + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 123 + for (size_t k = 0; k < l1; ++k) // 122 + MPINPLACE(C1(0, k, jc), C1(0, k, j)); + + // everything in C + // memset(ch,0,ip*l1*ido*sizeof(double)); + + for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) // 127 + { + for (size_t ik = 0; ik < idl1; ++ik) // 124 + { + CH2(ik, l) = C2(ik, 0) + csarr[2 * l] * C2(ik, 1) + csarr[4 * l] * C2(ik, 2); + CH2(ik, lc) = csarr[2 * l + 1] * C2(ik, ip - 1) + csarr[4 * l + 1] * C2(ik, ip - 2); + } + size_t iang = 2 * l; + size_t j = 3, jc = ip - 3; + for (; j < ipph - 3; j += 4, jc -= 4) // 126 + { + iang += l; + if (iang >= ip) iang -= ip; + T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1]; + iang += l; + if (iang >= ip) iang -= ip; + T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1]; + iang += l; + if (iang >= ip) iang -= ip; + T0 ar3 = csarr[2 * iang], ai3 = csarr[2 * iang + 1]; + iang += l; + if (iang >= ip) iang -= ip; + T0 ar4 = csarr[2 * iang], ai4 = csarr[2 * iang + 1]; + for (size_t ik = 0; ik < idl1; ++ik) // 125 + { + CH2(ik, l) += ar1 * C2(ik, j) + ar2 * C2(ik, j + 1) + ar3 * C2(ik, j + 2) + + ar4 * C2(ik, j + 3); + CH2(ik, lc) += ai1 * C2(ik, jc) + ai2 * C2(ik, jc - 1) + ai3 * C2(ik, jc - 2) + + ai4 * C2(ik, jc - 3); + } + } + for (; j < ipph - 1; j += 2, jc -= 2) // 126 + { + iang += l; + if (iang >= ip) iang -= ip; + T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1]; + iang += l; + if (iang >= ip) iang -= ip; + T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1]; + for (size_t ik = 0; ik < idl1; ++ik) // 125 + { + CH2(ik, l) += ar1 * C2(ik, j) + ar2 * C2(ik, j + 1); + CH2(ik, lc) += ai1 * C2(ik, jc) + ai2 * C2(ik, jc - 1); + } + } + for (; j < ipph; ++j, --jc) // 126 + { + iang += l; + if (iang >= ip) iang -= ip; + T0 ar = csarr[2 * iang], ai = csarr[2 * iang + 1]; + for (size_t ik = 0; ik < idl1; ++ik) // 125 + { + CH2(ik, l) += ar * C2(ik, j); + CH2(ik, lc) += ai * C2(ik, jc); + } + } + } + for (size_t ik = 0; ik < idl1; ++ik) // 101 + CH2(ik, 0) = C2(ik, 0); + for (size_t j = 1; j < ipph; ++j) // 129 + for (size_t ik = 0; ik < idl1; ++ik) // 128 + CH2(ik, 0) += C2(ik, j); + + // everything in CH at this point! + // memset(cc,0,ip*l1*ido*sizeof(double)); + + for (size_t k = 0; k < l1; ++k) // 131 + for (size_t i = 0; i < ido; ++i) // 130 + CC(i, 0, k) = CH(i, k, 0); + + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 137 + { + size_t j2 = 2 * j - 1; + for (size_t k = 0; k < l1; ++k) // 136 + { + CC(ido - 1, j2, k) = CH(0, k, j); + CC(0, j2 + 1, k) = CH(0, k, jc); + } + } + + if (ido == 1) return; + + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 140 + { + size_t j2 = 2 * j - 1; + for (size_t k = 0; k < l1; ++k) // 139 + for (size_t i = 1, ic = ido - i - 2; i <= ido - 2; i += 2, ic -= 2) // 138 + { + CC(i, j2 + 1, k) = CH(i, k, j) + CH(i, k, jc); + CC(ic, j2, k) = CH(i, k, j) - CH(i, k, jc); + CC(i + 1, j2 + 1, k) = CH(i + 1, k, j) + CH(i + 1, k, jc); + CC(ic + 1, j2, k) = CH(i + 1, k, jc) - CH(i + 1, k, j); + } + } + } + + template + void radb2( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa) const { + auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 2 * c)]; + }; + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + + for (size_t k = 0; k < l1; k++) PM(CH(0, k, 0), CH(0, k, 1), CC(0, 0, k), CC(ido - 1, 1, k)); + if ((ido & 1) == 0) + for (size_t k = 0; k < l1; k++) { + CH(ido - 1, k, 0) = 2 * CC(ido - 1, 0, k); + CH(ido - 1, k, 1) = -2 * CC(0, 1, k); + } + if (ido <= 2) return; + for (size_t k = 0; k < l1; ++k) + for (size_t i = 2; i < ido; i += 2) { + size_t ic = ido - i; + T ti2, tr2; + PM(CH(i - 1, k, 0), tr2, CC(i - 1, 0, k), CC(ic - 1, 1, k)); + PM(ti2, CH(i, k, 0), CC(i, 0, k), CC(ic, 1, k)); + MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), ti2, tr2); + } + } + + template + void radb3( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa) const { + constexpr T0 taur = -0.5, taui = T0(0.8660254037844386467637231707529362L); + + auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 3 * c)]; + }; + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + + for (size_t k = 0; k < l1; k++) { + T tr2 = 2 * CC(ido - 1, 1, k); + T cr2 = CC(0, 0, k) + taur * tr2; + CH(0, k, 0) = CC(0, 0, k) + tr2; + T ci3 = 2 * taui * CC(0, 2, k); + PM(CH(0, k, 2), CH(0, k, 1), cr2, ci3); + } + if (ido == 1) return; + for (size_t k = 0; k < l1; k++) + for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) { + T tr2 = CC(i - 1, 2, k) + CC(ic - 1, 1, k); // t2=CC(I) + conj(CC(ic)) + T ti2 = CC(i, 2, k) - CC(ic, 1, k); + T cr2 = CC(i - 1, 0, k) + taur * tr2; // c2=CC +taur*t2 + T ci2 = CC(i, 0, k) + taur * ti2; + CH(i - 1, k, 0) = CC(i - 1, 0, k) + tr2; // CH=CC+t2 + CH(i, k, 0) = CC(i, 0, k) + ti2; + T cr3 = taui * (CC(i - 1, 2, k) - CC(ic - 1, 1, k)); // c3=taui*(CC(i)-conj(CC(ic))) + T ci3 = taui * (CC(i, 2, k) + CC(ic, 1, k)); + T di2, di3, dr2, dr3; + PM(dr3, dr2, cr2, ci3); // d2= (cr2-ci3, ci2+cr3) = c2+i*c3 + PM(di2, di3, ci2, cr3); // d3= (cr2+ci3, ci2-cr3) = c2-i*c3 + MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), di2, dr2); // ch = WA*d2 + MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), di3, dr3); + } + } + + template + void radb4( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa) const { + constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L); + + auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 4 * c)]; + }; + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + + for (size_t k = 0; k < l1; k++) { + T tr1, tr2; + PM(tr2, tr1, CC(0, 0, k), CC(ido - 1, 3, k)); + T tr3 = 2 * CC(ido - 1, 1, k); + T tr4 = 2 * CC(0, 2, k); + PM(CH(0, k, 0), CH(0, k, 2), tr2, tr3); + PM(CH(0, k, 3), CH(0, k, 1), tr1, tr4); + } + if ((ido & 1) == 0) + for (size_t k = 0; k < l1; k++) { + T tr1, tr2, ti1, ti2; + PM(ti1, ti2, CC(0, 3, k), CC(0, 1, k)); + PM(tr2, tr1, CC(ido - 1, 0, k), CC(ido - 1, 2, k)); + CH(ido - 1, k, 0) = tr2 + tr2; + CH(ido - 1, k, 1) = sqrt2 * (tr1 - ti1); + CH(ido - 1, k, 2) = ti2 + ti2; + CH(ido - 1, k, 3) = -sqrt2 * (tr1 + ti1); + } + if (ido <= 2) return; + for (size_t k = 0; k < l1; ++k) + for (size_t i = 2; i < ido; i += 2) { + T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4; + size_t ic = ido - i; + PM(tr2, tr1, CC(i - 1, 0, k), CC(ic - 1, 3, k)); + PM(ti1, ti2, CC(i, 0, k), CC(ic, 3, k)); + PM(tr4, ti3, CC(i, 2, k), CC(ic, 1, k)); + PM(tr3, ti4, CC(i - 1, 2, k), CC(ic - 1, 1, k)); + PM(CH(i - 1, k, 0), cr3, tr2, tr3); + PM(CH(i, k, 0), ci3, ti2, ti3); + PM(cr4, cr2, tr1, tr4); + PM(ci2, ci4, ti1, ti4); + MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), ci2, cr2); + MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), ci3, cr3); + MULPM(CH(i, k, 3), CH(i - 1, k, 3), WA(2, i - 2), WA(2, i - 1), ci4, cr4); + } + } + + template + void radb5( + size_t ido, + size_t l1, + const T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa) const { + constexpr T0 tr11 = T0(0.3090169943749474241022934171828191L), + ti11 = T0(0.9510565162951535721164393333793821L), + tr12 = T0(-0.8090169943749474241022934171828191L), + ti12 = T0(0.5877852522924731291687059546390728L); + + auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; }; + auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + 5 * c)]; + }; + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + + for (size_t k = 0; k < l1; k++) { + T ti5 = CC(0, 2, k) + CC(0, 2, k); + T ti4 = CC(0, 4, k) + CC(0, 4, k); + T tr2 = CC(ido - 1, 1, k) + CC(ido - 1, 1, k); + T tr3 = CC(ido - 1, 3, k) + CC(ido - 1, 3, k); + CH(0, k, 0) = CC(0, 0, k) + tr2 + tr3; + T cr2 = CC(0, 0, k) + tr11 * tr2 + tr12 * tr3; + T cr3 = CC(0, 0, k) + tr12 * tr2 + tr11 * tr3; + T ci4, ci5; + MULPM(ci5, ci4, ti5, ti4, ti11, ti12); + PM(CH(0, k, 4), CH(0, k, 1), cr2, ci5); + PM(CH(0, k, 3), CH(0, k, 2), cr3, ci4); + } + if (ido == 1) return; + for (size_t k = 0; k < l1; ++k) + for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) { + T tr2, tr3, tr4, tr5, ti2, ti3, ti4, ti5; + PM(tr2, tr5, CC(i - 1, 2, k), CC(ic - 1, 1, k)); + PM(ti5, ti2, CC(i, 2, k), CC(ic, 1, k)); + PM(tr3, tr4, CC(i - 1, 4, k), CC(ic - 1, 3, k)); + PM(ti4, ti3, CC(i, 4, k), CC(ic, 3, k)); + CH(i - 1, k, 0) = CC(i - 1, 0, k) + tr2 + tr3; + CH(i, k, 0) = CC(i, 0, k) + ti2 + ti3; + T cr2 = CC(i - 1, 0, k) + tr11 * tr2 + tr12 * tr3; + T ci2 = CC(i, 0, k) + tr11 * ti2 + tr12 * ti3; + T cr3 = CC(i - 1, 0, k) + tr12 * tr2 + tr11 * tr3; + T ci3 = CC(i, 0, k) + tr12 * ti2 + tr11 * ti3; + T ci4, ci5, cr5, cr4; + MULPM(cr5, cr4, tr5, tr4, ti11, ti12); + MULPM(ci5, ci4, ti5, ti4, ti11, ti12); + T dr2, dr3, dr4, dr5, di2, di3, di4, di5; + PM(dr4, dr3, cr3, ci4); + PM(di3, di4, ci3, cr4); + PM(dr5, dr2, cr2, ci5); + PM(di2, di5, ci2, cr5); + MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), di2, dr2); + MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), di3, dr3); + MULPM(CH(i, k, 3), CH(i - 1, k, 3), WA(2, i - 2), WA(2, i - 1), di4, dr4); + MULPM(CH(i, k, 4), CH(i - 1, k, 4), WA(3, i - 2), WA(3, i - 1), di5, dr5); + } + } + + template + void radbg( + size_t ido, + size_t ip, + size_t l1, + T *POCKETFFT_RESTRICT cc, + T *POCKETFFT_RESTRICT ch, + const T0 *POCKETFFT_RESTRICT wa, + const T0 *POCKETFFT_RESTRICT csarr) const { + const size_t cdim = ip; + size_t ipph = (ip + 1) / 2; + size_t idl1 = ido * l1; + + auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + cdim * c)]; + }; + auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T & { + return ch[a + ido * (b + l1 * c)]; + }; + auto C1 = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T & { + return cc[a + ido * (b + l1 * c)]; + }; + auto C2 = [cc, idl1](size_t a, size_t b) -> T & { return cc[a + idl1 * b]; }; + auto CH2 = [ch, idl1](size_t a, size_t b) -> T & { return ch[a + idl1 * b]; }; + + for (size_t k = 0; k < l1; ++k) // 102 + for (size_t i = 0; i < ido; ++i) // 101 + CH(i, k, 0) = CC(i, 0, k); + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 108 + { + size_t j2 = 2 * j - 1; + for (size_t k = 0; k < l1; ++k) { + CH(0, k, j) = 2 * CC(ido - 1, j2, k); + CH(0, k, jc) = 2 * CC(0, j2 + 1, k); + } + } + + if (ido != 1) { + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 111 + { + size_t j2 = 2 * j - 1; + for (size_t k = 0; k < l1; ++k) + for (size_t i = 1, ic = ido - i - 2; i <= ido - 2; i += 2, ic -= 2) // 109 + { + CH(i, k, j) = CC(i, j2 + 1, k) + CC(ic, j2, k); + CH(i, k, jc) = CC(i, j2 + 1, k) - CC(ic, j2, k); + CH(i + 1, k, j) = CC(i + 1, j2 + 1, k) - CC(ic + 1, j2, k); + CH(i + 1, k, jc) = CC(i + 1, j2 + 1, k) + CC(ic + 1, j2, k); + } + } + } + for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) { + for (size_t ik = 0; ik < idl1; ++ik) { + C2(ik, l) = CH2(ik, 0) + csarr[2 * l] * CH2(ik, 1) + csarr[4 * l] * CH2(ik, 2); + C2(ik, lc) = csarr[2 * l + 1] * CH2(ik, ip - 1) + csarr[4 * l + 1] * CH2(ik, ip - 2); + } + size_t iang = 2 * l; + size_t j = 3, jc = ip - 3; + for (; j < ipph - 3; j += 4, jc -= 4) { + iang += l; + if (iang > ip) iang -= ip; + T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1]; + iang += l; + if (iang > ip) iang -= ip; + T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1]; + iang += l; + if (iang > ip) iang -= ip; + T0 ar3 = csarr[2 * iang], ai3 = csarr[2 * iang + 1]; + iang += l; + if (iang > ip) iang -= ip; + T0 ar4 = csarr[2 * iang], ai4 = csarr[2 * iang + 1]; + for (size_t ik = 0; ik < idl1; ++ik) { + C2(ik, l) += ar1 * CH2(ik, j) + ar2 * CH2(ik, j + 1) + ar3 * CH2(ik, j + 2) + + ar4 * CH2(ik, j + 3); + C2(ik, lc) += ai1 * CH2(ik, jc) + ai2 * CH2(ik, jc - 1) + ai3 * CH2(ik, jc - 2) + + ai4 * CH2(ik, jc - 3); + } + } + for (; j < ipph - 1; j += 2, jc -= 2) { + iang += l; + if (iang > ip) iang -= ip; + T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1]; + iang += l; + if (iang > ip) iang -= ip; + T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1]; + for (size_t ik = 0; ik < idl1; ++ik) { + C2(ik, l) += ar1 * CH2(ik, j) + ar2 * CH2(ik, j + 1); + C2(ik, lc) += ai1 * CH2(ik, jc) + ai2 * CH2(ik, jc - 1); + } + } + for (; j < ipph; ++j, --jc) { + iang += l; + if (iang > ip) iang -= ip; + T0 war = csarr[2 * iang], wai = csarr[2 * iang + 1]; + for (size_t ik = 0; ik < idl1; ++ik) { + C2(ik, l) += war * CH2(ik, j); + C2(ik, lc) += wai * CH2(ik, jc); + } + } + } + for (size_t j = 1; j < ipph; ++j) + for (size_t ik = 0; ik < idl1; ++ik) CH2(ik, 0) += CH2(ik, j); + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 124 + for (size_t k = 0; k < l1; ++k) PM(CH(0, k, jc), CH(0, k, j), C1(0, k, j), C1(0, k, jc)); + + if (ido == 1) return; + + for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc) // 127 + for (size_t k = 0; k < l1; ++k) + for (size_t i = 1; i <= ido - 2; i += 2) { + CH(i, k, j) = C1(i, k, j) - C1(i + 1, k, jc); + CH(i, k, jc) = C1(i, k, j) + C1(i + 1, k, jc); + CH(i + 1, k, j) = C1(i + 1, k, j) + C1(i, k, jc); + CH(i + 1, k, jc) = C1(i + 1, k, j) - C1(i, k, jc); + } + + // All in CH + + for (size_t j = 1; j < ip; ++j) { + size_t is = (j - 1) * (ido - 1); + for (size_t k = 0; k < l1; ++k) { + size_t idij = is; + for (size_t i = 1; i <= ido - 2; i += 2) { + T t1 = CH(i, k, j), t2 = CH(i + 1, k, j); + CH(i, k, j) = wa[idij] * t1 - wa[idij + 1] * t2; + CH(i + 1, k, j) = wa[idij] * t2 + wa[idij + 1] * t1; + idij += 2; + } + } + } + } + + template + void copy_and_norm(T *c, T *p1, T0 fct) const { + if (p1 != c) { + if (fct != 1.) + for (size_t i = 0; i < length; ++i) c[i] = fct * p1[i]; + else + std::copy_n(p1, length, c); + } else if (fct != 1.) + for (size_t i = 0; i < length; ++i) c[i] *= fct; + } + + public: + template + void exec(T c[], T0 fct, bool r2hc) const { + if (length == 1) { + c[0] *= fct; + return; + } + size_t nf = fact.size(); + arr ch(length); + T *p1 = c, *p2 = ch.data(); + + if (r2hc) + for (size_t k1 = 0, l1 = length; k1 < nf; ++k1) { + size_t k = nf - k1 - 1; + size_t ip = fact[k].fct; + size_t ido = length / l1; + l1 /= ip; + if (ip == 4) + radf4(ido, l1, p1, p2, fact[k].tw); + else if (ip == 2) + radf2(ido, l1, p1, p2, fact[k].tw); + else if (ip == 3) + radf3(ido, l1, p1, p2, fact[k].tw); + else if (ip == 5) + radf5(ido, l1, p1, p2, fact[k].tw); + else { + radfg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws); + std::swap(p1, p2); + } + std::swap(p1, p2); + } + else + for (size_t k = 0, l1 = 1; k < nf; k++) { + size_t ip = fact[k].fct, ido = length / (ip * l1); + if (ip == 4) + radb4(ido, l1, p1, p2, fact[k].tw); + else if (ip == 2) + radb2(ido, l1, p1, p2, fact[k].tw); + else if (ip == 3) + radb3(ido, l1, p1, p2, fact[k].tw); + else if (ip == 5) + radb5(ido, l1, p1, p2, fact[k].tw); + else + radbg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws); + std::swap(p1, p2); + l1 *= ip; + } + + copy_and_norm(c, p1, fct); + } + + private: + void factorize() { + size_t len = length; + while ((len % 4) == 0) { + add_factor(4); + len >>= 2; + } + if ((len % 2) == 0) { + len >>= 1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor = 3; divisor * divisor <= len; divisor += 2) + while ((len % divisor) == 0) { + add_factor(divisor); + len /= divisor; + } + if (len > 1) add_factor(len); + } + + size_t twsize() const { + size_t twsz = 0, l1 = 1; + for (size_t k = 0; k < fact.size(); ++k) { + size_t ip = fact[k].fct, ido = length / (l1 * ip); + twsz += (ip - 1) * (ido - 1); + if (ip > 5) twsz += 2 * ip; + l1 *= ip; + } + return twsz; + } + + void comp_twiddle() { + sincos_2pibyn twid(length); + size_t l1 = 1; + T0 *ptr = mem.data(); + for (size_t k = 0; k < fact.size(); ++k) { + size_t ip = fact[k].fct, ido = length / (l1 * ip); + if (k < fact.size() - 1) // last factor doesn't need twiddles + { + fact[k].tw = ptr; + ptr += (ip - 1) * (ido - 1); + for (size_t j = 1; j < ip; ++j) + for (size_t i = 1; i <= (ido - 1) / 2; ++i) { + fact[k].tw[(j - 1) * (ido - 1) + 2 * i - 2] = twid[j * l1 * i].r; + fact[k].tw[(j - 1) * (ido - 1) + 2 * i - 1] = twid[j * l1 * i].i; + } + } + if (ip > 5) // special factors required by *g functions + { + fact[k].tws = ptr; + ptr += 2 * ip; + fact[k].tws[0] = 1.; + fact[k].tws[1] = 0.; + for (size_t i = 2, ic = 2 * ip - 2; i <= ic; i += 2, ic -= 2) { + fact[k].tws[i] = twid[i / 2 * (length / ip)].r; + fact[k].tws[i + 1] = twid[i / 2 * (length / ip)].i; + fact[k].tws[ic] = twid[i / 2 * (length / ip)].r; + fact[k].tws[ic + 1] = -twid[i / 2 * (length / ip)].i; + } + } + l1 *= ip; + } + } + + public: + POCKETFFT_NOINLINE rfftp(size_t length_) + : length(length_) { + if (length == 0) throw std::runtime_error("zero-length FFT requested"); + if (length == 1) return; + factorize(); + mem.resize(twsize()); + comp_twiddle(); + } +}; + +// +// complex Bluestein transforms +// + +template +class fftblue { + private: + size_t n, n2; + cfftp plan; + arr> mem; + cmplx *bk, *bkf; + + template + void fft(cmplx c[], T0 fct) const { + arr> akf(n2); + + /* initialize a_k and FFT it */ + for (size_t m = 0; m < n; ++m) special_mul(c[m], bk[m], akf[m]); + auto zero = akf[0] * T0(0); + for (size_t m = n; m < n2; ++m) akf[m] = zero; + + plan.exec(akf.data(), 1., true); + + /* do the convolution */ + akf[0] = akf[0].template special_mul(bkf[0]); + for (size_t m = 1; m < (n2 + 1) / 2; ++m) { + akf[m] = akf[m].template special_mul(bkf[m]); + akf[n2 - m] = akf[n2 - m].template special_mul(bkf[m]); + } + if ((n2 & 1) == 0) akf[n2 / 2] = akf[n2 / 2].template special_mul(bkf[n2 / 2]); + + /* inverse FFT */ + plan.exec(akf.data(), 1., false); + + /* multiply by b_k */ + for (size_t m = 0; m < n; ++m) c[m] = akf[m].template special_mul(bk[m]) * fct; + } + + public: + POCKETFFT_NOINLINE fftblue(size_t length) + : n(length), + n2(util::good_size_cmplx(n * 2 - 1)), + plan(n2), + mem(n + n2 / 2 + 1), + bk(mem.data()), + bkf(mem.data() + n) { + /* initialize b_k */ + sincos_2pibyn tmp(2 * n); + bk[0].Set(1, 0); + + size_t coeff = 0; + for (size_t m = 1; m < n; ++m) { + coeff += 2 * m - 1; + if (coeff >= 2 * n) coeff -= 2 * n; + bk[m] = tmp[coeff]; + } + + /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */ + arr> tbkf(n2); + T0 xn2 = T0(1) / T0(n2); + tbkf[0] = bk[0] * xn2; + for (size_t m = 1; m < n; ++m) tbkf[m] = tbkf[n2 - m] = bk[m] * xn2; + for (size_t m = n; m <= (n2 - n); ++m) tbkf[m].Set(0., 0.); + plan.exec(tbkf.data(), 1., true); + for (size_t i = 0; i < n2 / 2 + 1; ++i) bkf[i] = tbkf[i]; + } + + template + void exec(cmplx c[], T0 fct, bool fwd) const { + fwd ? fft(c, fct) : fft(c, fct); + } + + template + void exec_r(T c[], T0 fct, bool fwd) { + arr> tmp(n); + if (fwd) { + auto zero = T0(0) * c[0]; + for (size_t m = 0; m < n; ++m) tmp[m].Set(c[m], zero); + fft(tmp.data(), fct); + c[0] = tmp[0].r; + std::copy_n(&tmp[1].r, n - 1, &c[1]); + } else { + tmp[0].Set(c[0], c[0] * 0); + std::copy_n(c + 1, n - 1, &tmp[1].r); + if ((n & 1) == 0) tmp[n / 2].i = T0(0) * c[0]; + for (size_t m = 1; 2 * m < n; ++m) tmp[n - m].Set(tmp[m].r, -tmp[m].i); + fft(tmp.data(), fct); + for (size_t m = 0; m < n; ++m) c[m] = tmp[m].r; + } + } +}; + +// +// flexible (FFTPACK/Bluestein) complex 1D transform +// + +template +class pocketfft_c { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_c(size_t length) + : len(length) { + if (length == 0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length < 50) ? 0 : util::largest_prime_factor(length); + if (tmp * tmp <= length) { + packplan = std::unique_ptr>(new cfftp(length)); + return; + } + double comp1 = util::cost_guess(length); + double comp2 = 2 * util::cost_guess(util::good_size_cmplx(2 * length - 1)); + comp2 *= 1.5; /* fudge factor that appears to give good overall performance */ + if (comp2 < comp1) // use Bluestein + blueplan = std::unique_ptr>(new fftblue(length)); + else + packplan = std::unique_ptr>(new cfftp(length)); + } + + template + POCKETFFT_NOINLINE void exec(cmplx c[], T0 fct, bool fwd) const { + packplan ? packplan->exec(c, fct, fwd) : blueplan->exec(c, fct, fwd); + } + + size_t length() const { + return len; + } +}; + +// +// flexible (FFTPACK/Bluestein) real-valued 1D transform +// + +template +class pocketfft_r { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_r(size_t length) + : len(length) { + if (length == 0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length < 50) ? 0 : util::largest_prime_factor(length); + if (tmp * tmp <= length) { + packplan = std::unique_ptr>(new rfftp(length)); + return; + } + double comp1 = 0.5 * util::cost_guess(length); + double comp2 = 2 * util::cost_guess(util::good_size_cmplx(2 * length - 1)); + comp2 *= 1.5; /* fudge factor that appears to give good overall performance */ + if (comp2 < comp1) // use Bluestein + blueplan = std::unique_ptr>(new fftblue(length)); + else + packplan = std::unique_ptr>(new rfftp(length)); + } + + template + POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const { + packplan ? packplan->exec(c, fct, fwd) : blueplan->exec_r(c, fct, fwd); + } + + size_t length() const { + return len; + } +}; + +// +// sine/cosine transforms +// + +template +class T_dct1 { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dct1(size_t length) + : fftplan(2 * (length - 1)) { + } + + template + POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int /*type*/, bool /*cosine*/) const { + constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L); + size_t N = fftplan.length(), n = N / 2 + 1; + if (ortho) { + c[0] *= sqrt2; + c[n - 1] *= sqrt2; + } + arr tmp(N); + tmp[0] = c[0]; + for (size_t i = 1; i < n; ++i) tmp[i] = tmp[N - i] = c[i]; + fftplan.exec(tmp.data(), fct, true); + c[0] = tmp[0]; + for (size_t i = 1; i < n; ++i) c[i] = tmp[2 * i - 1]; + if (ortho) { + c[0] *= sqrt2 * T0(0.5); + c[n - 1] *= sqrt2 * T0(0.5); + } + } + + size_t length() const { + return fftplan.length() / 2 + 1; + } +}; + +template +class T_dst1 { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dst1(size_t length) + : fftplan(2 * (length + 1)) { + } + + template + POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool /*cosine*/) const { + size_t N = fftplan.length(), n = N / 2 - 1; + arr tmp(N); + tmp[0] = tmp[n + 1] = c[0] * 0; + for (size_t i = 0; i < n; ++i) { + tmp[i + 1] = c[i]; + tmp[N - 1 - i] = -c[i]; + } + fftplan.exec(tmp.data(), fct, true); + for (size_t i = 0; i < n; ++i) c[i] = -tmp[2 * i + 2]; + } + + size_t length() const { + return fftplan.length() / 2 - 1; + } +}; + +template +class T_dcst23 { + private: + pocketfft_r fftplan; + std::vector twiddle; + + public: + POCKETFFT_NOINLINE T_dcst23(size_t length) + : fftplan(length), + twiddle(length) { + sincos_2pibyn tw(4 * length); + for (size_t i = 0; i < length; ++i) twiddle[i] = tw[i + 1].r; + } + + template + POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int type, bool cosine) const { + constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L); + size_t N = length(); + size_t NS2 = (N + 1) / 2; + if (type == 2) { + if (!cosine) + for (size_t k = 1; k < N; k += 2) c[k] = -c[k]; + c[0] *= 2; + if ((N & 1) == 0) c[N - 1] *= 2; + for (size_t k = 1; k < N - 1; k += 2) MPINPLACE(c[k + 1], c[k]); + fftplan.exec(c, fct, false); + for (size_t k = 1, kc = N - 1; k < NS2; ++k, --kc) { + T t1 = twiddle[k - 1] * c[kc] + twiddle[kc - 1] * c[k]; + T t2 = twiddle[k - 1] * c[k] - twiddle[kc - 1] * c[kc]; + c[k] = T0(0.5) * (t1 + t2); + c[kc] = T0(0.5) * (t1 - t2); + } + if ((N & 1) == 0) c[NS2] *= twiddle[NS2 - 1]; + if (!cosine) + for (size_t k = 0, kc = N - 1; k < kc; ++k, --kc) std::swap(c[k], c[kc]); + if (ortho) cosine ? c[0] *= sqrt2 * T0(0.5) : c[N - 1] *= sqrt2 * T0(0.5); + } else { + if (ortho) cosine ? c[0] *= sqrt2 : c[N - 1] *= sqrt2; + if (!cosine) + for (size_t k = 0, kc = N - 1; k < NS2; ++k, --kc) std::swap(c[k], c[kc]); + for (size_t k = 1, kc = N - 1; k < NS2; ++k, --kc) { + T t1 = c[k] + c[kc], t2 = c[k] - c[kc]; + c[k] = twiddle[k - 1] * t2 + twiddle[kc - 1] * t1; + c[kc] = twiddle[k - 1] * t1 - twiddle[kc - 1] * t2; + } + if ((N & 1) == 0) c[NS2] *= 2 * twiddle[NS2 - 1]; + fftplan.exec(c, fct, true); + for (size_t k = 1; k < N - 1; k += 2) MPINPLACE(c[k], c[k + 1]); + if (!cosine) + for (size_t k = 1; k < N; k += 2) c[k] = -c[k]; + } + } + + size_t length() const { + return fftplan.length(); + } +}; + +template +class T_dcst4 { + private: + size_t N; + std::unique_ptr> fft; + std::unique_ptr> rfft; + arr> C2; + + public: + POCKETFFT_NOINLINE T_dcst4(size_t length) + : N(length), + fft((N & 1) ? nullptr : new pocketfft_c(N / 2)), + rfft((N & 1) ? new pocketfft_r(N) : nullptr), + C2((N & 1) ? 0 : N / 2) { + if ((N & 1) == 0) { + sincos_2pibyn tw(16 * N); + for (size_t i = 0; i < N / 2; ++i) C2[i] = conj(tw[8 * i + 1]); + } + } + + template + POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool cosine) const { + size_t n2 = N / 2; + if (!cosine) + for (size_t k = 0, kc = N - 1; k < n2; ++k, --kc) std::swap(c[k], c[kc]); + if (N & 1) { + // The following code is derived from the FFTW3 function apply_re11() + // and is released under the 3-clause BSD license with friendly + // permission of Matteo Frigo and Steven G. Johnson. + + arr y(N); + { + size_t i = 0, m = n2; + for (; m < N; ++i, m += 4) y[i] = c[m]; + for (; m < 2 * N; ++i, m += 4) y[i] = -c[2 * N - m - 1]; + for (; m < 3 * N; ++i, m += 4) y[i] = -c[m - 2 * N]; + for (; m < 4 * N; ++i, m += 4) y[i] = c[4 * N - m - 1]; + for (; i < N; ++i, m += 4) y[i] = c[m - 4 * N]; + } + rfft->exec(y.data(), fct, true); + { + auto SGN = [](size_t i) { + constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L); + return (i & 2) ? -sqrt2 : sqrt2; + }; + c[n2] = y[0] * SGN(n2 + 1); + size_t i = 0, i1 = 1, k = 1; + for (; k < n2; ++i, ++i1, k += 2) { + c[i] = y[2 * k - 1] * SGN(i1) + y[2 * k] * SGN(i); + c[N - i1] = y[2 * k - 1] * SGN(N - i) - y[2 * k] * SGN(N - i1); + c[n2 - i1] = y[2 * k + 1] * SGN(n2 - i) - y[2 * k + 2] * SGN(n2 - i1); + c[n2 + i1] = y[2 * k + 1] * SGN(n2 + i + 2) + y[2 * k + 2] * SGN(n2 + i1); + } + if (k == n2) { + c[i] = y[2 * k - 1] * SGN(i + 1) + y[2 * k] * SGN(i); + c[N - i1] = y[2 * k - 1] * SGN(i + 2) + y[2 * k] * SGN(i1); + } + } + + // FFTW-derived code ends here + } else { + // even length algorithm from + // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/ + arr> y(n2); + for (size_t i = 0; i < n2; ++i) { + y[i].Set(c[2 * i], c[N - 1 - 2 * i]); + y[i] *= C2[i]; + } + fft->exec(y.data(), fct, true); + for (size_t i = 0, ic = n2 - 1; i < n2; ++i, --ic) { + c[2 * i] = 2 * (y[i].r * C2[i].r - y[i].i * C2[i].i); + c[2 * i + 1] = -2 * (y[ic].i * C2[ic].r + y[ic].r * C2[ic].i); + } + } + if (!cosine) + for (size_t k = 1; k < N; k += 2) c[k] = -c[k]; + } + + size_t length() const { + return N; + } +}; + +// +// multi-D infrastructure +// + +template +std::shared_ptr get_plan(size_t length) { +#if POCKETFFT_CACHE_SIZE == 0 + return std::make_shared(length); +#else + constexpr size_t nmax = POCKETFFT_CACHE_SIZE; + static std::array, nmax> cache; + static std::array last_access{{0}}; + static size_t access_counter = 0; + static std::mutex mut; + + auto find_in_cache = [&]() -> std::shared_ptr { + for (size_t i = 0; i < nmax; ++i) + if (cache[i] && (cache[i]->length() == length)) { + // no need to update if this is already the most recent entry + if (last_access[i] != access_counter) { + last_access[i] = ++access_counter; + // Guard against overflow + if (access_counter == 0) last_access.fill(0); + } + return cache[i]; + } + + return nullptr; + }; + + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + } + auto plan = std::make_shared(length); + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + + size_t lru = 0; + for (size_t i = 1; i < nmax; ++i) + if (last_access[i] < last_access[lru]) lru = i; + + cache[lru] = plan; + last_access[lru] = ++access_counter; + } + return plan; +#endif +} + +class arr_info { + protected: + shape_t shp; + stride_t str; + + public: + arr_info(const shape_t &shape_, const stride_t &stride_) + : shp(shape_), + str(stride_) { + } + size_t ndim() const { + return shp.size(); + } + size_t size() const { + return util::prod(shp); + } + const shape_t &shape() const { + return shp; + } + size_t shape(size_t i) const { + return shp[i]; + } + const stride_t &stride() const { + return str; + } + const ptrdiff_t &stride(size_t i) const { + return str[i]; + } +}; + +template +class cndarr : public arr_info { + protected: + const char *d; + + public: + cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_) + : arr_info(shape_, stride_), + d(reinterpret_cast(data_)) { + } + const T &operator[](ptrdiff_t ofs) const { + return *reinterpret_cast(d + ofs); + } +}; + +template +class ndarr : public cndarr { + public: + ndarr(void *data_, const shape_t &shape_, const stride_t &stride_) + : cndarr::cndarr(const_cast(data_), shape_, stride_) { + } + T &operator[](ptrdiff_t ofs) { + return *reinterpret_cast(const_cast(cndarr::d + ofs)); + } +}; + +template +class multi_iter { + private: + shape_t pos; + const arr_info &iarr, &oarr; + ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; + size_t idim, rem; + + void advance_i() { + for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) { + auto i = size_t(i_); + if (i == idim) continue; + p_ii += iarr.stride(i); + p_oi += oarr.stride(i); + if (++pos[i] < iarr.shape(i)) return; + pos[i] = 0; + p_ii -= ptrdiff_t(iarr.shape(i)) * iarr.stride(i); + p_oi -= ptrdiff_t(oarr.shape(i)) * oarr.stride(i); + } + } + + public: + multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_) + : pos(iarr_.ndim(), 0), + iarr(iarr_), + oarr(oarr_), + p_ii(0), + str_i(iarr.stride(idim_)), + p_oi(0), + str_o(oarr.stride(idim_)), + idim(idim_), + rem(iarr.size() / iarr.shape(idim)) { + auto nshares = threading::num_threads(); + if (nshares == 1) return; + if (nshares == 0) throw std::runtime_error("can't run with zero threads"); + auto myshare = threading::thread_id(); + if (myshare >= nshares) throw std::runtime_error("impossible share requested"); + size_t nbase = rem / nshares; + size_t additional = rem % nshares; + size_t lo = myshare * nbase + ((myshare < additional) ? myshare : additional); + size_t hi = lo + nbase + (myshare < additional); + size_t todo = hi - lo; + + size_t chunk = rem; + for (size_t i = 0; i < pos.size(); ++i) { + if (i == idim) continue; + chunk /= iarr.shape(i); + size_t n_advance = lo / chunk; + pos[i] += n_advance; + p_ii += ptrdiff_t(n_advance) * iarr.stride(i); + p_oi += ptrdiff_t(n_advance) * oarr.stride(i); + lo -= n_advance * chunk; + } + rem = todo; + } + void advance(size_t n) { + if (rem < n) throw std::runtime_error("underrun"); + for (size_t i = 0; i < n; ++i) { + p_i[i] = p_ii; + p_o[i] = p_oi; + advance_i(); + } + rem -= n; + } + ptrdiff_t iofs(size_t i) const { + return p_i[0] + ptrdiff_t(i) * str_i; + } + ptrdiff_t iofs(size_t j, size_t i) const { + return p_i[j] + ptrdiff_t(i) * str_i; + } + ptrdiff_t oofs(size_t i) const { + return p_o[0] + ptrdiff_t(i) * str_o; + } + ptrdiff_t oofs(size_t j, size_t i) const { + return p_o[j] + ptrdiff_t(i) * str_o; + } + size_t length_in() const { + return iarr.shape(idim); + } + size_t length_out() const { + return oarr.shape(idim); + } + ptrdiff_t stride_in() const { + return str_i; + } + ptrdiff_t stride_out() const { + return str_o; + } + size_t remaining() const { + return rem; + } +}; + +class simple_iter { + private: + shape_t pos; + const arr_info &arr; + ptrdiff_t p; + size_t rem; + + public: + simple_iter(const arr_info &arr_) + : pos(arr_.ndim(), 0), + arr(arr_), + p(0), + rem(arr_.size()) { + } + void advance() { + --rem; + for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) { + auto i = size_t(i_); + p += arr.stride(i); + if (++pos[i] < arr.shape(i)) return; + pos[i] = 0; + p -= ptrdiff_t(arr.shape(i)) * arr.stride(i); + } + } + ptrdiff_t ofs() const { + return p; + } + size_t remaining() const { + return rem; + } +}; + +class rev_iter { + private: + shape_t pos; + const arr_info &arr; + std::vector rev_axis; + std::vector rev_jump; + size_t last_axis, last_size; + shape_t shp; + ptrdiff_t p, rp; + size_t rem; + + public: + rev_iter(const arr_info &arr_, const shape_t &axes) + : pos(arr_.ndim(), 0), + arr(arr_), + rev_axis(arr_.ndim(), 0), + rev_jump(arr_.ndim(), 1), + p(0), + rp(0) { + for (auto ax : axes) rev_axis[ax] = 1; + last_axis = axes.back(); + last_size = arr.shape(last_axis) / 2 + 1; + shp = arr.shape(); + shp[last_axis] = last_size; + rem = 1; + for (auto i : shp) rem *= i; + } + void advance() { + --rem; + for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) { + auto i = size_t(i_); + p += arr.stride(i); + if (!rev_axis[i]) + rp += arr.stride(i); + else { + rp -= arr.stride(i); + if (rev_jump[i]) { + rp += ptrdiff_t(arr.shape(i)) * arr.stride(i); + rev_jump[i] = 0; + } + } + if (++pos[i] < shp[i]) return; + pos[i] = 0; + p -= ptrdiff_t(shp[i]) * arr.stride(i); + if (rev_axis[i]) { + rp -= ptrdiff_t(arr.shape(i) - shp[i]) * arr.stride(i); + rev_jump[i] = 1; + } else + rp -= ptrdiff_t(shp[i]) * arr.stride(i); + } + } + ptrdiff_t ofs() const { + return p; + } + ptrdiff_t rev_ofs() const { + return rp; + } + size_t remaining() const { + return rem; + } +}; + +template +struct VTYPE {}; +template +using vtype_t = typename VTYPE::type; + +#ifndef POCKETFFT_NO_VECTORS +template<> +struct VTYPE { + using type = float __attribute__((vector_size(VLEN::val * sizeof(float)))); +}; +template<> +struct VTYPE { + using type = double __attribute__((vector_size(VLEN::val * sizeof(double)))); +}; +template<> +struct VTYPE { + using type = long double + __attribute__((vector_size(VLEN::val * sizeof(long double)))); +}; +#endif + +template +arr alloc_tmp(const shape_t &shape, size_t axsize, size_t elemsize) { + auto othersize = util::prod(shape) / axsize; + auto tmpsize = axsize * ((othersize >= VLEN::val) ? VLEN::val : 1); + return arr(tmpsize * elemsize); +} +template +arr alloc_tmp(const shape_t &shape, const shape_t &axes, size_t elemsize) { + size_t fullsize = util::prod(shape); + size_t tmpsize = 0; + for (size_t i = 0; i < axes.size(); ++i) { + auto axsize = shape[axes[i]]; + auto othersize = fullsize / axsize; + auto sz = axsize * ((othersize >= VLEN::val) ? VLEN::val : 1); + if (sz > tmpsize) tmpsize = sz; + } + return arr(tmpsize * elemsize); +} + +template +void copy_input( + const multi_iter &it, + const cndarr> &src, + cmplx> *POCKETFFT_RESTRICT dst) { + for (size_t i = 0; i < it.length_in(); ++i) + for (size_t j = 0; j < vlen; ++j) { + dst[i].r[j] = src[it.iofs(j, i)].r; + dst[i].i[j] = src[it.iofs(j, i)].i; + } +} + +template +void copy_input( + const multi_iter &it, + const cndarr &src, + vtype_t *POCKETFFT_RESTRICT dst) { + for (size_t i = 0; i < it.length_in(); ++i) + for (size_t j = 0; j < vlen; ++j) dst[i][j] = src[it.iofs(j, i)]; +} + +template +void copy_input(const multi_iter &it, const cndarr &src, T *POCKETFFT_RESTRICT dst) { + if (dst == &src[it.iofs(0)]) return; // in-place + for (size_t i = 0; i < it.length_in(); ++i) dst[i] = src[it.iofs(i)]; +} + +template +void copy_output( + const multi_iter &it, + const cmplx> *POCKETFFT_RESTRICT src, + ndarr> &dst) { + for (size_t i = 0; i < it.length_out(); ++i) + for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, i)].Set(src[i].r[j], src[i].i[j]); +} + +template +void copy_output( + const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, + ndarr &dst) { + for (size_t i = 0; i < it.length_out(); ++i) + for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, i)] = src[i][j]; +} + +template +void copy_output(const multi_iter &it, const T *POCKETFFT_RESTRICT src, ndarr &dst) { + if (src == &dst[it.oofs(0)]) return; // in-place + for (size_t i = 0; i < it.length_out(); ++i) dst[it.oofs(i)] = src[i]; +} + +template +struct add_vec { + using type = vtype_t; +}; +template +struct add_vec> { + using type = cmplx>; +}; +template +using add_vec_t = typename add_vec::type; + +template +POCKETFFT_NOINLINE void general_nd( + const cndarr &in, + ndarr &out, + const shape_t &axes, + T0 fct, + size_t nthreads, + const Exec &exec, + const bool allow_inplace = true) { + std::shared_ptr plan; + + for (size_t iax = 0; iax < axes.size(); ++iax) { + size_t len = in.shape(axes[iax]); + if ((!plan) || (len != plan->length())) plan = get_plan(len); + + threading::thread_map(util::thread_count(nthreads, in.shape(), axes[iax], VLEN::val), [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + const auto &tin(iax == 0 ? in : out); + multi_iter it(tin, out, axes[iax]); +#ifndef POCKETFFT_NO_VECTORS + if (vlen > 1) + while (it.remaining() >= vlen) { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + exec(it, tin, out, tdatav, *plan, fct); + } +#endif + while (it.remaining() > 0) { + it.advance(1); + auto buf = allow_inplace && it.stride_out() == sizeof(T) + ? &out[it.oofs(0)] + : reinterpret_cast(storage.data()); + exec(it, tin, out, buf, *plan, fct); + } + }); // end of parallel region + fct = T0(1); // factor has been applied, use 1 for remaining axes + } +} + +struct ExecC2C { + bool forward; + + template + void operator()( + const multi_iter &it, + const cndarr> &in, + ndarr> &out, + T *buf, + const pocketfft_c &plan, + T0 fct) const { + copy_input(it, in, buf); + plan.exec(buf, fct, forward); + copy_output(it, buf, out); + } +}; + +template +void copy_hartley( + const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, + ndarr &dst) { + for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, 0)] = src[0][j]; + size_t i = 1, i1 = 1, i2 = it.length_out() - 1; + for (i = 1; i < it.length_out() - 1; i += 2, ++i1, --i2) + for (size_t j = 0; j < vlen; ++j) { + dst[it.oofs(j, i1)] = src[i][j] + src[i + 1][j]; + dst[it.oofs(j, i2)] = src[i][j] - src[i + 1][j]; + } + if (i < it.length_out()) + for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, i1)] = src[i][j]; +} + +template +void copy_hartley(const multi_iter &it, const T *POCKETFFT_RESTRICT src, ndarr &dst) { + dst[it.oofs(0)] = src[0]; + size_t i = 1, i1 = 1, i2 = it.length_out() - 1; + for (i = 1; i < it.length_out() - 1; i += 2, ++i1, --i2) { + dst[it.oofs(i1)] = src[i] + src[i + 1]; + dst[it.oofs(i2)] = src[i] - src[i + 1]; + } + if (i < it.length_out()) dst[it.oofs(i1)] = src[i]; +} + +struct ExecHartley { + template + void operator()( + const multi_iter &it, + const cndarr &in, + ndarr &out, + T *buf, + const pocketfft_r &plan, + T0 fct) const { + copy_input(it, in, buf); + plan.exec(buf, fct, true); + copy_hartley(it, buf, out); + } +}; + +struct ExecDcst { + bool ortho; + int type; + bool cosine; + + template + void operator()( + const multi_iter &it, + const cndarr &in, + ndarr &out, + T *buf, + const Tplan &plan, + T0 fct) const { + copy_input(it, in, buf); + plan.exec(buf, fct, ortho, type, cosine); + copy_output(it, buf, out); + } +}; + +template +POCKETFFT_NOINLINE void general_r2c( + const cndarr &in, + ndarr> &out, + size_t axis, + bool forward, + T fct, + size_t nthreads) { + auto plan = get_plan>(in.shape(axis)); + size_t len = in.shape(axis); + threading::thread_map(util::thread_count(nthreads, in.shape(), axis, VLEN::val), [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen > 1) + while (it.remaining() >= vlen) { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + copy_input(it, in, tdatav); + plan->exec(tdatav, fct, true); + for (size_t j = 0; j < vlen; ++j) out[it.oofs(j, 0)].Set(tdatav[0][j]); + size_t i = 1, ii = 1; + if (forward) + for (; i < len - 1; i += 2, ++ii) + for (size_t j = 0; j < vlen; ++j) + out[it.oofs(j, ii)].Set(tdatav[i][j], tdatav[i + 1][j]); + else + for (; i < len - 1; i += 2, ++ii) + for (size_t j = 0; j < vlen; ++j) + out[it.oofs(j, ii)].Set(tdatav[i][j], -tdatav[i + 1][j]); + if (i < len) + for (size_t j = 0; j < vlen; ++j) out[it.oofs(j, ii)].Set(tdatav[i][j]); + } +#endif + while (it.remaining() > 0) { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + copy_input(it, in, tdata); + plan->exec(tdata, fct, true); + out[it.oofs(0)].Set(tdata[0]); + size_t i = 1, ii = 1; + if (forward) + for (; i < len - 1; i += 2, ++ii) out[it.oofs(ii)].Set(tdata[i], tdata[i + 1]); + else + for (; i < len - 1; i += 2, ++ii) out[it.oofs(ii)].Set(tdata[i], -tdata[i + 1]); + if (i < len) out[it.oofs(ii)].Set(tdata[i]); + } + }); // end of parallel region +} +template +POCKETFFT_NOINLINE void general_c2r( + const cndarr> &in, + ndarr &out, + size_t axis, + bool forward, + T fct, + size_t nthreads) { + auto plan = get_plan>(out.shape(axis)); + size_t len = out.shape(axis); + threading::thread_map(util::thread_count(nthreads, in.shape(), axis, VLEN::val), [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(out.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen > 1) + while (it.remaining() >= vlen) { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + for (size_t j = 0; j < vlen; ++j) tdatav[0][j] = in[it.iofs(j, 0)].r; + { + size_t i = 1, ii = 1; + if (forward) + for (; i < len - 1; i += 2, ++ii) + for (size_t j = 0; j < vlen; ++j) { + tdatav[i][j] = in[it.iofs(j, ii)].r; + tdatav[i + 1][j] = -in[it.iofs(j, ii)].i; + } + else + for (; i < len - 1; i += 2, ++ii) + for (size_t j = 0; j < vlen; ++j) { + tdatav[i][j] = in[it.iofs(j, ii)].r; + tdatav[i + 1][j] = in[it.iofs(j, ii)].i; + } + if (i < len) + for (size_t j = 0; j < vlen; ++j) tdatav[i][j] = in[it.iofs(j, ii)].r; + } + plan->exec(tdatav, fct, false); + copy_output(it, tdatav, out); + } +#endif + while (it.remaining() > 0) { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + tdata[0] = in[it.iofs(0)].r; + { + size_t i = 1, ii = 1; + if (forward) + for (; i < len - 1; i += 2, ++ii) { + tdata[i] = in[it.iofs(ii)].r; + tdata[i + 1] = -in[it.iofs(ii)].i; + } + else + for (; i < len - 1; i += 2, ++ii) { + tdata[i] = in[it.iofs(ii)].r; + tdata[i + 1] = in[it.iofs(ii)].i; + } + if (i < len) tdata[i] = in[it.iofs(ii)].r; + } + plan->exec(tdata, fct, false); + copy_output(it, tdata, out); + } + }); // end of parallel region +} + +struct ExecR2R { + bool r2h, forward; + + template + void operator()( + const multi_iter &it, + const cndarr &in, + ndarr &out, + T *buf, + const pocketfft_r &plan, + T0 fct) const { + copy_input(it, in, buf); + if ((!r2h) && forward) + for (size_t i = 2; i < it.length_out(); i += 2) buf[i] = -buf[i]; + plan.exec(buf, fct, r2h); + if (r2h && (!forward)) + for (size_t i = 2; i < it.length_out(); i += 2) buf[i] = -buf[i]; + copy_output(it, buf, out); + } +}; + +template +void c2c( + const shape_t &shape, + const stride_t &stride_in, + const stride_t &stride_out, + const shape_t &axes, + bool forward, + const std::complex *data_in, + std::complex *data_out, + T fct, + size_t nthreads = 1) { + if (util::prod(shape) == 0) return; + util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); + cndarr> ain(data_in, shape, stride_in); + ndarr> aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); +} + +template +void dct( + const shape_t &shape, + const stride_t &stride_in, + const stride_t &stride_out, + const shape_t &axes, + int type, + const T *data_in, + T *data_out, + T fct, + bool ortho, + size_t nthreads = 1) { + if ((type < 1) || (type > 4)) throw std::invalid_argument("invalid DCT type"); + if (util::prod(shape) == 0) return; + util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, true}; + if (type == 1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type == 4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); +} + +template +void dst( + const shape_t &shape, + const stride_t &stride_in, + const stride_t &stride_out, + const shape_t &axes, + int type, + const T *data_in, + T *data_out, + T fct, + bool ortho, + size_t nthreads = 1) { + if ((type < 1) || (type > 4)) throw std::invalid_argument("invalid DST type"); + if (util::prod(shape) == 0) return; + util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, false}; + if (type == 1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type == 4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); +} + +template +void r2c( + const shape_t &shape_in, + const stride_t &stride_in, + const stride_t &stride_out, + size_t axis, + bool forward, + const T *data_in, + std::complex *data_out, + T fct, + size_t nthreads = 1) { + if (util::prod(shape_in) == 0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axis); + cndarr ain(data_in, shape_in, stride_in); + shape_t shape_out(shape_in); + shape_out[axis] = shape_in[axis] / 2 + 1; + ndarr> aout(data_out, shape_out, stride_out); + general_r2c(ain, aout, axis, forward, fct, nthreads); +} + +template +void r2c( + const shape_t &shape_in, + const stride_t &stride_in, + const stride_t &stride_out, + const shape_t &axes, + bool forward, + const T *data_in, + std::complex *data_out, + T fct, + size_t nthreads = 1) { + if (util::prod(shape_in) == 0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axes); + r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out, fct, nthreads); + if (axes.size() == 1) return; + + shape_t shape_out(shape_in); + shape_out[axes.back()] = shape_in[axes.back()] / 2 + 1; + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out, T(1), nthreads); +} + +template +void c2r( + const shape_t &shape_out, + const stride_t &stride_in, + const stride_t &stride_out, + size_t axis, + bool forward, + const std::complex *data_in, + T *data_out, + T fct, + size_t nthreads = 1) { + if (util::prod(shape_out) == 0) return; + util::sanity_check(shape_out, stride_in, stride_out, false, axis); + shape_t shape_in(shape_out); + shape_in[axis] = shape_out[axis] / 2 + 1; + cndarr> ain(data_in, shape_in, stride_in); + ndarr aout(data_out, shape_out, stride_out); + general_c2r(ain, aout, axis, forward, fct, nthreads); +} + +template +void c2r( + const shape_t &shape_out, + const stride_t &stride_in, + const stride_t &stride_out, + const shape_t &axes, + bool forward, + const std::complex *data_in, + T *data_out, + T fct, + size_t nthreads = 1) { + if (util::prod(shape_out) == 0) return; + if (axes.size() == 1) + return c2r( + shape_out, + stride_in, + stride_out, + axes[0], + forward, + data_in, + data_out, + fct, + nthreads); + util::sanity_check(shape_out, stride_in, stride_out, false, axes); + auto shape_in = shape_out; + shape_in[axes.back()] = shape_out[axes.back()] / 2 + 1; + auto nval = util::prod(shape_in); + stride_t stride_inter(shape_in.size()); + stride_inter.back() = sizeof(cmplx); + for (int i = int(shape_in.size()) - 2; i >= 0; --i) + stride_inter[size_t(i)] = stride_inter[size_t(i + 1)] * ptrdiff_t(shape_in[size_t(i + 1)]); + arr> tmp(nval); + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(), T(1), nthreads); + c2r(shape_out, + stride_inter, + stride_out, + axes.back(), + forward, + tmp.data(), + data_out, + fct, + nthreads); +} + +template +void r2r_fftpack( + const shape_t &shape, + const stride_t &stride_in, + const stride_t &stride_out, + const shape_t &axes, + bool real2hermitian, + bool forward, + const T *data_in, + T *data_out, + T fct, + size_t nthreads = 1) { + if (util::prod(shape) == 0) return; + util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecR2R{real2hermitian, forward}); +} + +template +void r2r_separable_hartley( + const shape_t &shape, + const stride_t &stride_in, + const stride_t &stride_out, + const shape_t &axes, + const T *data_in, + T *data_out, + T fct, + size_t nthreads = 1) { + if (util::prod(shape) == 0) return; + util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecHartley{}, false); +} + +template +void r2r_genuine_hartley( + const shape_t &shape, + const stride_t &stride_in, + const stride_t &stride_out, + const shape_t &axes, + const T *data_in, + T *data_out, + T fct, + size_t nthreads = 1) { + if (util::prod(shape) == 0) return; + if (axes.size() == 1) + return r2r_separable_hartley( + shape, + stride_in, + stride_out, + axes, + data_in, + data_out, + fct, + nthreads); + util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes); + shape_t tshp(shape); + tshp[axes.back()] = tshp[axes.back()] / 2 + 1; + arr> tdata(util::prod(tshp)); + stride_t tstride(shape.size()); + tstride.back() = sizeof(std::complex); + for (size_t i = tstride.size() - 1; i > 0; --i) tstride[i - 1] = tstride[i] * ptrdiff_t(tshp[i]); + r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads); + cndarr> atmp(tdata.data(), tshp, tstride); + ndarr aout(data_out, shape, stride_out); + simple_iter iin(atmp); + rev_iter iout(aout, axes); + while (iin.remaining() > 0) { + auto v = atmp[iin.ofs()]; + aout[iout.ofs()] = v.r + v.i; + aout[iout.rev_ofs()] = v.r - v.i; + iin.advance(); + iout.advance(); + } +} + +} // namespace detail + +using detail::BACKWARD; +using detail::c2c; +using detail::c2r; +using detail::dct; +using detail::dst; +using detail::FORWARD; +using detail::r2c; +using detail::r2r_fftpack; +using detail::r2r_genuine_hartley; +using detail::r2r_separable_hartley; +using detail::shape_t; +using detail::stride_t; + +} // namespace pocketfft + +#undef POCKETFFT_NOINLINE +#undef POCKETFFT_RESTRICT + +#endif // POCKETFFT_HDRONLY_H \ No newline at end of file diff --git a/tools/model_exporter.py b/tools/model_exporter.py index 14aad331..72578dd6 100644 --- a/tools/model_exporter.py +++ b/tools/model_exporter.py @@ -253,6 +253,11 @@ def export_embedding(self, ctx: Context, module: nn.Embedding): ctx = ctx.with_subname("weight") self._write(ctx, module.weight) + def export_linear(self, ctx: Context, module, has_bias=True): + self._write(ctx.with_subname("weight"), module.weight) + if has_bias: + self._write(ctx.with_subname("bias").with_quant(Quant.NONE), module.bias) + def export_layer_norm(self, ctx: Context, module: nn.LayerNorm): self._write(ctx.with_subname("weight").with_quant(Quant.NONE), module.weight) self._write(ctx.with_subname("bias").with_quant(Quant.NONE), module.bias) diff --git a/tools/whisper_exporter.py b/tools/whisper_exporter.py new file mode 100644 index 00000000..bc767456 --- /dev/null +++ b/tools/whisper_exporter.py @@ -0,0 +1,228 @@ +# The MIT License (MIT) +# +# Copyright (c) 2023 Xiaoyang Chen +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software +# and associated documentation files (the "Software"), to deal in the Software without +# restriction, including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or +# substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import argparse +import torch +import tempfile +import configparser +import zipfile +import io +import sys +import os +import urllib +from os import path +from model_exporter import Context, ModelExporter, TensorWriter, Quant +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline +from bpe_exporter import read_spm_model, read_transformers_fast_bpe_model + +class WhisperExporter(ModelExporter): + def __init__(self, writer: TensorWriter) -> None: + super().__init__(writer) + + def _export_conv1d(self, ctx: Context, module): + self._write(ctx.with_subname("weight"), module.weight) + self._write(ctx.with_subname("bias").with_quant(Quant.NONE), module.bias) + assert(module.dilation == (1, )) + assert(module.padding == (1, )) + + def _export_enc_pos_embd(self, ctx: Context, module): + self._write(ctx, module.weight) + + def _export_attn(self, ctx: Context, module, cross_attn=False): + q_proj = module.q_proj.weight + k_proj = module.k_proj.weight + v_proj = module.v_proj.weight + + if cross_attn: + kv_proj = torch.cat((k_proj, v_proj), dim=0) + self._write(ctx.with_subname("kv_proj.weight"), kv_proj) + self._write(ctx.with_subname("q_proj.weight"), q_proj) + else: + qkv_proj = torch.cat((q_proj, k_proj, v_proj), dim=0) + self._write(ctx.with_subname("qkv_proj.weight"), qkv_proj) + + assert module.q_proj.bias is not None + assert module.v_proj.bias is not None + assert module.k_proj.bias is None + q_bias = module.q_proj.bias + k_bias = torch.zeros_like(q_bias) + v_bias = module.v_proj.bias + + if cross_attn: + kv_bias = torch.cat((k_bias, v_bias), dim=0) + self._write(ctx.with_subname("kv_proj.bias").with_quant(Quant.NONE), kv_bias) + self._write(ctx.with_subname("q_proj.bias").with_quant(Quant.NONE), q_bias) + else: + qkv_bias = torch.cat((q_bias, k_bias, v_bias), dim=0) + self._write(ctx.with_subname("qkv_proj.bias").with_quant(Quant.NONE), qkv_bias) + + self._write(ctx.with_subname("out_proj.weight"), module.out_proj.weight) + self._write(ctx.with_subname("out_proj.bias").with_quant(Quant.NONE), module.out_proj.bias) + + def _export_encoder_layer(self, ctx: Context, module): + self.export_layer_norm(ctx.with_subname("norm1"), module.self_attn_layer_norm) + self.export_layer_norm(ctx.with_subname("norm2"), module.final_layer_norm) + self._export_attn(ctx.with_subname("attn"), module.self_attn) + self.export_linear(ctx.with_subname("fc1"), module.fc1) + self.export_linear(ctx.with_subname("fc2"), module.fc2) + + def _export_encoder(self, ctx: Context, module): + self._export_conv1d(ctx.with_subname("conv1"), module.conv1) + self._export_conv1d(ctx.with_subname("conv2"), module.conv2) + self._export_enc_pos_embd(ctx.with_subname("pos_embd").with_quant(Quant.NONE), module.embed_positions) + for idx, layer in enumerate(module.layers): + self._export_encoder_layer(ctx.with_subname(f"layer{idx}"), layer) + self.export_layer_norm(ctx.with_subname("norm"), module.layer_norm) + + def _export_decoder_layer(self, ctx: Context, module): + self.export_layer_norm(ctx.with_subname("norm1"), module.self_attn_layer_norm) + self.export_layer_norm(ctx.with_subname("norm2"), module.encoder_attn_layer_norm) + self.export_layer_norm(ctx.with_subname("norm3"), module.final_layer_norm) + self._export_attn(ctx.with_subname("self_attn"), module.self_attn) + self._export_attn(ctx.with_subname("cross_attn"), module.encoder_attn, cross_attn=True) + self.export_linear(ctx.with_subname("fc1"), module.fc1) + self.export_linear(ctx.with_subname("fc2"), module.fc2) + + def _export_decoder(self, ctx: Context, module): + self.export_embedding(ctx.with_subname("embd"), module.embed_tokens) + self._export_enc_pos_embd(ctx.with_subname("pos_embd").with_quant(Quant.NONE), module.embed_positions) + for idx, layer in enumerate(module.layers): + self._export_decoder_layer(ctx.with_subname(f"layer{idx}"), layer) + self.export_layer_norm(ctx.with_subname("norm"), module.layer_norm) + + @classmethod + def generate_config(cls, whisper_config) -> configparser.ConfigParser: + config = configparser.ConfigParser() + config["whisper"] = {} + + section = config["whisper"] + section["hidden_size"] = str(whisper_config.hidden_size) + section["encoder_num_heads"] = str(whisper_config.encoder_attention_heads) + section["encoder_ffn_dim"] = str(whisper_config.encoder_ffn_dim) + section["encoder_num_layers"] = str(whisper_config.encoder_layers) + section["decoder_num_layers"] = str(whisper_config.decoder_layers) + section["decoder_ffn_dim"] = str(whisper_config.decoder_ffn_dim) + section["vocab_size"] = str(whisper_config.vocab_size) + section["max_tgt_length"] = str(whisper_config.max_target_positions) + + return config + + def _export(self, ctx: Context, whisper_model): + model = whisper_model.base_model + self._export_encoder(ctx.with_subname("encoder"), model.encoder) + self._export_decoder(ctx.with_subname("decoder"), model.decoder) + self.export_linear( + ctx.with_subname("decoder").with_subname("out_proj"), + whisper_model.proj_out, + has_bias=False) + + @classmethod + def export(cls, whisper_model, fp, quant: Quant): + config = whisper_model.config + + assert config.activation_function == "gelu" + + ctx = Context("whisper", quant=quant) + with TensorWriter(fp) as writer: + exporter = cls(writer) + exporter._export(ctx, whisper_model) + + ini_config = cls.generate_config(config) + ini_config["model"] = {} + ini_config["model"]["type"] = "whisper" + ini_config["model"]["model_file"] = path.basename(MODEL_BIN) + + return ini_config + +HELLO_URL = "https://upload.wikimedia.org/wikipedia/commons/9/9a/En-us-hello-2.ogg" + +def run_whisper(huggingface_name): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + + model_id = huggingface_name + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True + ) + model.to(device) + + processor = AutoProcessor.from_pretrained(model_id) + pipe = pipeline( + "automatic-speech-recognition", + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + max_new_tokens=128, + chunk_length_s=30, + batch_size=16, + return_timestamps=True, + torch_dtype=torch_dtype, + device=device, + ) + + audio_file = path.join(tempfile.gettempdir(), "en-us-hello.ogg") + if not path.exists(audio_file): + urllib.request.urlretrieve(HELLO_URL, audio_file) + + print(pipe(audio_file)) + +MODEL_NAME = "openai/whisper-large-v3" +MODEL_BIN = "model.bin" +MODEL_INI = "model.ini" +TOKENIZER_BIN = "tokenizer.bin" +TOKENIZER_INI = "tokenizer.ini" + +if __name__ == '__main__': + from transformers import AutoTokenizer + + + parser = argparse.ArgumentParser(description='export whisper model from huggingface to libllm format.') + parser.add_argument('-huggingface_name', type=str, help='the whisper model name in huggingface.', default=MODEL_NAME) + parser.add_argument('-quant', type=Quant.parse, help='quantization type, "q4" or "none"', default=Quant.Q4) + parser.add_argument('-output', type=str, help='output file name.', default="whisper.llmpkg") + parser.add_argument('-run', action="store_true") + args = parser.parse_args() + + if args.run: + run_whisper(args.huggingface_name) + sys.exit(0) + + tokenizer = AutoTokenizer.from_pretrained(args.huggingface_name, trust_remote_code=True) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + model = AutoModelForSpeechSeq2Seq.from_pretrained( + args.huggingface_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True + ) + model = model.eval() + + with zipfile.ZipFile(args.output, "w", compression=zipfile.ZIP_STORED) as package: + whisper_tokenizer = read_transformers_fast_bpe_model(args.huggingface_name) + + with package.open(MODEL_BIN, "w", force_zip64=True) as fp: + config = WhisperExporter.export(model, fp, args.quant) + + config["whisper"]["eot_token_id"] = str(tokenizer.eos_token_id) + with package.open(MODEL_INI, "w", force_zip64=True) as fp: + config.write(io.TextIOWrapper(fp)) + + with package.open(TOKENIZER_BIN, "w", force_zip64=True) as fp: + whisper_tokenizer.save(fp) + + with package.open(TOKENIZER_INI, "w", force_zip64=True) as fp: + whisper_tokenizer.get_config().to_ini(TOKENIZER_BIN).write(io.TextIOWrapper(fp))