From 44c22c4f5ec92972b2c60096b41eabdf1271a510 Mon Sep 17 00:00:00 2001 From: Zhenghao Zhang Date: Wed, 12 Feb 2025 20:48:10 +0800 Subject: [PATCH] Add TestLLM --- common/mock/openai.go | 29 ++++++++++-- common/mock/openai_test.go | 9 ++-- logics/item_to_item.go | 93 ++++++++++++++++++++++++++++++++++++- logics/item_to_item_test.go | 39 ++++++++++++++++ master/tasks.go | 5 +- 5 files changed, 164 insertions(+), 11 deletions(-) diff --git a/common/mock/openai.go b/common/mock/openai.go index 390b3849d..834b78509 100644 --- a/common/mock/openai.go +++ b/common/mock/openai.go @@ -16,12 +16,15 @@ package mock import ( "bytes" + "crypto/md5" "encoding/json" "fmt" - "github.com/emicklei/go-restful/v3" - "github.com/sashabaranov/go-openai" "net" "net/http" + + "github.com/emicklei/go-restful/v3" + "github.com/samber/lo" + "github.com/sashabaranov/go-openai" ) type OpenAIServer struct { @@ -114,15 +117,35 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon } func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) { + // parse request var r openai.EmbeddingRequest err := req.ReadEntity(&r) if err != nil { _ = resp.WriteError(http.StatusBadRequest, err) return } + input, ok := r.Input.(string) + if !ok { + _ = resp.WriteError(http.StatusBadRequest, fmt.Errorf("invalid input type")) + return + } + + // write response _ = resp.WriteEntity(openai.EmbeddingResponse{ Data: []openai.Embedding{{ - Embedding: make([]float32, 1024), + Embedding: Hash(input), }}, }) } + +func Hash(input string) []float32 { + hasher := md5.New() + _, err := hasher.Write([]byte(input)) + if err != nil { + panic(err) + } + h := hasher.Sum(nil) + return lo.Map(h, func(b byte, _ int) float32 { + return float32(b) + }) +} diff --git a/common/mock/openai_test.go b/common/mock/openai_test.go index 7453de940..592d1ac7a 100644 --- a/common/mock/openai_test.go +++ b/common/mock/openai_test.go @@ -16,12 +16,13 @@ package mock import ( "context" - "github.com/juju/errors" - "github.com/sashabaranov/go-openai" - "github.com/stretchr/testify/suite" "io" "strings" "testing" + + "github.com/juju/errors" + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/suite" ) type OpenAITestSuite struct { @@ -105,7 +106,7 @@ func (suite *OpenAITestSuite) TestEmbeddings() { }, ) suite.NoError(err) - suite.Equal(make([]float32, 1024), resp.Data[0].Embedding) + suite.Equal([]float32{139, 26, 153, 83, 196, 97, 18, 150, 168, 39, 171, 248, 196, 120, 4, 215}, resp.Data[0].Embedding) } func TestOpenAITestSuite(t *testing.T) { diff --git a/logics/item_to_item.go b/logics/item_to_item.go index 9d19be35c..7358e5f4b 100644 --- a/logics/item_to_item.go +++ b/logics/item_to_item.go @@ -15,15 +15,21 @@ package logics import ( + "context" "errors" + "fmt" "sort" + "strings" "time" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/expr-lang/expr" "github.com/expr-lang/expr/vm" + "github.com/nikolalohinski/gonja/v2" + "github.com/nikolalohinski/gonja/v2/exec" "github.com/samber/lo" + "github.com/sashabaranov/go-openai" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/common/ann" @@ -35,8 +41,9 @@ import ( ) type ItemToItemOptions struct { - TagsIDF []float32 - UsersIDF []float32 + TagsIDF []float32 + UsersIDF []float32 + OpenAIConfig config.OpenAIConfig } type ItemToItem interface { @@ -64,6 +71,11 @@ func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, opts return nil, errors.New("tags and users IDF are required for auto item-to-item") } return newAutoItemToItem(cfg, n, timestamp, opts.TagsIDF, opts.UsersIDF) + case "llm": + if opts == nil || opts.OpenAIConfig.BaseURL == "" || opts.OpenAIConfig.AuthToken == "" { + return nil, errors.New("OpenAI config is required for LLM item-to-item") + } + return newGenerativeItemToItem(cfg, n, timestamp, opts.OpenAIConfig) default: return nil, errors.New("invalid item-to-item type") } @@ -338,3 +350,80 @@ func flatten(o any, tSet mapset.Set[dataset.ID]) { } } } + +type llmItemToItem struct { + *embeddingItemToItem + template *exec.Template + client *openai.Client + chatModel string + embeddingModel string +} + +func newGenerativeItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, openaiConfig config.OpenAIConfig) (*llmItemToItem, error) { + // create embedding item-to-item recommender + embedding, err := newEmbeddingItemToItem(cfg, n, timestamp) + if err != nil { + return nil, err + } + // parse template + template, err := gonja.FromString(cfg.Prompt) + if err != nil { + return nil, err + } + // create openai client + clientConfig := openai.DefaultConfig(openaiConfig.AuthToken) + clientConfig.BaseURL = openaiConfig.BaseURL + return &llmItemToItem{ + embeddingItemToItem: embedding, + template: template, + client: openai.NewClientWithConfig(clientConfig), + chatModel: openaiConfig.ChatCompletionModel, + embeddingModel: openaiConfig.EmbeddingsModel, + }, nil +} + +func (g *llmItemToItem) PopAll(i int) []cache.Score { + // render template + var buf strings.Builder + ctx := exec.NewContext(map[string]any{ + "item": g.items[i], + }) + if err := g.template.Execute(&buf, ctx); err != nil { + log.Logger().Error("failed to execute template", zap.Error(err)) + return nil + } + fmt.Println(buf.String()) + // chat completion + resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: g.chatModel, + Messages: []openai.ChatCompletionMessage{{ + Role: openai.ChatMessageRoleUser, + Content: buf.String(), + }}, + }) + if err != nil { + log.Logger().Error("failed to chat completion", zap.Error(err)) + return nil + } + message := resp.Choices[0].Message.Content + // message embedding + resp2, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Input: message, + Model: openai.EmbeddingModel(g.embeddingModel), + }) + if err != nil { + log.Logger().Error("failed to create embeddings", zap.Error(err)) + return nil + } + embedding := resp2.Data[0].Embedding + // search index + scores := g.index.SearchVector(embedding, g.n+1, true) + return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { + return cache.Score{ + Id: g.items[v.A].ItemId, + Categories: g.items[v.A].Categories, + Score: -float64(v.B), + Timestamp: g.timestamp, + } + }) +} diff --git a/logics/item_to_item_test.go b/logics/item_to_item_test.go index 1778f779c..6c490d971 100644 --- a/logics/item_to_item_test.go +++ b/logics/item_to_item_test.go @@ -20,6 +20,8 @@ import ( "time" "github.com/stretchr/testify/suite" + "github.com/zhenghaoz/gorse/base/floats" + "github.com/zhenghaoz/gorse/common/mock" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/dataset" "github.com/zhenghaoz/gorse/storage/data" @@ -194,6 +196,43 @@ func (suite *ItemToItemTestSuite) TestAuto() { } } +func (suite *ItemToItemTestSuite) TestLLM() { + mockAI := mock.NewOpenAIServer() + go mockAI.Start() + mockAI.Ready() + defer mockAI.Close() + + timestamp := time.Now() + item2item, err := newGenerativeItemToItem(config.ItemToItemConfig{ + Column: "item.Labels.embeddings", + Prompt: "Please generate similar items for {{ item.Labels.title }}.", + }, 10, timestamp, config.OpenAIConfig{ + BaseURL: mockAI.BaseURL(), + AuthToken: mockAI.AuthToken(), + ChatCompletionModel: "gpt-3.5-turbo", + EmbeddingsModel: "text-similarity-ada-001", + }) + suite.NoError(err) + + for i := 0; i < 100; i++ { + embedding := mock.Hash("Please generate similar items for item_0.") + floats.AddConst(embedding, float32(i)) + item2item.Push(&data.Item{ + ItemId: strconv.Itoa(i), + Labels: map[string]any{ + "title": "item_" + strconv.Itoa(i), + "embeddings": embedding, + }, + }, nil) + } + + scores := item2item.PopAll(0) + suite.Len(scores, 10) + for i := 1; i <= 10; i++ { + suite.Equal(strconv.Itoa(i), scores[i-1].Id) + } +} + func TestItemToItem(t *testing.T) { suite.Run(t, new(ItemToItemTestSuite)) } diff --git a/master/tasks.go b/master/tasks.go index 9d4086ef9..b73ce668f 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -1012,8 +1012,9 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error { itemToItemRecommenders := make([]logics.ItemToItem, 0, len(itemToItemConfigs)) for _, cfg := range itemToItemConfigs { recommender, err := logics.NewItemToItem(cfg, m.Config.Recommend.CacheSize, dataset.GetTimestamp(), &logics.ItemToItemOptions{ - TagsIDF: dataset.GetItemColumnValuesIDF(), - UsersIDF: dataset.GetUserIDF(), + TagsIDF: dataset.GetItemColumnValuesIDF(), + UsersIDF: dataset.GetUserIDF(), + OpenAIConfig: m.Config.OpenAI, }) if err != nil { return errors.Trace(err)