diff --git a/base/parallel/future.go b/base/parallel/future.go deleted file mode 100644 index 1d3986e64..000000000 --- a/base/parallel/future.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2020 gorse Project Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package parallel - -type Future struct { - done chan struct{} -} - -func Async(f func()) *Future { - future := &Future{done: make(chan struct{})} - go func() { - f() - close(future.done) - }() - return future -} - -func (f *Future) Wait() { - <-f.done -} diff --git a/base/parallel/future_test.go b/base/parallel/future_test.go deleted file mode 100644 index 90ab9fb70..000000000 --- a/base/parallel/future_test.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2020 gorse Project Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package parallel - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestAsync(t *testing.T) { - var a int - future := Async(func() { - a = 1 - }) - future.Wait() - assert.Equal(t, 1, a) -} diff --git a/base/parallel/condition_channel.go b/common/parallel/condition_channel.go similarity index 100% rename from base/parallel/condition_channel.go rename to common/parallel/condition_channel.go diff --git a/base/parallel/condition_channel_test.go b/common/parallel/condition_channel_test.go similarity index 100% rename from base/parallel/condition_channel_test.go rename to common/parallel/condition_channel_test.go diff --git a/base/parallel/parallel.go b/common/parallel/parallel.go similarity index 100% rename from base/parallel/parallel.go rename to common/parallel/parallel.go diff --git a/base/parallel/parallel_test.go b/common/parallel/parallel_test.go similarity index 100% rename from base/parallel/parallel_test.go rename to common/parallel/parallel_test.go diff --git a/common/parallel/pool.go b/common/parallel/pool.go new file mode 100644 index 000000000..48a3ba703 --- /dev/null +++ b/common/parallel/pool.go @@ -0,0 +1,47 @@ +package parallel + +import "sync" + +type Pool interface { + Run(runner func()) + Wait() +} + +type SequentialPool struct{} + +func NewSequentialPool() *SequentialPool { + return &SequentialPool{} +} + +func (p *SequentialPool) Run(runner func()) { + runner() +} + +func (p *SequentialPool) Wait() {} + +type ConcurrentPool struct { + wg sync.WaitGroup + pool chan struct{} +} + +func NewConcurrentPool(size int) *ConcurrentPool { + return &ConcurrentPool{ + pool: make(chan struct{}, size), + } +} + +func (p *ConcurrentPool) Run(runner func()) { + p.wg.Add(1) + go func() { + p.pool <- struct{}{} + defer func() { + <-p.pool + p.wg.Done() + }() + runner() + }() +} + +func (p *ConcurrentPool) Wait() { + p.wg.Wait() +} diff --git a/common/parallel/pool_test.go b/common/parallel/pool_test.go new file mode 100644 index 000000000..396be68ce --- /dev/null +++ b/common/parallel/pool_test.go @@ -0,0 +1,32 @@ +package parallel + +import ( + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSequentialPool(t *testing.T) { + pool := NewSequentialPool() + count := 0 + for i := 0; i < 100; i++ { + pool.Run(func() { + count++ + }) + } + pool.Wait() + assert.Equal(t, 100, count) +} + +func TestConcurrentPool(t *testing.T) { + pool := NewConcurrentPool(100) + count := atomic.Int64{} + for i := 0; i < 100; i++ { + pool.Run(func() { + count.Add(1) + }) + } + pool.Wait() + assert.Equal(t, int64(100), count.Load()) +} diff --git a/common/parallel/ratelimit.go b/common/parallel/ratelimit.go new file mode 100644 index 000000000..36cdf80f5 --- /dev/null +++ b/common/parallel/ratelimit.go @@ -0,0 +1,46 @@ +package parallel + +import ( + "time" + + "github.com/juju/ratelimit" +) + +var ( + ChatCompletionBackoff = time.Duration(0) + ChatCompletionRequestsLimiter RateLimiter = &Unlimited{} + ChatCompletionTokensLimiter RateLimiter = &Unlimited{} + EmbeddingBackoff = time.Duration(0) + EmbeddingRequestsLimiter RateLimiter = &Unlimited{} + EmbeddingTokensLimiter RateLimiter = &Unlimited{} +) + +func InitChatCompletionLimiters(rpm, tpm int) { + if rpm > 0 { + ChatCompletionBackoff = time.Minute / time.Duration(rpm) + ChatCompletionRequestsLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(rpm/60), int64(rpm/60)) + } + if tpm > 0 { + ChatCompletionTokensLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(tpm/60), int64(tpm/60)) + } +} + +func InitEmbeddingLimiters(rpm, tpm int) { + if rpm > 0 { + EmbeddingBackoff = time.Minute / time.Duration(rpm) + EmbeddingRequestsLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(rpm/60), int64(rpm/60)) + } + if tpm > 0 { + EmbeddingTokensLimiter = ratelimit.NewBucketWithQuantum(time.Second, int64(tpm/60), int64(tpm/60)) + } +} + +type RateLimiter interface { + Take(count int64) time.Duration +} + +type Unlimited struct{} + +func (n *Unlimited) Take(count int64) time.Duration { + return 0 +} diff --git a/common/parallel/ratelimit_test.go b/common/parallel/ratelimit_test.go new file mode 100644 index 000000000..0158bc10b --- /dev/null +++ b/common/parallel/ratelimit_test.go @@ -0,0 +1,29 @@ +package parallel + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestUnlimited(t *testing.T) { + rateLimiter := &Unlimited{} + assert.Zero(t, rateLimiter.Take(1)) +} + +func TestInitEmbeddingLimiters(t *testing.T) { + InitEmbeddingLimiters(120, 180) + assert.Equal(t, time.Duration(0), EmbeddingRequestsLimiter.Take(1)) + assert.InDelta(t, time.Second, EmbeddingRequestsLimiter.Take(2), float64(time.Millisecond)) + assert.Equal(t, time.Duration(0), EmbeddingTokensLimiter.Take(2)) + assert.InDelta(t, 2*time.Second, EmbeddingTokensLimiter.Take(5), float64(time.Millisecond)) +} + +func TestInitChatCompletionLimiters(t *testing.T) { + InitChatCompletionLimiters(120, 180) + assert.Equal(t, time.Duration(0), ChatCompletionRequestsLimiter.Take(1)) + assert.InDelta(t, time.Second, ChatCompletionRequestsLimiter.Take(2), float64(time.Millisecond)) + assert.Equal(t, time.Duration(0), ChatCompletionTokensLimiter.Take(2)) + assert.InDelta(t, 2*time.Second, ChatCompletionTokensLimiter.Take(5), float64(time.Millisecond)) +} diff --git a/config/config.go b/config/config.go index 53c0e63ff..15434c4cf 100644 --- a/config/config.go +++ b/config/config.go @@ -220,8 +220,12 @@ type OpenAIConfig struct { BaseURL string `mapstructure:"base_url"` AuthToken string `mapstructure:"auth_token"` ChatCompletionModel string `mapstructure:"chat_completion_model"` + ChatCompletionRPM int `mapstructure:"chat_completion_rpm"` + ChatCompletionTPM int `mapstructure:"chat_completion_tpm"` EmbeddingModel string `mapstructure:"embedding_model"` EmbeddingDimensions int `mapstructure:"embedding_dimensions"` + EmbeddingRPM int `mapstructure:"embedding_rpm"` + EmbeddingTPM int `mapstructure:"embedding_tpm"` LogFile string `mapstructure:"log_file"` } diff --git a/config/config.toml b/config/config.toml index 7f08132d5..863422c13 100644 --- a/config/config.toml +++ b/config/config.toml @@ -341,8 +341,23 @@ auth_token = "ollama" # Name of chat completion model. chat_completion_model = "qwen2.5" +# Maximum requests per minute for chat completion. +chat_completion_rpm = 15000 + +# Maximum tokens per minute for chat completion. +chat_completion_tpm = 1200000 + # Name of embedding model. embedding_model = "mxbai-embed-large" # Dimensions of embedding vectors. embedding_dimensions = 1024 + +# Maximum requests per minute for embedding. +embedding_rpm = 1800 + +# Maximum tokens per minute for embedding. +embedding_tpm = 1200000 + +# Log file for OpenAI API. +log_file = "openai.log" diff --git a/config/config_test.go b/config/config_test.go index 4ee92399c..74f25df86 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -161,8 +161,12 @@ func TestUnmarshal(t *testing.T) { assert.Equal(t, "http://localhost:11434/v1", config.OpenAI.BaseURL) assert.Equal(t, "ollama", config.OpenAI.AuthToken) assert.Equal(t, "qwen2.5", config.OpenAI.ChatCompletionModel) + assert.Equal(t, 15000, config.OpenAI.ChatCompletionRPM) + assert.Equal(t, 1200000, config.OpenAI.ChatCompletionTPM) assert.Equal(t, "mxbai-embed-large", config.OpenAI.EmbeddingModel) assert.Equal(t, 1024, config.OpenAI.EmbeddingDimensions) + assert.Equal(t, 1800, config.OpenAI.EmbeddingRPM) + assert.Equal(t, 1200000, config.OpenAI.EmbeddingTPM) }) } } diff --git a/go.mod b/go.mod index 8939305a3..11c5a88f7 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de github.com/benhoyt/goawk v1.20.0 github.com/bits-and-blooms/bitset v1.2.1 + github.com/cenkalti/backoff/v5 v5.0.2 github.com/chewxy/math32 v1.11.1 github.com/coreos/go-oidc/v3 v3.11.0 github.com/deckarep/golang-set/v2 v2.3.1 @@ -29,6 +30,7 @@ require ( github.com/jellydator/ttlcache/v3 v3.3.0 github.com/json-iterator/go v1.1.12 github.com/juju/errors v1.0.0 + github.com/juju/ratelimit v1.0.2 github.com/klauspost/cpuid/v2 v2.2.3 github.com/lafikl/consistent v0.0.0-20220512074542-bdd3606bfc3e github.com/lib/pq v1.10.6 @@ -52,6 +54,7 @@ require ( github.com/steinfletcher/apitest v1.5.17 github.com/stretchr/testify v1.10.0 github.com/thoas/go-funk v0.9.2 + github.com/tiktoken-go/tokenizer v0.5.1 github.com/yuin/goldmark v1.7.8 go.mongodb.org/mongo-driver v1.16.1 go.opentelemetry.io/contrib/instrumentation/github.com/emicklei/go-restful/otelrestful v0.36.4 @@ -96,6 +99,7 @@ require ( github.com/chewxy/hm v1.0.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect diff --git a/go.sum b/go.sum index 0930a05c3..e270c00ea 100644 --- a/go.sum +++ b/go.sum @@ -80,6 +80,8 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8= +github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -126,6 +128,8 @@ github.com/deckarep/golang-set/v2 v2.3.1 h1:vjmkvJt/IV27WXPyYQpAh4bRyWJc5Y435D17 github.com/deckarep/golang-set/v2 v2.3.1/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= +github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= @@ -405,6 +409,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1 github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/juju/errors v1.0.0 h1:yiq7kjCLll1BiaRuNY53MGI0+EQ3rF6GB+wvboZDefM= github.com/juju/errors v1.0.0/go.mod h1:B5x9thDqx0wIMH3+aLIMP9HjItInYWObRovoCFM5Qe8= +github.com/juju/ratelimit v1.0.2 h1:sRxmtRiajbvrcLQT7S+JbqU0ntsb9W2yhSdNN8tWfaI= +github.com/juju/ratelimit v1.0.2/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= @@ -637,6 +643,8 @@ github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSW github.com/takama/daemon v0.0.0-20180403113744-aa76b0035d12/go.mod h1:So5Nv647d/sgbZNAfiWtw6egowH8vNNrPXAwooWeElk= github.com/thoas/go-funk v0.9.2 h1:oKlNYv0AY5nyf9g+/GhMgS/UO2ces0QRdPKwkhY3VCk= github.com/thoas/go-funk v0.9.2/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= +github.com/tiktoken-go/tokenizer v0.5.1 h1:EOpjlSAVLPX+6ioMUufTI9xmzHU4SI4ARK0DgkBdz+g= +github.com/tiktoken-go/tokenizer v0.5.1/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= diff --git a/logics/item_to_item.go b/logics/item_to_item.go index c6e223b63..f23c68d7b 100644 --- a/logics/item_to_item.go +++ b/logics/item_to_item.go @@ -22,6 +22,7 @@ import ( "strings" "time" + "github.com/cenkalti/backoff/v5" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/expr-lang/expr" @@ -30,6 +31,7 @@ import ( "github.com/nikolalohinski/gonja/v2/exec" "github.com/samber/lo" "github.com/sashabaranov/go-openai" + "github.com/tiktoken-go/tokenizer" "github.com/yuin/goldmark" "github.com/yuin/goldmark/ast" "github.com/yuin/goldmark/text" @@ -37,6 +39,7 @@ import ( "github.com/zhenghaoz/gorse/base/heap" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/common/ann" + "github.com/zhenghaoz/gorse/common/parallel" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/dataset" "github.com/zhenghaoz/gorse/storage/cache" @@ -44,6 +47,16 @@ import ( "go.uber.org/zap" ) +var cl100kBaseTokenizer tokenizer.Codec + +func init() { + var err error + cl100kBaseTokenizer, err = tokenizer.Get(tokenizer.Cl100kBase) + if err != nil { + panic(err) + } +} + type ItemToItemOptions struct { TagsIDF []float32 UsersIDF []float32 @@ -55,6 +68,7 @@ type ItemToItem interface { Items() []*data.Item Push(item *data.Item, feedback []dataset.ID) PopAll(i int) []cache.Score + Pool() parallel.Pool } func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, opts *ItemToItemOptions) (ItemToItem, error) { @@ -119,6 +133,10 @@ func (b *baseItemToItem[T]) PopAll(i int) []cache.Score { }) } +func (b *baseItemToItem[T]) Pool() parallel.Pool { + return parallel.NewSequentialPool() +} + type embeddingItemToItem struct { baseItemToItem[[]float32] dimension int @@ -367,6 +385,7 @@ type chatItemToItem struct { chatCompletionModel string embeddingModel string embeddingDimensions int + poolSize int } func newChatItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, openaiConfig config.OpenAIConfig) (*chatItemToItem, error) { @@ -390,6 +409,7 @@ func newChatItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, chatCompletionModel: openaiConfig.ChatCompletionModel, embeddingModel: openaiConfig.EmbeddingModel, embeddingDimensions: openaiConfig.EmbeddingDimensions, + poolSize: min(openaiConfig.ChatCompletionRPM, openaiConfig.EmbeddingRPM), }, nil } @@ -419,15 +439,27 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { } // chat completion start := time.Now() - resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ - Model: g.chatCompletionModel, - Messages: []openai.ChatCompletionMessage{{ - Role: openai.ChatMessageRoleUser, - Content: buf.String(), - }}, - }) + ids, _, _ := cl100kBaseTokenizer.Encode(buf.String()) + resp, err := backoff.Retry(context.Background(), func() (openai.ChatCompletionResponse, error) { + time.Sleep(parallel.ChatCompletionRequestsLimiter.Take(1)) + time.Sleep(parallel.ChatCompletionTokensLimiter.Take(int64(len(ids)))) + resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: g.chatCompletionModel, + Messages: []openai.ChatCompletionMessage{{ + Role: openai.ChatMessageRoleUser, + Content: buf.String(), + }}, + }) + if err == nil { + return resp, nil + } + if throttled(err) { + return openai.ChatCompletionResponse{}, err + } + return openai.ChatCompletionResponse{}, backoff.Permanent(err) + }, backoff.WithBackOff(backoff.NewExponentialBackOff())) if err != nil { - log.Logger().Error("failed to chat completion", zap.Error(err)) + log.Logger().Error("failed to chat completion", zap.String("item_id", g.items[i].ItemId), zap.Error(err)) return nil } duration := time.Since(start) @@ -443,13 +475,25 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { // message embedding embeddings := make([][]float32, len(parsed)) for i, message := range parsed { - resp, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ - Input: message, - Model: openai.EmbeddingModel(g.embeddingModel), - Dimensions: g.embeddingDimensions, - }) + ids, _, _ := cl100kBaseTokenizer.Encode(message) + resp, err := backoff.Retry(context.Background(), func() (openai.EmbeddingResponse, error) { + time.Sleep(parallel.EmbeddingRequestsLimiter.Take(1)) + time.Sleep(parallel.EmbeddingTokensLimiter.Take(int64(len(ids)))) + resp, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Input: message, + Model: openai.EmbeddingModel(g.embeddingModel), + Dimensions: g.embeddingDimensions, + }) + if err == nil { + return resp, nil + } + if throttled(err) { + return openai.EmbeddingResponse{}, err + } + return openai.EmbeddingResponse{}, backoff.Permanent(err) + }, backoff.WithBackOff(backoff.NewExponentialBackOff())) if err != nil { - log.Logger().Error("failed to create embeddings", zap.Error(err)) + log.Logger().Error("failed to create embeddings", zap.String("item_id", g.items[i].ItemId), zap.Error(err)) return nil } embeddings[i] = resp.Data[0].Embedding @@ -481,6 +525,10 @@ func (g *chatItemToItem) PopAll(i int) []cache.Score { return scores } +func (g *chatItemToItem) Pool() parallel.Pool { + return parallel.NewConcurrentPool(g.poolSize) +} + func stripThinkInCompletion(s string) string { if len(s) < 7 || s[:7] != "" { return s @@ -534,3 +582,12 @@ func parseJSONArrayFromCompletion(completion string) []string { } return []string{string(source)} } + +func throttled(err error) bool { + if requestErr, ok := err.(*openai.APIError); ok { + if requestErr.HTTPStatusCode == 429 { + return true + } + } + return false +} diff --git a/logics/item_to_item_test.go b/logics/item_to_item_test.go index ab87c30a3..755fd8fae 100644 --- a/logics/item_to_item_test.go +++ b/logics/item_to_item_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/common/mock" + "github.com/zhenghaoz/gorse/common/parallel" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/dataset" "github.com/zhenghaoz/gorse/storage/data" @@ -89,6 +90,7 @@ func (suite *ItemToItemTestSuite) TestEmbedding() { Column: "item.Labels.description", }, 10, timestamp) suite.NoError(err) + suite.IsType(item2item.Pool(), ¶llel.SequentialPool{}) for i := 0; i < 100; i++ { item2item.Push(&data.Item{ @@ -116,6 +118,7 @@ func (suite *ItemToItemTestSuite) TestTags() { Column: "item.Labels", }, 10, timestamp, idf) suite.NoError(err) + suite.IsType(item2item.Pool(), ¶llel.SequentialPool{}) for i := 0; i < 100; i++ { labels := make(map[string]any) @@ -143,6 +146,7 @@ func (suite *ItemToItemTestSuite) TestUsers() { } item2item, err := newUsersItemToItem(config.ItemToItemConfig{}, 10, timestamp, idf) suite.NoError(err) + suite.IsType(item2item.Pool(), ¶llel.SequentialPool{}) for i := 0; i < 100; i++ { feedback := make([]dataset.ID, 0, 100-i) @@ -167,6 +171,7 @@ func (suite *ItemToItemTestSuite) TestAuto() { } item2item, err := newAutoItemToItem(config.ItemToItemConfig{}, 10, timestamp, idf, idf) suite.NoError(err) + suite.IsType(item2item.Pool(), ¶llel.SequentialPool{}) for i := 0; i < 100; i++ { item := &data.Item{ItemId: strconv.Itoa(i)} @@ -216,6 +221,7 @@ func (suite *ItemToItemTestSuite) TestChat() { EmbeddingModel: "text-similarity-ada-001", }) suite.NoError(err) + suite.IsType(item2item.Pool(), ¶llel.ConcurrentPool{}) for i := 0; i < 100; i++ { embedding := mock.Hash("Please generate similar items for item_0.") diff --git a/master/master.go b/master/master.go index 1924da92b..efa407444 100644 --- a/master/master.go +++ b/master/master.go @@ -33,9 +33,9 @@ import ( "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/encoding" "github.com/zhenghaoz/gorse/base/log" - "github.com/zhenghaoz/gorse/base/parallel" "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/task" + "github.com/zhenghaoz/gorse/common/parallel" "github.com/zhenghaoz/gorse/common/sizeof" "github.com/zhenghaoz/gorse/common/util" "github.com/zhenghaoz/gorse/config" @@ -133,6 +133,9 @@ func NewMaster(cfg *config.Config, cacheFile string, managedMode bool) *Master { clientConfig.BaseURL = cfg.OpenAI.BaseURL // setup OpenAI logger log.InitOpenAILogger(cfg.OpenAI.LogFile) + // setup OpenAI rate limiter + parallel.InitChatCompletionLimiters(cfg.OpenAI.ChatCompletionRPM, cfg.OpenAI.ChatCompletionTPM) + parallel.InitEmbeddingLimiters(cfg.OpenAI.EmbeddingRPM, cfg.OpenAI.EmbeddingTPM) m := &Master{ // create task monitor diff --git a/master/tasks.go b/master/tasks.go index 5133bd6e4..3111e406b 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -28,9 +28,9 @@ import ( "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/encoding" "github.com/zhenghaoz/gorse/base/log" - "github.com/zhenghaoz/gorse/base/parallel" "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/task" + "github.com/zhenghaoz/gorse/common/parallel" "github.com/zhenghaoz/gorse/common/sizeof" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/dataset" @@ -995,18 +995,22 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error { // Add built-in item-to-item recommenders itemToItemConfigs := m.Config.Recommend.ItemToItem - builtInConfig := config.ItemToItemConfig{} - builtInConfig.Name = cache.Neighbors - switch m.Config.Recommend.ItemNeighbors.NeighborType { - case config.NeighborTypeSimilar: - builtInConfig.Type = "tags" - builtInConfig.Column = "item.Labels" - case config.NeighborTypeRelated: - builtInConfig.Type = "users" - case config.NeighborTypeAuto: - builtInConfig.Type = "auto" + if !lo.ContainsBy(itemToItemConfigs, func(item config.ItemToItemConfig) bool { + return item.Name == cache.Neighbors + }) { + builtInConfig := config.ItemToItemConfig{} + builtInConfig.Name = cache.Neighbors + switch m.Config.Recommend.ItemNeighbors.NeighborType { + case config.NeighborTypeSimilar: + builtInConfig.Type = "tags" + builtInConfig.Column = "item.Labels" + case config.NeighborTypeRelated: + builtInConfig.Type = "users" + case config.NeighborTypeAuto: + builtInConfig.Type = "auto" + } + itemToItemConfigs = append(itemToItemConfigs, builtInConfig) } - itemToItemConfigs = append(itemToItemConfigs, builtInConfig) // Build item-to-item recommenders itemToItemRecommenders := make([]logics.ItemToItem, 0, len(itemToItemConfigs)) @@ -1034,44 +1038,50 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error { // Save item-to-item recommendations to cache for i, recommender := range itemToItemRecommenders { + pool := recommender.Pool() for j, item := range recommender.Items() { itemToItemConfig := itemToItemConfigs[i] if m.needUpdateItemToItem(item.ItemId, itemToItemConfig) { - score := recommender.PopAll(j) - if score == nil { - continue - } - log.Logger().Debug("update item-to-item recommendation", - zap.String("item_id", item.ItemId), - zap.String("name", itemToItemConfig.Name), - zap.Int("n_recommendations", len(score))) - // Save item-to-item recommendation to cache - if err := m.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key(itemToItemConfig.Name, item.ItemId), score); err != nil { - log.Logger().Error("failed to save item-to-item recommendation to cache", - zap.String("item_id", item.ItemId), zap.Error(err)) - continue - } - // Save item-to-item digest and last update time to cache - if err := m.CacheClient.Set(ctx, - cache.String(cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, item.ItemId), itemToItemConfig.Hash()), - cache.Time(cache.Key(cache.ItemToItemUpdateTime, itemToItemConfig.Name, item.ItemId), time.Now()), - ); err != nil { - log.Logger().Error("failed to save item-to-item digest to cache", - zap.String("item_id", item.ItemId), zap.Error(err)) - continue - } - // Remove stale item-to-item recommendation - if err := m.CacheClient.DeleteScores(ctx, []string{cache.ItemToItem}, cache.ScoreCondition{ - Subset: lo.ToPtr(cache.Key(itemToItemConfig.Name, item.ItemId)), - Before: lo.ToPtr(recommender.Timestamp()), - }); err != nil { - log.Logger().Error("failed to remove stale item-to-item recommendation", - zap.String("item_id", item.ItemId), zap.Error(err)) - continue - } + pool.Run(func() { + defer span.Add(1) + score := recommender.PopAll(j) + if score == nil { + return + } + log.Logger().Debug("update item-to-item recommendation", + zap.String("item_id", item.ItemId), + zap.String("name", itemToItemConfig.Name), + zap.Int("n_recommendations", len(score))) + // Save item-to-item recommendation to cache + if err := m.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key(itemToItemConfig.Name, item.ItemId), score); err != nil { + log.Logger().Error("failed to save item-to-item recommendation to cache", + zap.String("item_id", item.ItemId), zap.Error(err)) + return + } + // Save item-to-item digest and last update time to cache + if err := m.CacheClient.Set(ctx, + cache.String(cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, item.ItemId), itemToItemConfig.Hash()), + cache.Time(cache.Key(cache.ItemToItemUpdateTime, itemToItemConfig.Name, item.ItemId), time.Now()), + ); err != nil { + log.Logger().Error("failed to save item-to-item digest to cache", + zap.String("item_id", item.ItemId), zap.Error(err)) + return + } + // Remove stale item-to-item recommendation + if err := m.CacheClient.DeleteScores(ctx, []string{cache.ItemToItem}, cache.ScoreCondition{ + Subset: lo.ToPtr(cache.Key(itemToItemConfig.Name, item.ItemId)), + Before: lo.ToPtr(recommender.Timestamp()), + }); err != nil { + log.Logger().Error("failed to remove stale item-to-item recommendation", + zap.String("item_id", item.ItemId), zap.Error(err)) + return + } + }) + } else { + span.Add(1) } - span.Add(1) } + pool.Wait() } return nil } diff --git a/model/click/model.go b/model/click/model.go index b28281eae..43730caec 100644 --- a/model/click/model.go +++ b/model/click/model.go @@ -31,9 +31,9 @@ import ( "github.com/zhenghaoz/gorse/base/encoding" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/log" - "github.com/zhenghaoz/gorse/base/parallel" "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/task" + "github.com/zhenghaoz/gorse/common/parallel" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" ) diff --git a/model/ranking/evaluator.go b/model/ranking/evaluator.go index 09da130cd..cd1af2f01 100644 --- a/model/ranking/evaluator.go +++ b/model/ranking/evaluator.go @@ -20,7 +20,7 @@ import ( "github.com/thoas/go-funk" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/heap" - "github.com/zhenghaoz/gorse/base/parallel" + "github.com/zhenghaoz/gorse/common/parallel" ) /* Evaluate Item Ranking */ diff --git a/model/ranking/model.go b/model/ranking/model.go index 0c6efa58d..ae1f4dd9d 100644 --- a/model/ranking/model.go +++ b/model/ranking/model.go @@ -31,9 +31,9 @@ import ( "github.com/zhenghaoz/gorse/base/encoding" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/log" - "github.com/zhenghaoz/gorse/base/parallel" "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/task" + "github.com/zhenghaoz/gorse/common/parallel" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" ) diff --git a/worker/worker.go b/worker/worker.go index 0171122cc..bdaf2748f 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -37,10 +37,10 @@ import ( "github.com/zhenghaoz/gorse/base/encoding" "github.com/zhenghaoz/gorse/base/heap" "github.com/zhenghaoz/gorse/base/log" - "github.com/zhenghaoz/gorse/base/parallel" "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/cmd/version" encoding2 "github.com/zhenghaoz/gorse/common/encoding" + "github.com/zhenghaoz/gorse/common/parallel" "github.com/zhenghaoz/gorse/common/sizeof" "github.com/zhenghaoz/gorse/common/util" "github.com/zhenghaoz/gorse/config" diff --git a/worker/worker_test.go b/worker/worker_test.go index 6980b59ab..57fc4e1d5 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -35,8 +35,8 @@ import ( "github.com/stretchr/testify/suite" "github.com/thoas/go-funk" "github.com/zhenghaoz/gorse/base" - "github.com/zhenghaoz/gorse/base/parallel" "github.com/zhenghaoz/gorse/base/progress" + "github.com/zhenghaoz/gorse/common/parallel" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/model" "github.com/zhenghaoz/gorse/model/click"