Skip to content

Commit

Permalink
refactor user neighbors and item neighbors (#932)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jan 25, 2025
1 parent a9f6a0a commit 265529d
Show file tree
Hide file tree
Showing 24 changed files with 1,130 additions and 1,246 deletions.
16 changes: 9 additions & 7 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package client
import (
"context"
"encoding/base64"
"testing"
"time"

client "github.com/gorse-io/gorse-go"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/suite"
"testing"
"time"
"github.com/zhenghaoz/gorse/storage/cache"
)

const (
Expand Down Expand Up @@ -108,22 +110,22 @@ func (suite *GorseClientTestSuite) TestRecommend() {

func (suite *GorseClientTestSuite) TestSessionRecommend() {
ctx := context.Background()
suite.hSet("item_neighbors", "1", []client.Score{
suite.hSet("item-to-item", cache.Key(cache.Neighbors, "1"), []client.Score{
{Id: "2", Score: 100000},
{Id: "9", Score: 1},
})
suite.hSet("item_neighbors", "2", []client.Score{
suite.hSet("item-to-item", cache.Key(cache.Neighbors, "2"), []client.Score{
{Id: "3", Score: 100000},
{Id: "8", Score: 1},
{Id: "9", Score: 1},
})
suite.hSet("item_neighbors", "3", []client.Score{
suite.hSet("item-to-item", cache.Key(cache.Neighbors, "3"), []client.Score{
{Id: "4", Score: 100000},
{Id: "7", Score: 1},
{Id: "8", Score: 1},
{Id: "9", Score: 1},
})
suite.hSet("item_neighbors", "4", []client.Score{
suite.hSet("item-to-item", cache.Key(cache.Neighbors, "4"), []client.Score{
{Id: "1", Score: 100000},
{Id: "6", Score: 1},
{Id: "7", Score: 1},
Expand Down Expand Up @@ -179,7 +181,7 @@ func (suite *GorseClientTestSuite) TestSessionRecommend() {

func (suite *GorseClientTestSuite) TestNeighbors() {
ctx := context.Background()
suite.hSet("item_neighbors", "100", []client.Score{
suite.hSet("item-to-item", cache.Key(cache.Neighbors, "100"), []client.Score{
{Id: "1", Score: 1},
{Id: "2", Score: 2},
{Id: "3", Score: 3},
Expand Down
10 changes: 5 additions & 5 deletions common/ann/bruteforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ import (

// Bruteforce is a naive implementation of vector index.
type Bruteforce[T any] struct {
distanceFunc func(a, b []T) float32
vectors [][]T
distanceFunc func(a, b T) float32
vectors []T
}

func NewBruteforce[T any](distanceFunc func(a, b []T) float32) *Bruteforce[T] {
func NewBruteforce[T any](distanceFunc func(a, b T) float32) *Bruteforce[T] {
return &Bruteforce[T]{distanceFunc: distanceFunc}
}

func (b *Bruteforce[T]) Add(v []T) (int, error) {
func (b *Bruteforce[T]) Add(v T) (int, error) {
// Add vector
b.vectors = append(b.vectors, v)
return len(b.vectors), nil
Expand Down Expand Up @@ -62,7 +62,7 @@ func (b *Bruteforce[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, flo
return scores, nil
}

func (b *Bruteforce[T]) SearchVector(q []T, k int, prune0 bool) []lo.Tuple2[int, float32] {
func (b *Bruteforce[T]) SearchVector(q T, k int, prune0 bool) []lo.Tuple2[int, float32] {
// Search
pq := heap.NewPriorityQueue(true)
for i, vec := range b.vectors {
Expand Down
18 changes: 9 additions & 9 deletions common/ann/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import (

// HNSW is a vector index based on Hierarchical Navigable Small Worlds.
type HNSW[T any] struct {
distanceFunc func(a, b []T) float32
vectors [][]T
distanceFunc func(a, b T) float32
vectors []T
bottomNeighbors []*heap.PriorityQueue
upperNeighbors []map[int32]*heap.PriorityQueue
enterPoint int32
Expand All @@ -40,7 +40,7 @@ type HNSW[T any] struct {
efConstruction int
}

func NewHNSW[T any](distanceFunc func(a, b []T) float32) *HNSW[T] {
func NewHNSW[T any](distanceFunc func(a, b T) float32) *HNSW[T] {
return &HNSW[T]{
distanceFunc: distanceFunc,
levelFactor: 1.0 / math32.Log(48),
Expand All @@ -50,7 +50,7 @@ func NewHNSW[T any](distanceFunc func(a, b []T) float32) *HNSW[T] {
}
}

func (h *HNSW[T]) Add(v []T) (int, error) {
func (h *HNSW[T]) Add(v T) (int, error) {
// Add vector
h.vectors = append(h.vectors, v)
h.bottomNeighbors = append(h.bottomNeighbors, heap.NewPriorityQueue(false))
Expand All @@ -70,7 +70,7 @@ func (h *HNSW[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32],
return scores, nil
}

func (h *HNSW[T]) SearchVector(q []T, k int, prune0 bool) []lo.Tuple2[int, float32] {
func (h *HNSW[T]) SearchVector(q T, k int, prune0 bool) []lo.Tuple2[int, float32] {
w := h.knnSearch(q, k, h.efSearchValue(k))
scores := make([]lo.Tuple2[int, float32], 0)
for w.Len() > 0 {
Expand All @@ -82,7 +82,7 @@ func (h *HNSW[T]) SearchVector(q []T, k int, prune0 bool) []lo.Tuple2[int, float
return scores
}

func (h *HNSW[T]) knnSearch(q []T, k, ef int) *heap.PriorityQueue {
func (h *HNSW[T]) knnSearch(q T, k, ef int) *heap.PriorityQueue {
var (
w *heap.PriorityQueue // set for the current the nearest element
enterPoints = h.distance(q, []int32{h.enterPoint}) // get enter point for hnsw
Expand Down Expand Up @@ -157,7 +157,7 @@ func (h *HNSW[T]) insert(q int32) {
}
}

func (h *HNSW[T]) searchLayer(q []T, enterPoints *heap.PriorityQueue, ef, currentLayer int) *heap.PriorityQueue {
func (h *HNSW[T]) searchLayer(q T, enterPoints *heap.PriorityQueue, ef, currentLayer int) *heap.PriorityQueue {
var (
v = mapset.NewSet(enterPoints.Values()...) // set of visited elements
candidates = enterPoints.Clone() // set of candidates
Expand Down Expand Up @@ -210,15 +210,15 @@ func (h *HNSW[T]) getNeighbourhood(e int32, currentLayer int) *heap.PriorityQueu
}
}

func (h *HNSW[T]) selectNeighbors(_ []T, candidates *heap.PriorityQueue, m int) *heap.PriorityQueue {
func (h *HNSW[T]) selectNeighbors(_ T, candidates *heap.PriorityQueue, m int) *heap.PriorityQueue {
pq := candidates.Reverse()
for pq.Len() > m {
pq.Pop()
}
return pq.Reverse()
}

func (h *HNSW[T]) distance(q []T, points []int32) *heap.PriorityQueue {
func (h *HNSW[T]) distance(q T, points []int32) *heap.PriorityQueue {
pq := heap.NewPriorityQueue(false)
for _, point := range points {
pq.Push(point, h.distanceFunc(h.vectors[point], q))
Expand Down
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ type NeighborsConfig struct {

type ItemToItemConfig struct {
Name string `mapstructure:"name" json:"name"`
Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags"`
Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users"`
Column string `mapstructure:"column" json:"column" validate:"item_expr"`
}

Expand Down
16 changes: 15 additions & 1 deletion config/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,20 @@ score = "count(feedback, .FeedbackType == 'star')"
# The filter for items in the leaderboard.
filter = "(now() - item.Timestamp).Hours() < 168"

# [[recommend.item-to-item]]

# # The name of the item-to-item recommender.
# name = "similar_embedding"

# # The type of the item-to-item recommender. There are three types:
# # embedding: recommend by Euclidean distance of embeddings.
# # tags: recommend by number of common tags.
# # users: recommend by number of common users.
# type = "embedding"

# # The column of the item embeddings. Leave blank if type is "users".
# column = "item.Labels.embedding"

[recommend.user_neighbors]

# The type of neighbors for users. There are three types:
Expand All @@ -157,7 +171,7 @@ filter = "(now() - item.Timestamp).Hours() < 168"
# auto: If a user have labels, neighbors are found by number of common labels.
# If this user have no labels, neighbors are found by number of common liked items.
# The default value is "auto".
neighbor_type = "similar"
neighbor_type = "related"

[recommend.item_neighbors]

Expand Down
2 changes: 1 addition & 1 deletion config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestUnmarshal(t *testing.T) {
assert.Equal(t, "count(feedback, .FeedbackType == 'star')", config.Recommend.NonPersonalized[0].Score)
assert.Equal(t, "(now() - item.Timestamp).Hours() < 168", config.Recommend.NonPersonalized[0].Filter)
// [recommend.user_neighbors]
assert.Equal(t, "similar", config.Recommend.UserNeighbors.NeighborType)
assert.Equal(t, "related", config.Recommend.UserNeighbors.NeighborType)
// [recommend.item_neighbors]
assert.Equal(t, "similar", config.Recommend.ItemNeighbors.NeighborType)
// [recommend.collaborative]
Expand Down
126 changes: 112 additions & 14 deletions dataset/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,65 +15,163 @@
package dataset

import (
"time"

"github.com/chewxy/math32"
"github.com/samber/lo"
"github.com/zhenghaoz/gorse/storage/data"
"modernc.org/strutil"
"time"
)

type ID int

type Dataset struct {
timestamp time.Time
users []data.User
items []data.Item
columnNames *strutil.Pool
columnValues *FreqDict
userLabels *Labels
itemLabels *Labels
userFeedback [][]ID
itemFeedback [][]ID
userDict *FreqDict
itemDict *FreqDict
}

func NewDataset(timestamp time.Time, itemCount int) *Dataset {
func NewDataset(timestamp time.Time, userCount, itemCount int) *Dataset {
return &Dataset{
timestamp: timestamp,
users: make([]data.User, 0, userCount),
items: make([]data.Item, 0, itemCount),
columnNames: strutil.NewPool(),
columnValues: NewFreqDict(),
userLabels: NewLabels(),
itemLabels: NewLabels(),
userFeedback: make([][]ID, userCount),
itemFeedback: make([][]ID, itemCount),
userDict: NewFreqDict(),
itemDict: NewFreqDict(),
}
}

func (d *Dataset) GetTimestamp() time.Time {
return d.timestamp
}

func (d *Dataset) GetUsers() []data.User {
return d.users
}

func (d *Dataset) GetItems() []data.Item {
return d.items
}

func (d *Dataset) GetUserFeedback() [][]ID {
return d.userFeedback
}

func (d *Dataset) GetItemFeedback() [][]ID {
return d.itemFeedback
}

// GetUserIDF returns the IDF of users.
//
// IDF(u) = log(I/freq(u))
//
// I is the number of items.
// freq(u) is the frequency of user u in all feedback.
func (d *Dataset) GetUserIDF() []float32 {
idf := make([]float32, d.userDict.Count())
for i := 0; i < d.userDict.Count(); i++ {
// Since zero IDF will cause NaN in the future, we set the minimum value to 1e-3.
idf[i] = max(math32.Log(float32(len(d.items))/float32(d.userDict.Freq(i))), 1e-3)
}
return idf
}

// GetItemIDF returns the IDF of items.
//
// IDF(i) = log(U/freq(i))
//
// U is the number of users.
// freq(i) is the frequency of item i in all feedback.
func (d *Dataset) GetItemIDF() []float32 {
idf := make([]float32, d.itemDict.Count())
for i := 0; i < d.itemDict.Count(); i++ {
// Since zero IDF will cause NaN in the future, we set the minimum value to 1e-3.
idf[i] = max(math32.Log(float32(len(d.users))/float32(d.itemDict.Freq(i))), 1e-3)
}
return idf
}

func (d *Dataset) GetUserColumnValuesIDF() []float32 {
idf := make([]float32, d.userLabels.values.Count())
for i := 0; i < d.userLabels.values.Count(); i++ {
// Since zero IDF will cause NaN in the future, we set the minimum value to 1e-3.
idf[i] = max(math32.Log(float32(len(d.users))/float32(d.userLabels.values.Freq(i))), 1e-3)
}
return idf
}

func (d *Dataset) GetItemColumnValuesIDF() []float32 {
idf := make([]float32, d.columnValues.Count())
for i := 0; i < d.columnValues.Count(); i++ {
idf := make([]float32, d.itemLabels.values.Count())
for i := 0; i < d.itemLabels.values.Count(); i++ {
// Since zero IDF will cause NaN in the future, we set the minimum value to 1e-3.
idf[i] = max(math32.Log(float32(len(d.items)/(d.columnValues.Freq(i)))), 1e-3)
idf[i] = max(math32.Log(float32(len(d.items))/float32(d.itemLabels.values.Freq(i))), 1e-3)
}
return idf
}

func (d *Dataset) AddUser(user data.User) {
d.users = append(d.users, data.User{
UserId: user.UserId,
Labels: d.userLabels.processLabels(user.Labels, ""),
Subscribe: user.Subscribe,
Comment: user.Comment,
})
d.userDict.NotCount(user.UserId)
if len(d.userFeedback) < len(d.users) {
d.userFeedback = append(d.userFeedback, nil)
}
}

func (d *Dataset) AddItem(item data.Item) {
d.items = append(d.items, data.Item{
ItemId: item.ItemId,
IsHidden: item.IsHidden,
Categories: item.Categories,
Timestamp: item.Timestamp,
Labels: d.processLabels(item.Labels, ""),
Labels: d.itemLabels.processLabels(item.Labels, ""),
Comment: item.Comment,
})
d.itemDict.NotCount(item.ItemId)
if len(d.itemFeedback) < len(d.items) {
d.itemFeedback = append(d.itemFeedback, nil)
}
}

func (d *Dataset) AddFeedback(userId, itemId string) {
userIndex := d.userDict.Id(userId)
itemIndex := d.itemDict.Id(itemId)
d.userFeedback[userIndex] = append(d.userFeedback[userIndex], ID(itemIndex))
d.itemFeedback[itemIndex] = append(d.itemFeedback[itemIndex], ID(userIndex))
}

type Labels struct {
fields *strutil.Pool
values *FreqDict
}

func NewLabels() *Labels {
return &Labels{
fields: strutil.NewPool(),
values: NewFreqDict(),
}
}

func (d *Dataset) processLabels(labels any, parent string) any {
func (l *Labels) processLabels(labels any, parent string) any {
switch typed := labels.(type) {
case map[string]any:
o := make(map[string]any)
for k, v := range typed {
o[d.columnNames.Align(k)] = d.processLabels(v, parent+"."+k)
o[l.fields.Align(k)] = l.processLabels(v, parent+"."+k)
}
return o
case []any:
Expand All @@ -83,12 +181,12 @@ func (d *Dataset) processLabels(labels any, parent string) any {
})
} else if isSliceOf[string](typed) {
return lo.Map(typed, func(e any, _ int) ID {
return ID(d.columnValues.Id(parent + ":" + e.(string)))
return ID(l.values.Id(parent + ":" + e.(string)))
})
}
return typed
case string:
return ID(d.columnValues.Id(parent + ":" + typed))
return ID(l.values.Id(parent + ":" + typed))
default:
return labels
}
Expand Down
Loading

0 comments on commit 265529d

Please sign in to comment.