Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use correct address for CloseAddr #1140

Merged
merged 7 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 64 additions & 9 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,45 @@ type Client interface {
SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error)
}

// ClientExt is a client has extended interfaces.
type ClientExt interface {
// CloseAddrVer closes gRPC connections to the address with additional `ver` parameter.
// Each new connection will have an incremented `ver` value, and attempts to close a previous `ver` will be ignored.
// Passing `math.MaxUint64` as the `ver` parameter will forcefully close all connections to the address.
CloseAddrVer(addr string, ver uint64) error
}

// ErrConn wraps error with target address and version of the connection.
type ErrConn struct {
Err error
Addr string
Ver uint64
}

func (e *ErrConn) Error() string {
return fmt.Sprintf("[%s](%d) %s", e.Addr, e.Ver, e.Err.Error())
}

func (e *ErrConn) Unwrap() error {
return e.Err
}

func WrapErrConn(err error, conn *connArray) error {
if err == nil {
return nil
}
return &ErrConn{
Err: err,
Addr: conn.target,
Ver: conn.ver,
}
}

type connArray struct {
// The target host.
target string
// version of the connection array, increase by 1 when reconnect.
ver uint64

index uint32
v []*monitoredConn
Expand All @@ -125,9 +161,10 @@ type connArray struct {
monitor *connMonitor
}

func newConnArray(maxSize uint, addr string, security config.Security,
func newConnArray(maxSize uint, addr string, ver uint64, security config.Security,
idleNotify *uint32, enableBatch bool, dialTimeout time.Duration, m *connMonitor, opts []grpc.DialOption) (*connArray, error) {
a := &connArray{
ver: ver,
index: 0,
v: make([]*monitoredConn, maxSize),
streamTimeout: make(chan *tikvrpc.Lease, 1024),
Expand Down Expand Up @@ -390,6 +427,7 @@ type RPCClient struct {
sync.RWMutex

conns map[string]*connArray
vers map[string]uint64
option *option

idleNotify uint32
Expand All @@ -405,6 +443,7 @@ type RPCClient struct {
func NewRPCClient(opts ...Opt) *RPCClient {
cli := &RPCClient{
conns: make(map[string]*connArray),
vers: make(map[string]uint64),
option: &option{
dialTimeout: dialTimeout,
},
Expand Down Expand Up @@ -452,9 +491,11 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func(
for _, opt := range opts {
opt(&client)
}
ver := c.vers[addr] + 1
array, err = newConnArray(
client.GrpcConnectionCount,
addr,
ver,
c.option.security,
&c.idleNotify,
enableBatch,
Expand All @@ -466,6 +507,7 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func(
return nil, err
}
c.conns[addr] = array
c.vers[addr] = ver
}
return array, nil
}
Expand Down Expand Up @@ -603,6 +645,10 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
return nil, err
}

wrapErrConn := func(resp *tikvrpc.Response, err error) (*tikvrpc.Response, error) {
return resp, WrapErrConn(err, connArray)
}

start := time.Now()
staleRead := req.GetStaleRead()
defer func() {
Expand All @@ -625,7 +671,7 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
if config.GetGlobalConfig().TiKVClient.MaxBatchSize > 0 && enableBatch {
if batchReq := req.ToBatchCommandsRequest(); batchReq != nil {
defer trace.StartRegion(ctx, req.Type.String()).End()
return sendBatchRequest(ctx, addr, req.ForwardedHost, connArray.batchConn, batchReq, timeout)
return wrapErrConn(sendBatchRequest(ctx, addr, req.ForwardedHost, connArray.batchConn, batchReq, timeout))
}
}

Expand All @@ -639,7 +685,7 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
client := debugpb.NewDebugClient(clientConn)
ctx1, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return tikvrpc.CallDebugRPC(ctx1, client, req)
return wrapErrConn(tikvrpc.CallDebugRPC(ctx1, client, req))
}

client := tikvpb.NewTikvClient(clientConn)
Expand All @@ -650,16 +696,16 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
}
switch req.Type {
case tikvrpc.CmdBatchCop:
return c.getBatchCopStreamResponse(ctx, client, req, timeout, connArray)
return wrapErrConn(c.getBatchCopStreamResponse(ctx, client, req, timeout, connArray))
case tikvrpc.CmdCopStream:
return c.getCopStreamResponse(ctx, client, req, timeout, connArray)
return wrapErrConn(c.getCopStreamResponse(ctx, client, req, timeout, connArray))
case tikvrpc.CmdMPPConn:
return c.getMPPStreamResponse(ctx, client, req, timeout, connArray)
return wrapErrConn(c.getMPPStreamResponse(ctx, client, req, timeout, connArray))
}
// Or else it's a unary call.
ctx1, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return tikvrpc.CallRPC(ctx1, client, req)
return wrapErrConn(tikvrpc.CallRPC(ctx1, client, req))
}

// SendRequest sends a Request to server and receives Response.
Expand Down Expand Up @@ -793,11 +839,20 @@ func (c *RPCClient) Close() error {

// CloseAddr closes gRPC connections to the address.
func (c *RPCClient) CloseAddr(addr string) error {
return c.CloseAddrVer(addr, math.MaxUint64)
}

func (c *RPCClient) CloseAddrVer(addr string, ver uint64) error {
c.Lock()
conn, ok := c.conns[addr]
if ok {
delete(c.conns, addr)
logutil.BgLogger().Debug("close connection", zap.String("target", addr))
if conn.ver <= ver {
delete(c.conns, addr)
logutil.BgLogger().Debug("close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver))
} else {
logutil.BgLogger().Debug("ignore close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver))
conn = nil
}
}
c.Unlock()

Expand Down
6 changes: 4 additions & 2 deletions internal/client/client_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -823,16 +823,18 @@ func (c *RPCClient) recycleIdleConnArray() {
start := time.Now()

var addrs []string
var vers []uint64
c.RLock()
for _, conn := range c.conns {
if conn.batchConn != nil && conn.isIdle() {
addrs = append(addrs, conn.target)
vers = append(vers, conn.ver)
}
}
c.RUnlock()

for _, addr := range addrs {
c.CloseAddr(addr)
for i, addr := range addrs {
c.CloseAddrVer(addr, vers[i])
}

metrics.TiKVBatchClientRecycle.Observe(time.Since(start).Seconds())
Expand Down
41 changes: 39 additions & 2 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,18 @@ func TestConn(t *testing.T) {
assert.Nil(t, err)
assert.False(t, conn2.Get() == conn1.Get())

assert.Nil(t, client.CloseAddr(addr))
ver := conn2.ver
assert.Nil(t, client.CloseAddrVer(addr, ver-1))
_, ok := client.conns[addr]
assert.True(t, ok)
assert.Nil(t, client.CloseAddrVer(addr, ver))
_, ok = client.conns[addr]
assert.False(t, ok)

conn3, err := client.getConnArray(addr, true)
assert.Nil(t, err)
assert.NotNil(t, conn3)
assert.Equal(t, ver+1, conn3.ver)

client.Close()
conn4, err := client.getConnArray(addr, true)
Expand Down Expand Up @@ -135,7 +141,7 @@ func TestSendWhenReconnect(t *testing.T) {

req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
_, err = rpcClient.SendRequest(context.Background(), addr, req, 100*time.Second)
assert.True(t, err.Error() == "no available connections")
assert.EqualError(t, err, fmt.Sprintf("[%s](%d) no available connections", addr, 1))
server.Stop()
}

Expand Down Expand Up @@ -723,3 +729,34 @@ func TestBatchClientRecoverAfterServerRestart(t *testing.T) {
require.NoError(t, err)
}
}

func TestErrConn(t *testing.T) {
e := errors.New("conn error")
err1 := &ErrConn{Err: e, Addr: "127.0.0.1", Ver: 10}
err2 := &ErrConn{Err: e, Addr: "127.0.0.1", Ver: 10}

e3 := errors.New("conn error 3")
err3 := &ErrConn{Err: e3}

err4 := errors.New("not ErrConn")

assert.True(t, errors.Is(err1, err1))
assert.True(t, errors.Is(fmt.Errorf("%w", err1), err1))
assert.False(t, errors.Is(fmt.Errorf("%w", err2), err1)) // err2 != err1
assert.False(t, errors.Is(fmt.Errorf("%w", err4), err1))

var errConn *ErrConn
assert.True(t, errors.As(err1, &errConn))
assert.Equal(t, "127.0.0.1", errConn.Addr)
assert.EqualValues(t, 10, errConn.Ver)
assert.EqualError(t, errConn.Err, "conn error")

assert.True(t, errors.As(err3, &errConn))
assert.EqualError(t, e3, "conn error 3")

assert.False(t, errors.As(err4, &errConn))

errMsg := errors.New("unknown")
assert.True(t, errors.As(err1, &errMsg))
assert.EqualError(t, err1, errMsg.Error())
}
17 changes: 16 additions & 1 deletion internal/locate/region_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,14 @@ func (s *RegionRequestSender) GetClient() client.Client {
return s.client
}

// getClientExt returns the client with ClientExt interface.
// Return nil if the client does not implement ClientExt.
// Don't use in critical path.
func (s *RegionRequestSender) getClientExt() client.ClientExt {
ext, _ := s.client.(client.ClientExt)
return ext
}

// SetStoreAddr specifies the dest store address.
func (s *RegionRequestSender) SetStoreAddr(addr string) {
s.storeAddr = addr
Expand Down Expand Up @@ -1837,7 +1845,14 @@ func (s *RegionRequestSender) onSendFail(bo *retry.Backoffer, ctx *RPCContext, r
// Canceled by gRPC remote may happen when tikv is killed and exiting.
// Close the connection, backoff, and retry.
logutil.Logger(bo.GetCtx()).Warn("receive a grpc cancel signal", zap.Error(err))
s.client.CloseAddr(ctx.Addr)
var errConn *client.ErrConn
if errors.As(err, &errConn) {
if ext := s.getClientExt(); ext != nil {
ext.CloseAddrVer(errConn.Addr, errConn.Ver)
} else {
s.client.CloseAddr(errConn.Addr)
}
}
}
}

Expand Down
26 changes: 26 additions & 0 deletions internal/locate/region_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ package locate
import (
"context"
"fmt"
"math"
"math/rand"
"net"
"sync"
Expand Down Expand Up @@ -99,14 +100,20 @@ func (s *testRegionRequestToSingleStoreSuite) TearDownTest() {
type fnClient struct {
fn func(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error)
closedAddr string
closedVer uint64
}

func (f *fnClient) Close() error {
return nil
}

func (f *fnClient) CloseAddr(addr string) error {
return f.CloseAddrVer(addr, math.MaxUint64)
}

func (f *fnClient) CloseAddrVer(addr string, ver uint64) error {
f.closedAddr = addr
f.closedVer = ver
return nil
}

Expand Down Expand Up @@ -664,6 +671,8 @@ func (s *testRegionRequestToSingleStoreSuite) TestCloseConnectionOnStoreNotMatch
regionErr, _ := resp.GetRegionError()
s.NotNil(regionErr)
s.Equal(target, client.closedAddr)
var expected uint64 = math.MaxUint64
s.Equal(expected, client.closedVer)
}

func (s *testRegionRequestToSingleStoreSuite) TestStaleReadRetry() {
Expand Down Expand Up @@ -824,3 +833,20 @@ func (s *testRegionRequestToSingleStoreSuite) TestCountReplicaNumber() {
s.Equal(4, s.regionRequestSender.countReplicaNumber(peers)) // Only count 1 tiflash replica for tiflash write-nodes.
}
}

type emptyClient struct {
client.Client
}

func (s *testRegionRequestToSingleStoreSuite) TestClientExt() {
var cli client.Client = client.NewRPCClient()
sender := NewRegionRequestSender(s.cache, cli)
s.NotNil(sender.client)
s.NotNil(sender.getClientExt())
cli.Close()

cli = &emptyClient{}
sender = NewRegionRequestSender(s.cache, cli)
s.NotNil(sender.client)
s.Nil(sender.getClientExt())
}
Loading