Skip to content

Commit

Permalink
refactor: improve readability
Browse files Browse the repository at this point in the history
  • Loading branch information
rueian committed Dec 26, 2021
1 parent 7a985cb commit 742d058
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 57 deletions.
64 changes: 29 additions & 35 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

type connFn func(dst string, opt ConnOption) conn
type dialFn func(dst string, opt ConnOption) (net.Conn, error)
type wireFn func(conn net.Conn, opt ConnOption, onDisconnected func(err error)) (wire, error)
type wireFn func(onDisconnected func(err error)) (wire, error)

type singleconnect struct {
w wire
Expand All @@ -32,33 +32,43 @@ var _ conn = (*mux)(nil)

type mux struct {
dst string
opt ConnOption
pool *pool
dead wire
wire atomic.Value
mu sync.Mutex
sc *singleconnect

dialFn dialFn
wireFn wireFn

onDisconnected atomic.Value
}

func makeMux(dst string, option ConnOption, dialFn dialFn) *mux {
return newMux(dst, option, (*pipe)(nil), dialFn, func(conn net.Conn, opt ConnOption, onDisconnected func(err error)) (wire, error) {
return newPipe(conn, opt, onDisconnected)
return newMux(dst, option, (*pipe)(nil), func(onDisconnected func(err error)) (w wire, err error) {
conn, err := dialFn(dst, option)
if err == nil {
w, err = newPipe(conn, option, onDisconnected)
}
return w, err
})
}

func newMux(dst string, option ConnOption, dead wire, dialFn dialFn, wireFn wireFn) *mux {
m := &mux{dst: dst, opt: option, dead: dead, dialFn: dialFn, wireFn: wireFn}
func newMux(dst string, option ConnOption, dead wire, wireFn wireFn) *mux {
m := &mux{dst: dst, dead: dead, wireFn: wireFn}
m.wire.Store(dead)
m.pool = newPool(option.BlockingPoolSize, m.dialRetry)
m.pool = newPool(option.BlockingPoolSize, m._newPooledWire)
return m
}

func (m *mux) connect() (w wire, err error) {
func (m *mux) _newPooledWire() wire {
retry:
if wire, err := m.wireFn(nil); err == nil {
return wire
}
goto retry
}

func (m *mux) _pipe() (w wire, err error) {
if w = m.wire.Load().(wire); w != m.dead {
return w, nil
}
Expand All @@ -77,7 +87,7 @@ func (m *mux) connect() (w wire, err error) {
}

if w = m.wire.Load().(wire); w == m.dead {
if w, err = m.dial(); err == nil {
if w, err = m.wireFn(m.disconnected); err == nil {
m.wire.Store(w)
}
}
Expand All @@ -94,14 +104,6 @@ func (m *mux) connect() (w wire, err error) {
return w, err
}

func (m *mux) dial() (w wire, err error) {
conn, err := m.dialFn(m.dst, m.opt)
if err == nil {
w, err = m.wireFn(conn, m.opt, m.disconnected)
}
return w, err
}

func (m *mux) disconnected(err error) {
if fn := m.onDisconnected.Load(); fn != nil {
fn.(func(err error))(err)
Expand All @@ -112,33 +114,25 @@ func (m *mux) OnDisconnected(fn func(err error)) {
m.onDisconnected.CompareAndSwap(nil, fn)
}

func (m *mux) dialRetry() wire {
retry:
if wire, err := m.dial(); err == nil {
return wire
}
goto retry
}

func (m *mux) acquire() wire {
func (m *mux) pipe() wire {
retry:
if wire, err := m.connect(); err == nil {
if wire, err := m._pipe(); err == nil {
return wire
}
goto retry
}

func (m *mux) Dial() error { // no retry
_, err := m.connect()
_, err := m._pipe()
return err
}

func (m *mux) Info() map[string]proto.Message {
return m.acquire().Info()
return m.pipe().Info()
}

func (m *mux) Error() error {
return m.acquire().Error()
return m.pipe().Error()
}

func (m *mux) Do(cmd cmds.Completed) (resp proto.Result) {
Expand Down Expand Up @@ -191,15 +185,15 @@ func (m *mux) blockingMulti(cmd []cmds.Completed) (resp []proto.Result) {
}

func (m *mux) pipeline(cmd cmds.Completed) (resp proto.Result) {
wire := m.acquire()
wire := m.pipe()
if resp = wire.Do(cmd); isNetworkErr(resp.NonRedisError()) {
m.wire.CompareAndSwap(wire, m.dead)
}
return resp
}

func (m *mux) pipelineMulti(cmd []cmds.Completed) (resp []proto.Result) {
wire := m.acquire()
wire := m.pipe()
resp = wire.DoMulti(cmd...)
for _, r := range resp {
if isNetworkErr(r.NonRedisError()) {
Expand All @@ -212,7 +206,7 @@ func (m *mux) pipelineMulti(cmd []cmds.Completed) (resp []proto.Result) {

func (m *mux) DoCache(cmd cmds.Cacheable, ttl time.Duration) proto.Result {
retry:
wire := m.acquire()
wire := m.pipe()
resp := wire.DoCache(cmd, ttl)
if isNetworkErr(resp.NonRedisError()) {
m.wire.CompareAndSwap(wire, m.dead)
Expand All @@ -230,7 +224,7 @@ func (m *mux) Store(w wire) {
}

func (m *mux) Close() {
m.acquire().Close()
m.pipe().Close()
m.pool.Close()
}

Expand Down
32 changes: 10 additions & 22 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ import (
func setupMux(wires []*mock.Wire) (conn *mux, checkClean func(t *testing.T)) {
var mu sync.Mutex
var count = -1
return newMux("", ConnOption{}, (*mock.Wire)(nil), func(dst string, opt ConnOption) (net.Conn, error) {
return nil, nil
}, func(conn net.Conn, opt ConnOption, fn func(err error)) (wire, error) {
return newMux("", ConnOption{}, (*mock.Wire)(nil), func(fn func(err error)) (wire, error) {
mu.Lock()
defer mu.Unlock()
count++
Expand Down Expand Up @@ -59,9 +57,7 @@ func TestNewMux(t *testing.T) {

func TestMuxOnDisconnected(t *testing.T) {
var trigger func(err error)
m := newMux("", ConnOption{}, (*mock.Wire)(nil), func(dst string, opt ConnOption) (net.Conn, error) {
return nil, nil
}, func(conn net.Conn, opt ConnOption, fn func(err error)) (wire, error) {
m := newMux("", ConnOption{}, (*mock.Wire)(nil), func(fn func(err error)) (wire, error) {
trigger = fn
return &mock.Wire{}, nil
})
Expand Down Expand Up @@ -89,12 +85,9 @@ func TestMuxOnDisconnected(t *testing.T) {
}

func TestMuxDialSuppress(t *testing.T) {
var dials, wires, waits, done int64
var wires, waits, done int64
blocking := make(chan struct{})
m := newMux("", ConnOption{}, (*mock.Wire)(nil), func(dst string, opt ConnOption) (net.Conn, error) {
atomic.AddInt64(&dials, 1)
return nil, nil
}, func(conn net.Conn, opt ConnOption, fn func(err error)) (wire, error) {
m := newMux("", ConnOption{}, (*mock.Wire)(nil), func(fn func(err error)) (wire, error) {
atomic.AddInt64(&wires, 1)
<-blocking
return &mock.Wire{}, nil
Expand All @@ -113,9 +106,6 @@ func TestMuxDialSuppress(t *testing.T) {
for atomic.LoadInt64(&done) != 1000 {
runtime.Gosched()
}
if atomic.LoadInt64(&dials) != 1 {
t.Fatalf("dailFn is not suppressed")
}
if atomic.LoadInt64(&wires) != 1 {
t.Fatalf("wireFn is not suppressed")
}
Expand Down Expand Up @@ -480,18 +470,16 @@ func TestMuxCMDRetry(t *testing.T) {
func TestMuxDialRetry(t *testing.T) {
setup := func() (*mux, *int64) {
var count int64
return newMux("", ConnOption{}, (*mock.Wire)(nil), func(dst string, opt ConnOption) (net.Conn, error) {
return newMux("", ConnOption{}, (*mock.Wire)(nil), func(fn func(err error)) (wire, error) {
if count == 1 {
return nil, nil
return &mock.Wire{
DoFn: func(cmd cmds.Completed) proto.Result {
return proto.NewResult(proto.Message{Type: '+', String: "PONG"}, nil)
},
}, nil
}
count++
return nil, errors.New("network err")
}, func(conn net.Conn, opt ConnOption, fn func(err error)) (wire, error) {
return &mock.Wire{
DoFn: func(cmd cmds.Completed) proto.Result {
return proto.NewResult(proto.Message{Type: '+', String: "PONG"}, nil)
},
}, nil
}), &count
}
t.Run("retry on auto pipeline", func(t *testing.T) {
Expand Down

0 comments on commit 742d058

Please sign in to comment.