From 742d058cb3a546f9a45597aa1e6a93aac04dc572 Mon Sep 17 00:00:00 2001 From: Rueian Date: Sun, 26 Dec 2021 21:30:53 +0800 Subject: [PATCH] refactor: improve readability --- mux.go | 64 ++++++++++++++++++++++++----------------------------- mux_test.go | 32 +++++++++------------------ 2 files changed, 39 insertions(+), 57 deletions(-) diff --git a/mux.go b/mux.go index c556a08e..bc6d1707 100644 --- a/mux.go +++ b/mux.go @@ -12,7 +12,7 @@ import ( type connFn func(dst string, opt ConnOption) conn type dialFn func(dst string, opt ConnOption) (net.Conn, error) -type wireFn func(conn net.Conn, opt ConnOption, onDisconnected func(err error)) (wire, error) +type wireFn func(onDisconnected func(err error)) (wire, error) type singleconnect struct { w wire @@ -32,33 +32,43 @@ var _ conn = (*mux)(nil) type mux struct { dst string - opt ConnOption pool *pool dead wire wire atomic.Value mu sync.Mutex sc *singleconnect - dialFn dialFn wireFn wireFn onDisconnected atomic.Value } func makeMux(dst string, option ConnOption, dialFn dialFn) *mux { - return newMux(dst, option, (*pipe)(nil), dialFn, func(conn net.Conn, opt ConnOption, onDisconnected func(err error)) (wire, error) { - return newPipe(conn, opt, onDisconnected) + return newMux(dst, option, (*pipe)(nil), func(onDisconnected func(err error)) (w wire, err error) { + conn, err := dialFn(dst, option) + if err == nil { + w, err = newPipe(conn, option, onDisconnected) + } + return w, err }) } -func newMux(dst string, option ConnOption, dead wire, dialFn dialFn, wireFn wireFn) *mux { - m := &mux{dst: dst, opt: option, dead: dead, dialFn: dialFn, wireFn: wireFn} +func newMux(dst string, option ConnOption, dead wire, wireFn wireFn) *mux { + m := &mux{dst: dst, dead: dead, wireFn: wireFn} m.wire.Store(dead) - m.pool = newPool(option.BlockingPoolSize, m.dialRetry) + m.pool = newPool(option.BlockingPoolSize, m._newPooledWire) return m } -func (m *mux) connect() (w wire, err error) { +func (m *mux) _newPooledWire() wire { +retry: + if wire, err := m.wireFn(nil); err == nil { + return wire + } + goto retry +} + +func (m *mux) _pipe() (w wire, err error) { if w = m.wire.Load().(wire); w != m.dead { return w, nil } @@ -77,7 +87,7 @@ func (m *mux) connect() (w wire, err error) { } if w = m.wire.Load().(wire); w == m.dead { - if w, err = m.dial(); err == nil { + if w, err = m.wireFn(m.disconnected); err == nil { m.wire.Store(w) } } @@ -94,14 +104,6 @@ func (m *mux) connect() (w wire, err error) { return w, err } -func (m *mux) dial() (w wire, err error) { - conn, err := m.dialFn(m.dst, m.opt) - if err == nil { - w, err = m.wireFn(conn, m.opt, m.disconnected) - } - return w, err -} - func (m *mux) disconnected(err error) { if fn := m.onDisconnected.Load(); fn != nil { fn.(func(err error))(err) @@ -112,33 +114,25 @@ func (m *mux) OnDisconnected(fn func(err error)) { m.onDisconnected.CompareAndSwap(nil, fn) } -func (m *mux) dialRetry() wire { -retry: - if wire, err := m.dial(); err == nil { - return wire - } - goto retry -} - -func (m *mux) acquire() wire { +func (m *mux) pipe() wire { retry: - if wire, err := m.connect(); err == nil { + if wire, err := m._pipe(); err == nil { return wire } goto retry } func (m *mux) Dial() error { // no retry - _, err := m.connect() + _, err := m._pipe() return err } func (m *mux) Info() map[string]proto.Message { - return m.acquire().Info() + return m.pipe().Info() } func (m *mux) Error() error { - return m.acquire().Error() + return m.pipe().Error() } func (m *mux) Do(cmd cmds.Completed) (resp proto.Result) { @@ -191,7 +185,7 @@ func (m *mux) blockingMulti(cmd []cmds.Completed) (resp []proto.Result) { } func (m *mux) pipeline(cmd cmds.Completed) (resp proto.Result) { - wire := m.acquire() + wire := m.pipe() if resp = wire.Do(cmd); isNetworkErr(resp.NonRedisError()) { m.wire.CompareAndSwap(wire, m.dead) } @@ -199,7 +193,7 @@ func (m *mux) pipeline(cmd cmds.Completed) (resp proto.Result) { } func (m *mux) pipelineMulti(cmd []cmds.Completed) (resp []proto.Result) { - wire := m.acquire() + wire := m.pipe() resp = wire.DoMulti(cmd...) for _, r := range resp { if isNetworkErr(r.NonRedisError()) { @@ -212,7 +206,7 @@ func (m *mux) pipelineMulti(cmd []cmds.Completed) (resp []proto.Result) { func (m *mux) DoCache(cmd cmds.Cacheable, ttl time.Duration) proto.Result { retry: - wire := m.acquire() + wire := m.pipe() resp := wire.DoCache(cmd, ttl) if isNetworkErr(resp.NonRedisError()) { m.wire.CompareAndSwap(wire, m.dead) @@ -230,7 +224,7 @@ func (m *mux) Store(w wire) { } func (m *mux) Close() { - m.acquire().Close() + m.pipe().Close() m.pool.Close() } diff --git a/mux_test.go b/mux_test.go index f9af3d2a..b4eb547c 100644 --- a/mux_test.go +++ b/mux_test.go @@ -18,9 +18,7 @@ import ( func setupMux(wires []*mock.Wire) (conn *mux, checkClean func(t *testing.T)) { var mu sync.Mutex var count = -1 - return newMux("", ConnOption{}, (*mock.Wire)(nil), func(dst string, opt ConnOption) (net.Conn, error) { - return nil, nil - }, func(conn net.Conn, opt ConnOption, fn func(err error)) (wire, error) { + return newMux("", ConnOption{}, (*mock.Wire)(nil), func(fn func(err error)) (wire, error) { mu.Lock() defer mu.Unlock() count++ @@ -59,9 +57,7 @@ func TestNewMux(t *testing.T) { func TestMuxOnDisconnected(t *testing.T) { var trigger func(err error) - m := newMux("", ConnOption{}, (*mock.Wire)(nil), func(dst string, opt ConnOption) (net.Conn, error) { - return nil, nil - }, func(conn net.Conn, opt ConnOption, fn func(err error)) (wire, error) { + m := newMux("", ConnOption{}, (*mock.Wire)(nil), func(fn func(err error)) (wire, error) { trigger = fn return &mock.Wire{}, nil }) @@ -89,12 +85,9 @@ func TestMuxOnDisconnected(t *testing.T) { } func TestMuxDialSuppress(t *testing.T) { - var dials, wires, waits, done int64 + var wires, waits, done int64 blocking := make(chan struct{}) - m := newMux("", ConnOption{}, (*mock.Wire)(nil), func(dst string, opt ConnOption) (net.Conn, error) { - atomic.AddInt64(&dials, 1) - return nil, nil - }, func(conn net.Conn, opt ConnOption, fn func(err error)) (wire, error) { + m := newMux("", ConnOption{}, (*mock.Wire)(nil), func(fn func(err error)) (wire, error) { atomic.AddInt64(&wires, 1) <-blocking return &mock.Wire{}, nil @@ -113,9 +106,6 @@ func TestMuxDialSuppress(t *testing.T) { for atomic.LoadInt64(&done) != 1000 { runtime.Gosched() } - if atomic.LoadInt64(&dials) != 1 { - t.Fatalf("dailFn is not suppressed") - } if atomic.LoadInt64(&wires) != 1 { t.Fatalf("wireFn is not suppressed") } @@ -480,18 +470,16 @@ func TestMuxCMDRetry(t *testing.T) { func TestMuxDialRetry(t *testing.T) { setup := func() (*mux, *int64) { var count int64 - return newMux("", ConnOption{}, (*mock.Wire)(nil), func(dst string, opt ConnOption) (net.Conn, error) { + return newMux("", ConnOption{}, (*mock.Wire)(nil), func(fn func(err error)) (wire, error) { if count == 1 { - return nil, nil + return &mock.Wire{ + DoFn: func(cmd cmds.Completed) proto.Result { + return proto.NewResult(proto.Message{Type: '+', String: "PONG"}, nil) + }, + }, nil } count++ return nil, errors.New("network err") - }, func(conn net.Conn, opt ConnOption, fn func(err error)) (wire, error) { - return &mock.Wire{ - DoFn: func(cmd cmds.Completed) proto.Result { - return proto.NewResult(proto.Message{Type: '+', String: "PONG"}, nil) - }, - }, nil }), &count } t.Run("retry on auto pipeline", func(t *testing.T) {