Skip to content

Commit

Permalink
feat: Add Lua Locking for Redis < 7 Compatibility (#731)
Browse files Browse the repository at this point in the history
* feat: Add Lua Locking for Redis < 7 Compatibility

* feat: address comment

* feat: add LuaLock flag to clientOption

* feat: add tests for client with lua lock

---------

Co-authored-by: Anuragkillswitch <[email protected]>
  • Loading branch information
SoulPancake and SoulPancake authored Feb 1, 2025
1 parent 34c5717 commit 5127a2c
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 12 deletions.
38 changes: 26 additions & 12 deletions rueidisaside/aside.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type ClientOption struct {
ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error)
ClientOption rueidis.ClientOption
ClientTTL time.Duration // TTL for the client marker, refreshed every 1/2 TTL. Defaults to 10s. The marker allows other client to know if this client is still alive.
UseLuaLock bool
}

type CacheAsideClient interface {
Expand All @@ -33,8 +34,9 @@ func NewClient(option ClientOption) (cc CacheAsideClient, err error) {
option.ClientTTL = 10 * time.Second
}
ca := &Client{
waits: make(map[string]chan struct{}),
ttl: option.ClientTTL,
waits: make(map[string]chan struct{}),
ttl: option.ClientTTL,
useLuaLock: option.UseLuaLock,
}
option.ClientOption.OnInvalidations = ca.onInvalidation
if option.ClientBuilder != nil {
Expand All @@ -50,13 +52,14 @@ func NewClient(option ClientOption) (cc CacheAsideClient, err error) {
}

type Client struct {
client rueidis.Client
ctx context.Context
waits map[string]chan struct{}
cancel context.CancelFunc
id string
ttl time.Duration
mu sync.Mutex
client rueidis.Client
ctx context.Context
waits map[string]chan struct{}
cancel context.CancelFunc
id string
ttl time.Duration
mu sync.Mutex
useLuaLock bool
}

func (c *Client) onInvalidation(messages []rueidis.RedisMessage) {
Expand Down Expand Up @@ -144,14 +147,21 @@ func randStr() string {
func (c *Client) Get(ctx context.Context, ttl time.Duration, key string, fn func(ctx context.Context, key string) (val string, err error)) (string, error) {
ctx, cancel := context.WithTimeout(ctx, ttl)
defer cancel()

retry:
wait := c.register(key)
resp := c.client.DoCache(ctx, c.client.B().Get().Key(key).Cache(), ttl)
val, err := resp.ToString()

if rueidis.IsRedisNil(err) && fn != nil { // cache miss, prepare to populate the value by fn()
var id string
if id, err = c.keepalive(); err == nil { // acquire client id
val, err = c.client.Do(ctx, c.client.B().Set().Key(key).Value(id).Nx().Get().Px(ttl).Build()).ToString()
if c.useLuaLock {
val, err = acquireLock.Exec(ctx, c.client, []string{key}, []string{id, strconv.FormatInt(ttl.Milliseconds(), 10)}).ToString()
} else {
val, err = c.client.Do(ctx, c.client.B().Set().Key(key).Value(id).Nx().Get().Px(ttl).Build()).ToString()
}

if rueidis.IsRedisNil(err) { // successfully set client id on the key as a lock
if val, err = fn(ctx, key); err == nil {
err = setkey.Exec(ctx, c.client, []string{key}, []string{id, val, strconv.FormatInt(ttl.Milliseconds(), 10)}).Error()
Expand All @@ -162,9 +172,11 @@ retry:
}
}
}

if err != nil {
return val, err
}

if strings.HasPrefix(val, PlaceholderPrefix) {
ph := c.register(val)
err = c.client.DoCache(ctx, c.client.B().Get().Key(val).Cache(), c.ttl).Error()
Expand All @@ -184,6 +196,7 @@ retry:
goto retry
}
}

return val, err
}

Expand All @@ -210,6 +223,7 @@ func (c *Client) Close() {
const PlaceholderPrefix = "rueidisid:"

var (
delkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`)
setkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`)
delkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`)
setkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`)
acquireLock = rueidis.NewLuaScript(`if redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2]) then return nil else return redis.call("GET", KEYS[1]) end`)
)
238 changes: 238 additions & 0 deletions rueidisaside/aside_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ func makeClient(t *testing.T, addr []string) CacheAsideClient {
return client
}

func makeClientWithLuaLock(t *testing.T, addr []string) CacheAsideClient {
client, err := NewClient(ClientOption{
UseLuaLock: true,
ClientOption: rueidis.ClientOption{InitAddress: addr, PipelineMultiplex: -1},
ClientTTL: time.Second,
})
if err != nil {
t.Fatal(err)
}
return client
}

func TestClientErr(t *testing.T) {
if _, err := NewClient(ClientOption{}); err == nil {
t.Error(err)
Expand Down Expand Up @@ -72,6 +84,29 @@ func TestCacheFilled(t *testing.T) {
}
}

func TestCacheFilledLL(t *testing.T) {
client := makeClientWithLuaLock(t, addr)
defer client.Close()
key := strconv.Itoa(rand.Int())
for i := 0; i < 2; i++ {
val, err := client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) {
return "1", nil
})
if err != nil || val != "1" {
t.Fatal(err)
}
val, err = client.Get(context.Background(), time.Millisecond*500, key, nil)
if err != nil || val != "1" {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 600)
val, err = client.Get(context.Background(), time.Millisecond*500, key, nil) // should miss
if !rueidis.IsRedisNil(err) {
t.Fatal(err)
}
}
}

func TestCacheDel(t *testing.T) {
client := makeClient(t, addr)
defer client.Close()
Expand All @@ -98,6 +133,32 @@ func TestCacheDel(t *testing.T) {
}
}

func TestCacheDelLL(t *testing.T) {
client := makeClientWithLuaLock(t, addr)
defer client.Close()
key := strconv.Itoa(rand.Int())
for i := 0; i < 2; i++ {
val, err := client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) {
return "1", nil
})
if err != nil || val != "1" {
t.Fatal(err)
}
val, err = client.Get(context.Background(), time.Millisecond*500, key, nil)
if err != nil || val != "1" {
t.Fatal(err)
}
if err = client.Del(context.Background(), key); err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
val, err = client.Get(context.Background(), time.Millisecond*500, key, nil) // should miss
if !rueidis.IsRedisNil(err) {
t.Fatal(err)
}
}
}

func TestClientRefresh(t *testing.T) {
client := makeClient(t, addr).(*Client)
defer client.Close()
Expand All @@ -118,6 +179,26 @@ func TestClientRefresh(t *testing.T) {
})
}

func TestClientRefreshLL(t *testing.T) {
client := makeClientWithLuaLock(t, addr).(*Client)
defer client.Close()
key := strconv.Itoa(rand.Int())
_, _ = client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) {
id, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString()
if err != nil {
t.Error(err)
}
for i := 0; i < 2; i++ {
err = client.client.Do(context.Background(), client.client.B().Get().Key(id).Build()).Error()
if err != nil {
t.Error(err)
}
time.Sleep(client.ttl)
}
return "1", nil
})
}

func TestCloseCleanup(t *testing.T) {
client := makeClient(t, addr).(*Client)
key := strconv.Itoa(rand.Int())
Expand All @@ -143,6 +224,31 @@ func TestCloseCleanup(t *testing.T) {
}
}

func TestCloseCleanupLL(t *testing.T) {
client := makeClientWithLuaLock(t, addr).(*Client)
key := strconv.Itoa(rand.Int())
ch := make(chan string, 1)
_, _ = client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) {
id, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString()
if err != nil {
t.Error(err)
}
err = client.client.Do(context.Background(), client.client.B().Get().Key(id).Build()).Error()
if err != nil {
t.Error(err)
}
ch <- id
return "1", nil
})
client.Close()
client = makeClient(t, addr).(*Client)
defer client.Close()
err := client.client.Do(context.Background(), client.client.B().Get().Key(<-ch).Build()).Error()
if !rueidis.IsRedisNil(err) {
t.Error(err)
}
}

func TestWriteCancel(t *testing.T) {
client := makeClient(t, addr).(*Client)
defer client.Close()
Expand Down Expand Up @@ -170,6 +276,33 @@ func TestWriteCancel(t *testing.T) {
}
}

func TestWriteCancelLL(t *testing.T) {
client := makeClientWithLuaLock(t, addr).(*Client)
defer client.Close()
key := strconv.Itoa(rand.Int())
ch := make(chan string, 1)
ctx, cancel := context.WithCancel(context.Background())
val, err := client.Get(ctx, time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) {
id, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString()
if err != nil {
t.Error(err)
}
cancel()
ch <- id
return "1", nil
})
if val != "1" {
t.Fatal(err)
}
if err != context.Canceled {
t.Fatal(err)
}
err = client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).Error()
if !rueidis.IsRedisNil(err) {
t.Error(err)
}
}

func TestTimeout(t *testing.T) {
client := makeClient(t, addr).(*Client)
defer client.Close()
Expand All @@ -188,6 +321,24 @@ func TestTimeout(t *testing.T) {
}
}

func TestTimeoutLL(t *testing.T) {
client := makeClientWithLuaLock(t, addr).(*Client)
defer client.Close()
key := strconv.Itoa(rand.Int())
_, err := client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) {
_, err = client.Get(context.Background(), time.Millisecond*500, key, func(ctx context.Context, key string) (val string, err error) {
return "1", nil
})
if err != context.DeadlineExceeded {
t.Error(err)
}
return "", err
})
if err != context.DeadlineExceeded {
t.Fatal(err)
}
}

func TestDisconnect(t *testing.T) {
client := makeClient(t, addr).(*Client)
defer client.Close()
Expand Down Expand Up @@ -238,6 +389,56 @@ func TestDisconnect(t *testing.T) {
time.Sleep(client.ttl) // wait old refresh goroutine exit
}

func TestDisconnectLL(t *testing.T) {
client := makeClientWithLuaLock(t, addr).(*Client)
defer client.Close()
key := strconv.Itoa(rand.Int())
ch := make(chan string, 2)
val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
id1, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString()
if err != nil {
t.Error(err)
}
go func() {
val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
id2, err := client.client.Do(context.Background(), client.client.B().Get().Key(key).Build()).ToString()
if err != nil {
t.Error(err)
}
ch <- id2
return "2", nil
})
if val != "2" {
t.Error(err)
}
}()
client.onInvalidation(nil) // simulate disconnection
id2 := <-ch
if id1 == id2 {
t.Error("id not changed")
}
ch <- id1
ch <- id2
return "1", nil
})
if val != "1" {
t.Fatal(err)
}
val, err = client.Get(context.Background(), time.Millisecond*500, key, nil)
if val != "2" {
t.Error(err)
}
err = client.client.Do(context.Background(), client.client.B().Get().Key(<-ch).Build()).Error() // id1
if !rueidis.IsRedisNil(err) {
t.Error(err)
}
err = client.client.Do(context.Background(), client.client.B().Get().Key(<-ch).Build()).Error() // id2
if err != nil {
t.Error(err)
}
time.Sleep(client.ttl) // wait old refresh goroutine exit
}

func TestMultipleClient(t *testing.T) {
clients := make([]CacheAsideClient, 10)
for i := 0; i < len(clients); i++ {
Expand Down Expand Up @@ -274,3 +475,40 @@ func TestMultipleClient(t *testing.T) {
}
}
}

func TestMultipleClientLL(t *testing.T) {
clients := make([]CacheAsideClient, 10)
for i := 0; i < len(clients); i++ {
clients[i] = makeClientWithLuaLock(t, addr)
}
defer func() {
for _, client := range clients {
client.Close()
}
}()
cnt := 1000
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(len(clients))
key := strconv.Itoa(rand.Int())
sum := int64(0)
for i, c := range clients {
go func(i int, c CacheAsideClient) {
defer wg.Done()
for j := 0; j < cnt; j++ {
v, err := c.Get(context.Background(), time.Second, key, func(ctx context.Context, key string) (val string, err error) {
atomic.AddInt64(&sum, 1)
return "1", nil
})
if err != nil || v != "1" {
t.Error(err)
}
}
}(i, c)
}
wg.Wait()
if atomic.LoadInt64(&sum) != 1 {
t.Fatalf("unexpected sum")
}
}
}

0 comments on commit 5127a2c

Please sign in to comment.