From 2271038b270297585a7a6b290dd277aa03c6b3d3 Mon Sep 17 00:00:00 2001 From: Zhenghao Zhang Date: Sun, 16 Feb 2025 22:07:28 +0800 Subject: [PATCH 1/5] implement LLM-based personalized recommenders --- config/config.go | 13 +++++ logics/chat.go | 114 ++++++++++++++++++++++++++++++++++++++ logics/chat_test.go | 52 +++++++++++++++++ storage/cache/database.go | 4 ++ 4 files changed, 183 insertions(+) create mode 100644 logics/chat.go create mode 100644 logics/chat_test.go diff --git a/config/config.go b/config/config.go index 2f524f2cd..28bb2ce01 100644 --- a/config/config.go +++ b/config/config.go @@ -163,6 +163,19 @@ func (config *ItemToItemConfig) Hash() string { return string(hash.Sum(nil)) } +type ChatConfig struct { + Name string `mapstructure:"name" json:"name"` + Column string `mapstructure:"column" json:"column" validate:"item_expr"` + Prompt string `mapstructure:"prompt" json:"prompt"` +} + +func (config *ChatConfig) Hash() string { + hash := md5.New() + hash.Write([]byte(config.Name)) + hash.Write([]byte(config.Column)) + return string(hash.Sum(nil)) +} + type CollaborativeConfig struct { ModelFitPeriod time.Duration `mapstructure:"model_fit_period" validate:"gt=0"` ModelSearchPeriod time.Duration `mapstructure:"model_search_period" validate:"gt=0"` diff --git a/logics/chat.go b/logics/chat.go new file mode 100644 index 000000000..f4c4f5fb0 --- /dev/null +++ b/logics/chat.go @@ -0,0 +1,114 @@ +// Copyright 2025 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logics + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/expr-lang/expr" + "github.com/nikolalohinski/gonja/v2" + "github.com/nikolalohinski/gonja/v2/exec" + "github.com/samber/lo" + "github.com/sashabaranov/go-openai" + "github.com/zhenghaoz/gorse/base/floats" + "github.com/zhenghaoz/gorse/base/log" + "github.com/zhenghaoz/gorse/common/ann" + "github.com/zhenghaoz/gorse/config" + "github.com/zhenghaoz/gorse/storage/cache" + "github.com/zhenghaoz/gorse/storage/data" + "go.uber.org/zap" +) + +type Chat chatItemToItem + +func NewChat(cfg config.ChatConfig, n int, timestamp time.Time, openaiConfig config.OpenAIConfig) (*Chat, error) { + // Compile column expression + columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{ + "item": data.Item{}, + })) + if err != nil { + return nil, err + } + // parse template + template, err := gonja.FromString(cfg.Prompt) + if err != nil { + return nil, err + } + // create openai client + clientConfig := openai.DefaultConfig(openaiConfig.AuthToken) + clientConfig.BaseURL = openaiConfig.BaseURL + return &Chat{ + embeddingItemToItem: &embeddingItemToItem{baseItemToItem: baseItemToItem[[]float32]{ + name: cfg.Name, + n: n, + timestamp: timestamp, + columnFunc: columnFunc, + index: ann.NewHNSW(floats.Euclidean), + }}, + template: template, + client: openai.NewClientWithConfig(clientConfig), + chatModel: openaiConfig.ChatCompletionModel, + embeddingModel: openaiConfig.EmbeddingsModel, + }, nil +} + +func (c *Chat) PopAll(i int) []cache.Score { + // render template + var buf strings.Builder + ctx := exec.NewContext(map[string]any{ + "item": c.items[i], + }) + if err := c.template.Execute(&buf, ctx); err != nil { + log.Logger().Error("failed to execute template", zap.Error(err)) + return nil + } + fmt.Println(buf.String()) + // chat completion + resp, err := c.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: c.chatModel, + Messages: []openai.ChatCompletionMessage{{ + Role: openai.ChatMessageRoleUser, + Content: buf.String(), + }}, + }) + if err != nil { + log.Logger().Error("failed to chat completion", zap.Error(err)) + return nil + } + message := stripThink(resp.Choices[0].Message.Content) + // message embedding + resp2, err := c.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Input: message, + Model: openai.EmbeddingModel(c.embeddingModel), + }) + if err != nil { + log.Logger().Error("failed to create embeddings", zap.Error(err)) + return nil + } + embedding := resp2.Data[0].Embedding + // search index + scores := c.index.SearchVector(embedding, c.n+1, true) + return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { + return cache.Score{ + Id: c.items[v.A].ItemId, + Categories: c.items[v.A].Categories, + Score: -float64(v.B), + Timestamp: c.timestamp, + } + }) +} diff --git a/logics/chat_test.go b/logics/chat_test.go new file mode 100644 index 000000000..14a1a94fc --- /dev/null +++ b/logics/chat_test.go @@ -0,0 +1,52 @@ +// Copyright 2025 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logics + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/zhenghaoz/gorse/common/mock" + "github.com/zhenghaoz/gorse/config" + "github.com/zhenghaoz/gorse/storage/data" +) + +func TestChat(t *testing.T) { + mockAI := mock.NewOpenAIServer() + go func() { + _ = mockAI.Start() + }() + mockAI.Ready() + defer mockAI.Close() + + chat, err := NewChat(config.ChatConfig{ + Column: "item.Labels.description", + }, 10, time.Now(), config.OpenAIConfig{ + BaseURL: mockAI.BaseURL(), + AuthToken: mockAI.AuthToken(), + ChatCompletionModel: "deepseek-r1", + EmbeddingsModel: "text-similarity-ada-001", + }) + assert.NoError(t, err) + + chat.Push(&data.Item{ + ItemId: "1", + Labels: map[string]any{ + "description": []float32{0.1, 0.2, 0.3}, + }, + }, nil) + assert.Len(t, chat.Items(), 1) +} diff --git a/storage/cache/database.go b/storage/cache/database.go index 00dbece81..b10df44d3 100644 --- a/storage/cache/database.go +++ b/storage/cache/database.go @@ -72,6 +72,10 @@ const ( UserToUserUpdateTime = "user-to-user_update_time" Neighbors = "neighbors" + Chat = "chat" + ChatDigest = "chat_digest" + ChatUpdateTime = "chat_update_time" + // ItemCategories is the set of item categories. The format of key: // Global item categories - item_categories ItemCategories = "item_categories" From e088c82fdb35bb1e5fa59100855333bc32a75bc4 Mon Sep 17 00:00:00 2001 From: Zhenghao Zhang Date: Tue, 18 Feb 2025 08:16:23 +0800 Subject: [PATCH 2/5] chat: support JSON output --- logics/chat.go | 78 +++++++++++++++++++++++++++++++++++------- logics/chat_test.go | 59 +++++++++++++++++++++++++++----- logics/item_to_item.go | 59 +++++++++++++++++--------------- 3 files changed, 147 insertions(+), 49 deletions(-) diff --git a/logics/chat.go b/logics/chat.go index f4c4f5fb0..a18bb2827 100644 --- a/logics/chat.go +++ b/logics/chat.go @@ -16,7 +16,7 @@ package logics import ( "context" - "fmt" + "encoding/json" "strings" "time" @@ -25,6 +25,9 @@ import ( "github.com/nikolalohinski/gonja/v2/exec" "github.com/samber/lo" "github.com/sashabaranov/go-openai" + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/text" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/common/ann" @@ -67,20 +70,21 @@ func NewChat(cfg config.ChatConfig, n int, timestamp time.Time, openaiConfig con }, nil } -func (c *Chat) PopAll(i int) []cache.Score { +func (g *Chat) PopAll(indices []int) []cache.Score { // render template var buf strings.Builder ctx := exec.NewContext(map[string]any{ - "item": c.items[i], + "items": lo.Map(indices, func(i int, _ int) any { + return g.items[i] + }), }) - if err := c.template.Execute(&buf, ctx); err != nil { + if err := g.template.Execute(&buf, ctx); err != nil { log.Logger().Error("failed to execute template", zap.Error(err)) return nil } - fmt.Println(buf.String()) // chat completion - resp, err := c.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ - Model: c.chatModel, + resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: g.chatModel, Messages: []openai.ChatCompletionMessage{{ Role: openai.ChatMessageRoleUser, Content: buf.String(), @@ -92,9 +96,9 @@ func (c *Chat) PopAll(i int) []cache.Score { } message := stripThink(resp.Choices[0].Message.Content) // message embedding - resp2, err := c.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + resp2, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ Input: message, - Model: openai.EmbeddingModel(c.embeddingModel), + Model: openai.EmbeddingModel(g.embeddingModel), }) if err != nil { log.Logger().Error("failed to create embeddings", zap.Error(err)) @@ -102,13 +106,61 @@ func (c *Chat) PopAll(i int) []cache.Score { } embedding := resp2.Data[0].Embedding // search index - scores := c.index.SearchVector(embedding, c.n+1, true) + scores := g.index.SearchVector(embedding, g.n, true) return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { return cache.Score{ - Id: c.items[v.A].ItemId, - Categories: c.items[v.A].Categories, + Id: g.items[v.A].ItemId, + Categories: g.items[v.A].Categories, Score: -float64(v.B), - Timestamp: c.timestamp, + Timestamp: g.timestamp, } }) } + +// stripThink strips the tag from the message. +func stripThink(s string) string { + if len(s) < 7 || s[:7] != "" { + return s + } + end := strings.Index(s, "") + if end == -1 { + return s + } + return s[end+8:] +} + +// parseMessage parse message from chat completion response. +// If there is any JSON in the message, it returns the JSON. +// Otherwise, it returns the message. +func parseMessage(message string) []string { + source := []byte(stripThink(message)) + root := goldmark.DefaultParser().Parse(text.NewReader(source)) + for n := root.FirstChild(); n != nil; n = n.NextSibling() { + if n.Kind() != ast.KindFencedCodeBlock { + continue + } + if codeBlock, ok := n.(*ast.FencedCodeBlock); ok { + if string(codeBlock.Language(source)) == "json" { + bytes := codeBlock.Text(source) + if bytes[0] == '[' { + var temp []any + err := json.Unmarshal(bytes, &temp) + if err != nil { + return []string{string(bytes)} + } + var result []string + for _, v := range temp { + bytes, err := json.Marshal(v) + if err != nil { + return []string{string(bytes)} + } + result = append(result, string(bytes)) + } + return result + } + return []string{string(bytes)} + } + } + } + return []string{string(source)} +} diff --git a/logics/chat_test.go b/logics/chat_test.go index 14a1a94fc..302857c90 100644 --- a/logics/chat_test.go +++ b/logics/chat_test.go @@ -15,10 +15,12 @@ package logics import ( + "strconv" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/common/mock" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/storage/data" @@ -33,7 +35,16 @@ func TestChat(t *testing.T) { defer mockAI.Close() chat, err := NewChat(config.ChatConfig{ - Column: "item.Labels.description", + Column: "item.Labels.embeddings", + Prompt: `You are a ecommerce recommender system. I have purchased: +{%- for item in items %} + {%- if loop.index0 == 0 %} + {{- ' ' + item.ItemId -}} + {%- else %} + {{- ', ' + item.ItemId -}} + {%- endif %} +{%- endfor %}. Please recommend me more items. +`, }, 10, time.Now(), config.OpenAIConfig{ BaseURL: mockAI.BaseURL(), AuthToken: mockAI.AuthToken(), @@ -42,11 +53,43 @@ func TestChat(t *testing.T) { }) assert.NoError(t, err) - chat.Push(&data.Item{ - ItemId: "1", - Labels: map[string]any{ - "description": []float32{0.1, 0.2, 0.3}, - }, - }, nil) - assert.Len(t, chat.Items(), 1) + for i := 0; i < 100; i++ { + embedding := mock.Hash("You are a ecommerce recommender system. I have purchased: 3, 1, 5. Please recommend me more items.") + floats.AddConst(embedding, float32(i)) + chat.Push(&data.Item{ + ItemId: strconv.Itoa(i), + Labels: map[string]any{ + "embeddings": embedding, + }, + }, nil) + } + assert.Len(t, chat.Items(), 100) + + scores := chat.PopAll([]int{3, 1, 5}) + assert.Len(t, scores, 10) + for i := 0; i < 10; i++ { + assert.Equal(t, strconv.Itoa(i), scores[i].Id) + } +} + +func TestParseMessage(t *testing.T) { + // parse JSON object + message := "```json\n{\"a\": 1, \"b\": 2}\n```" + contents := parseMessage(message) + assert.Equal(t, []string{"{\"a\": 1, \"b\": 2}\n"}, contents) + + // parse JSON array + message = "```json\n[1, 2]\n```" + contents = parseMessage(message) + assert.Equal(t, []string{"1", "2"}, contents) + + // parse text + message = "Hello, world!" + contents = parseMessage(message) + assert.Equal(t, []string{"Hello, world!"}, contents) + + // strip think + message = "helloWorld!" + content := stripThink(message) + assert.Equal(t, "World!", content) } diff --git a/logics/item_to_item.go b/logics/item_to_item.go index 857952e95..61ea85e9f 100644 --- a/logics/item_to_item.go +++ b/logics/item_to_item.go @@ -17,7 +17,6 @@ package logics import ( "context" "errors" - "fmt" "sort" "strings" "time" @@ -31,6 +30,7 @@ import ( "github.com/samber/lo" "github.com/sashabaranov/go-openai" "github.com/zhenghaoz/gorse/base/floats" + "github.com/zhenghaoz/gorse/base/heap" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/common/ann" "github.com/zhenghaoz/gorse/config" @@ -392,7 +392,6 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { log.Logger().Error("failed to execute template", zap.Error(err)) return nil } - fmt.Println(buf.String()) // chat completion resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ Model: g.chatModel, @@ -405,36 +404,40 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { log.Logger().Error("failed to chat completion", zap.Error(err)) return nil } - message := stripThink(resp.Choices[0].Message.Content) + messages := parseMessage(resp.Choices[0].Message.Content) // message embedding - resp2, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ - Input: message, - Model: openai.EmbeddingModel(g.embeddingModel), - }) - if err != nil { - log.Logger().Error("failed to create embeddings", zap.Error(err)) - return nil + embeddings := make([][]float32, len(messages)) + for i, message := range messages { + resp2, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Input: message, + Model: openai.EmbeddingModel(g.embeddingModel), + }) + if err != nil { + log.Logger().Error("failed to create embeddings", zap.Error(err)) + return nil + } + embeddings[i] = resp2.Data[0].Embedding } - embedding := resp2.Data[0].Embedding // search index - scores := g.index.SearchVector(embedding, g.n+1, true) - return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { - return cache.Score{ - Id: g.items[v.A].ItemId, - Categories: g.items[v.A].Categories, - Score: -float64(v.B), - Timestamp: g.timestamp, + pq := heap.NewPriorityQueue(false) + for _, embedding := range embeddings { + scores := g.index.SearchVector(embedding, g.n+1, true) + for _, score := range scores { + pq.Push(int32(score.A), score.B) + if pq.Len() > g.n { + pq.Pop() + } } - }) -} - -func stripThink(s string) string { - if len(s) < 7 || s[:7] != "" { - return s } - end := strings.Index(s, "") - if end == -1 { - return s + scores := make([]cache.Score, pq.Len()) + for i := range scores { + id, score := pq.Pop() + scores[i] = cache.Score{ + Id: g.items[id].ItemId, + Categories: g.items[id].Categories, + Score: -float64(score), + Timestamp: g.timestamp, + } } - return s[end+8:] + return scores } From 04301a328c2163fd369f2f7d8b6731f443488f6b Mon Sep 17 00:00:00 2001 From: Zhenghao Zhang Date: Thu, 20 Feb 2025 00:40:19 +0800 Subject: [PATCH 3/5] Add embedding dimensions to config --- config/config.go | 3 ++- config/config_test.go | 3 ++- logics/item_to_item.go | 7 ++++++- logics/item_to_item_test.go | 2 +- master/tasks.go | 9 +++++++++ 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/config/config.go b/config/config.go index 28bb2ce01..d7e0f4637 100644 --- a/config/config.go +++ b/config/config.go @@ -233,7 +233,8 @@ type OpenAIConfig struct { BaseURL string `mapstructure:"base_url"` AuthToken string `mapstructure:"auth_token"` ChatCompletionModel string `mapstructure:"chat_completion_model"` - EmbeddingsModel string `mapstructure:"embeddings_model"` + EmbeddingModel string `mapstructure:"embedding_model"` + EmbeddingDimensions int `mapstructure:"embedding_dimensions"` } func GetDefaultConfig() *Config { diff --git a/config/config_test.go b/config/config_test.go index ea1eb9599..4ee92399c 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -161,7 +161,8 @@ func TestUnmarshal(t *testing.T) { assert.Equal(t, "http://localhost:11434/v1", config.OpenAI.BaseURL) assert.Equal(t, "ollama", config.OpenAI.AuthToken) assert.Equal(t, "qwen2.5", config.OpenAI.ChatCompletionModel) - assert.Equal(t, "mxbai-embed-large", config.OpenAI.EmbeddingsModel) + assert.Equal(t, "mxbai-embed-large", config.OpenAI.EmbeddingModel) + assert.Equal(t, 1024, config.OpenAI.EmbeddingDimensions) }) } } diff --git a/logics/item_to_item.go b/logics/item_to_item.go index 61ea85e9f..d016e9a2e 100644 --- a/logics/item_to_item.go +++ b/logics/item_to_item.go @@ -47,6 +47,7 @@ type ItemToItemOptions struct { } type ItemToItem interface { + Timestamp() time.Time Items() []*data.Item Push(item *data.Item, feedback []dataset.ID) PopAll(i int) []cache.Score @@ -90,6 +91,10 @@ type baseItemToItem[T any] struct { items []*data.Item } +func (b *baseItemToItem[T]) Timestamp() time.Time { + return b.timestamp +} + func (b *baseItemToItem[T]) Items() []*data.Item { return b.items } @@ -419,7 +424,7 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { embeddings[i] = resp2.Data[0].Embedding } // search index - pq := heap.NewPriorityQueue(false) + pq := heap.NewPriorityQueue(true) for _, embedding := range embeddings { scores := g.index.SearchVector(embedding, g.n+1, true) for _, score := range scores { diff --git a/logics/item_to_item_test.go b/logics/item_to_item_test.go index 8a8ba351a..e1f41e56b 100644 --- a/logics/item_to_item_test.go +++ b/logics/item_to_item_test.go @@ -212,7 +212,7 @@ func (suite *ItemToItemTestSuite) TestChat() { BaseURL: mockAI.BaseURL(), AuthToken: mockAI.AuthToken(), ChatCompletionModel: "deepseek-r1", - EmbeddingsModel: "text-similarity-ada-001", + EmbeddingModel: "text-similarity-ada-001", }) suite.NoError(err) diff --git a/master/tasks.go b/master/tasks.go index b73ce668f..5133bd6e4 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -1060,6 +1060,15 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error { zap.String("item_id", item.ItemId), zap.Error(err)) continue } + // Remove stale item-to-item recommendation + if err := m.CacheClient.DeleteScores(ctx, []string{cache.ItemToItem}, cache.ScoreCondition{ + Subset: lo.ToPtr(cache.Key(itemToItemConfig.Name, item.ItemId)), + Before: lo.ToPtr(recommender.Timestamp()), + }); err != nil { + log.Logger().Error("failed to remove stale item-to-item recommendation", + zap.String("item_id", item.ItemId), zap.Error(err)) + continue + } } span.Add(1) } From 0cfdc691ebd4b5f9d76aad283847865c7e29b355 Mon Sep 17 00:00:00 2001 From: Zhenghao Zhang Date: Mon, 24 Feb 2025 20:46:11 +0800 Subject: [PATCH 4/5] remove chatItemToItem --- config/config.go | 13 --- config/config.toml | 7 +- go.mod | 1 + go.sum | 2 + logics/chat.go | 166 ------------------------------------ logics/chat_test.go | 95 --------------------- logics/item_to_item.go | 102 ++++++++++++++++++---- logics/item_to_item_test.go | 25 +++++- 8 files changed, 119 insertions(+), 292 deletions(-) delete mode 100644 logics/chat.go delete mode 100644 logics/chat_test.go diff --git a/config/config.go b/config/config.go index d7e0f4637..f25e15f9f 100644 --- a/config/config.go +++ b/config/config.go @@ -163,19 +163,6 @@ func (config *ItemToItemConfig) Hash() string { return string(hash.Sum(nil)) } -type ChatConfig struct { - Name string `mapstructure:"name" json:"name"` - Column string `mapstructure:"column" json:"column" validate:"item_expr"` - Prompt string `mapstructure:"prompt" json:"prompt"` -} - -func (config *ChatConfig) Hash() string { - hash := md5.New() - hash.Write([]byte(config.Name)) - hash.Write([]byte(config.Column)) - return string(hash.Sum(nil)) -} - type CollaborativeConfig struct { ModelFitPeriod time.Duration `mapstructure:"model_fit_period" validate:"gt=0"` ModelSearchPeriod time.Duration `mapstructure:"model_search_period" validate:"gt=0"` diff --git a/config/config.toml b/config/config.toml index 7bbb1a1b3..b617a5aca 100644 --- a/config/config.toml +++ b/config/config.toml @@ -321,5 +321,8 @@ auth_token = "ollama" # Name of chat completion model. chat_completion_model = "qwen2.5" -# Name of embeddings model. -embeddings_model = "mxbai-embed-large" \ No newline at end of file +# Name of embedding model. +embedding_model = "mxbai-embed-large" + +# Dimensions of embedding vectors. +embedding_dimensions = 1024 diff --git a/go.mod b/go.mod index 3cb59abb0..8939305a3 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,7 @@ require ( github.com/steinfletcher/apitest v1.5.17 github.com/stretchr/testify v1.10.0 github.com/thoas/go-funk v0.9.2 + github.com/yuin/goldmark v1.7.8 go.mongodb.org/mongo-driver v1.16.1 go.opentelemetry.io/contrib/instrumentation/github.com/emicklei/go-restful/otelrestful v0.36.4 go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo v0.55.0 diff --git a/go.sum b/go.sum index 594cd32ab..0930a05c3 100644 --- a/go.sum +++ b/go.sum @@ -654,6 +654,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= +github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.mongodb.org/mongo-driver v1.16.1 h1:rIVLL3q0IHM39dvE+z2ulZLp9ENZKThVfuvN/IiN4l8= go.mongodb.org/mongo-driver v1.16.1/go.mod h1:oB6AhJQvFQL4LEHyXi6aJzQJtBiTQHiAd83l0GdFaiw= diff --git a/logics/chat.go b/logics/chat.go deleted file mode 100644 index a18bb2827..000000000 --- a/logics/chat.go +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright 2025 gorse Project Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logics - -import ( - "context" - "encoding/json" - "strings" - "time" - - "github.com/expr-lang/expr" - "github.com/nikolalohinski/gonja/v2" - "github.com/nikolalohinski/gonja/v2/exec" - "github.com/samber/lo" - "github.com/sashabaranov/go-openai" - "github.com/yuin/goldmark" - "github.com/yuin/goldmark/ast" - "github.com/yuin/goldmark/text" - "github.com/zhenghaoz/gorse/base/floats" - "github.com/zhenghaoz/gorse/base/log" - "github.com/zhenghaoz/gorse/common/ann" - "github.com/zhenghaoz/gorse/config" - "github.com/zhenghaoz/gorse/storage/cache" - "github.com/zhenghaoz/gorse/storage/data" - "go.uber.org/zap" -) - -type Chat chatItemToItem - -func NewChat(cfg config.ChatConfig, n int, timestamp time.Time, openaiConfig config.OpenAIConfig) (*Chat, error) { - // Compile column expression - columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{ - "item": data.Item{}, - })) - if err != nil { - return nil, err - } - // parse template - template, err := gonja.FromString(cfg.Prompt) - if err != nil { - return nil, err - } - // create openai client - clientConfig := openai.DefaultConfig(openaiConfig.AuthToken) - clientConfig.BaseURL = openaiConfig.BaseURL - return &Chat{ - embeddingItemToItem: &embeddingItemToItem{baseItemToItem: baseItemToItem[[]float32]{ - name: cfg.Name, - n: n, - timestamp: timestamp, - columnFunc: columnFunc, - index: ann.NewHNSW(floats.Euclidean), - }}, - template: template, - client: openai.NewClientWithConfig(clientConfig), - chatModel: openaiConfig.ChatCompletionModel, - embeddingModel: openaiConfig.EmbeddingsModel, - }, nil -} - -func (g *Chat) PopAll(indices []int) []cache.Score { - // render template - var buf strings.Builder - ctx := exec.NewContext(map[string]any{ - "items": lo.Map(indices, func(i int, _ int) any { - return g.items[i] - }), - }) - if err := g.template.Execute(&buf, ctx); err != nil { - log.Logger().Error("failed to execute template", zap.Error(err)) - return nil - } - // chat completion - resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ - Model: g.chatModel, - Messages: []openai.ChatCompletionMessage{{ - Role: openai.ChatMessageRoleUser, - Content: buf.String(), - }}, - }) - if err != nil { - log.Logger().Error("failed to chat completion", zap.Error(err)) - return nil - } - message := stripThink(resp.Choices[0].Message.Content) - // message embedding - resp2, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ - Input: message, - Model: openai.EmbeddingModel(g.embeddingModel), - }) - if err != nil { - log.Logger().Error("failed to create embeddings", zap.Error(err)) - return nil - } - embedding := resp2.Data[0].Embedding - // search index - scores := g.index.SearchVector(embedding, g.n, true) - return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { - return cache.Score{ - Id: g.items[v.A].ItemId, - Categories: g.items[v.A].Categories, - Score: -float64(v.B), - Timestamp: g.timestamp, - } - }) -} - -// stripThink strips the tag from the message. -func stripThink(s string) string { - if len(s) < 7 || s[:7] != "" { - return s - } - end := strings.Index(s, "") - if end == -1 { - return s - } - return s[end+8:] -} - -// parseMessage parse message from chat completion response. -// If there is any JSON in the message, it returns the JSON. -// Otherwise, it returns the message. -func parseMessage(message string) []string { - source := []byte(stripThink(message)) - root := goldmark.DefaultParser().Parse(text.NewReader(source)) - for n := root.FirstChild(); n != nil; n = n.NextSibling() { - if n.Kind() != ast.KindFencedCodeBlock { - continue - } - if codeBlock, ok := n.(*ast.FencedCodeBlock); ok { - if string(codeBlock.Language(source)) == "json" { - bytes := codeBlock.Text(source) - if bytes[0] == '[' { - var temp []any - err := json.Unmarshal(bytes, &temp) - if err != nil { - return []string{string(bytes)} - } - var result []string - for _, v := range temp { - bytes, err := json.Marshal(v) - if err != nil { - return []string{string(bytes)} - } - result = append(result, string(bytes)) - } - return result - } - return []string{string(bytes)} - } - } - } - return []string{string(source)} -} diff --git a/logics/chat_test.go b/logics/chat_test.go deleted file mode 100644 index 302857c90..000000000 --- a/logics/chat_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2025 gorse Project Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logics - -import ( - "strconv" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/zhenghaoz/gorse/base/floats" - "github.com/zhenghaoz/gorse/common/mock" - "github.com/zhenghaoz/gorse/config" - "github.com/zhenghaoz/gorse/storage/data" -) - -func TestChat(t *testing.T) { - mockAI := mock.NewOpenAIServer() - go func() { - _ = mockAI.Start() - }() - mockAI.Ready() - defer mockAI.Close() - - chat, err := NewChat(config.ChatConfig{ - Column: "item.Labels.embeddings", - Prompt: `You are a ecommerce recommender system. I have purchased: -{%- for item in items %} - {%- if loop.index0 == 0 %} - {{- ' ' + item.ItemId -}} - {%- else %} - {{- ', ' + item.ItemId -}} - {%- endif %} -{%- endfor %}. Please recommend me more items. -`, - }, 10, time.Now(), config.OpenAIConfig{ - BaseURL: mockAI.BaseURL(), - AuthToken: mockAI.AuthToken(), - ChatCompletionModel: "deepseek-r1", - EmbeddingsModel: "text-similarity-ada-001", - }) - assert.NoError(t, err) - - for i := 0; i < 100; i++ { - embedding := mock.Hash("You are a ecommerce recommender system. I have purchased: 3, 1, 5. Please recommend me more items.") - floats.AddConst(embedding, float32(i)) - chat.Push(&data.Item{ - ItemId: strconv.Itoa(i), - Labels: map[string]any{ - "embeddings": embedding, - }, - }, nil) - } - assert.Len(t, chat.Items(), 100) - - scores := chat.PopAll([]int{3, 1, 5}) - assert.Len(t, scores, 10) - for i := 0; i < 10; i++ { - assert.Equal(t, strconv.Itoa(i), scores[i].Id) - } -} - -func TestParseMessage(t *testing.T) { - // parse JSON object - message := "```json\n{\"a\": 1, \"b\": 2}\n```" - contents := parseMessage(message) - assert.Equal(t, []string{"{\"a\": 1, \"b\": 2}\n"}, contents) - - // parse JSON array - message = "```json\n[1, 2]\n```" - contents = parseMessage(message) - assert.Equal(t, []string{"1", "2"}, contents) - - // parse text - message = "Hello, world!" - contents = parseMessage(message) - assert.Equal(t, []string{"Hello, world!"}, contents) - - // strip think - message = "helloWorld!" - content := stripThink(message) - assert.Equal(t, "World!", content) -} diff --git a/logics/item_to_item.go b/logics/item_to_item.go index d016e9a2e..e69d95e3e 100644 --- a/logics/item_to_item.go +++ b/logics/item_to_item.go @@ -16,6 +16,7 @@ package logics import ( "context" + "encoding/json" "errors" "sort" "strings" @@ -29,6 +30,9 @@ import ( "github.com/nikolalohinski/gonja/v2/exec" "github.com/samber/lo" "github.com/sashabaranov/go-openai" + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/text" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/heap" "github.com/zhenghaoz/gorse/base/log" @@ -358,10 +362,11 @@ func flatten(o any, tSet mapset.Set[dataset.ID]) { type chatItemToItem struct { *embeddingItemToItem - template *exec.Template - client *openai.Client - chatModel string - embeddingModel string + template *exec.Template + client *openai.Client + chatCompletionModel string + embeddingModel string + embeddingDimensions int } func newChatItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, openaiConfig config.OpenAIConfig) (*chatItemToItem, error) { @@ -382,12 +387,27 @@ func newChatItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, embeddingItemToItem: embedding, template: template, client: openai.NewClientWithConfig(clientConfig), - chatModel: openaiConfig.ChatCompletionModel, - embeddingModel: openaiConfig.EmbeddingsModel, + chatCompletionModel: openaiConfig.ChatCompletionModel, + embeddingModel: openaiConfig.EmbeddingModel, + embeddingDimensions: openaiConfig.EmbeddingDimensions, }, nil } func (g *chatItemToItem) PopAll(i int) []cache.Score { + // evaluate column expression and get embedding vector + result, err := expr.Run(g.columnFunc, map[string]any{ + "item": g.items[i], + }) + if err != nil { + log.Logger().Error("failed to evaluate column expression", + zap.Any("item", g.items[i]), zap.Error(err)) + return nil + } + embedding0, ok := result.([]float32) + if !ok { + log.Logger().Error("invalid column type", zap.Any("column", result)) + return nil + } // render template var buf strings.Builder ctx := exec.NewContext(map[string]any{ @@ -399,7 +419,7 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { } // chat completion resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ - Model: g.chatModel, + Model: g.chatCompletionModel, Messages: []openai.ChatCompletionMessage{{ Role: openai.ChatMessageRoleUser, Content: buf.String(), @@ -413,29 +433,33 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { // message embedding embeddings := make([][]float32, len(messages)) for i, message := range messages { - resp2, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ - Input: message, - Model: openai.EmbeddingModel(g.embeddingModel), + resp, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Input: message, + Model: openai.EmbeddingModel(g.embeddingModel), + Dimensions: g.embeddingDimensions, }) if err != nil { log.Logger().Error("failed to create embeddings", zap.Error(err)) return nil } - embeddings[i] = resp2.Data[0].Embedding + embeddings[i] = resp.Data[0].Embedding } // search index pq := heap.NewPriorityQueue(true) for _, embedding := range embeddings { + score0 := floats.Euclidean(embedding, embedding0) scores := g.index.SearchVector(embedding, g.n+1, true) for _, score := range scores { - pq.Push(int32(score.A), score.B) - if pq.Len() > g.n { - pq.Pop() + if score.A != i { + pq.Push(int32(score.A), score.B*score0) + if pq.Len() > g.n { + pq.Pop() + } } } } scores := make([]cache.Score, pq.Len()) - for i := range scores { + for i := 9; i >= 0; i-- { id, score := pq.Pop() scores[i] = cache.Score{ Id: g.items[id].ItemId, @@ -446,3 +470,51 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { } return scores } + +// stripThink strips the tag from the message. +func stripThink(s string) string { + if len(s) < 7 || s[:7] != "" { + return s + } + end := strings.Index(s, "") + if end == -1 { + return s + } + return s[end+8:] +} + +// parseMessage parse message from chat completion response. +// If there is any JSON in the message, it returns the JSON. +// Otherwise, it returns the message. +func parseMessage(message string) []string { + source := []byte(stripThink(message)) + root := goldmark.DefaultParser().Parse(text.NewReader(source)) + for n := root.FirstChild(); n != nil; n = n.NextSibling() { + if n.Kind() != ast.KindFencedCodeBlock { + continue + } + if codeBlock, ok := n.(*ast.FencedCodeBlock); ok { + if string(codeBlock.Language(source)) == "json" { + bytes := codeBlock.Text(source) + if bytes[0] == '[' { + var temp []any + err := json.Unmarshal(bytes, &temp) + if err != nil { + return []string{string(bytes)} + } + var result []string + for _, v := range temp { + bytes, err := json.Marshal(v) + if err != nil { + return []string{string(bytes)} + } + result = append(result, string(bytes)) + } + return result + } + return []string{string(bytes)} + } + } + } + return []string{string(source)} +} diff --git a/logics/item_to_item_test.go b/logics/item_to_item_test.go index e1f41e56b..321c35dcc 100644 --- a/logics/item_to_item_test.go +++ b/logics/item_to_item_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/common/mock" @@ -218,7 +219,7 @@ func (suite *ItemToItemTestSuite) TestChat() { for i := 0; i < 100; i++ { embedding := mock.Hash("Please generate similar items for item_0.") - floats.AddConst(embedding, float32(i)) + floats.AddConst(embedding, float32(i+1)) item2item.Push(&data.Item{ ItemId: strconv.Itoa(i), Labels: map[string]any{ @@ -238,3 +239,25 @@ func (suite *ItemToItemTestSuite) TestChat() { func TestItemToItem(t *testing.T) { suite.Run(t, new(ItemToItemTestSuite)) } + +func TestParseMessage(t *testing.T) { + // parse JSON object + message := "```json\n{\"a\": 1, \"b\": 2}\n```" + contents := parseMessage(message) + assert.Equal(t, []string{"{\"a\": 1, \"b\": 2}\n"}, contents) + + // parse JSON array + message = "```json\n[1, 2]\n```" + contents = parseMessage(message) + assert.Equal(t, []string{"1", "2"}, contents) + + // parse text + message = "Hello, world!" + contents = parseMessage(message) + assert.Equal(t, []string{"Hello, world!"}, contents) + + // strip think + message = "helloWorld!" + content := stripThink(message) + assert.Equal(t, "World!", content) +} From 482d053b18a122259839ddc165c1abf66c95e3b4 Mon Sep 17 00:00:00 2001 From: Zhenghao Zhang Date: Mon, 24 Feb 2025 21:02:12 +0800 Subject: [PATCH 5/5] Add example --- config/config.toml | 20 ++++++++++++++++++++ logics/item_to_item.go | 2 ++ storage/cache/database.go | 4 ---- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/config/config.toml b/config/config.toml index b617a5aca..7f08132d5 100644 --- a/config/config.toml +++ b/config/config.toml @@ -158,11 +158,31 @@ filter = "(now() - item.Timestamp).Hours() < 168" # # embedding: recommend by Euclidean distance of embeddings. # # tags: recommend by number of common tags. # # users: recommend by number of common users. +# # chat: recommend by chat completion model. # type = "embedding" # # The column of the item embeddings. Leave blank if type is "users". # column = "item.Labels.embedding" +# [[recommend.item-to-item]] + +# # The name of the item-to-item recommender. +# name = "chat_recommend" + +# # The type of the item-to-item recommender. +# type = "chat" + +# # The column of the item embeddings. Leave blank if type is "users". +# column = "item.Labels.embedding" + +# # The prompt for the chat completion model. +# prompt = """ +# This is the description of GitHub repository https://github.com/{{ item.ItemId | replace(':','/') }}: +# {{ item.Comment }} +# Please find some similar repositores on GitHub and provide a brief description for each of them. +# The output should be a JSON array. +# """ + [recommend.user_neighbors] # The type of neighbors for users. There are three types: diff --git a/logics/item_to_item.go b/logics/item_to_item.go index e69d95e3e..01065d82b 100644 --- a/logics/item_to_item.go +++ b/logics/item_to_item.go @@ -430,6 +430,8 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { return nil } messages := parseMessage(resp.Choices[0].Message.Content) + log.Logger().Debug("chat based item-to-item recommendation", + zap.String("prompt", buf.String()), zap.Strings("response", messages)) // message embedding embeddings := make([][]float32, len(messages)) for i, message := range messages { diff --git a/storage/cache/database.go b/storage/cache/database.go index b10df44d3..00dbece81 100644 --- a/storage/cache/database.go +++ b/storage/cache/database.go @@ -72,10 +72,6 @@ const ( UserToUserUpdateTime = "user-to-user_update_time" Neighbors = "neighbors" - Chat = "chat" - ChatDigest = "chat_digest" - ChatUpdateTime = "chat_update_time" - // ItemCategories is the set of item categories. The format of key: // Global item categories - item_categories ItemCategories = "item_categories"