Skip to content

Commit

Permalink
Add Go bindings (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 authored Jun 23, 2024
1 parent 67bb2a9 commit d14f88a
Show file tree
Hide file tree
Showing 11 changed files with 997 additions and 16 deletions.
9 changes: 9 additions & 0 deletions go/cmd/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module main

go 1.22.4

replace github.com/ling0322/libllm/go/llm => ../llm

require (
github.com/ling0322/libllm/go/llm v1.0.0
)
42 changes: 42 additions & 0 deletions go/cmd/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package main

import (
"fmt"
"log"

"github.com/ling0322/libllm/go/llm"
)

func main() {
model, err := llm.NewModel("../../tools/llama.llmpkg", llm.Cuda)
if err != nil {
log.Fatal(err)
}

prompt := llm.NewPrompt()
prompt.AppendControlToken("<|begin_of_text|>")
prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText("user")
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\nHi")
prompt.AppendControlToken("<|eot_id|>")
prompt.AppendControlToken("<|start_header_id|>")
prompt.AppendText("assistant")
prompt.AppendControlToken("<|end_header_id|>")
prompt.AppendText("\n\n")

log.Println(model)

comp, err := model.Complete(llm.NewCompletionConfig(), prompt)
if err != nil {
log.Fatal(err)
}

for comp.IsActive() {
chunk, err := comp.GenerateNextChunk()
if err != nil {
log.Fatal(err)
}
fmt.Printf(chunk.Text)
}
}
48 changes: 32 additions & 16 deletions bindings/go/libllm/llm.go → go/llm/chunk.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2023 Xiaoyang Chen
// Copyright (c) 2024 Xiaoyang Chen
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
// and associated documentation files (the "Software"), to deal in the Software without
Expand All @@ -17,26 +17,42 @@
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package libllm
package llm

// #include <stdlib.h>
// #include "llm_api.h"
import "C"
import (
"unsafe"
"errors"
"fmt"
"os"
"runtime"
)

type Device int

const (
Cpu Device = iota
Cuda
Auto
)

type Model struct {
handle unsafe.Pointer
// Generate by Compeltion.
type Chunk struct {
Text string
}

func LoadModel(filename string) (model *Model, err error) {

type chunkHandle struct {
handle *C.llmChunk_t
}

func (m *Model) Complete()
func newChunkHandle() (*chunkHandle, error) {
cHandle := C.llmChunk_New()
if cHandle == nil {
return nil, errors.New(C.GoString(C.llmGetLastErrorMessage()))
}

handle := &chunkHandle{
cHandle,
}
runtime.SetFinalizer(handle, func(h *chunkHandle) {
status := C.llmChunk_Delete(h.handle)
if status != C.LLM_OK {
fmt.Fprintln(os.Stderr, "failed to call llmPrompt_Delete()")
}
})

return handle, nil
}
101 changes: 101 additions & 0 deletions go/llm/completion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// The MIT License (MIT)
//
// Copyright (c) 2024 Xiaoyang Chen
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
// and associated documentation files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package llm

// #include <stdlib.h>
// #include "llm_api.h"
import "C"
import (
"errors"
"fmt"
"os"
"runtime"
)

// Config for LLM completion.
type Completion interface {
IsActive() bool
GenerateNextChunk() (Chunk, error)
}

type completionHandle struct {
handle *C.llmCompletion_t
}

type completionImpl struct {
handle *completionHandle
chunkHandle *chunkHandle
}

func (c *completionImpl) IsActive() bool {
return C.llmCompletion_IsActive(c.handle.handle) != 0
}

func (c *completionImpl) GenerateNextChunk() (Chunk, error) {
status := C.llmCompletion_GenerateNextChunk(c.handle.handle, c.chunkHandle.handle)
if status != C.LLM_OK {
return Chunk{}, errors.New(C.GoString(C.llmGetLastErrorMessage()))
}

chunk := Chunk{}
cText := C.llmChunk_GetText(c.chunkHandle.handle)
if cText == nil {
return Chunk{}, errors.New(C.GoString(C.llmGetLastErrorMessage()))
}

chunk.Text = C.GoString(cText)
return chunk, nil
}

func newCompletionImpl(modelHandle *modelHandle) (*completionImpl, error) {
handle, err := newCompletionHandle(modelHandle)
if err != nil {
return nil, err
}

chunkHandle, err := newChunkHandle()
if err != nil {
return nil, err
}

return &completionImpl{
handle: handle,
chunkHandle: chunkHandle,
}, nil
}

func newCompletionHandle(m *modelHandle) (*completionHandle, error) {
cHandle := C.llmCompletion_New(m.handle)
if cHandle == nil {
return nil, errors.New(C.GoString(C.llmGetLastErrorMessage()))
}

handle := &completionHandle{
cHandle,
}
runtime.SetFinalizer(handle, func(h *completionHandle) {
status := C.llmCompletion_Delete(h.handle)
if status != C.LLM_OK {
fmt.Fprintln(os.Stderr, "failed to call llmCompletion_Delete()")
}
})

return handle, nil
}
94 changes: 94 additions & 0 deletions go/llm/completion_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// The MIT License (MIT)
//
// Copyright (c) 2024 Xiaoyang Chen
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software
// and associated documentation files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
// BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package llm

// #include <stdlib.h>
// #include "llm_api.h"
import "C"
import "errors"

// Config for LLM completion.
type CompletionConfig interface {
SetTopP(topP float32)
GetTopP() float32

SetTopK(topK int)
GetTopK() int

SetTemperature(temperature float32)
GetTemperature() float32

// update the llmCompletion_t according to the config.
updateCompHandle(compHandle *completionHandle) error
}

type completionConfigImpl struct {
topP float32
topK int
temperature float32
}

func NewCompletionConfig() CompletionConfig {
return &completionConfigImpl{
topP: 0.8,
topK: 50,
temperature: 1.0,
}
}

func (c *completionConfigImpl) SetTopP(topP float32) {
c.topP = topP
}

func (c *completionConfigImpl) GetTopP() float32 {
return c.topP
}

func (c *completionConfigImpl) SetTopK(topK int) {
c.topK = topK
}

func (c *completionConfigImpl) GetTopK() int {
return c.topK
}

func (c *completionConfigImpl) SetTemperature(temperature float32) {
c.temperature = temperature
}

func (c *completionConfigImpl) GetTemperature() float32 {
return c.temperature
}

func (c *completionConfigImpl) updateCompHandle(compHandle *completionHandle) error {
if C.llmCompletion_SetTopP(compHandle.handle, C.float(c.topP)) != C.LLM_OK {
return errors.New(C.GoString(C.llmGetLastErrorMessage()))
}

if C.llmCompletion_SetTopK(compHandle.handle, C.int(c.topK)) != C.LLM_OK {
return errors.New(C.GoString(C.llmGetLastErrorMessage()))
}

if C.llmCompletion_SetTemperature(compHandle.handle, C.float(c.temperature)) != C.LLM_OK {
return errors.New(C.GoString(C.llmGetLastErrorMessage()))
}

return nil
}
3 changes: 3 additions & 0 deletions go/llm/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module github.com/ling0322/libllm/go/llm

go 1.22.4
Loading

0 comments on commit d14f88a

Please sign in to comment.