Skip to content

Commit

Permalink
feat: improve closing processes on mux and pool
Browse files Browse the repository at this point in the history
  • Loading branch information
rueian committed Jan 19, 2022
1 parent a83d473 commit 2e19c27
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 52 deletions.
5 changes: 2 additions & 3 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,10 @@ func (c *clusterClient) _refresh() (err error) {
retry:
c.mu.RLock()
for addr, cc := range c.conns {
if reply, err = cc.Do(cmds.SlotCmd).ToMessage(); err != nil {
dead = append(dead, addr)
} else {
if reply, err = cc.Do(cmds.SlotCmd).ToMessage(); err == nil {
break
}
dead = append(dead, addr)
}
c.mu.RUnlock()

Expand Down
43 changes: 23 additions & 20 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var _ conn = (*mux)(nil)
type mux struct {
dst string
pool *pool
init wire
dead wire
wire atomic.Value
mu sync.Mutex
Expand All @@ -43,23 +44,23 @@ type mux struct {
}

func makeMux(dst string, option ClientOption, dialFn dialFn, retryOnRefuse bool) *mux {
return newMux(dst, option, (*pipe)(nil), func(onDisconnected func(err error)) (w wire, err error) {
return newMux(dst, option, (*pipe)(nil), dead, func(onDisconnected func(err error)) (w wire, err error) {
conn, err := dialFn(dst, option)
if err == nil {
w, err = newPipe(conn, option, onDisconnected)
} else if !retryOnRefuse {
if e, ok := err.(*net.OpError); ok && !e.Timeout() && !e.Temporary() {
if e, ok := err.(net.Error); ok && !e.Timeout() && !e.Temporary() {
return dead, nil
}
}
return w, err
})
}

func newMux(dst string, option ClientOption, dead wire, wireFn wireFn) *mux {
m := &mux{dst: dst, dead: dead, wireFn: wireFn}
m.wire.Store(dead)
m.pool = newPool(option.BlockingPoolSize, m._newPooledWire)
func newMux(dst string, option ClientOption, init, dead wire, wireFn wireFn) *mux {
m := &mux{dst: dst, init: init, dead: dead, wireFn: wireFn}
m.wire.Store(init)
m.pool = newPool(option.BlockingPoolSize, dead, m._newPooledWire)
return m
}

Expand All @@ -71,8 +72,16 @@ retry:
goto retry
}

func (m *mux) pipe() wire {
retry:
if wire, err := m._pipe(); err == nil {
return wire
}
goto retry
}

func (m *mux) _pipe() (w wire, err error) {
if w = m.wire.Load().(wire); w != m.dead {
if w = m.wire.Load().(wire); w != m.init {
return w, nil
}

Expand All @@ -89,7 +98,7 @@ func (m *mux) _pipe() (w wire, err error) {
return sc.w, sc.e
}

if w = m.wire.Load().(wire); w == m.dead {
if w = m.wire.Load().(wire); w == m.init {
if w, err = m.wireFn(m.disconnected); err == nil {
m.wire.Store(w)
}
Expand Down Expand Up @@ -117,14 +126,6 @@ func (m *mux) OnDisconnected(fn func(err error)) {
m.onDisconnected.CompareAndSwap(nil, fn)
}

func (m *mux) pipe() wire {
retry:
if wire, err := m._pipe(); err == nil {
return wire
}
goto retry
}

func (m *mux) Dial() error { // no retry
_, err := m._pipe()
return err
Expand Down Expand Up @@ -190,7 +191,7 @@ func (m *mux) blockingMulti(cmd []cmds.Completed) (resp []RedisResult) {
func (m *mux) pipeline(cmd cmds.Completed) (resp RedisResult) {
wire := m.pipe()
if resp = wire.Do(cmd); isNetworkErr(resp.NonRedisError()) {
m.wire.CompareAndSwap(wire, m.dead)
m.wire.CompareAndSwap(wire, m.init)
}
return resp
}
Expand All @@ -200,7 +201,7 @@ func (m *mux) pipelineMulti(cmd []cmds.Completed) (resp []RedisResult) {
resp = wire.DoMulti(cmd...)
for _, r := range resp {
if isNetworkErr(r.NonRedisError()) {
m.wire.CompareAndSwap(wire, m.dead)
m.wire.CompareAndSwap(wire, m.init)
return resp
}
}
Expand All @@ -212,7 +213,7 @@ retry:
wire := m.pipe()
resp := wire.DoCache(cmd, ttl)
if isNetworkErr(resp.NonRedisError()) {
m.wire.CompareAndSwap(wire, m.dead)
m.wire.CompareAndSwap(wire, m.init)
goto retry
}
return resp
Expand All @@ -227,7 +228,9 @@ func (m *mux) Store(w wire) {
}

func (m *mux) Close() {
m.pipe().Close()
if prev := m.wire.Swap(m.dead).(wire); prev != m.init {
prev.Close()
}
m.pool.Close()
}

Expand Down
14 changes: 10 additions & 4 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func setupMux(wires []*mockWire) (conn *mux, checkClean func(t *testing.T)) {
var mu sync.Mutex
var count = -1
return newMux("", ClientOption{}, (*mockWire)(nil), func(fn func(err error)) (wire, error) {
return newMux("", ClientOption{}, (*mockWire)(nil), (*mockWire)(nil), func(fn func(err error)) (wire, error) {
mu.Lock()
defer mu.Unlock()
count++
Expand Down Expand Up @@ -55,7 +55,7 @@ func TestNewMux(t *testing.T) {

func TestMuxOnDisconnected(t *testing.T) {
var trigger func(err error)
m := newMux("", ClientOption{}, (*mockWire)(nil), func(fn func(err error)) (wire, error) {
m := newMux("", ClientOption{}, (*mockWire)(nil), (*mockWire)(nil), func(fn func(err error)) (wire, error) {
trigger = fn
return &mockWire{}, nil
})
Expand Down Expand Up @@ -85,7 +85,7 @@ func TestMuxOnDisconnected(t *testing.T) {
func TestMuxDialSuppress(t *testing.T) {
var wires, waits, done int64
blocking := make(chan struct{})
m := newMux("", ClientOption{}, (*mockWire)(nil), func(fn func(err error)) (wire, error) {
m := newMux("", ClientOption{}, (*mockWire)(nil), (*mockWire)(nil), func(fn func(err error)) (wire, error) {
atomic.AddInt64(&wires, 1)
<-blocking
return &mockWire{}, nil
Expand Down Expand Up @@ -468,7 +468,7 @@ func TestMuxCMDRetry(t *testing.T) {
func TestMuxDialRetry(t *testing.T) {
setup := func() (*mux, *int64) {
var count int64
return newMux("", ClientOption{}, (*mockWire)(nil), func(fn func(err error)) (wire, error) {
return newMux("", ClientOption{}, (*mockWire)(nil), (*mockWire)(nil), func(fn func(err error)) (wire, error) {
if count == 1 {
return &mockWire{
DoFn: func(cmd cmds.Completed) RedisResult {
Expand Down Expand Up @@ -577,13 +577,19 @@ func (m *mockWire) Info() map[string]RedisMessage {
}

func (m *mockWire) Error() error {
if m == nil {
return ErrClosing
}
if m.ErrorFn != nil {
return m.ErrorFn()
}
return nil
}

func (m *mockWire) Close() {
if m == nil {
return
}
if m.CloseFn != nil {
m.CloseFn()
}
Expand Down
24 changes: 9 additions & 15 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package rueidis

import "sync"

func newPool(cap int, makeFn func() wire) *pool {
func newPool(cap int, dead wire, makeFn func() wire) *pool {
if cap <= 0 {
cap = DefaultPoolSize
}

return &pool{
size: 0,
dead: dead,
make: makeFn,
list: make([]wire, 0, cap),
cond: sync.NewCond(&sync.Mutex{}),
Expand All @@ -17,6 +18,7 @@ func newPool(cap int, makeFn func() wire) *pool {

type pool struct {
list []wire
dead wire
cond *sync.Cond
make func() wire
size int
Expand All @@ -28,34 +30,26 @@ func (p *pool) Acquire() (v wire) {
for len(p.list) == 0 && p.size == cap(p.list) {
p.cond.Wait()
}
if len(p.list) == 0 {
v = p.make()
if p.down {
v = p.dead
} else if len(p.list) == 0 {
p.size++
if p.down {
v.Close()
p.list = append(p.list, v)
}
v = p.make()
} else {
i := len(p.list) - 1
v = p.list[i]
if p.down {
v.Close()
} else {
p.list = p.list[:i]
}
p.list = p.list[:i]
}
p.cond.L.Unlock()
return v
}

func (p *pool) Store(v wire) {
p.cond.L.Lock()
if v.Error() == nil {
if !p.down && v.Error() == nil {
p.list = append(p.list, v)
} else {
p.size--
}
if p.down {
v.Close()
}
p.cond.L.Unlock()
Expand Down
20 changes: 10 additions & 10 deletions pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
func TestPool(t *testing.T) {
setup := func(size int) (*pool, *int32) {
var count int32
return newPool(size, func() wire {
return newPool(size, dead, func() wire {
atomic.AddInt32(&count, 1)
closed := false
return &mockWire{
Expand All @@ -27,7 +27,7 @@ func TestPool(t *testing.T) {
}

t.Run("DefaultPoolSize", func(t *testing.T) {
p := newPool(0, func() wire { return nil })
p := newPool(0, dead, func() wire { return nil })
if cap(p.list) == 0 {
t.Fatalf("DefaultPoolSize is not applied")
}
Expand Down Expand Up @@ -103,8 +103,8 @@ func TestPool(t *testing.T) {
t.Fatalf("pool does not close exsiting wire after Close()")
}
for i := 0; i < 100; i++ {
if rw := pool.Acquire(); rw != w1 {
t.Fatalf("pool does not return the same wire after Close()")
if rw := pool.Acquire(); rw != dead {
t.Fatalf("pool does not return the dead wire after Close()")
}
}
pool.Store(w2)
Expand All @@ -122,14 +122,14 @@ func TestPool(t *testing.T) {
pool.Close()
w2 := pool.Acquire()
if w2.Error() != ErrClosing {
t.Fatalf("pool does not close new wire after Close()")
t.Fatalf("pool does not close wire after Close()")
}
if atomic.LoadInt32(count) != 2 {
t.Fatalf("pool does not make new wire")
if atomic.LoadInt32(count) != 1 {
t.Fatalf("pool should not make new wire")
}
for i := 0; i < 100; i++ {
if rw := pool.Acquire(); rw != w2 {
t.Fatalf("pool does not return the same wire after Close()")
if rw := pool.Acquire(); rw != dead {
t.Fatalf("pool does not return the dead wire after Close()")
}
}
pool.Store(w1)
Expand All @@ -142,7 +142,7 @@ func TestPool(t *testing.T) {
func TestPoolError(t *testing.T) {
setup := func(size int) (*pool, *int32) {
var count int32
return newPool(size, func() wire {
return newPool(size, dead, func() wire {
w := &pipe{}
c := atomic.AddInt32(&count, 1)
if c%2 == 0 {
Expand Down

0 comments on commit 2e19c27

Please sign in to comment.