Skip to content

Commit

Permalink
feat: add TxPipeline to rueidiscompat (#605)
Browse files Browse the repository at this point in the history
* feat: add TxPipeline to rueidiscompat

Signed-off-by: Rueian <[email protected]>

* feat: alias rueidiscompat.Nil to rueidis.Nil

* feat: add todos to the rueidiscompat cmdable interface

Signed-off-by: Rueian <[email protected]>

---------

Signed-off-by: Rueian <[email protected]>
  • Loading branch information
rueian committed Aug 11, 2024
1 parent 44c7724 commit ae247e0
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 24 deletions.
90 changes: 85 additions & 5 deletions valkeycompat/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,20 @@ const (
BitCountIndexBit = "BIT"
)

var Nil = valkey.Nil

type Cmdable interface {
CoreCmdable
Cache(ttl time.Duration) CacheCompat

Subscribe(ctx context.Context, channels ...string) PubSub
PSubscribe(ctx context.Context, patterns ...string) PubSub
SSubscribe(ctx context.Context, channels ...string) PubSub

Watch(ctx context.Context, fn func(Tx) error, keys ...string) error
}

type CoreCmdable interface {
Command(ctx context.Context) *CommandsInfoCmd
CommandList(ctx context.Context, filter FilterBy) *StringSliceCmd
CommandGetKeys(ctx context.Context, commands ...any) *StringSliceCmd
Expand Down Expand Up @@ -127,11 +138,13 @@ type Cmdable interface {
BitPos(ctx context.Context, key string, bit int64, pos ...int64) *IntCmd
BitPosSpan(ctx context.Context, key string, bit int64, start, end int64, span string) *IntCmd
BitField(ctx context.Context, key string, args ...any) *IntSliceCmd
// TODO BitFieldRO(ctx context.Context, key string, values ...interface{}) *IntSliceCmd

Scan(ctx context.Context, cursor uint64, match string, count int64) *ScanCmd
ScanType(ctx context.Context, cursor uint64, match string, count int64, keyType string) *ScanCmd
SScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd
HScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd
// TODO HScanNoValues(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd
ZScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd

HDel(ctx context.Context, key string, fields ...string) *IntCmd
Expand All @@ -149,6 +162,19 @@ type Cmdable interface {
HVals(ctx context.Context, key string) *StringSliceCmd
HRandField(ctx context.Context, key string, count int64) *StringSliceCmd
HRandFieldWithValues(ctx context.Context, key string, count int64) *KeyValueSliceCmd
// TODO HExpire(ctx context.Context, key string, expiration time.Duration, fields ...string) *IntSliceCmd
// TODO HExpireWithArgs(ctx context.Context, key string, expiration time.Duration, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd
// TODO HPExpire(ctx context.Context, key string, expiration time.Duration, fields ...string) *IntSliceCmd
// TODO HPExpireWithArgs(ctx context.Context, key string, expiration time.Duration, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd
// TODO HExpireAt(ctx context.Context, key string, tm time.Time, fields ...string) *IntSliceCmd
// TODO HExpireAtWithArgs(ctx context.Context, key string, tm time.Time, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd
// TODO HPExpireAt(ctx context.Context, key string, tm time.Time, fields ...string) *IntSliceCmd
// TODO HPExpireAtWithArgs(ctx context.Context, key string, tm time.Time, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd
// TODO HPersist(ctx context.Context, key string, fields ...string) *IntSliceCmd
// TODO HExpireTime(ctx context.Context, key string, fields ...string) *IntSliceCmd
// TODO HPExpireTime(ctx context.Context, key string, fields ...string) *IntSliceCmd
// TODO HTTL(ctx context.Context, key string, fields ...string) *IntSliceCmd
// TODO HPTTL(ctx context.Context, key string, fields ...string) *IntSliceCmd

BLPop(ctx context.Context, timeout time.Duration, keys ...string) *StringSliceCmd
BLMPop(ctx context.Context, timeout time.Duration, direction string, count int64, keys ...string) *KeyValuesCmd
Expand Down Expand Up @@ -375,6 +401,8 @@ type Cmdable interface {
ClusterFailover(ctx context.Context) *StatusCmd
ClusterAddSlots(ctx context.Context, slots ...int64) *StatusCmd
ClusterAddSlotsRange(ctx context.Context, min, max int64) *StatusCmd
// TODO ReadOnly(ctx context.Context) *StatusCmd
// TODO ReadWrite(ctx context.Context) *StatusCmd

GeoAdd(ctx context.Context, key string, geoLocation ...GeoLocation) *IntCmd
GeoPos(ctx context.Context, key string, members ...string) *GeoPosCmd
Expand All @@ -389,13 +417,48 @@ type Cmdable interface {
GeoHash(ctx context.Context, key string, members ...string) *StringSliceCmd

ACLDryRun(ctx context.Context, username string, command ...any) *StringCmd
// TODO ACLLog(ctx context.Context, count int64) *ACLLogCmd
// TODO ACLLogReset(ctx context.Context) *StatusCmd

// TODO ModuleLoadex(ctx context.Context, conf *ModuleLoadexConfig) *StringCmd
GearsCmdable
ProbabilisticCmdable
TimeseriesCmdable
JSONCmdable
}
// TODO SearchCmdable
}

// TODO SearchCmdable
//type SearchCmdable interface {
// FT_List(ctx context.Context) *StringSliceCmd
// FTAggregate(ctx context.Context, index string, query string) *MapStringInterfaceCmd
// FTAggregateWithArgs(ctx context.Context, index string, query string, options *FTAggregateOptions) *AggregateCmd
// FTAliasAdd(ctx context.Context, index string, alias string) *StatusCmd
// FTAliasDel(ctx context.Context, alias string) *StatusCmd
// FTAliasUpdate(ctx context.Context, index string, alias string) *StatusCmd
// FTAlter(ctx context.Context, index string, skipInitalScan bool, definition []interface{}) *StatusCmd
// FTConfigGet(ctx context.Context, option string) *MapMapStringInterfaceCmd
// FTConfigSet(ctx context.Context, option string, value interface{}) *StatusCmd
// FTCreate(ctx context.Context, index string, options *FTCreateOptions, schema ...*FieldSchema) *StatusCmd
// FTCursorDel(ctx context.Context, index string, cursorId int) *StatusCmd
// FTCursorRead(ctx context.Context, index string, cursorId int, count int) *MapStringInterfaceCmd
// FTDictAdd(ctx context.Context, dict string, term ...interface{}) *IntCmd
// FTDictDel(ctx context.Context, dict string, term ...interface{}) *IntCmd
// FTDictDump(ctx context.Context, dict string) *StringSliceCmd
// FTDropIndex(ctx context.Context, index string) *StatusCmd
// FTDropIndexWithArgs(ctx context.Context, index string, options *FTDropIndexOptions) *StatusCmd
// FTExplain(ctx context.Context, index string, query string) *StringCmd
// FTExplainWithArgs(ctx context.Context, index string, query string, options *FTExplainOptions) *StringCmd
// FTInfo(ctx context.Context, index string) *FTInfoCmd
// FTSpellCheck(ctx context.Context, index string, query string) *FTSpellCheckCmd
// FTSpellCheckWithArgs(ctx context.Context, index string, query string, options *FTSpellCheckOptions) *FTSpellCheckCmd
// FTSearch(ctx context.Context, index string, query string) *FTSearchCmd
// FTSearchWithArgs(ctx context.Context, index string, query string, options *FTSearchOptions) *FTSearchCmd
// FTSynDump(ctx context.Context, index string) *FTSynDumpCmd
// FTSynUpdate(ctx context.Context, index string, synGroupId interface{}, terms []interface{}) *StatusCmd
// FTSynUpdateWithArgs(ctx context.Context, index string, synGroupId interface{}, options *FTSynUpdateOptions, terms []interface{}) *StatusCmd
// FTTagVals(ctx context.Context, index string, field string) *StringSliceCmd
//}

// https://github.com/redis/go-redis/blob/af4872cbd0de349855ce3f0978929c2f56eb995f/probabilistic.go#L10
type ProbabilisticCmdable interface {
Expand Down Expand Up @@ -470,12 +533,11 @@ type ProbabilisticCmdable interface {
TDigestRevRank(ctx context.Context, key string, values ...float64) *IntSliceCmd
TDigestTrimmedMean(ctx context.Context, key string, lowCutQuantile, highCutQuantile float64) *FloatCmd

Subscribe(ctx context.Context, channels ...string) PubSub
PSubscribe(ctx context.Context, patterns ...string) PubSub
SSubscribe(ctx context.Context, channels ...string) PubSub

Pipeline() Pipeliner
Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error)

TxPipeline() Pipeliner
TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error)
}

// Align with go-redis
Expand Down Expand Up @@ -4628,6 +4690,24 @@ func (c *Compat) Pipeline() Pipeliner {
return newPipeline(c.client)
}

func (c *Compat) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return newTxPipeline(c.client).Pipelined(ctx, fn)
}

func (c *Compat) TxPipeline() Pipeliner {
return newTxPipeline(c.client)
}

func (c *Compat) Watch(ctx context.Context, fn func(Tx) error, keys ...string) error {
dc, cancel := c.client.Dedicate()
defer cancel()
tx := newTx(dc, cancel)
if err := tx.Watch(ctx, keys...).Err(); err != nil {
return err
}
return fn(newTx(dc, cancel))
}

func (c CacheCompat) BitCount(ctx context.Context, key string, bitCount *BitCount) *IntCmd {
var resp valkey.ValkeyResult
if bitCount == nil {
Expand Down
22 changes: 3 additions & 19 deletions valkeycompat/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import (
// To avoid this: it is good idea to use reasonable bigger read/write timeouts
// depends on your batch size and/or use TxPipeline.
type Pipeliner interface {
Cmdable
CoreCmdable

// Len is to obtain the number of commands in the pipeline that have not yet been executed.
Len() int
Expand Down Expand Up @@ -96,10 +96,6 @@ type Pipeline struct {
rets []Cmder
}

func (c *Pipeline) Cache(ttl time.Duration) CacheCompat {
return c.comp.Cache(ttl)
}

func (c *Pipeline) Command(ctx context.Context) *CommandsInfoCmd {
ret := c.comp.Command(ctx)
c.rets = append(c.rets, ret)
Expand Down Expand Up @@ -2434,18 +2430,6 @@ func (c *Pipeline) TDigestTrimmedMean(ctx context.Context, key string, lowCutQua
return ret
}

func (c *Pipeline) Subscribe(ctx context.Context, channels ...string) PubSub {
return c.comp.Subscribe(ctx, channels...)
}

func (c *Pipeline) PSubscribe(ctx context.Context, patterns ...string) PubSub {
return c.comp.PSubscribe(ctx, patterns...)
}

func (c *Pipeline) SSubscribe(ctx context.Context, channels ...string) PubSub {
return c.comp.SSubscribe(ctx, channels...)
}

func (c *Pipeline) TSAdd(ctx context.Context, key string, timestamp interface{}, value float64) *IntCmd {
ret := c.comp.TSAdd(ctx, key, timestamp, value)
c.rets = append(c.rets, ret)
Expand Down Expand Up @@ -2788,10 +2772,10 @@ func (c *Pipeline) Len() int {
}

// Do queues the custom command for later execution.
func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd {
func (c *Pipeline) Do(_ context.Context, args ...interface{}) *Cmd {
ret := &Cmd{}
if len(args) == 0 {
ret.SetErr(errors.New("redis: please enter the command to be executed"))
ret.SetErr(errors.New("valkey: please enter the command to be executed"))
return ret
}
p := c.comp.client.(*proxy)
Expand Down
145 changes: 145 additions & 0 deletions valkeycompat/tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package valkeycompat

import (
"context"
"errors"
"time"
"unsafe"

"github.com/valkey-io/valkey-go"
)

var TxFailedErr = errors.New("valkey: transaction failed")

var _ Pipeliner = (*TxPipeline)(nil)

type rePipeline = Pipeline

func newTxPipeline(real valkey.Client) *TxPipeline {
return &TxPipeline{rePipeline: newPipeline(real)}
}

type TxPipeline struct {
*rePipeline
}

func (c *TxPipeline) Exec(ctx context.Context) ([]Cmder, error) {
p := c.comp.client.(*proxy)
if len(p.cmds) == 0 {
return nil, nil
}

rets := c.rets
cmds := p.cmds
c.rets = nil
p.cmds = nil

cmds = append(cmds, c.comp.client.B().Multi().Build(), c.comp.client.B().Exec().Build())
for i := len(cmds) - 2; i >= 1; i-- {
j := i - 1
cmds[j], cmds[i] = cmds[i], cmds[j]
}

resp := p.DoMulti(ctx, cmds...)
results, err := resp[len(resp)-1].ToArray()
if valkey.IsValkeyNil(err) {
err = TxFailedErr
}
for i, r := range results {
rets[i].from(*(*valkey.ValkeyResult)(unsafe.Pointer(&proxyresult{
err: resp[i+1].NonValkeyError(),
val: r,
})))
}
return rets, err
}

func (c *TxPipeline) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
if err := fn(c); err != nil {
return nil, err
}
return c.Exec(ctx)
}

func (c *TxPipeline) Pipeline() Pipeliner {
return c
}

func (c *TxPipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipelined(ctx, fn)
}

func (c *TxPipeline) TxPipeline() Pipeliner {
return c
}

var _ valkey.Client = (*txproxy)(nil)

type txproxy struct {
valkey.CoreClient
}

func (p *txproxy) DoCache(_ context.Context, _ valkey.Cacheable, _ time.Duration) (resp valkey.ValkeyResult) {
panic("not implemented")
}

func (p *txproxy) DoMultiCache(_ context.Context, _ ...valkey.CacheableTTL) (resp []valkey.ValkeyResult) {
panic("not implemented")
}

func (p *txproxy) DoStream(_ context.Context, _ valkey.Completed) valkey.ValkeyResultStream {
panic("not implemented")
}

func (p *txproxy) DoMultiStream(_ context.Context, _ ...valkey.Completed) valkey.MultiValkeyResultStream {
panic("not implemented")
}

func (p *txproxy) Dedicated(_ func(valkey.DedicatedClient) error) (err error) {
panic("not implemented")
}

func (p *txproxy) Dedicate() (client valkey.DedicatedClient, cancel func()) {
panic("not implemented")
}

func (p *txproxy) Nodes() map[string]valkey.Client {
panic("not implemented")
}

type Tx interface {
CoreCmdable
Watch(ctx context.Context, keys ...string) *StatusCmd
Unwatch(ctx context.Context, keys ...string) *StatusCmd
Close(ctx context.Context) error
}

func newTx(client valkey.DedicatedClient, cancel func()) *tx {
return &tx{CoreCmdable: NewAdapter(&txproxy{CoreClient: client}), cancel: cancel}
}

type tx struct {
CoreCmdable
cancel func()
}

func (t *tx) Watch(ctx context.Context, keys ...string) *StatusCmd {
ret := &StatusCmd{}
if len(keys) != 0 {
client := t.CoreCmdable.(*Compat).client
ret.from(client.Do(ctx, client.B().Watch().Key(keys...).Build()))
}
return ret
}

func (t *tx) Unwatch(ctx context.Context, _ ...string) *StatusCmd {
ret := &StatusCmd{}
client := t.CoreCmdable.(*Compat).client
ret.from(client.Do(ctx, client.B().Unwatch().Build()))
return ret
}

func (t *tx) Close(_ context.Context) error {
t.cancel()
return nil
}
Loading

0 comments on commit ae247e0

Please sign in to comment.