From ae247e091d2793f35293adaff12f8a33105e5383 Mon Sep 17 00:00:00 2001 From: Rueian Date: Sun, 11 Aug 2024 00:00:26 +0800 Subject: [PATCH] feat: add TxPipeline to rueidiscompat (#605) * feat: add TxPipeline to rueidiscompat Signed-off-by: Rueian * feat: alias rueidiscompat.Nil to rueidis.Nil * feat: add todos to the rueidiscompat cmdable interface Signed-off-by: Rueian --------- Signed-off-by: Rueian --- valkeycompat/adapter.go | 90 +++++++++++++++++++-- valkeycompat/pipeline.go | 22 +---- valkeycompat/tx.go | 145 +++++++++++++++++++++++++++++++++ valkeycompat/tx_test.go | 168 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 401 insertions(+), 24 deletions(-) create mode 100644 valkeycompat/tx.go create mode 100644 valkeycompat/tx_test.go diff --git a/valkeycompat/adapter.go b/valkeycompat/adapter.go index 91bb586..f5f072e 100644 --- a/valkeycompat/adapter.go +++ b/valkeycompat/adapter.go @@ -49,9 +49,20 @@ const ( BitCountIndexBit = "BIT" ) +var Nil = valkey.Nil + type Cmdable interface { + CoreCmdable Cache(ttl time.Duration) CacheCompat + Subscribe(ctx context.Context, channels ...string) PubSub + PSubscribe(ctx context.Context, patterns ...string) PubSub + SSubscribe(ctx context.Context, channels ...string) PubSub + + Watch(ctx context.Context, fn func(Tx) error, keys ...string) error +} + +type CoreCmdable interface { Command(ctx context.Context) *CommandsInfoCmd CommandList(ctx context.Context, filter FilterBy) *StringSliceCmd CommandGetKeys(ctx context.Context, commands ...any) *StringSliceCmd @@ -127,11 +138,13 @@ type Cmdable interface { BitPos(ctx context.Context, key string, bit int64, pos ...int64) *IntCmd BitPosSpan(ctx context.Context, key string, bit int64, start, end int64, span string) *IntCmd BitField(ctx context.Context, key string, args ...any) *IntSliceCmd + // TODO BitFieldRO(ctx context.Context, key string, values ...interface{}) *IntSliceCmd Scan(ctx context.Context, cursor uint64, match string, count int64) *ScanCmd ScanType(ctx context.Context, cursor uint64, match string, count int64, keyType string) *ScanCmd SScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd HScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd + // TODO HScanNoValues(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd ZScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd HDel(ctx context.Context, key string, fields ...string) *IntCmd @@ -149,6 +162,19 @@ type Cmdable interface { HVals(ctx context.Context, key string) *StringSliceCmd HRandField(ctx context.Context, key string, count int64) *StringSliceCmd HRandFieldWithValues(ctx context.Context, key string, count int64) *KeyValueSliceCmd + // TODO HExpire(ctx context.Context, key string, expiration time.Duration, fields ...string) *IntSliceCmd + // TODO HExpireWithArgs(ctx context.Context, key string, expiration time.Duration, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd + // TODO HPExpire(ctx context.Context, key string, expiration time.Duration, fields ...string) *IntSliceCmd + // TODO HPExpireWithArgs(ctx context.Context, key string, expiration time.Duration, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd + // TODO HExpireAt(ctx context.Context, key string, tm time.Time, fields ...string) *IntSliceCmd + // TODO HExpireAtWithArgs(ctx context.Context, key string, tm time.Time, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd + // TODO HPExpireAt(ctx context.Context, key string, tm time.Time, fields ...string) *IntSliceCmd + // TODO HPExpireAtWithArgs(ctx context.Context, key string, tm time.Time, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd + // TODO HPersist(ctx context.Context, key string, fields ...string) *IntSliceCmd + // TODO HExpireTime(ctx context.Context, key string, fields ...string) *IntSliceCmd + // TODO HPExpireTime(ctx context.Context, key string, fields ...string) *IntSliceCmd + // TODO HTTL(ctx context.Context, key string, fields ...string) *IntSliceCmd + // TODO HPTTL(ctx context.Context, key string, fields ...string) *IntSliceCmd BLPop(ctx context.Context, timeout time.Duration, keys ...string) *StringSliceCmd BLMPop(ctx context.Context, timeout time.Duration, direction string, count int64, keys ...string) *KeyValuesCmd @@ -375,6 +401,8 @@ type Cmdable interface { ClusterFailover(ctx context.Context) *StatusCmd ClusterAddSlots(ctx context.Context, slots ...int64) *StatusCmd ClusterAddSlotsRange(ctx context.Context, min, max int64) *StatusCmd + // TODO ReadOnly(ctx context.Context) *StatusCmd + // TODO ReadWrite(ctx context.Context) *StatusCmd GeoAdd(ctx context.Context, key string, geoLocation ...GeoLocation) *IntCmd GeoPos(ctx context.Context, key string, members ...string) *GeoPosCmd @@ -389,13 +417,48 @@ type Cmdable interface { GeoHash(ctx context.Context, key string, members ...string) *StringSliceCmd ACLDryRun(ctx context.Context, username string, command ...any) *StringCmd + // TODO ACLLog(ctx context.Context, count int64) *ACLLogCmd + // TODO ACLLogReset(ctx context.Context) *StatusCmd // TODO ModuleLoadex(ctx context.Context, conf *ModuleLoadexConfig) *StringCmd GearsCmdable ProbabilisticCmdable TimeseriesCmdable JSONCmdable -} + // TODO SearchCmdable +} + +// TODO SearchCmdable +//type SearchCmdable interface { +// FT_List(ctx context.Context) *StringSliceCmd +// FTAggregate(ctx context.Context, index string, query string) *MapStringInterfaceCmd +// FTAggregateWithArgs(ctx context.Context, index string, query string, options *FTAggregateOptions) *AggregateCmd +// FTAliasAdd(ctx context.Context, index string, alias string) *StatusCmd +// FTAliasDel(ctx context.Context, alias string) *StatusCmd +// FTAliasUpdate(ctx context.Context, index string, alias string) *StatusCmd +// FTAlter(ctx context.Context, index string, skipInitalScan bool, definition []interface{}) *StatusCmd +// FTConfigGet(ctx context.Context, option string) *MapMapStringInterfaceCmd +// FTConfigSet(ctx context.Context, option string, value interface{}) *StatusCmd +// FTCreate(ctx context.Context, index string, options *FTCreateOptions, schema ...*FieldSchema) *StatusCmd +// FTCursorDel(ctx context.Context, index string, cursorId int) *StatusCmd +// FTCursorRead(ctx context.Context, index string, cursorId int, count int) *MapStringInterfaceCmd +// FTDictAdd(ctx context.Context, dict string, term ...interface{}) *IntCmd +// FTDictDel(ctx context.Context, dict string, term ...interface{}) *IntCmd +// FTDictDump(ctx context.Context, dict string) *StringSliceCmd +// FTDropIndex(ctx context.Context, index string) *StatusCmd +// FTDropIndexWithArgs(ctx context.Context, index string, options *FTDropIndexOptions) *StatusCmd +// FTExplain(ctx context.Context, index string, query string) *StringCmd +// FTExplainWithArgs(ctx context.Context, index string, query string, options *FTExplainOptions) *StringCmd +// FTInfo(ctx context.Context, index string) *FTInfoCmd +// FTSpellCheck(ctx context.Context, index string, query string) *FTSpellCheckCmd +// FTSpellCheckWithArgs(ctx context.Context, index string, query string, options *FTSpellCheckOptions) *FTSpellCheckCmd +// FTSearch(ctx context.Context, index string, query string) *FTSearchCmd +// FTSearchWithArgs(ctx context.Context, index string, query string, options *FTSearchOptions) *FTSearchCmd +// FTSynDump(ctx context.Context, index string) *FTSynDumpCmd +// FTSynUpdate(ctx context.Context, index string, synGroupId interface{}, terms []interface{}) *StatusCmd +// FTSynUpdateWithArgs(ctx context.Context, index string, synGroupId interface{}, options *FTSynUpdateOptions, terms []interface{}) *StatusCmd +// FTTagVals(ctx context.Context, index string, field string) *StringSliceCmd +//} // https://github.com/redis/go-redis/blob/af4872cbd0de349855ce3f0978929c2f56eb995f/probabilistic.go#L10 type ProbabilisticCmdable interface { @@ -470,12 +533,11 @@ type ProbabilisticCmdable interface { TDigestRevRank(ctx context.Context, key string, values ...float64) *IntSliceCmd TDigestTrimmedMean(ctx context.Context, key string, lowCutQuantile, highCutQuantile float64) *FloatCmd - Subscribe(ctx context.Context, channels ...string) PubSub - PSubscribe(ctx context.Context, patterns ...string) PubSub - SSubscribe(ctx context.Context, channels ...string) PubSub - Pipeline() Pipeliner Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) + + TxPipeline() Pipeliner + TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) } // Align with go-redis @@ -4628,6 +4690,24 @@ func (c *Compat) Pipeline() Pipeliner { return newPipeline(c.client) } +func (c *Compat) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { + return newTxPipeline(c.client).Pipelined(ctx, fn) +} + +func (c *Compat) TxPipeline() Pipeliner { + return newTxPipeline(c.client) +} + +func (c *Compat) Watch(ctx context.Context, fn func(Tx) error, keys ...string) error { + dc, cancel := c.client.Dedicate() + defer cancel() + tx := newTx(dc, cancel) + if err := tx.Watch(ctx, keys...).Err(); err != nil { + return err + } + return fn(newTx(dc, cancel)) +} + func (c CacheCompat) BitCount(ctx context.Context, key string, bitCount *BitCount) *IntCmd { var resp valkey.ValkeyResult if bitCount == nil { diff --git a/valkeycompat/pipeline.go b/valkeycompat/pipeline.go index 97e2f24..1867cbe 100644 --- a/valkeycompat/pipeline.go +++ b/valkeycompat/pipeline.go @@ -49,7 +49,7 @@ import ( // To avoid this: it is good idea to use reasonable bigger read/write timeouts // depends on your batch size and/or use TxPipeline. type Pipeliner interface { - Cmdable + CoreCmdable // Len is to obtain the number of commands in the pipeline that have not yet been executed. Len() int @@ -96,10 +96,6 @@ type Pipeline struct { rets []Cmder } -func (c *Pipeline) Cache(ttl time.Duration) CacheCompat { - return c.comp.Cache(ttl) -} - func (c *Pipeline) Command(ctx context.Context) *CommandsInfoCmd { ret := c.comp.Command(ctx) c.rets = append(c.rets, ret) @@ -2434,18 +2430,6 @@ func (c *Pipeline) TDigestTrimmedMean(ctx context.Context, key string, lowCutQua return ret } -func (c *Pipeline) Subscribe(ctx context.Context, channels ...string) PubSub { - return c.comp.Subscribe(ctx, channels...) -} - -func (c *Pipeline) PSubscribe(ctx context.Context, patterns ...string) PubSub { - return c.comp.PSubscribe(ctx, patterns...) -} - -func (c *Pipeline) SSubscribe(ctx context.Context, channels ...string) PubSub { - return c.comp.SSubscribe(ctx, channels...) -} - func (c *Pipeline) TSAdd(ctx context.Context, key string, timestamp interface{}, value float64) *IntCmd { ret := c.comp.TSAdd(ctx, key, timestamp, value) c.rets = append(c.rets, ret) @@ -2788,10 +2772,10 @@ func (c *Pipeline) Len() int { } // Do queues the custom command for later execution. -func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd { +func (c *Pipeline) Do(_ context.Context, args ...interface{}) *Cmd { ret := &Cmd{} if len(args) == 0 { - ret.SetErr(errors.New("redis: please enter the command to be executed")) + ret.SetErr(errors.New("valkey: please enter the command to be executed")) return ret } p := c.comp.client.(*proxy) diff --git a/valkeycompat/tx.go b/valkeycompat/tx.go new file mode 100644 index 0000000..15b4b8b --- /dev/null +++ b/valkeycompat/tx.go @@ -0,0 +1,145 @@ +package valkeycompat + +import ( + "context" + "errors" + "time" + "unsafe" + + "github.com/valkey-io/valkey-go" +) + +var TxFailedErr = errors.New("valkey: transaction failed") + +var _ Pipeliner = (*TxPipeline)(nil) + +type rePipeline = Pipeline + +func newTxPipeline(real valkey.Client) *TxPipeline { + return &TxPipeline{rePipeline: newPipeline(real)} +} + +type TxPipeline struct { + *rePipeline +} + +func (c *TxPipeline) Exec(ctx context.Context) ([]Cmder, error) { + p := c.comp.client.(*proxy) + if len(p.cmds) == 0 { + return nil, nil + } + + rets := c.rets + cmds := p.cmds + c.rets = nil + p.cmds = nil + + cmds = append(cmds, c.comp.client.B().Multi().Build(), c.comp.client.B().Exec().Build()) + for i := len(cmds) - 2; i >= 1; i-- { + j := i - 1 + cmds[j], cmds[i] = cmds[i], cmds[j] + } + + resp := p.DoMulti(ctx, cmds...) + results, err := resp[len(resp)-1].ToArray() + if valkey.IsValkeyNil(err) { + err = TxFailedErr + } + for i, r := range results { + rets[i].from(*(*valkey.ValkeyResult)(unsafe.Pointer(&proxyresult{ + err: resp[i+1].NonValkeyError(), + val: r, + }))) + } + return rets, err +} + +func (c *TxPipeline) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { + if err := fn(c); err != nil { + return nil, err + } + return c.Exec(ctx) +} + +func (c *TxPipeline) Pipeline() Pipeliner { + return c +} + +func (c *TxPipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { + return c.Pipelined(ctx, fn) +} + +func (c *TxPipeline) TxPipeline() Pipeliner { + return c +} + +var _ valkey.Client = (*txproxy)(nil) + +type txproxy struct { + valkey.CoreClient +} + +func (p *txproxy) DoCache(_ context.Context, _ valkey.Cacheable, _ time.Duration) (resp valkey.ValkeyResult) { + panic("not implemented") +} + +func (p *txproxy) DoMultiCache(_ context.Context, _ ...valkey.CacheableTTL) (resp []valkey.ValkeyResult) { + panic("not implemented") +} + +func (p *txproxy) DoStream(_ context.Context, _ valkey.Completed) valkey.ValkeyResultStream { + panic("not implemented") +} + +func (p *txproxy) DoMultiStream(_ context.Context, _ ...valkey.Completed) valkey.MultiValkeyResultStream { + panic("not implemented") +} + +func (p *txproxy) Dedicated(_ func(valkey.DedicatedClient) error) (err error) { + panic("not implemented") +} + +func (p *txproxy) Dedicate() (client valkey.DedicatedClient, cancel func()) { + panic("not implemented") +} + +func (p *txproxy) Nodes() map[string]valkey.Client { + panic("not implemented") +} + +type Tx interface { + CoreCmdable + Watch(ctx context.Context, keys ...string) *StatusCmd + Unwatch(ctx context.Context, keys ...string) *StatusCmd + Close(ctx context.Context) error +} + +func newTx(client valkey.DedicatedClient, cancel func()) *tx { + return &tx{CoreCmdable: NewAdapter(&txproxy{CoreClient: client}), cancel: cancel} +} + +type tx struct { + CoreCmdable + cancel func() +} + +func (t *tx) Watch(ctx context.Context, keys ...string) *StatusCmd { + ret := &StatusCmd{} + if len(keys) != 0 { + client := t.CoreCmdable.(*Compat).client + ret.from(client.Do(ctx, client.B().Watch().Key(keys...).Build())) + } + return ret +} + +func (t *tx) Unwatch(ctx context.Context, _ ...string) *StatusCmd { + ret := &StatusCmd{} + client := t.CoreCmdable.(*Compat).client + ret.from(client.Do(ctx, client.B().Unwatch().Build())) + return ret +} + +func (t *tx) Close(_ context.Context) error { + t.cancel() + return nil +} diff --git a/valkeycompat/tx_test.go b/valkeycompat/tx_test.go new file mode 100644 index 0000000..02c6e41 --- /dev/null +++ b/valkeycompat/tx_test.go @@ -0,0 +1,168 @@ +// Copyright (c) 2013 The github.com/go-redis/redis Authors. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package valkeycompat + +import ( + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/valkey-io/valkey-go" +) + +var _ = Describe("RESP3 TxPipeline Commands", func() { + testAdapterTxPipeline(true) +}) + +var _ = Describe("RESP2 TxPipeline Commands", func() { + testAdapterTxPipeline(false) +}) + +func testAdapterTxPipeline(resp3 bool) { + var adapter Cmdable + + BeforeEach(func() { + if resp3 { + adapter = adapterresp3 + } else { + adapter = adapterresp2 + } + Expect(adapter.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + Expect(adapter.FlushAll(ctx).Err()).NotTo(HaveOccurred()) + }) + + It("should TxPipelined", func() { + var echo, ping *StringCmd + rets, err := adapter.TxPipelined(ctx, func(pipe Pipeliner) error { + echo = pipe.Echo(ctx, "hello") + ping = pipe.Ping(ctx) + Expect(echo.Err()).To(MatchError(placeholder.err)) + Expect(ping.Err()).To(MatchError(placeholder.err)) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(rets).To(HaveLen(2)) + Expect(rets[0]).To(Equal(echo)) + Expect(rets[1]).To(Equal(ping)) + Expect(echo.Err()).NotTo(HaveOccurred()) + Expect(echo.Val()).To(Equal("hello")) + Expect(ping.Err()).NotTo(HaveOccurred()) + Expect(ping.Val()).To(Equal("PONG")) + }) + + It("should TxPipeline", func() { + pipe := adapter.TxPipeline() + echo := pipe.Echo(ctx, "hello") + ping := pipe.Ping(ctx) + Expect(echo.Err()).To(MatchError(placeholder.err)) + Expect(ping.Err()).To(MatchError(placeholder.err)) + Expect(pipe.Len()).To(Equal(2)) + + rets, err := pipe.Exec(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(pipe.Len()).To(Equal(0)) + Expect(rets).To(HaveLen(2)) + Expect(rets[0]).To(Equal(echo)) + Expect(rets[1]).To(Equal(ping)) + Expect(echo.Err()).NotTo(HaveOccurred()) + Expect(echo.Val()).To(Equal("hello")) + Expect(ping.Err()).NotTo(HaveOccurred()) + Expect(ping.Val()).To(Equal("PONG")) + }) + + It("should Discard", func() { + pipe := adapter.TxPipeline() + echo := pipe.Echo(ctx, "hello") + ping := pipe.Ping(ctx) + + pipe.Discard() + Expect(pipe.Len()).To(Equal(0)) + + rets, err := pipe.Exec(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(rets).To(HaveLen(0)) + + Expect(echo.Err()).To(MatchError(placeholder.err)) + Expect(ping.Err()).To(MatchError(placeholder.err)) + }) + + It("should Watch", func() { + k1 := "_k1_" + k2 := "_k2_" + err := adapter.Watch(ctx, func(t Tx) error { + if t.Get(ctx, k1).Err() != Nil { + return errors.New("unclean") + } + if t.Get(ctx, k2).Err() != Nil { + return errors.New("unclean") + } + _, err := t.TxPipelined(ctx, func(pipe Pipeliner) error { + pipe.Set(ctx, k1, k1, 0) + pipe.Set(ctx, k2, k2, 0) + return nil + }) + return err + }, k1, k2) + Expect(err).NotTo(HaveOccurred()) + Expect(adapter.Get(ctx, k1).Val()).To(Equal(k1)) + Expect(adapter.Get(ctx, k2).Val()).To(Equal(k2)) + }) + + It("should Watch Abort", func() { + k1 := "_k1_" + ch := make(chan error) + go func() { + ch <- adapter.Watch(ctx, func(t Tx) error { + ch <- nil + <-ch + _, err := t.TxPipelined(ctx, func(pipe Pipeliner) error { + pipe.Del(ctx, k1) + return nil + }) + return err + }, k1) + }() + <-ch + Expect(adapter.Set(ctx, k1, k1, 0).Err()).NotTo(HaveOccurred()) + ch <- nil + Expect(<-ch).To(MatchError(TxFailedErr)) + }) + + It("should Unwatch and Close", func() { + k1 := "_k1_" + err := adapter.Watch(ctx, func(t Tx) error { + Expect(t.Unwatch(ctx).Err()).NotTo(HaveOccurred()) + Expect(t.Close(ctx)).NotTo(HaveOccurred()) + _, err := t.TxPipelined(ctx, func(pipe Pipeliner) error { + pipe.Del(ctx, k1) + return nil + }) + return err + }, k1) + Expect(err).To(MatchError(valkey.ErrDedicatedClientRecycled)) + }) +}