Skip to content

Commit

Permalink
Add TestLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Feb 12, 2025
1 parent fcaad91 commit 44c22c4
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 11 deletions.
29 changes: 26 additions & 3 deletions common/mock/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ package mock

import (
"bytes"
"crypto/md5"
"encoding/json"
"fmt"
"github.com/emicklei/go-restful/v3"
"github.com/sashabaranov/go-openai"
"net"
"net/http"

"github.com/emicklei/go-restful/v3"
"github.com/samber/lo"
"github.com/sashabaranov/go-openai"
)

type OpenAIServer struct {
Expand Down Expand Up @@ -114,15 +117,35 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon
}

func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) {
// parse request
var r openai.EmbeddingRequest
err := req.ReadEntity(&r)
if err != nil {
_ = resp.WriteError(http.StatusBadRequest, err)
return
}
input, ok := r.Input.(string)
if !ok {
_ = resp.WriteError(http.StatusBadRequest, fmt.Errorf("invalid input type"))
return
}

// write response
_ = resp.WriteEntity(openai.EmbeddingResponse{
Data: []openai.Embedding{{
Embedding: make([]float32, 1024),
Embedding: Hash(input),
}},
})
}

func Hash(input string) []float32 {
hasher := md5.New()
_, err := hasher.Write([]byte(input))
if err != nil {
panic(err)
}
h := hasher.Sum(nil)
return lo.Map(h, func(b byte, _ int) float32 {
return float32(b)
})
}
9 changes: 5 additions & 4 deletions common/mock/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ package mock

import (
"context"
"github.com/juju/errors"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/suite"
"io"
"strings"
"testing"

"github.com/juju/errors"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/suite"
)

type OpenAITestSuite struct {
Expand Down Expand Up @@ -105,7 +106,7 @@ func (suite *OpenAITestSuite) TestEmbeddings() {
},
)
suite.NoError(err)
suite.Equal(make([]float32, 1024), resp.Data[0].Embedding)
suite.Equal([]float32{139, 26, 153, 83, 196, 97, 18, 150, 168, 39, 171, 248, 196, 120, 4, 215}, resp.Data[0].Embedding)
}

func TestOpenAITestSuite(t *testing.T) {
Expand Down
93 changes: 91 additions & 2 deletions logics/item_to_item.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@
package logics

import (
"context"
"errors"
"fmt"
"sort"
"strings"
"time"

"github.com/chewxy/math32"
mapset "github.com/deckarep/golang-set/v2"
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"
"github.com/nikolalohinski/gonja/v2"
"github.com/nikolalohinski/gonja/v2/exec"
"github.com/samber/lo"
"github.com/sashabaranov/go-openai"
"github.com/zhenghaoz/gorse/base/floats"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/common/ann"
Expand All @@ -35,8 +41,9 @@ import (
)

type ItemToItemOptions struct {
TagsIDF []float32
UsersIDF []float32
TagsIDF []float32
UsersIDF []float32
OpenAIConfig config.OpenAIConfig
}

type ItemToItem interface {
Expand Down Expand Up @@ -64,6 +71,11 @@ func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, opts
return nil, errors.New("tags and users IDF are required for auto item-to-item")
}
return newAutoItemToItem(cfg, n, timestamp, opts.TagsIDF, opts.UsersIDF)
case "llm":
if opts == nil || opts.OpenAIConfig.BaseURL == "" || opts.OpenAIConfig.AuthToken == "" {
return nil, errors.New("OpenAI config is required for LLM item-to-item")
}
return newGenerativeItemToItem(cfg, n, timestamp, opts.OpenAIConfig)
default:
return nil, errors.New("invalid item-to-item type")
}
Expand Down Expand Up @@ -338,3 +350,80 @@ func flatten(o any, tSet mapset.Set[dataset.ID]) {
}
}
}

type llmItemToItem struct {
*embeddingItemToItem
template *exec.Template
client *openai.Client
chatModel string
embeddingModel string
}

func newGenerativeItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, openaiConfig config.OpenAIConfig) (*llmItemToItem, error) {
// create embedding item-to-item recommender
embedding, err := newEmbeddingItemToItem(cfg, n, timestamp)
if err != nil {
return nil, err
}
// parse template
template, err := gonja.FromString(cfg.Prompt)
if err != nil {
return nil, err
}
// create openai client
clientConfig := openai.DefaultConfig(openaiConfig.AuthToken)
clientConfig.BaseURL = openaiConfig.BaseURL
return &llmItemToItem{
embeddingItemToItem: embedding,
template: template,
client: openai.NewClientWithConfig(clientConfig),
chatModel: openaiConfig.ChatCompletionModel,
embeddingModel: openaiConfig.EmbeddingsModel,
}, nil
}

func (g *llmItemToItem) PopAll(i int) []cache.Score {
// render template
var buf strings.Builder
ctx := exec.NewContext(map[string]any{
"item": g.items[i],
})
if err := g.template.Execute(&buf, ctx); err != nil {
log.Logger().Error("failed to execute template", zap.Error(err))
return nil
}
fmt.Println(buf.String())
// chat completion
resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: g.chatModel,
Messages: []openai.ChatCompletionMessage{{
Role: openai.ChatMessageRoleUser,
Content: buf.String(),
}},
})
if err != nil {
log.Logger().Error("failed to chat completion", zap.Error(err))
return nil
}
message := resp.Choices[0].Message.Content
// message embedding
resp2, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
Input: message,
Model: openai.EmbeddingModel(g.embeddingModel),
})
if err != nil {
log.Logger().Error("failed to create embeddings", zap.Error(err))
return nil
}
embedding := resp2.Data[0].Embedding
// search index
scores := g.index.SearchVector(embedding, g.n+1, true)
return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score {
return cache.Score{
Id: g.items[v.A].ItemId,
Categories: g.items[v.A].Categories,
Score: -float64(v.B),
Timestamp: g.timestamp,
}
})
}
39 changes: 39 additions & 0 deletions logics/item_to_item_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"time"

"github.com/stretchr/testify/suite"
"github.com/zhenghaoz/gorse/base/floats"
"github.com/zhenghaoz/gorse/common/mock"
"github.com/zhenghaoz/gorse/config"
"github.com/zhenghaoz/gorse/dataset"
"github.com/zhenghaoz/gorse/storage/data"
Expand Down Expand Up @@ -194,6 +196,43 @@ func (suite *ItemToItemTestSuite) TestAuto() {
}
}

func (suite *ItemToItemTestSuite) TestLLM() {
mockAI := mock.NewOpenAIServer()
go mockAI.Start()

Check failure on line 201 in logics/item_to_item_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `mockAI.Start` is not checked (errcheck)
mockAI.Ready()
defer mockAI.Close()

timestamp := time.Now()
item2item, err := newGenerativeItemToItem(config.ItemToItemConfig{
Column: "item.Labels.embeddings",
Prompt: "Please generate similar items for {{ item.Labels.title }}.",
}, 10, timestamp, config.OpenAIConfig{
BaseURL: mockAI.BaseURL(),
AuthToken: mockAI.AuthToken(),
ChatCompletionModel: "gpt-3.5-turbo",
EmbeddingsModel: "text-similarity-ada-001",
})
suite.NoError(err)

for i := 0; i < 100; i++ {
embedding := mock.Hash("Please generate similar items for item_0.")
floats.AddConst(embedding, float32(i))
item2item.Push(&data.Item{
ItemId: strconv.Itoa(i),
Labels: map[string]any{
"title": "item_" + strconv.Itoa(i),
"embeddings": embedding,
},
}, nil)
}

scores := item2item.PopAll(0)
suite.Len(scores, 10)
for i := 1; i <= 10; i++ {
suite.Equal(strconv.Itoa(i), scores[i-1].Id)
}
}

func TestItemToItem(t *testing.T) {
suite.Run(t, new(ItemToItemTestSuite))
}
5 changes: 3 additions & 2 deletions master/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -1012,8 +1012,9 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error {
itemToItemRecommenders := make([]logics.ItemToItem, 0, len(itemToItemConfigs))
for _, cfg := range itemToItemConfigs {
recommender, err := logics.NewItemToItem(cfg, m.Config.Recommend.CacheSize, dataset.GetTimestamp(), &logics.ItemToItemOptions{
TagsIDF: dataset.GetItemColumnValuesIDF(),
UsersIDF: dataset.GetUserIDF(),
TagsIDF: dataset.GetItemColumnValuesIDF(),
UsersIDF: dataset.GetUserIDF(),
OpenAIConfig: m.Config.OpenAI,
})
if err != nil {
return errors.Trace(err)
Expand Down

0 comments on commit 44c22c4

Please sign in to comment.