From a8483e9e625cea1bc5057c9348ba26d04d4f0fb1 Mon Sep 17 00:00:00 2001 From: David <8039876+AmoebaProtozoa@users.noreply.github.com> Date: Tue, 27 Dec 2022 10:08:14 +0800 Subject: [PATCH 1/8] reformat codec and add keyspace support Signed-off-by: David <8039876+AmoebaProtozoa@users.noreply.github.com> --- config/config.go | 9 +- config/config_test.go | 8 +- go.mod | 2 +- integration_tests/2pc_test.go | 2 +- integration_tests/go.sum | 2 - integration_tests/lock_test.go | 2 +- integration_tests/raw/api_test.go | 6 +- integration_tests/split_test.go | 2 +- integration_tests/util_test.go | 64 +- internal/apicodec/codec.go | 59 ++ internal/apicodec/codec_v1.go | 200 +++++ internal/apicodec/codec_v2.go | 941 ++++++++++++++++++++++++ internal/apicodec/codec_v2_test.go | 268 +++++++ internal/apicodec/mem_codec.go | 57 ++ internal/client/api_version.go | 226 ------ internal/client/api_version_test.go | 68 -- internal/client/client.go | 17 +- internal/locate/pd_codec.go | 133 ++-- internal/locate/pd_codec_v2.go | 110 --- internal/locate/region_cache.go | 32 +- internal/locate/region_cache_test.go | 5 +- internal/locate/region_request.go | 4 +- internal/locate/region_request3_test.go | 3 +- internal/locate/region_request_test.go | 3 +- rawkv/rawkv.go | 38 +- tikv/client.go | 6 + tikv/kv.go | 8 +- tikv/region.go | 32 +- tikv/test_util.go | 32 +- txnkv/client.go | 60 +- txnkv/txnsnapshot/scan.go | 3 +- 31 files changed, 1857 insertions(+), 545 deletions(-) create mode 100644 internal/apicodec/codec.go create mode 100644 internal/apicodec/codec_v1.go create mode 100644 internal/apicodec/codec_v2.go create mode 100644 internal/apicodec/codec_v2_test.go create mode 100644 internal/apicodec/mem_codec.go delete mode 100644 internal/client/api_version.go delete mode 100644 internal/client/api_version_test.go delete mode 100644 internal/locate/pd_codec_v2.go diff --git a/config/config.go b/config/config.go index 2435288ea8..b4f0d6d31e 100644 --- a/config/config.go +++ b/config/config.go @@ -179,8 +179,8 @@ func GetTxnScopeFromConfig() string { } // ParsePath parses this path. -// Path example: tikv://etcd-node1:port,etcd-node2:port?cluster=1&disableGC=false -func ParsePath(path string) (etcdAddrs []string, disableGC bool, err error) { +// Path example: tikv://etcd-node1:port,etcd-node2:port?cluster=1&disableGC=false&keyspaceName=SomeKeyspace +func ParsePath(path string) (etcdAddrs []string, disableGC bool, keyspaceName string, err error) { var u *url.URL u, err = url.Parse(path) if err != nil { @@ -192,7 +192,10 @@ func ParsePath(path string) (etcdAddrs []string, disableGC bool, err error) { logutil.BgLogger().Error("parsePath error", zap.Error(err)) return } - switch strings.ToLower(u.Query().Get("disableGC")) { + + query := u.Query() + keyspaceName = query.Get("keyspaceName") + switch strings.ToLower(query.Get("disableGC")) { case "true": disableGC = true case "false", "": diff --git a/config/config_test.go b/config/config_test.go index 259292bebe..3058a29e59 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -42,18 +42,20 @@ import ( ) func TestParsePath(t *testing.T) { - etcdAddrs, disableGC, err := ParsePath("tikv://node1:2379,node2:2379") + etcdAddrs, disableGC, keyspaceName, err := ParsePath("tikv://node1:2379,node2:2379") assert.Nil(t, err) assert.Equal(t, []string{"node1:2379", "node2:2379"}, etcdAddrs) assert.False(t, disableGC) + assert.Empty(t, keyspaceName) - _, _, err = ParsePath("tikv://node1:2379") + _, _, _, err = ParsePath("tikv://node1:2379") assert.Nil(t, err) - _, disableGC, err = ParsePath("tikv://node1:2379?disableGC=true") + _, disableGC, keyspaceName, err = ParsePath("tikv://node1:2379?disableGC=true&keyspaceName=DEFAULT") assert.Nil(t, err) assert.True(t, disableGC) + assert.Equal(t, "DEFAULT", keyspaceName) } func TestTxnScopeValue(t *testing.T) { diff --git a/go.mod b/go.mod index 89517c4ace..b62611645c 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/google/uuid v1.1.2 github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 github.com/opentracing/opentracing-go v1.2.0 + github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 github.com/pingcap/kvproto v0.0.0-20221129023506-621ec37aac7a @@ -43,7 +44,6 @@ require ( github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/gomega v1.18.1 // indirect - github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/common v0.26.0 // indirect github.com/prometheus/procfs v0.6.0 // indirect diff --git a/integration_tests/2pc_test.go b/integration_tests/2pc_test.go index 178dcefd04..607ba23c04 100644 --- a/integration_tests/2pc_test.go +++ b/integration_tests/2pc_test.go @@ -95,7 +95,7 @@ func (s *testCommitterSuite) SetupTest() { s.Require().Nil(err) testutils.BootstrapWithMultiRegions(cluster, []byte("a"), []byte("b"), []byte("c")) s.cluster = cluster - pdCli := &tikv.CodecPDClient{Client: pdClient} + pdCli := tikv.NewCodecPDClient(tikv.ModeTxn, pdClient) spkv := tikv.NewMockSafePointKV() store, err := tikv.NewKVStore("mocktikv-store", pdCli, spkv, client) store.EnableTxnLocalLatches(8096) diff --git a/integration_tests/go.sum b/integration_tests/go.sum index c7121451fb..664504ee96 100644 --- a/integration_tests/go.sum +++ b/integration_tests/go.sum @@ -505,8 +505,6 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 h1:mbAskLJ0oJfDRtkanvQPiooDH8HvJ2FBh+iKT/OmiQQ= -github.com/tiancaiamao/gp v0.0.0-20221214071713-abacb15f16f1 h1:iffZXeHZTd35tTOS3nJ2OyMUmn40eNkLHCeQXMs6KYI= -github.com/tiancaiamao/gp v0.0.0-20221214071713-abacb15f16f1/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= github.com/tiancaiamao/gp v0.0.0-20221221095600-1a473d1f9b4b h1:4RNtqw1/tW67qP9fFgfQpTVd7DrfkaAWu4vsC18QmBo= github.com/tiancaiamao/gp v0.0.0-20221221095600-1a473d1f9b4b/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= github.com/tidwall/gjson v1.14.1 h1:iymTbGkQBhveq21bEvAQ81I0LEBork8BFe1CUZXdyuo= diff --git a/integration_tests/lock_test.go b/integration_tests/lock_test.go index e79f47585a..7dc5a349c2 100644 --- a/integration_tests/lock_test.go +++ b/integration_tests/lock_test.go @@ -515,7 +515,7 @@ func (s *testLockSuite) TestBatchResolveLocks() { s.Nil(err) committer.SetUseAsyncCommit() committer.SetLockTTL(20000) - committer.PrewriteAllMutations(context.Background()) + err = committer.PrewriteAllMutations(context.Background()) s.Nil(err) var locks []*txnkv.Lock diff --git a/integration_tests/raw/api_test.go b/integration_tests/raw/api_test.go index 415efff441..bf6ccaa79f 100644 --- a/integration_tests/raw/api_test.go +++ b/integration_tests/raw/api_test.go @@ -75,9 +75,11 @@ func (s *apiTestSuite) newRawKVClient(pdCli pd.Client, addrs []string) *rawkv.Cl } func (s *apiTestSuite) wrapPDClient(pdCli pd.Client, addrs []string) pd.Client { - if s.apiVersion == kvrpcpb.APIVersion_V2 { - return tikv.NewCodecPDClientV2(pdCli, tikv.ModeRaw) + var err error + if s.getApiVersion(pdCli) == kvrpcpb.APIVersion_V2 { + pdCli, err = tikv.NewCodecPDClientWithKeyspace(tikv.ModeRaw, pdCli, tikv.DefaultKeyspaceName) } + s.Nil(err) return pdCli } diff --git a/integration_tests/split_test.go b/integration_tests/split_test.go index 7e2170291c..aa125b8c59 100644 --- a/integration_tests/split_test.go +++ b/integration_tests/split_test.go @@ -36,10 +36,10 @@ package tikv_test import ( "context" - "github.com/pingcap/kvproto/pkg/keyspacepb" "sync" "testing" + "github.com/pingcap/kvproto/pkg/keyspacepb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pkg/errors" diff --git a/integration_tests/util_test.go b/integration_tests/util_test.go index 0a97d37c91..3825622353 100644 --- a/integration_tests/util_test.go +++ b/integration_tests/util_test.go @@ -38,14 +38,18 @@ import ( "context" "flag" "fmt" + "io/ioutil" + "net/http" "strings" "testing" "unsafe" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/kv" txndriver "github.com/pingcap/tidb/store/driver/txn" "github.com/pingcap/tidb/store/mockstore/unistore" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" "github.com/tikv/client-go/v2/config" "github.com/tikv/client-go/v2/testutils" "github.com/tikv/client-go/v2/tikv" @@ -95,21 +99,69 @@ func NewTestUniStore(t *testing.T) *tikv.KVStore { } func newTiKVStore(t *testing.T) *tikv.KVStore { + re := require.New(t) addrs := strings.Split(*pdAddrs, ",") pdClient, err := pd.NewClient(addrs, pd.SecurityOption{}) - require.Nil(t, err) + re.Nil(err) + var opt tikv.ClientOpt + switch mustGetApiVersion(re, pdClient) { + case kvrpcpb.APIVersion_V1: + pdClient = tikv.NewCodecPDClient(tikv.ModeTxn, pdClient) + opt = tikv.WithCodec(tikv.NewCodecV1(tikv.ModeTxn)) + case kvrpcpb.APIVersion_V2: + codecCli, err := tikv.NewCodecPDClientWithKeyspace(tikv.ModeTxn, pdClient, tikv.DefaultKeyspaceName) + pdClient = codecCli + re.Nil(err) + opt = tikv.WithCodec(codecCli.GetCodec()) + default: + re.Fail("unknown api version") + } var securityConfig config.Security tlsConfig, err := securityConfig.ToTLSConfig() - require.Nil(t, err) + re.Nil(err) spKV, err := tikv.NewEtcdSafePointKV(addrs, tlsConfig) - require.Nil(t, err) - store, err := tikv.NewKVStore("test-store", &tikv.CodecPDClient{Client: pdClient}, spKV, tikv.NewRPCClient()) - require.Nil(t, err) + re.Nil(err) + store, err := tikv.NewKVStore( + "test-store", + pdClient, + spKV, + tikv.NewRPCClient(opt), + ) + re.Nil(err) err = clearStorage(store) - require.Nil(t, err) + re.Nil(err) return store } +func mustGetApiVersion(re *require.Assertions, pdCli pd.Client) kvrpcpb.APIVersion { + stores, err := pdCli.GetAllStores(context.Background()) + re.NoError(err) + + for _, store := range stores { + resp := mustGetConfig(re, fmt.Sprintf("http://%s/config", store.StatusAddress)) + v := gjson.Get(resp, "storage.api-version") + if v.Type == gjson.Null || v.Uint() != 2 { + return kvrpcpb.APIVersion_V1 + } + } + return kvrpcpb.APIVersion_V2 +} + +func mustGetConfig(re *require.Assertions, url string) string { + transport := &http.Transport{} + client := http.Client{ + Transport: transport, + } + defer transport.CloseIdleConnections() + resp, err := client.Get(url) + re.NoError(err) + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + re.NoError(err) + return string(body) +} + func clearStorage(store *tikv.KVStore) error { txn, err := store.Begin() if err != nil { diff --git a/internal/apicodec/codec.go b/internal/apicodec/codec.go new file mode 100644 index 0000000000..78c1fc35b8 --- /dev/null +++ b/internal/apicodec/codec.go @@ -0,0 +1,59 @@ +package apicodec + +import ( + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/tikv/client-go/v2/tikvrpc" +) + +// Mode represents the operation mode of a request. +type Mode int + +const ( + // ModeRaw represent a raw operation in TiKV + ModeRaw = iota + // ModeTxn represent a transaction operation in TiKV + ModeTxn +) + +// Codec is responsible for encode/decode requests. +type Codec interface { + // GetAPIVersion returns the api version of the codec. + GetAPIVersion() kvrpcpb.APIVersion + // GetKeyspace return the keyspace id of the codec. + GetKeyspace() []byte + // EncodeRequest encodes with the given Codec. + // NOTE: req is reused on retry. MUST encode on cloned request, other than overwrite the original. + EncodeRequest(req *tikvrpc.Request) (*tikvrpc.Request, error) + // DecodeResponse decode the resp with the given codec. + DecodeResponse(req *tikvrpc.Request, resp *tikvrpc.Response) (*tikvrpc.Response, error) + // EncodeRegionKey encode region's key. + EncodeRegionKey(key []byte) []byte + // DecodeRegionKey decode region's key + DecodeRegionKey(encodedKey []byte) ([]byte, error) + // EncodeRegionRange encode region's start and end. + EncodeRegionRange(start, end []byte) ([]byte, []byte) + // DecodeRegionRange decode region's start and end. + DecodeRegionRange(encodedStart, encodedEnd []byte) ([]byte, []byte, error) + // EncodeRange encode a key range. + EncodeRange(start, end []byte) ([]byte, []byte) + // DecodeRange decode a key range. + DecodeRange(encodedStart, encodedEnd []byte) ([]byte, []byte, error) + // EncodeKey encode a key. + EncodeKey(key []byte) []byte + // DecodeKey decode a key. + DecodeKey(encoded []byte) ([]byte, error) +} + +func DecodeKey(encoded []byte, version kvrpcpb.APIVersion) ([]byte, []byte, error) { + switch version { + case kvrpcpb.APIVersion_V1: + return nil, encoded, nil + case kvrpcpb.APIVersion_V2: + if len(encoded) < keyspacePrefixLen { + return nil, nil, errors.Errorf("invalid V2 key: %s", encoded) + } + return encoded[:keyspacePrefixLen], encoded[keyspacePrefixLen:], nil + } + return nil, nil, errors.Errorf("unsupported api version %s", version.String()) +} diff --git a/internal/apicodec/codec_v1.go b/internal/apicodec/codec_v1.go new file mode 100644 index 0000000000..b21ef8802a --- /dev/null +++ b/internal/apicodec/codec_v1.go @@ -0,0 +1,200 @@ +package apicodec + +import ( + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/tikv/client-go/v2/tikvrpc" +) + +type codecV1 struct { + memCodec memCodec +} + +// NewCodecV1 returns a codec that can be used to encode/decode +// keys and requests to and from APIv1 format. +func NewCodecV1(mode Mode) Codec { + switch mode { + case ModeRaw: + return &codecV1{memCodec: &defaultMemCodec{}} + case ModeTxn: + return &codecV1{memCodec: &memComparableCodec{}} + } + panic("unknown mode") +} + +func (c *codecV1) GetAPIVersion() kvrpcpb.APIVersion { + return kvrpcpb.APIVersion_V1 +} + +func (c *codecV1) GetKeyspace() []byte { + return nil +} + +func (c *codecV1) EncodeRequest(req *tikvrpc.Request) (*tikvrpc.Request, error) { + return req, nil +} + +func (c *codecV1) DecodeResponse(req *tikvrpc.Request, resp *tikvrpc.Response) (*tikvrpc.Response, error) { + regionError, err := resp.GetRegionError() + // If GetRegionError returns error, it means the response does not contain region error to decode, + // therefore we skip decoding and return the response as is. + if err != nil { + return resp, nil + } + decodeRegionError, err := c.decodeRegionError(regionError) + if err != nil { + return nil, err + } + switch req.Type { + case tikvrpc.CmdGet: + r := resp.Resp.(*kvrpcpb.GetResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdScan: + r := resp.Resp.(*kvrpcpb.ScanResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdPrewrite: + r := resp.Resp.(*kvrpcpb.PrewriteResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdCommit: + r := resp.Resp.(*kvrpcpb.CommitResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdCleanup: + r := resp.Resp.(*kvrpcpb.CleanupResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdBatchGet: + r := resp.Resp.(*kvrpcpb.BatchGetResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdBatchRollback: + r := resp.Resp.(*kvrpcpb.BatchRollbackResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdScanLock: + r := resp.Resp.(*kvrpcpb.ScanLockResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdResolveLock: + r := resp.Resp.(*kvrpcpb.ResolveLockResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdGC: + r := resp.Resp.(*kvrpcpb.GCResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdDeleteRange: + r := resp.Resp.(*kvrpcpb.DeleteRangeResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdPessimisticLock: + r := resp.Resp.(*kvrpcpb.PessimisticLockResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdPessimisticRollback: + r := resp.Resp.(*kvrpcpb.PessimisticRollbackResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdTxnHeartBeat: + r := resp.Resp.(*kvrpcpb.TxnHeartBeatResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdCheckTxnStatus: + r := resp.Resp.(*kvrpcpb.CheckTxnStatusResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdCheckSecondaryLocks: + r := resp.Resp.(*kvrpcpb.CheckSecondaryLocksResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawGet: + r := resp.Resp.(*kvrpcpb.RawGetResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawBatchGet: + r := resp.Resp.(*kvrpcpb.RawBatchGetResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawPut: + r := resp.Resp.(*kvrpcpb.RawPutResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawBatchPut: + r := resp.Resp.(*kvrpcpb.RawBatchPutResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawDelete: + r := resp.Resp.(*kvrpcpb.RawDeleteResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawBatchDelete: + r := resp.Resp.(*kvrpcpb.RawBatchDeleteResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawDeleteRange: + r := resp.Resp.(*kvrpcpb.RawDeleteRangeResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawScan: + r := resp.Resp.(*kvrpcpb.RawScanResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdGetKeyTTL: + r := resp.Resp.(*kvrpcpb.RawGetKeyTTLResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawCompareAndSwap: + r := resp.Resp.(*kvrpcpb.RawCASResponse) + r.RegionError = decodeRegionError + case tikvrpc.CmdRawChecksum: + r := resp.Resp.(*kvrpcpb.RawChecksumResponse) + r.RegionError = decodeRegionError + } + return resp, nil +} + +func (c *codecV1) EncodeRegionKey(key []byte) []byte { + return c.memCodec.encodeKey(key) +} + +func (c *codecV1) DecodeRegionKey(encodedKey []byte) ([]byte, error) { + if len(encodedKey) == 0 { + return encodedKey, nil + } + return c.memCodec.decodeKey(encodedKey) +} + +func (c *codecV1) EncodeRegionRange(start, end []byte) ([]byte, []byte) { + if len(end) > 0 { + return c.EncodeRegionKey(start), c.EncodeRegionKey(end) + } + return c.EncodeRegionKey(start), end +} + +func (c *codecV1) DecodeRegionRange(encodedStart, encodedEnd []byte) ([]byte, []byte, error) { + start, err := c.DecodeRegionKey(encodedStart) + if err != nil { + return nil, nil, err + } + end, err := c.DecodeRegionKey(encodedEnd) + if err != nil { + return nil, nil, err + } + return start, end, nil +} + +func (c *codecV1) decodeRegionError(regionError *errorpb.Error) (*errorpb.Error, error) { + if regionError == nil { + return nil, nil + } + var err error + if errInfo := regionError.KeyNotInRegion; errInfo != nil { + errInfo.StartKey, errInfo.EndKey, err = c.DecodeRegionRange(errInfo.StartKey, errInfo.EndKey) + if err != nil { + return nil, err + } + } + if errInfo := regionError.EpochNotMatch; errInfo != nil { + for _, meta := range errInfo.CurrentRegions { + meta.StartKey, meta.EndKey, err = c.DecodeRegionRange(meta.StartKey, meta.EndKey) + if err != nil { + return nil, err + } + } + } + return regionError, nil +} + +func (c *codecV1) EncodeKey(key []byte) []byte { + return key +} + +func (c *codecV1) EncodeRange(start, end []byte) ([]byte, []byte) { + return start, end +} + +func (c *codecV1) DecodeRange(start, end []byte) ([]byte, []byte, error) { + return start, end, nil +} + +func (c *codecV1) DecodeKey(key []byte) ([]byte, error) { + return key, nil +} diff --git a/internal/apicodec/codec_v2.go b/internal/apicodec/codec_v2.go new file mode 100644 index 0000000000..71fbd4984f --- /dev/null +++ b/internal/apicodec/codec_v2.go @@ -0,0 +1,941 @@ +package apicodec + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pkg/errors" + "github.com/tikv/client-go/v2/internal/logutil" + "github.com/tikv/client-go/v2/tikvrpc" + "go.uber.org/zap" +) + +var ( + // DefaultKeyspaceID is the keyspaceID of the default keyspace. + DefaultKeyspaceID uint32 = 0 + // DefaultKeyspaceName is the name of the default keyspace. + DefaultKeyspaceName = "DEFAULT" + + rawModePrefix byte = 'r' + txnModePrefix byte = 'x' + keyspacePrefixLen int = 4 + + // errKeyOutOfBound happens when key to be decoded lies outside the keyspace's range. + errKeyOutOfBound = errors.New("given key does not belong to the keyspace.") +) + +// BuildKeyspaceName builds a keyspace name +func BuildKeyspaceName(name string) string { + if name == "" { + return DefaultKeyspaceName + } + return name +} + +// codecV2 is used to encode/decode keys and request into APIv2 format. +type codecV2 struct { + prefix []byte + endKey []byte + memCodec memCodec +} + +// NewCodecV2 returns a codec that can be used to encode/decode +// keys and requests to and from APIv2 format. +func NewCodecV2(mode Mode, keyspaceID uint32) (Codec, error) { + prefix, err := getIDByte(keyspaceID) + if err != nil { + return nil, err + } + + // Region keys in CodecV2 are always encoded in memory comparable form. + codec := &codecV2{memCodec: &memComparableCodec{}} + codec.prefix = make([]byte, 4) + codec.endKey = make([]byte, 4) + switch mode { + case ModeRaw: + codec.prefix[0] = rawModePrefix + case ModeTxn: + codec.prefix[0] = txnModePrefix + default: + return nil, errors.Errorf("unknown mode") + } + copy(codec.prefix[1:], prefix) + prefixVal := binary.BigEndian.Uint32(codec.prefix) + binary.BigEndian.PutUint32(codec.endKey, prefixVal+1) + return codec, nil +} + +func getIDByte(keyspaceID uint32) ([]byte, error) { + // PutUint32 requires 4 bytes to operate, so must use buffer with size 4 here. + b := make([]byte, 4) + // Use BigEndian to put the least significant byte to last array position. + // For example, keyspaceID 1 should result in []byte{0, 0, 1} + binary.BigEndian.PutUint32(b, keyspaceID) + // When keyspaceID can't fit in 3 bytes, first byte of buffer will be non-zero. + // So return error. + if b[0] != 0 { + return nil, errors.Errorf("illegal keyspaceID: %v, keyspaceID must be 3 byte", b) + } + // Remove the first byte to make keyspace ID 3 bytes. + return b[1:], nil +} + +func (c *codecV2) GetKeyspace() []byte { + return c.prefix +} + +func (c *codecV2) GetAPIVersion() kvrpcpb.APIVersion { + return kvrpcpb.APIVersion_V2 +} + +// EncodeRequest encodes with the given Codec. +// NOTE: req is reused on retry. MUST encode on cloned request, other than overwrite the original. +func (c *codecV2) EncodeRequest(req *tikvrpc.Request) (*tikvrpc.Request, error) { + newReq := *req + // Encode requests based on command type. + switch req.Type { + // Transaction Request Types. + case tikvrpc.CmdGet: + r := *req.Get() + r.Key = c.EncodeKey(r.Key) + newReq.Req = &r + case tikvrpc.CmdScan: + r := *req.Scan() + r.StartKey, r.EndKey = c.encodeRange(r.StartKey, r.EndKey, r.Reverse) + newReq.Req = &r + case tikvrpc.CmdPrewrite: + r := *req.Prewrite() + r.Mutations = c.encodeMutations(r.Mutations) + r.PrimaryLock = c.EncodeKey(r.PrimaryLock) + r.Secondaries = c.encodeKeys(r.Secondaries) + newReq.Req = &r + case tikvrpc.CmdCommit: + r := *req.Commit() + r.Keys = c.encodeKeys(r.Keys) + newReq.Req = &r + case tikvrpc.CmdCleanup: + r := *req.Cleanup() + r.Key = c.EncodeKey(r.Key) + newReq.Req = &r + case tikvrpc.CmdBatchGet: + r := *req.BatchGet() + r.Keys = c.encodeKeys(r.Keys) + newReq.Req = &r + case tikvrpc.CmdBatchRollback: + r := *req.BatchRollback() + r.Keys = c.encodeKeys(r.Keys) + newReq.Req = &r + case tikvrpc.CmdScanLock: + r := *req.ScanLock() + r.StartKey, r.EndKey = c.encodeRange(r.StartKey, r.EndKey, false) + newReq.Req = &r + case tikvrpc.CmdResolveLock: + r := *req.ResolveLock() + r.Keys = c.encodeKeys(r.Keys) + newReq.Req = &r + case tikvrpc.CmdGC: + // TODO: Deprecate Central GC Mode. + case tikvrpc.CmdDeleteRange: + r := *req.DeleteRange() + r.StartKey, r.EndKey = c.encodeRange(r.StartKey, r.EndKey, false) + newReq.Req = &r + case tikvrpc.CmdPessimisticLock: + r := *req.PessimisticLock() + r.Mutations = c.encodeMutations(r.Mutations) + r.PrimaryLock = c.EncodeKey(r.PrimaryLock) + newReq.Req = &r + case tikvrpc.CmdPessimisticRollback: + r := *req.PessimisticRollback() + r.Keys = c.encodeKeys(r.Keys) + newReq.Req = &r + case tikvrpc.CmdTxnHeartBeat: + r := *req.TxnHeartBeat() + r.PrimaryLock = c.EncodeKey(r.PrimaryLock) + newReq.Req = &r + case tikvrpc.CmdCheckTxnStatus: + r := *req.CheckTxnStatus() + r.PrimaryKey = c.EncodeKey(r.PrimaryKey) + newReq.Req = &r + case tikvrpc.CmdCheckSecondaryLocks: + r := *req.CheckSecondaryLocks() + r.Keys = c.encodeKeys(r.Keys) + newReq.Req = &r + + // Raw Request Types. + case tikvrpc.CmdRawGet: + r := *req.RawGet() + r.Key = c.EncodeKey(r.Key) + newReq.Req = &r + case tikvrpc.CmdRawBatchGet: + r := *req.RawBatchGet() + r.Keys = c.encodeKeys(r.Keys) + newReq.Req = &r + case tikvrpc.CmdRawPut: + r := *req.RawPut() + r.Key = c.EncodeKey(r.Key) + newReq.Req = &r + case tikvrpc.CmdRawBatchPut: + r := *req.RawBatchPut() + r.Pairs = c.encodeParis(r.Pairs) + newReq.Req = &r + case tikvrpc.CmdRawDelete: + r := *req.RawDelete() + r.Key = c.EncodeKey(r.Key) + newReq.Req = &r + case tikvrpc.CmdRawBatchDelete: + r := *req.RawBatchDelete() + r.Keys = c.encodeKeys(r.Keys) + newReq.Req = &r + case tikvrpc.CmdRawDeleteRange: + r := *req.RawDeleteRange() + r.StartKey, r.EndKey = c.encodeRange(r.StartKey, r.EndKey, false) + newReq.Req = &r + case tikvrpc.CmdRawScan: + r := *req.RawScan() + r.StartKey, r.EndKey = c.encodeRange(r.StartKey, r.EndKey, r.Reverse) + newReq.Req = &r + case tikvrpc.CmdGetKeyTTL: + r := *req.RawGetKeyTTL() + r.Key = c.EncodeKey(r.Key) + newReq.Req = &r + case tikvrpc.CmdRawCompareAndSwap: + r := *req.RawCompareAndSwap() + r.Key = c.EncodeKey(r.Key) + newReq.Req = &r + case tikvrpc.CmdRawChecksum: + r := *req.RawChecksum() + r.Ranges = c.encodeKeyRanges(r.Ranges) + newReq.Req = &r + + // TiFlash Requests + case tikvrpc.CmdBatchCop: + r := *req.BatchCop() + r.Regions = c.encodeRegionInfos(r.Regions) + r.TableRegions = c.encodeTableRegions(r.TableRegions) + newReq.Req = &r + case tikvrpc.CmdMPPTask: + r := *req.DispatchMPPTask() + r.Regions = c.encodeRegionInfos(r.Regions) + r.TableRegions = c.encodeTableRegions(r.TableRegions) + newReq.Req = &r + + // Other requests. + case tikvrpc.CmdUnsafeDestroyRange: + r := *req.UnsafeDestroyRange() + r.StartKey, r.EndKey = c.encodeRange(r.StartKey, r.EndKey, false) + newReq.Req = &r + case tikvrpc.CmdPhysicalScanLock: + r := *req.PhysicalScanLock() + r.StartKey = c.EncodeKey(r.StartKey) + newReq.Req = &r + case tikvrpc.CmdStoreSafeTS: + r := *req.StoreSafeTS() + r.KeyRange = c.encodeKeyRange(r.KeyRange) + newReq.Req = &r + case tikvrpc.CmdCop: + r := *req.Cop() + r.Ranges = c.encodeCopRanges(r.Ranges) + newReq.Req = &r + case tikvrpc.CmdCopStream: + r := *req.Cop() + r.Ranges = c.encodeCopRanges(r.Ranges) + newReq.Req = &r + case tikvrpc.CmdMvccGetByKey: + r := *req.MvccGetByKey() + r.Key = c.EncodeRegionKey(r.Key) + newReq.Req = &r + case tikvrpc.CmdSplitRegion: + r := *req.SplitRegion() + r.SplitKeys = c.encodeKeys(r.SplitKeys) + newReq.Req = &r + } + + return &newReq, nil +} + +// DecodeResponse decode the resp with the given codec. +func (c *codecV2) DecodeResponse(req *tikvrpc.Request, resp *tikvrpc.Response) (*tikvrpc.Response, error) { + var err error + // Decode response based on command type. + switch req.Type { + // Transaction KV responses. + // Keys that need to be decoded lies in RegionError, KeyError and LockInfo. + case tikvrpc.CmdGet: + r := resp.Resp.(*kvrpcpb.GetResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + case tikvrpc.CmdScan: + r := resp.Resp.(*kvrpcpb.ScanResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Pairs, err = c.decodePairs(r.Pairs) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + case tikvrpc.CmdPrewrite: + r := resp.Resp.(*kvrpcpb.PrewriteResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Errors, err = c.decodeKeyErrors(r.Errors) + if err != nil { + return nil, err + } + case tikvrpc.CmdCommit: + r := resp.Resp.(*kvrpcpb.CommitResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + case tikvrpc.CmdCleanup: + r := resp.Resp.(*kvrpcpb.CleanupResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + case tikvrpc.CmdBatchGet: + r := resp.Resp.(*kvrpcpb.BatchGetResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Pairs, err = c.decodePairs(r.Pairs) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + case tikvrpc.CmdBatchRollback: + r := resp.Resp.(*kvrpcpb.BatchRollbackResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + case tikvrpc.CmdScanLock: + r := resp.Resp.(*kvrpcpb.ScanLockResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + r.Locks, err = c.decodeLockInfos(r.Locks) + if err != nil { + return nil, err + } + case tikvrpc.CmdResolveLock: + r := resp.Resp.(*kvrpcpb.ResolveLockResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + case tikvrpc.CmdGC: + // TODO: Deprecate Central GC Mode. + r := resp.Resp.(*kvrpcpb.GCResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + case tikvrpc.CmdDeleteRange: + r := resp.Resp.(*kvrpcpb.DeleteRangeResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdPessimisticLock: + r := resp.Resp.(*kvrpcpb.PessimisticLockResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Errors, err = c.decodeKeyErrors(r.Errors) + if err != nil { + return nil, err + } + case tikvrpc.CmdPessimisticRollback: + r := resp.Resp.(*kvrpcpb.PessimisticRollbackResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Errors, err = c.decodeKeyErrors(r.Errors) + if err != nil { + return nil, err + } + case tikvrpc.CmdTxnHeartBeat: + r := resp.Resp.(*kvrpcpb.TxnHeartBeatResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + case tikvrpc.CmdCheckTxnStatus: + r := resp.Resp.(*kvrpcpb.CheckTxnStatusResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + r.LockInfo, err = c.decodeLockInfo(r.LockInfo) + if err != nil { + return nil, err + } + case tikvrpc.CmdCheckSecondaryLocks: + r := resp.Resp.(*kvrpcpb.CheckSecondaryLocksResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Error, err = c.decodeKeyError(r.Error) + if err != nil { + return nil, err + } + r.Locks, err = c.decodeLockInfos(r.Locks) + if err != nil { + return nil, err + } + // RawKV Responses. + // Most of these responses does not require treatment aside from Region Error decoding. + // Exceptions are Response with keys attach to them, like RawScan and RawBatchGet, + // which need have their keys decoded. + case tikvrpc.CmdRawGet: + r := resp.Resp.(*kvrpcpb.RawGetResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdRawBatchGet: + r := resp.Resp.(*kvrpcpb.RawBatchGetResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Pairs, err = c.decodePairs(r.Pairs) + if err != nil { + return nil, err + } + case tikvrpc.CmdRawPut: + r := resp.Resp.(*kvrpcpb.RawPutResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdRawBatchPut: + r := resp.Resp.(*kvrpcpb.RawBatchPutResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdRawDelete: + r := resp.Resp.(*kvrpcpb.RawDeleteResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdRawBatchDelete: + r := resp.Resp.(*kvrpcpb.RawBatchDeleteResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdRawDeleteRange: + r := resp.Resp.(*kvrpcpb.RawDeleteRangeResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdRawScan: + r := resp.Resp.(*kvrpcpb.RawScanResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Kvs, err = c.decodePairs(r.Kvs) + if err != nil { + return nil, err + } + case tikvrpc.CmdGetKeyTTL: + r := resp.Resp.(*kvrpcpb.RawGetKeyTTLResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdRawCompareAndSwap: + r := resp.Resp.(*kvrpcpb.RawCASResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdRawChecksum: + r := resp.Resp.(*kvrpcpb.RawChecksumResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + + // Other requests. + case tikvrpc.CmdUnsafeDestroyRange: + r := resp.Resp.(*kvrpcpb.UnsafeDestroyRangeResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdPhysicalScanLock: + r := resp.Resp.(*kvrpcpb.PhysicalScanLockResponse) + r.Locks, err = c.decodeLockInfos(r.Locks) + if err != nil { + return nil, err + } + case tikvrpc.CmdCop: + r := resp.Resp.(*coprocessor.Response) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Locked, err = c.decodeLockInfo(r.Locked) + if err != nil { + return nil, err + } + r.Range, err = c.decodeCopRange(r.Range) + if err != nil { + return nil, err + } + case tikvrpc.CmdCopStream: + return nil, errors.New("streaming coprocessor is not supported yet") + case tikvrpc.CmdBatchCop, tikvrpc.CmdMPPTask: + // There aren't range infos in BatchCop and MPPTask responses. + case tikvrpc.CmdMvccGetByKey: + r := resp.Resp.(*kvrpcpb.MvccGetByKeyResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + case tikvrpc.CmdSplitRegion: + r := resp.Resp.(*kvrpcpb.SplitRegionResponse) + r.RegionError, err = c.decodeRegionError(r.RegionError) + if err != nil { + return nil, err + } + r.Regions, err = c.decodeRegions(r.Regions) + if err != nil { + return nil, err + } + } + + return resp, nil +} + +func (c *codecV2) EncodeRegionKey(key []byte) []byte { + encodeKey := c.EncodeKey(key) + return c.memCodec.encodeKey(encodeKey) +} + +func (c *codecV2) DecodeRegionKey(encodedKey []byte) ([]byte, error) { + memDecoded, err := c.memCodec.decodeKey(encodedKey) + if err != nil { + return nil, err + } + return c.DecodeKey(memDecoded) +} + +// EncodeRegionRange first append appropriate prefix to start and end, +// then pass them to memCodec to encode them to appropriate memory format. +func (c *codecV2) EncodeRegionRange(start, end []byte) ([]byte, []byte) { + encodedStart, encodedEnd := c.encodeRange(start, end, false) + encodedStart = c.memCodec.encodeKey(encodedStart) + encodedEnd = c.memCodec.encodeKey(encodedEnd) + return encodedStart, encodedEnd +} + +// DecodeRegionRange first decode key from memory compatible format, +// then pass decode them with DecodeRange to map them to correct range. +// Note that empty byte slice/ nil slice requires special treatment. +func (c *codecV2) DecodeRegionRange(encodedStart, encodedEnd []byte) ([]byte, []byte, error) { + var err error + if len(encodedStart) != 0 { + encodedStart, err = c.memCodec.decodeKey(encodedStart) + if err != nil { + return nil, nil, err + } + } + if len(encodedEnd) != 0 { + encodedEnd, err = c.memCodec.decodeKey(encodedEnd) + if err != nil { + return nil, nil, err + } + } + + return c.DecodeRange(encodedStart, encodedEnd) +} + +func (c *codecV2) EncodeRange(start, end []byte) ([]byte, []byte) { + return c.encodeRange(start, end, false) +} + +// encodeRange encodes start and end to correct range in APIv2. +// Note that if end is nil/ empty byte slice, it means no end. +// So we use endKey of the keyspace directly. +func (c *codecV2) encodeRange(start, end []byte, reverse bool) ([]byte, []byte) { + // If reverse, scan from end to start. + // Corresponding start and end encode needs to be reversed. + if reverse { + end, start = c.encodeRange(end, start, false) + return start, end + } + var encodedEnd []byte + if len(end) > 0 { + encodedEnd = c.EncodeKey(end) + } else { + encodedEnd = c.endKey + } + return c.EncodeKey(start), encodedEnd +} + +// DecodeRange maps encodedStart and end back to normal start and +// end without APIv2 prefixes. +func (c *codecV2) DecodeRange(encodedStart, encodedEnd []byte) (start []byte, end []byte, err error) { + if bytes.Compare(encodedStart, c.endKey) >= 0 || + (len(encodedEnd) > 0 && bytes.Compare(encodedEnd, c.prefix) <= 0) { + return nil, nil, errors.WithStack(errKeyOutOfBound) + } + + start, end = []byte{}, []byte{} + + if bytes.HasPrefix(encodedStart, c.prefix) { + start = encodedStart[len(c.prefix):] + } + + if bytes.HasPrefix(encodedEnd, c.prefix) { + end = encodedEnd[len(c.prefix):] + } + + return +} + +func (c *codecV2) EncodeKey(key []byte) []byte { + return append(c.prefix, key...) +} + +func (c *codecV2) DecodeKey(encodedKey []byte) ([]byte, error) { + // If the given key does not start with the correct prefix, + // return out of bound error. + if !bytes.HasPrefix(encodedKey, c.prefix) { + logutil.BgLogger().Warn("key not in keyspace", + zap.String("keyspacePrefix", hex.EncodeToString(c.prefix)), + zap.String("key", hex.EncodeToString(encodedKey)), + zap.Stack("stack")) + return nil, errKeyOutOfBound + } + return encodedKey[len(c.prefix):], nil +} + +func (c *codecV2) encodeKeyRange(keyRange *kvrpcpb.KeyRange) *kvrpcpb.KeyRange { + encodedRange := &kvrpcpb.KeyRange{} + encodedRange.StartKey, encodedRange.EndKey = c.encodeRange(keyRange.StartKey, keyRange.EndKey, false) + return encodedRange +} + +func (c *codecV2) encodeKeyRanges(keyRanges []*kvrpcpb.KeyRange) []*kvrpcpb.KeyRange { + encodedRanges := make([]*kvrpcpb.KeyRange, 0, len(keyRanges)) + for _, keyRange := range keyRanges { + encodedRanges = append(encodedRanges, c.encodeKeyRange(keyRange)) + } + return encodedRanges +} + +func (c *codecV2) encodeCopRange(r *coprocessor.KeyRange) *coprocessor.KeyRange { + newRange := &coprocessor.KeyRange{} + newRange.Start, newRange.End = c.encodeRange(r.Start, r.End, false) + return newRange +} + +func (c *codecV2) decodeCopRange(r *coprocessor.KeyRange) (*coprocessor.KeyRange, error) { + var err error + if r != nil { + r.Start, r.End, err = c.DecodeRange(r.Start, r.End) + } + if err != nil { + return nil, err + } + return r, nil +} + +func (c *codecV2) encodeCopRanges(ranges []*coprocessor.KeyRange) []*coprocessor.KeyRange { + newRanges := make([]*coprocessor.KeyRange, 0, len(ranges)) + for _, r := range ranges { + newRanges = append(newRanges, c.encodeCopRange(r)) + } + return newRanges +} + +func (c *codecV2) decodeRegions(regions []*metapb.Region) ([]*metapb.Region, error) { + var err error + for _, region := range regions { + region.StartKey, region.EndKey, err = c.DecodeRegionRange(region.StartKey, region.EndKey) + if err != nil { + return nil, err + } + } + return regions, nil +} + +func (c *codecV2) encodeKeys(keys [][]byte) [][]byte { + var encodedKeys [][]byte + for _, key := range keys { + encodedKeys = append(encodedKeys, c.EncodeKey(key)) + } + return encodedKeys +} + +func (c *codecV2) encodeParis(pairs []*kvrpcpb.KvPair) []*kvrpcpb.KvPair { + var encodedPairs []*kvrpcpb.KvPair + for _, pair := range pairs { + p := *pair + p.Key = c.EncodeKey(p.Key) + encodedPairs = append(encodedPairs, &p) + } + return encodedPairs +} + +func (c *codecV2) decodePairs(encodedPairs []*kvrpcpb.KvPair) ([]*kvrpcpb.KvPair, error) { + var pairs []*kvrpcpb.KvPair + for _, encodedPair := range encodedPairs { + var err error + p := *encodedPair + if p.Error != nil { + p.Error, err = c.decodeKeyError(p.Error) + if err != nil { + return nil, err + } + } + if len(p.Key) > 0 { + p.Key, err = c.DecodeKey(p.Key) + if err != nil { + return nil, err + } + } + pairs = append(pairs, &p) + } + return pairs, nil +} + +func (c *codecV2) encodeMutations(mutations []*kvrpcpb.Mutation) []*kvrpcpb.Mutation { + var encodedMutations []*kvrpcpb.Mutation + for _, mutation := range mutations { + m := *mutation + m.Key = c.EncodeKey(m.Key) + encodedMutations = append(encodedMutations, &m) + } + return encodedMutations +} + +func (c *codecV2) encodeRegionInfo(info *coprocessor.RegionInfo) *coprocessor.RegionInfo { + i := *info + i.Ranges = c.encodeCopRanges(info.Ranges) + return &i +} + +func (c *codecV2) encodeRegionInfos(infos []*coprocessor.RegionInfo) []*coprocessor.RegionInfo { + var encodedInfos []*coprocessor.RegionInfo + for _, info := range infos { + encodedInfos = append(encodedInfos, c.encodeRegionInfo(info)) + } + return encodedInfos +} + +func (c *codecV2) encodeTableRegions(infos []*coprocessor.TableRegions) []*coprocessor.TableRegions { + var encodedInfos []*coprocessor.TableRegions + for _, info := range infos { + i := *info + i.Regions = c.encodeRegionInfos(info.Regions) + encodedInfos = append(encodedInfos, &i) + } + return encodedInfos +} + +func (c *codecV2) decodeRegionError(regionError *errorpb.Error) (*errorpb.Error, error) { + if regionError == nil { + return nil, nil + } + var err error + if errInfo := regionError.KeyNotInRegion; errInfo != nil { + errInfo.Key, err = c.DecodeKey(errInfo.Key) + if err != nil { + return nil, err + } + errInfo.StartKey, errInfo.EndKey, err = c.DecodeRegionRange(errInfo.StartKey, errInfo.EndKey) + if err != nil { + return nil, err + } + } + + if errInfo := regionError.EpochNotMatch; errInfo != nil { + decodedRegions := make([]*metapb.Region, 0, len(errInfo.CurrentRegions)) + for _, meta := range errInfo.CurrentRegions { + meta.StartKey, meta.EndKey, err = c.DecodeRegionRange(meta.StartKey, meta.EndKey) + if err != nil { + // skip out of keyspace range's region + if errors.Is(err, errKeyOutOfBound) { + continue + } + return nil, err + } + decodedRegions = append(decodedRegions, meta) + } + errInfo.CurrentRegions = decodedRegions + } + + return regionError, nil +} + +func (c *codecV2) decodeKeyError(keyError *kvrpcpb.KeyError) (*kvrpcpb.KeyError, error) { + if keyError == nil { + return nil, nil + } + var err error + if keyError.Locked != nil { + keyError.Locked, err = c.decodeLockInfo(keyError.Locked) + if err != nil { + return nil, err + } + } + if keyError.Conflict != nil { + keyError.Conflict.Key, err = c.DecodeKey(keyError.Conflict.Key) + if err != nil { + return nil, err + } + keyError.Conflict.Primary, err = c.DecodeKey(keyError.Conflict.Primary) + if err != nil { + return nil, err + } + } + if keyError.AlreadyExist != nil { + keyError.AlreadyExist.Key, err = c.DecodeKey(keyError.AlreadyExist.Key) + if err != nil { + return nil, err + } + } + if keyError.Deadlock != nil { + keyError.Deadlock.LockKey, err = c.DecodeKey(keyError.Deadlock.LockKey) + if err != nil { + return nil, err + } + for _, wait := range keyError.Deadlock.WaitChain { + wait.Key, err = c.DecodeKey(wait.Key) + if err != nil { + return nil, err + } + } + } + if keyError.CommitTsExpired != nil { + keyError.CommitTsExpired.Key, err = c.DecodeKey(keyError.CommitTsExpired.Key) + if err != nil { + return nil, err + } + } + if keyError.TxnNotFound != nil { + keyError.TxnNotFound.PrimaryKey, err = c.DecodeKey(keyError.TxnNotFound.PrimaryKey) + if err != nil { + return nil, err + } + } + if keyError.AssertionFailed != nil { + keyError.AssertionFailed.Key, err = c.DecodeKey(keyError.AssertionFailed.Key) + if err != nil { + return nil, err + } + } + return keyError, nil +} + +func (c *codecV2) decodeKeyErrors(keyErrors []*kvrpcpb.KeyError) ([]*kvrpcpb.KeyError, error) { + var err error + for i := range keyErrors { + keyErrors[i], err = c.decodeKeyError(keyErrors[i]) + if err != nil { + return nil, err + } + } + return keyErrors, nil +} + +func (c *codecV2) decodeLockInfo(info *kvrpcpb.LockInfo) (*kvrpcpb.LockInfo, error) { + if info == nil { + return nil, nil + } + var err error + info.Key, err = c.DecodeKey(info.Key) + if err != nil { + return nil, err + } + info.PrimaryLock, err = c.DecodeKey(info.PrimaryLock) + if err != nil { + return nil, err + } + for i := range info.Secondaries { + info.Secondaries[i], err = c.DecodeKey(info.Secondaries[i]) + if err != nil { + return nil, err + } + } + return info, nil +} + +func (c *codecV2) decodeLockInfos(locks []*kvrpcpb.LockInfo) ([]*kvrpcpb.LockInfo, error) { + var err error + for i := range locks { + locks[i], err = c.decodeLockInfo(locks[i]) + if err != nil { + return nil, err + } + } + return locks, nil +} diff --git a/internal/apicodec/codec_v2_test.go b/internal/apicodec/codec_v2_test.go new file mode 100644 index 0000000000..fe58cecec4 --- /dev/null +++ b/internal/apicodec/codec_v2_test.go @@ -0,0 +1,268 @@ +package apicodec + +import ( + "math" + "testing" + + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/stretchr/testify/suite" + "github.com/tikv/client-go/v2/tikvrpc" +) + +var ( + testKeyspaceID = uint32(4242) + // Keys below are ordered as following: + // beforePrefix, keyspacePrefix, insideLeft, insideRight, keyspaceEndKey, afterEndKey + // where valid keyspace range is [keyspacePrefix, keyspaceEndKey) + keyspacePrefix = []byte{'r', 0, 16, 146} + keyspaceEndKey = []byte{'r', 0, 16, 147} + beforePrefix = []byte{'r', 0, 0, 1} + afterEndKey = []byte{'r', 1, 0, 0} + insideLeft = []byte{'r', 0, 16, 146, 100} + insideRight = []byte{'r', 0, 16, 146, 200} +) + +type testCodecV2Suite struct { + suite.Suite + codec *codecV2 +} + +func TestCodecV2(t *testing.T) { + suite.Run(t, new(testCodecV2Suite)) +} + +func (suite *testCodecV2Suite) SetupSuite() { + codec, err := NewCodecV2(ModeRaw, testKeyspaceID) + suite.NoError(err) + suite.Equal(keyspacePrefix, codec.GetKeyspace()) + suite.codec = codec.(*codecV2) +} + +func (suite *testCodecV2Suite) TestEncodeRequest() { + re := suite.Require() + req := &tikvrpc.Request{ + Type: tikvrpc.CmdRawGet, + Req: &kvrpcpb.RawGetRequest{ + Key: []byte("key"), + }, + } + req.ApiVersion = kvrpcpb.APIVersion_V2 + + r, err := suite.codec.EncodeRequest(req) + re.NoError(err) + re.Equal(append(keyspacePrefix, []byte("key")...), r.RawGet().Key) + + r, err = suite.codec.EncodeRequest(req) + re.NoError(err) + re.Equal(append(keyspacePrefix, []byte("key")...), r.RawGet().Key) +} + +func (suite *testCodecV2Suite) TestEncodeV2KeyRanges() { + re := suite.Require() + keyRanges := []*kvrpcpb.KeyRange{ + { + StartKey: []byte{}, + EndKey: []byte{}, + }, + { + StartKey: []byte{}, + EndKey: []byte{'z'}, + }, + { + StartKey: []byte{'a'}, + EndKey: []byte{}, + }, + { + StartKey: []byte{'a'}, + EndKey: []byte{'z'}, + }, + } + expect := []*kvrpcpb.KeyRange{ + { + StartKey: keyspacePrefix, + EndKey: keyspaceEndKey, + }, + { + StartKey: keyspacePrefix, + EndKey: append(keyspacePrefix, 'z'), + }, + { + StartKey: append(keyspacePrefix, 'a'), + EndKey: keyspaceEndKey, + }, + { + StartKey: append(keyspacePrefix, 'a'), + EndKey: append(keyspacePrefix, 'z'), + }, + } + + encodedKeyRanges := suite.codec.encodeKeyRanges(keyRanges) + re.Equal(expect, encodedKeyRanges) +} + +func (suite *testCodecV2Suite) TestNewCodecV2() { + re := suite.Require() + testCases := []struct { + mode Mode + spaceID uint32 + shouldErr bool + expectedPrefix []byte + expectedEnd []byte + }{ + { + mode: ModeRaw, + // A too large keyspaceID should result in error. + spaceID: math.MaxUint32, + shouldErr: true, + }, + { + // Bad mode should result in error. + mode: Mode(99), + spaceID: DefaultKeyspaceID, + shouldErr: true, + }, + { + mode: ModeRaw, + spaceID: 1<<24 - 2, + expectedPrefix: []byte{'r', 255, 255, 254}, + expectedEnd: []byte{'r', 255, 255, 255}, + }, + { + // EndKey should be able to carry over increment from lower byte. + mode: ModeTxn, + spaceID: 1<<8 - 1, + expectedPrefix: []byte{'x', 0, 0, 255}, + expectedEnd: []byte{'x', 0, 1, 0}, + }, + { + // EndKey should be able to carry over increment from lower byte. + mode: ModeTxn, + spaceID: 1<<16 - 1, + expectedPrefix: []byte{'x', 0, 255, 255}, + expectedEnd: []byte{'x', 1, 0, 0}, + }, + { + // If prefix is the last keyspace, then end should change the mode byte. + mode: ModeRaw, + spaceID: 1<<24 - 1, + expectedPrefix: []byte{'r', 255, 255, 255}, + expectedEnd: []byte{'s', 0, 0, 0}, + }, + } + for _, testCase := range testCases { + if testCase.shouldErr { + _, err := NewCodecV2(testCase.mode, testCase.spaceID) + re.Error(err) + continue + } + codec, err := NewCodecV2(testCase.mode, testCase.spaceID) + re.NoError(err) + + v2Codec, ok := codec.(*codecV2) + re.True(ok) + re.Equal(testCase.expectedPrefix, v2Codec.prefix) + re.Equal(testCase.expectedEnd, v2Codec.endKey) + } +} + +func (suite *testCodecV2Suite) TestDecodeEpochNotMatch() { + re := suite.Require() + codec := suite.codec + regionErr := &errorpb.Error{ + EpochNotMatch: &errorpb.EpochNotMatch{ + CurrentRegions: []*metapb.Region{ + { + // Region 1: + // keyspace range: ------[------)------ + // region range: ------[------)------ + // after decode: ------[------)------ + Id: 1, + StartKey: codec.memCodec.encodeKey(keyspacePrefix), + EndKey: codec.memCodec.encodeKey(keyspaceEndKey), + }, + { + // Region 2: + // keyspace range: ------[------)------ + // region range: ---[-------------)-- + // after decode: ------[------)------ + Id: 2, + StartKey: codec.memCodec.encodeKey(beforePrefix), + EndKey: codec.memCodec.encodeKey(afterEndKey), + }, + { + // Region 3: + // keyspace range: ------[------)------ + // region range: ---[----)----------- + // after decode: ------[-)----------- + Id: 3, + StartKey: codec.memCodec.encodeKey(beforePrefix), + EndKey: codec.memCodec.encodeKey(insideLeft), + }, + { + // Region 4: + // keyspace range: ------[------)------ + // region range: --------[--)-------- + // after decode: --[--)-- + Id: 4, + StartKey: codec.memCodec.encodeKey(insideLeft), + EndKey: codec.memCodec.encodeKey(insideRight), + }, + { + // Region 5: + // keyspace range: ------[------)------ + // region range: ---[--)------------- + // after decode: StartKey out of bound, should be removed. + Id: 5, + StartKey: codec.memCodec.encodeKey(beforePrefix), + EndKey: codec.memCodec.encodeKey(keyspacePrefix), + }, + { + // Region 6: + // keyspace range: ------[------)------ + // region range: -------------[--)--- + // after decode: EndKey out of bound, should be removed. + Id: 6, + StartKey: codec.memCodec.encodeKey(keyspaceEndKey), + EndKey: codec.memCodec.encodeKey(afterEndKey), + }, + }, + }, + } + + expected := &errorpb.Error{ + EpochNotMatch: &errorpb.EpochNotMatch{ + CurrentRegions: []*metapb.Region{ + { + Id: 1, + StartKey: []byte{}, + EndKey: []byte{}, + }, + { + Id: 2, + StartKey: []byte{}, + EndKey: []byte{}, + }, + { + Id: 3, + StartKey: []byte{}, + EndKey: insideLeft[len(keyspacePrefix):], + }, + { + Id: 4, + StartKey: insideLeft[len(keyspacePrefix):], + EndKey: insideRight[len(keyspacePrefix):], + }, + // Region 5 should be removed. + // Region 6 should be removed. + }, + }, + } + + result, err := codec.decodeRegionError(regionErr) + re.NoError(err) + for i := range result.EpochNotMatch.CurrentRegions { + re.Equal(expected.EpochNotMatch.CurrentRegions[i], result.EpochNotMatch.CurrentRegions[i], "index: %d", i) + } +} diff --git a/internal/apicodec/mem_codec.go b/internal/apicodec/mem_codec.go new file mode 100644 index 0000000000..394e08ff94 --- /dev/null +++ b/internal/apicodec/mem_codec.go @@ -0,0 +1,57 @@ +package apicodec + +import ( + "github.com/pkg/errors" + "github.com/tikv/client-go/v2/util/codec" +) + +// memCodec is used by Codec to encode/decode keys to +// memory comparable format. +type memCodec interface { + encodeKey(key []byte) []byte + decodeKey(encodedKey []byte) ([]byte, error) +} + +// decodeError happens if the region range key is not well-formed. +// It indicates TiKV has bugs and the client can't handle such a case, +// so it should report the error to users soon. +type decodeError struct { + error +} + +// IsDecodeError is used to determine if error is decode error. +func IsDecodeError(err error) bool { + _, ok := errors.Cause(err).(*decodeError) + if !ok { + _, ok = errors.Cause(err).(decodeError) + } + return ok +} + +// defaultMemCodec is used by RawKV client under APIv1, +// It returns the key as given. +type defaultMemCodec struct{} + +func (c *defaultMemCodec) encodeKey(key []byte) []byte { + return key +} + +func (c *defaultMemCodec) decodeKey(encodedKey []byte) ([]byte, error) { + return encodedKey, nil +} + +// memComparableCodec encode/decode key to/from mem comparable form. +// It throws decodeError on decode failure. +type memComparableCodec struct{} + +func (c *memComparableCodec) encodeKey(key []byte) []byte { + return codec.EncodeBytes([]byte(nil), key) +} + +func (c *memComparableCodec) decodeKey(encodedKey []byte) ([]byte, error) { + _, key, err := codec.DecodeBytes(encodedKey, nil) + if err != nil { + return nil, errors.WithStack(&decodeError{err}) + } + return key, nil +} diff --git a/internal/client/api_version.go b/internal/client/api_version.go deleted file mode 100644 index 6379aed226..0000000000 --- a/internal/client/api_version.go +++ /dev/null @@ -1,226 +0,0 @@ -package client - -import ( - "bytes" - - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pkg/errors" - "github.com/tikv/client-go/v2/tikvrpc" -) - -// Mode represents the operation mode of a request. -type Mode int - -const ( - // ModeRaw represent a raw operation in TiKV - ModeRaw = iota - - // ModeTxn represent a transaction operation in TiKV - ModeTxn -) - -var ( - // APIV2RawKeyPrefix is prefix of raw key in API V2. - APIV2RawKeyPrefix = []byte{'r', 0, 0, 0} - - // APIV2RawEndKey is max key of raw key in API V2. - APIV2RawEndKey = []byte{'r', 0, 0, 1} - - // APIV2TxnKeyPrefix is prefix of txn key in API V2. - APIV2TxnKeyPrefix = []byte{'x', 0, 0, 0} - - // APIV2TxnEndKey is max key of txn key in API V2. - APIV2TxnEndKey = []byte{'x', 0, 0, 1} -) - -func getV2Prefix(mode Mode) []byte { - switch mode { - case ModeRaw: - return APIV2RawKeyPrefix - case ModeTxn: - return APIV2TxnKeyPrefix - } - panic("unreachable") -} - -func getV2EndKey(mode Mode) []byte { - switch mode { - case ModeRaw: - return APIV2RawEndKey - case ModeTxn: - return APIV2TxnEndKey - } - panic("unreachable") -} - -// EncodeV2Key encode a user key into API V2 format. -func EncodeV2Key(mode Mode, key []byte) []byte { - return append(getV2Prefix(mode), key...) -} - -// EncodeV2Range encode a range into API V2 format. -func EncodeV2Range(mode Mode, start, end []byte) ([]byte, []byte) { - var b []byte - if len(end) > 0 { - b = EncodeV2Key(mode, end) - } else { - b = getV2EndKey(mode) - } - return EncodeV2Key(mode, start), b -} - -// EncodeV2KeyRanges encode KeyRange slice into API V2 formatted new slice. -func EncodeV2KeyRanges(mode Mode, keyRanges []*kvrpcpb.KeyRange) []*kvrpcpb.KeyRange { - encodedRanges := make([]*kvrpcpb.KeyRange, 0, len(keyRanges)) - for i := 0; i < len(keyRanges); i++ { - keyRange := kvrpcpb.KeyRange{} - keyRange.StartKey, keyRange.EndKey = EncodeV2Range(mode, keyRanges[i].StartKey, keyRanges[i].EndKey) - encodedRanges = append(encodedRanges, &keyRange) - } - return encodedRanges -} - -// MapV2RangeToV1 maps a range in API V2 format into V1 range. -// This function forbid the user seeing other keyspace. -func MapV2RangeToV1(mode Mode, start []byte, end []byte) ([]byte, []byte) { - var a, b []byte - minKey := getV2Prefix(mode) - if bytes.Compare(start, minKey) < 0 { - a = []byte{} - } else { - a = start[len(minKey):] - } - - maxKey := getV2EndKey(mode) - if len(end) == 0 || bytes.Compare(end, maxKey) >= 0 { - b = []byte{} - } else { - b = end[len(maxKey):] - } - - return a, b -} - -// EncodeV2Keys encodes keys into API V2 format. -func EncodeV2Keys(mode Mode, keys [][]byte) [][]byte { - var ks [][]byte - for _, key := range keys { - ks = append(ks, EncodeV2Key(mode, key)) - } - return ks -} - -// EncodeV2Pairs encodes pairs into API V2 format. -func EncodeV2Pairs(mode Mode, pairs []*kvrpcpb.KvPair) []*kvrpcpb.KvPair { - var ps []*kvrpcpb.KvPair - for _, pair := range pairs { - p := *pair - p.Key = EncodeV2Key(mode, p.Key) - ps = append(ps, &p) - } - return ps -} - -// EncodeRequest encodes req into specified API version format. -// NOTE: req is reused on retry. MUST encode on cloned request, other than overwrite the original. -func EncodeRequest(req *tikvrpc.Request) (*tikvrpc.Request, error) { - if req.GetApiVersion() == kvrpcpb.APIVersion_V1 { - return req, nil - } - - newReq := *req - - // TODO(iosmanthus): support transaction request types - switch req.Type { - case tikvrpc.CmdRawGet: - r := *req.RawGet() - r.Key = EncodeV2Key(ModeRaw, r.Key) - newReq.Req = &r - case tikvrpc.CmdRawBatchGet: - r := *req.RawBatchGet() - r.Keys = EncodeV2Keys(ModeRaw, r.Keys) - newReq.Req = &r - case tikvrpc.CmdRawPut: - r := *req.RawPut() - r.Key = EncodeV2Key(ModeRaw, r.Key) - newReq.Req = &r - case tikvrpc.CmdRawBatchPut: - r := *req.RawBatchPut() - r.Pairs = EncodeV2Pairs(ModeRaw, r.Pairs) - newReq.Req = &r - case tikvrpc.CmdRawDelete: - r := *req.RawDelete() - r.Key = EncodeV2Key(ModeRaw, r.Key) - newReq.Req = &r - case tikvrpc.CmdRawBatchDelete: - r := *req.RawBatchDelete() - r.Keys = EncodeV2Keys(ModeRaw, r.Keys) - newReq.Req = &r - case tikvrpc.CmdRawDeleteRange: - r := *req.RawDeleteRange() - r.StartKey, r.EndKey = EncodeV2Range(ModeRaw, r.StartKey, r.EndKey) - newReq.Req = &r - case tikvrpc.CmdRawScan: - r := *req.RawScan() - r.StartKey, r.EndKey = EncodeV2Range(ModeRaw, r.StartKey, r.EndKey) - newReq.Req = &r - case tikvrpc.CmdGetKeyTTL: - r := *req.RawGetKeyTTL() - r.Key = EncodeV2Key(ModeRaw, r.Key) - newReq.Req = &r - case tikvrpc.CmdRawCompareAndSwap: - r := *req.RawCompareAndSwap() - r.Key = EncodeV2Key(ModeRaw, r.Key) - newReq.Req = &r - case tikvrpc.CmdRawChecksum: - r := *req.RawChecksum() - r.Ranges = EncodeV2KeyRanges(ModeRaw, r.Ranges) - newReq.Req = &r - } - - return &newReq, nil -} - -// DecodeV2Key decodes API V2 encoded key into a normal user key. -func DecodeV2Key(mode Mode, key []byte) ([]byte, error) { - prefix := getV2Prefix(mode) - if !bytes.HasPrefix(key, prefix) { - return nil, errors.Errorf("invalid encoded key prefix: %q", key) - } - return key[len(prefix):], nil -} - -// DecodeV2Pairs decodes API V2 encoded pairs into normal user pairs. -func DecodeV2Pairs(mode Mode, pairs []*kvrpcpb.KvPair) ([]*kvrpcpb.KvPair, error) { - var ps []*kvrpcpb.KvPair - for _, pair := range pairs { - var err error - p := *pair - p.Key, err = DecodeV2Key(mode, p.Key) - if err != nil { - return nil, err - } - ps = append(ps, &p) - } - return ps, nil -} - -// DecodeResponse decode the resp in specified API version format. -func DecodeResponse(req *tikvrpc.Request, resp *tikvrpc.Response) (*tikvrpc.Response, error) { - if req.GetApiVersion() == kvrpcpb.APIVersion_V1 { - return resp, nil - } - - var err error - - switch req.Type { - case tikvrpc.CmdRawBatchGet: - r := resp.Resp.(*kvrpcpb.RawBatchGetResponse) - r.Pairs, err = DecodeV2Pairs(ModeRaw, r.Pairs) - case tikvrpc.CmdRawScan: - r := resp.Resp.(*kvrpcpb.RawScanResponse) - r.Kvs, err = DecodeV2Pairs(ModeRaw, r.Kvs) - } - - return resp, err -} diff --git a/internal/client/api_version_test.go b/internal/client/api_version_test.go deleted file mode 100644 index c68d88234a..0000000000 --- a/internal/client/api_version_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package client - -import ( - "testing" - - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/stretchr/testify/require" - "github.com/tikv/client-go/v2/tikvrpc" -) - -func TestEncodeRequest(t *testing.T) { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawGet, - Req: &kvrpcpb.RawGetRequest{ - Key: []byte("key"), - }, - } - req.ApiVersion = kvrpcpb.APIVersion_V2 - - r, err := EncodeRequest(req) - require.Nil(t, err) - require.Equal(t, append(APIV2RawKeyPrefix, []byte("key")...), r.RawGet().Key) - - r, err = EncodeRequest(req) - require.Nil(t, err) - require.Equal(t, append(APIV2RawKeyPrefix, []byte("key")...), r.RawGet().Key) -} - -func TestEncodeEncodeV2KeyRanges(t *testing.T) { - keyRanges := []*kvrpcpb.KeyRange{ - { - StartKey: []byte{}, - EndKey: []byte{}, - }, - { - StartKey: []byte{}, - EndKey: []byte{'z'}, - }, - { - StartKey: []byte{'a'}, - EndKey: []byte{}, - }, - { - StartKey: []byte{'a'}, - EndKey: []byte{'z'}, - }, - } - expect := []*kvrpcpb.KeyRange{ - { - StartKey: getV2Prefix(ModeRaw), - EndKey: getV2EndKey(ModeRaw), - }, - { - StartKey: getV2Prefix(ModeRaw), - EndKey: append(getV2Prefix(ModeRaw), 'z'), - }, - { - StartKey: append(getV2Prefix(ModeRaw), 'a'), - EndKey: getV2EndKey(ModeRaw), - }, - { - StartKey: append(getV2Prefix(ModeRaw), 'a'), - EndKey: append(getV2Prefix(ModeRaw), 'z'), - }, - } - encodedKeyRanges := EncodeV2KeyRanges(ModeRaw, keyRanges) - require.Equal(t, expect, encodedKeyRanges) -} diff --git a/internal/client/client.go b/internal/client/client.go index bd834a54ed..4a81024f64 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -57,6 +57,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/tikv/client-go/v2/config" tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/logutil" "github.com/tikv/client-go/v2/metrics" "github.com/tikv/client-go/v2/tikvrpc" @@ -258,6 +259,7 @@ type option struct { gRPCDialOptions []grpc.DialOption security config.Security dialTimeout time.Duration + codec apicodec.Codec } // Opt is the option for the client. @@ -277,6 +279,13 @@ func WithGRPCDialOptions(grpcDialOptions ...grpc.DialOption) Opt { } } +// WithCodec is used to set RPCClient's codec. +func WithCodec(codec apicodec.Codec) Opt { + return func(c *option) { + c.codec = codec + } +} + // RPCClient is RPC client struct. // TODO: Add flow control between RPC clients in TiDB ond RPC servers in TiKV. // Since we use shared client connection to communicate to the same TiKV, it's possible @@ -533,7 +542,11 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R // SendRequest sends a Request to server and receives Response. func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { - req, err := EncodeRequest(req) + if c.option == nil || c.option.codec == nil { + return c.sendRequest(ctx, addr, req, timeout) + } + + req, err := c.option.codec.EncodeRequest(req) if err != nil { return nil, err } @@ -541,7 +554,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R if err != nil { return nil, err } - return DecodeResponse(req, resp) + return c.option.codec.DecodeResponse(req, resp) } func (c *RPCClient) getCopStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connArray *connArray) (*tikvrpc.Response, error) { diff --git a/internal/locate/pd_codec.go b/internal/locate/pd_codec.go index 0b7dcea70e..848659fcfa 100644 --- a/internal/locate/pd_codec.go +++ b/internal/locate/pd_codec.go @@ -37,9 +37,10 @@ package locate import ( "context" - "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/keyspacepb" + "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pkg/errors" - "github.com/tikv/client-go/v2/util/codec" + "github.com/tikv/client-go/v2/internal/apicodec" pd "github.com/tikv/pd/client" ) @@ -48,51 +49,80 @@ var _ pd.Client = &CodecPDClient{} // CodecPDClient wraps a PD Client to decode the encoded keys in region meta. type CodecPDClient struct { pd.Client + codec apicodec.Codec } -// NewCodeCPDClient creates a CodecPDClient. -func NewCodeCPDClient(client pd.Client) *CodecPDClient { - return &CodecPDClient{client} +// NewCodecPDClient creates a CodecPDClient in API v1. +func NewCodecPDClient(mode apicodec.Mode, client pd.Client) *CodecPDClient { + codec := apicodec.NewCodecV1(mode) + return &CodecPDClient{client, codec} +} + +// NewCodecPDClientWithKeyspace creates a CodecPDClient in API v2 with keyspace name. +func NewCodecPDClientWithKeyspace(mode apicodec.Mode, client pd.Client, keyspace string) (*CodecPDClient, error) { + id, err := GetKeyspaceID(client, keyspace) + if err != nil { + return nil, err + } + codec, err := apicodec.NewCodecV2(mode, id) + if err != nil { + return nil, err + } + + return &CodecPDClient{client, codec}, nil +} + +func GetKeyspaceID(client pd.Client, name string) (uint32, error) { + meta, err := client.LoadKeyspace(context.Background(), apicodec.BuildKeyspaceName(name)) + if err != nil { + return 0, err + } + // If keyspace is not enabled, user should not be able to connect. + if meta.State != keyspacepb.KeyspaceState_ENABLED { + return 0, errors.Errorf("keyspace %s not enabled", name) + } + return meta.Id, nil +} + +// GetCodec returns CodecPDClient's codec. +func (c *CodecPDClient) GetCodec() apicodec.Codec { + return c.codec } // GetRegion encodes the key before send requests to pd-server and decodes the // returned StartKey && EndKey from pd-server. func (c *CodecPDClient) GetRegion(ctx context.Context, key []byte, opts ...pd.GetRegionOption) (*pd.Region, error) { - encodedKey := codec.EncodeBytes([]byte(nil), key) + encodedKey := c.codec.EncodeRegionKey(key) region, err := c.Client.GetRegion(ctx, encodedKey, opts...) - return processRegionResult(region, err) + return c.processRegionResult(region, err) } // GetPrevRegion encodes the key before send requests to pd-server and decodes the // returned StartKey && EndKey from pd-server. func (c *CodecPDClient) GetPrevRegion(ctx context.Context, key []byte, opts ...pd.GetRegionOption) (*pd.Region, error) { - encodedKey := codec.EncodeBytes([]byte(nil), key) + encodedKey := c.codec.EncodeRegionKey(key) region, err := c.Client.GetPrevRegion(ctx, encodedKey, opts...) - return processRegionResult(region, err) + return c.processRegionResult(region, err) } // GetRegionByID encodes the key before send requests to pd-server and decodes the // returned StartKey && EndKey from pd-server. func (c *CodecPDClient) GetRegionByID(ctx context.Context, regionID uint64, opts ...pd.GetRegionOption) (*pd.Region, error) { region, err := c.Client.GetRegionByID(ctx, regionID, opts...) - return processRegionResult(region, err) + return c.processRegionResult(region, err) } // ScanRegions encodes the key before send requests to pd-server and decodes the // returned StartKey && EndKey from pd-server. func (c *CodecPDClient) ScanRegions(ctx context.Context, startKey []byte, endKey []byte, limit int) ([]*pd.Region, error) { - startKey = codec.EncodeBytes([]byte(nil), startKey) - if len(endKey) > 0 { - endKey = codec.EncodeBytes([]byte(nil), endKey) - } - + startKey, endKey = c.codec.EncodeRegionRange(startKey, endKey) regions, err := c.Client.ScanRegions(ctx, startKey, endKey, limit) if err != nil { return nil, errors.WithStack(err) } for _, region := range regions { if region != nil { - err = decodeRegionKeyInPlace(region) + err = c.decodeRegionKeyInPlace(region) if err != nil { return nil, err } @@ -101,80 +131,47 @@ func (c *CodecPDClient) ScanRegions(ctx context.Context, startKey []byte, endKey return regions, nil } -func processRegionResult(region *pd.Region, err error) (*pd.Region, error) { +// SplitRegions split regions by given split keys +func (c *CodecPDClient) SplitRegions(ctx context.Context, splitKeys [][]byte, opts ...pd.RegionsOption) (*pdpb.SplitRegionsResponse, error) { + var keys [][]byte + for i := range splitKeys { + keys = append(keys, c.codec.EncodeRegionKey(splitKeys[i])) + } + return c.Client.SplitRegions(ctx, keys, opts...) +} + +func (c *CodecPDClient) processRegionResult(region *pd.Region, err error) (*pd.Region, error) { if err != nil { return nil, errors.WithStack(err) } if region == nil || region.Meta == nil { return nil, nil } - err = decodeRegionKeyInPlace(region) + err = c.decodeRegionKeyInPlace(region) if err != nil { return nil, err } return region, nil } -// decodeError happens if the region range key is not well-formed. -// It indicates TiKV has bugs and the client can't handle such a case, -// so it should report the error to users soon. -type decodeError struct { - error -} - -func isDecodeError(err error) bool { - _, ok := errors.Cause(err).(*decodeError) - if !ok { - _, ok = errors.Cause(err).(decodeError) - } - return ok -} - -func decodeRegionKeyInPlace(r *pd.Region) error { - if len(r.Meta.StartKey) != 0 { - _, decoded, err := codec.DecodeBytes(r.Meta.StartKey, nil) - if err != nil { - return errors.WithStack(&decodeError{err}) - } - r.Meta.StartKey = decoded - } - if len(r.Meta.EndKey) != 0 { - _, decoded, err := codec.DecodeBytes(r.Meta.EndKey, nil) - if err != nil { - return errors.WithStack(&decodeError{err}) - } - r.Meta.EndKey = decoded +func (c *CodecPDClient) decodeRegionKeyInPlace(r *pd.Region) error { + decodedStart, decodedEnd, err := c.codec.DecodeRegionRange(r.Meta.StartKey, r.Meta.EndKey) + if err != nil { + return err } + r.Meta.StartKey = decodedStart + r.Meta.EndKey = decodedEnd if r.Buckets != nil { for i, k := range r.Buckets.Keys { if len(k) == 0 { continue } - _, decoded, err := codec.DecodeBytes(k, nil) + decoded, err := c.codec.DecodeRegionKey(k) if err != nil { - return errors.WithStack(&decodeError{err}) + return errors.WithStack(err) } r.Buckets.Keys[i] = decoded } } return nil } - -func decodeRegionMetaKeyWithShallowCopy(r *metapb.Region) (*metapb.Region, error) { - nr := *r - if len(r.StartKey) != 0 { - _, decoded, err := codec.DecodeBytes(r.StartKey, nil) - if err != nil { - return nil, err - } - nr.StartKey = decoded - } - if len(r.EndKey) != 0 { - _, decoded, err := codec.DecodeBytes(r.EndKey, nil) - if err != nil { - return nil, err - } - nr.EndKey = decoded - } - return &nr, nil -} diff --git a/internal/locate/pd_codec_v2.go b/internal/locate/pd_codec_v2.go deleted file mode 100644 index 0b5953dd1f..0000000000 --- a/internal/locate/pd_codec_v2.go +++ /dev/null @@ -1,110 +0,0 @@ -package locate - -import ( - "context" - - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/kvproto/pkg/pdpb" - "github.com/tikv/client-go/v2/internal/client" - "github.com/tikv/client-go/v2/util/codec" - pd "github.com/tikv/pd/client" -) - -// CodecPDClientV2 wraps a PD Client to decode the region meta in API v2 manner. -type CodecPDClientV2 struct { - *CodecPDClient - mode client.Mode -} - -// NewCodecPDClientV2 create a CodecPDClientV2. -func NewCodecPDClientV2(client pd.Client, mode client.Mode) *CodecPDClientV2 { - codecClient := NewCodeCPDClient(client) - return &CodecPDClientV2{codecClient, mode} -} - -// GetRegion encodes the key before send requests to pd-server and decodes the -// returned StartKey && EndKey from pd-server. -func (c *CodecPDClientV2) GetRegion(ctx context.Context, key []byte, opts ...pd.GetRegionOption) (*pd.Region, error) { - queryKey := client.EncodeV2Key(c.mode, key) - region, err := c.CodecPDClient.GetRegion(ctx, queryKey, opts...) - return c.processRegionResult(region, err) -} - -// GetPrevRegion encodes the key before send requests to pd-server and decodes the -// returned StartKey && EndKey from pd-server. -func (c *CodecPDClientV2) GetPrevRegion(ctx context.Context, key []byte, opts ...pd.GetRegionOption) (*pd.Region, error) { - queryKey := client.EncodeV2Key(c.mode, key) - region, err := c.CodecPDClient.GetPrevRegion(ctx, queryKey, opts...) - return c.processRegionResult(region, err) -} - -// GetRegionByID encodes the key before send requests to pd-server and decodes the -// returned StartKey && EndKey from pd-server. -func (c *CodecPDClientV2) GetRegionByID(ctx context.Context, regionID uint64, opts ...pd.GetRegionOption) (*pd.Region, error) { - region, err := c.CodecPDClient.GetRegionByID(ctx, regionID, opts...) - return c.processRegionResult(region, err) -} - -// ScanRegions encodes the key before send requests to pd-server and decodes the -// returned StartKey && EndKey from pd-server. -func (c *CodecPDClientV2) ScanRegions(ctx context.Context, startKey []byte, endKey []byte, limit int) ([]*pd.Region, error) { - start, end := client.EncodeV2Range(c.mode, startKey, endKey) - regions, err := c.CodecPDClient.ScanRegions(ctx, start, end, limit) - if err != nil { - return nil, err - } - for i := range regions { - region, _ := c.processRegionResult(regions[i], nil) - regions[i] = region - } - return regions, nil -} - -// SplitRegions split regions by given split keys -func (c *CodecPDClientV2) SplitRegions(ctx context.Context, splitKeys [][]byte, opts ...pd.RegionsOption) (*pdpb.SplitRegionsResponse, error) { - var keys [][]byte - for i := range splitKeys { - withPrefix := client.EncodeV2Key(c.mode, splitKeys[i]) - keys = append(keys, codec.EncodeBytes(nil, withPrefix)) - } - return c.CodecPDClient.SplitRegions(ctx, keys, opts...) -} - -func (c *CodecPDClientV2) processRegionResult(region *pd.Region, err error) (*pd.Region, error) { - if err != nil { - return nil, err - } - - if region != nil { - // TODO(@iosmanthus): enable buckets support. - region.Buckets = nil - - region.Meta.StartKey, region.Meta.EndKey = - client.MapV2RangeToV1(c.mode, region.Meta.StartKey, region.Meta.EndKey) - } - - return region, nil -} - -func (c *CodecPDClientV2) decodeRegionWithShallowCopy(region *metapb.Region) (*metapb.Region, error) { - var err error - newRegion := *region - - if len(region.StartKey) > 0 { - _, newRegion.StartKey, err = codec.DecodeBytes(region.StartKey, nil) - } - if err != nil { - return nil, err - } - - if len(region.EndKey) > 0 { - _, newRegion.EndKey, err = codec.DecodeBytes(region.EndKey, nil) - } - if err != nil { - return nil, err - } - - newRegion.StartKey, newRegion.EndKey = client.MapV2RangeToV1(c.mode, newRegion.StartKey, newRegion.EndKey) - - return &newRegion, nil -} diff --git a/internal/locate/region_cache.go b/internal/locate/region_cache.go index a71d274816..06c0ebe7e0 100644 --- a/internal/locate/region_cache.go +++ b/internal/locate/region_cache.go @@ -51,12 +51,12 @@ import ( "github.com/gogo/protobuf/proto" "github.com/google/btree" "github.com/opentracing/opentracing-go" - "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pkg/errors" "github.com/stathat/consistent" "github.com/tikv/client-go/v2/config" tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/logutil" "github.com/tikv/client-go/v2/internal/retry" @@ -363,7 +363,7 @@ func (r *Region) isValid() bool { // purposes only. type RegionCache struct { pdClient pd.Client - apiVersion kvrpcpb.APIVersion + codec apicodec.Codec enableForwarding bool mu struct { @@ -400,11 +400,9 @@ func NewRegionCache(pdClient pd.Client) *RegionCache { pdClient: pdClient, } - switch pdClient.(type) { - case *CodecPDClientV2: - c.apiVersion = kvrpcpb.APIVersion_V2 - default: - c.apiVersion = kvrpcpb.APIVersion_V1 + c.codec = apicodec.NewCodecV1(apicodec.ModeRaw) + if codecPDClient, ok := pdClient.(*CodecPDClient); ok { + c.codec = codecPDClient.GetCodec() } c.mu.regions = make(map[RegionVerID]*Region) @@ -1456,7 +1454,7 @@ func (c *RegionCache) loadRegion(bo *retry.Backoffer, key []byte, isEndKey bool) metrics.RegionCacheCounterWithGetCacheMissOK.Inc() } if err != nil { - if isDecodeError(err) { + if apicodec.IsDecodeError(err) { return nil, errors.Errorf("failed to decode region range key, key: %q, err: %v", util.HexRegionKeyStr(key), err) } backoffErr = errors.Errorf("loadRegion from PD failed, key: %q, err: %v", util.HexRegionKeyStr(key), err) @@ -1503,7 +1501,7 @@ func (c *RegionCache) loadRegionByID(bo *retry.Backoffer, regionID uint64) (*Reg metrics.RegionCacheCounterWithGetRegionByIDOK.Inc() } if err != nil { - if isDecodeError(err) { + if apicodec.IsDecodeError(err) { return nil, errors.Errorf("failed to decode region range key, regionID: %q, err: %v", regionID, err) } backoffErr = errors.Errorf("loadRegion from PD failed, regionID: %v, err: %v", regionID, err) @@ -1564,7 +1562,7 @@ func (c *RegionCache) scanRegions(bo *retry.Backoffer, startKey, endKey []byte, regionsInfo, err := c.pdClient.ScanRegions(ctx, startKey, endKey, limit) metrics.LoadRegionCacheHistogramWithRegions.Observe(time.Since(start).Seconds()) if err != nil { - if isDecodeError(err) { + if apicodec.IsDecodeError(err) { return nil, errors.Errorf("failed to decode region range key, startKey: %q, limit: %q, err: %v", util.HexRegionKeyStr(startKey), limit, err) } metrics.RegionCacheCounterWithScanRegionsError.Inc() @@ -1755,20 +1753,6 @@ func (c *RegionCache) OnRegionEpochNotMatch(bo *retry.Backoffer, ctx *RPCContext newRegions := make([]*Region, 0, len(currentRegions)) // If the region epoch is not ahead of TiKV's, replace region meta in region cache. for _, meta := range currentRegions { - var err error - oldMeta := meta - switch c.pdClient.(type) { - case *CodecPDClient: - // Can't modify currentRegions in this function because it can be shared by - // multiple goroutines, refer to https://github.com/pingcap/tidb/pull/16962. - if meta, err = decodeRegionMetaKeyWithShallowCopy(meta); err != nil { - return false, errors.Errorf("newRegion's range key is not encoded: %v, %v", oldMeta, err) - } - case *CodecPDClientV2: - if meta, err = c.pdClient.(*CodecPDClientV2).decodeRegionWithShallowCopy(meta); err != nil { - return false, errors.Errorf("newRegion's range key is not encoded: %v, %v", oldMeta, err) - } - } // TODO(youjiali1995): new regions inherit old region's buckets now. Maybe we should make EpochNotMatch error // carry buckets information. Can it bring much overhead? region, err := newRegion(bo, c, &pd.Region{Meta: meta, Buckets: buckets}) diff --git a/internal/locate/region_cache_test.go b/internal/locate/region_cache_test.go index 3e92ac94ea..8d239b5e26 100644 --- a/internal/locate/region_cache_test.go +++ b/internal/locate/region_cache_test.go @@ -47,6 +47,7 @@ import ( "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/stretchr/testify/suite" + "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/mockstore/mocktikv" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/kv" @@ -79,7 +80,7 @@ func (s *testRegionCacheSuite) SetupTest() { s.store2 = storeIDs[1] s.peer1 = peerIDs[0] s.peer2 = peerIDs[1] - pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster)} + pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)} s.cache = NewRegionCache(pdCli) s.bo = retry.NewBackofferWithVars(context.Background(), 5000, nil) } @@ -956,7 +957,7 @@ func (s *testRegionCacheSuite) TestReconnect() { func (s *testRegionCacheSuite) TestRegionEpochAheadOfTiKV() { // Create a separated region cache to do this test. - pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster)} + pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)} cache := NewRegionCache(pdCli) defer cache.Close() diff --git a/internal/locate/region_request.go b/internal/locate/region_request.go index 6b6f9b8fd1..f02c6ab291 100644 --- a/internal/locate/region_request.go +++ b/internal/locate/region_request.go @@ -192,7 +192,7 @@ func RecordRegionRequestRuntimeStats(stats map[tikvrpc.CmdType]*RPCRuntimeStats, func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender { return &RegionRequestSender{ regionCache: regionCache, - apiVersion: regionCache.apiVersion, + apiVersion: regionCache.codec.GetAPIVersion(), client: client, } } @@ -1482,7 +1482,7 @@ func (s *RegionRequestSender) onRegionError(bo *retry.Backoffer, ctx *RPCContext } if regionErr.GetKeyNotInRegion() != nil { - logutil.BgLogger().Debug("tikv reports `KeyNotInRegion`", zap.Stringer("req", req), zap.Stringer("ctx", ctx)) + logutil.BgLogger().Error("tikv reports `KeyNotInRegion`", zap.Stringer("req", req), zap.Stringer("ctx", ctx)) s.regionCache.InvalidateCachedRegion(ctx.Region) return false, nil } diff --git a/internal/locate/region_request3_test.go b/internal/locate/region_request3_test.go index a9f495bf5f..af5e3b3de8 100644 --- a/internal/locate/region_request3_test.go +++ b/internal/locate/region_request3_test.go @@ -47,6 +47,7 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/suite" tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/mockstore/mocktikv" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/kv" @@ -75,7 +76,7 @@ func (s *testRegionRequestToThreeStoresSuite) SetupTest() { s.mvccStore = mocktikv.MustNewMVCCStore() s.cluster = mocktikv.NewCluster(s.mvccStore) s.storeIDs, s.peerIDs, s.regionID, s.leaderPeer = mocktikv.BootstrapWithMultiStores(s.cluster, 3) - pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster)} + pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)} s.cache = NewRegionCache(pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) diff --git a/internal/locate/region_request_test.go b/internal/locate/region_request_test.go index 8692acc705..b34eb43335 100644 --- a/internal/locate/region_request_test.go +++ b/internal/locate/region_request_test.go @@ -51,6 +51,7 @@ import ( "github.com/pingcap/kvproto/pkg/tikvpb" "github.com/pkg/errors" "github.com/stretchr/testify/suite" + "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/mockstore/mocktikv" "github.com/tikv/client-go/v2/internal/retry" @@ -78,7 +79,7 @@ func (s *testRegionRequestToSingleStoreSuite) SetupTest() { s.mvccStore = mocktikv.MustNewMVCCStore() s.cluster = mocktikv.NewCluster(s.mvccStore) s.store, s.peer, s.region = mocktikv.BootstrapWithSingleStore(s.cluster) - pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster)} + pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)} s.cache = NewRegionCache(pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) diff --git a/rawkv/rawkv.go b/rawkv/rawkv.go index ed8a6092e3..4be7a3fc8d 100644 --- a/rawkv/rawkv.go +++ b/rawkv/rawkv.go @@ -43,11 +43,13 @@ import ( "github.com/pkg/errors" "github.com/tikv/client-go/v2/config" tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/kvrpc" "github.com/tikv/client-go/v2/internal/locate" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/tikvrpc" pd "github.com/tikv/pd/client" "google.golang.org/grpc" @@ -123,6 +125,7 @@ type Client struct { apiVersion kvrpcpb.APIVersion clusterID uint64 regionCache *locate.RegionCache + codec apicodec.Codec pdClient pd.Client rpcClient client.Client cf string @@ -134,6 +137,7 @@ type option struct { security config.Security gRPCDialOptions []grpc.DialOption pdOptions []pd.ClientOption + keyspace string } // ClientOpt is factory to set the client options. @@ -167,6 +171,13 @@ func WithAPIVersion(apiVersion kvrpcpb.APIVersion) ClientOpt { } } +// WithKeyspace is used to set the keyspace Name. +func WithKeyspace(name string) ClientOpt { + return func(o *option) { + o.keyspace = name + } +} + // SetAtomicForCAS sets atomic mode for CompareAndSwap func (c *Client) SetAtomicForCAS(b bool) *Client { c.atomic = b @@ -191,26 +202,45 @@ func NewClientWithOpts(ctx context.Context, pdAddrs []string, opts ...ClientOpt) o(opt) } + // Use an unwrapped PDClient to obtain keyspace meta. pdCli, err := pd.NewClient(pdAddrs, pd.SecurityOption{ CAPath: opt.security.ClusterSSLCA, CertPath: opt.security.ClusterSSLCert, KeyPath: opt.security.ClusterSSLKey, }, opt.pdOptions...) - if err != nil { return nil, errors.WithStack(err) } - if opt.apiVersion == kvrpcpb.APIVersion_V2 { - pdCli = locate.NewCodecPDClientV2(pdCli, client.ModeRaw) + // Build a CodecPDClient + var codecCli *tikv.CodecPDClient + + switch opt.apiVersion { + case kvrpcpb.APIVersion_V1, kvrpcpb.APIVersion_V1TTL: + codecCli = locate.NewCodecPDClient(tikv.ModeRaw, pdCli) + case kvrpcpb.APIVersion_V2: + codecCli, err = tikv.NewCodecPDClientWithKeyspace(tikv.ModeRaw, pdCli, opt.keyspace) + if err != nil { + return nil, err + } + default: + return nil, errors.Errorf("unknown api version: %d", opt.apiVersion) } + pdCli = codecCli + + rpcCli := client.NewRPCClient( + client.WithSecurity(opt.security), + client.WithGRPCDialOptions(opt.gRPCDialOptions...), + client.WithCodec(codecCli.GetCodec()), + ) + return &Client{ apiVersion: opt.apiVersion, clusterID: pdCli.GetClusterID(ctx), regionCache: locate.NewRegionCache(pdCli), pdClient: pdCli, - rpcClient: client.NewRPCClient(client.WithSecurity(opt.security), client.WithGRPCDialOptions(opt.gRPCDialOptions...)), + rpcClient: rpcCli, }, nil } diff --git a/tikv/client.go b/tikv/client.go index d25fed10ca..90e10bc1ad 100644 --- a/tikv/client.go +++ b/tikv/client.go @@ -36,6 +36,7 @@ package tikv import ( "github.com/tikv/client-go/v2/config" + "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/client" ) @@ -51,6 +52,11 @@ func WithSecurity(security config.Security) ClientOpt { return client.WithSecurity(security) } +// WithCodec is used to set client codec. +func WithCodec(codec apicodec.Codec) ClientOpt { + return client.WithCodec(codec) +} + // Timeout durations. const ( ReadTimeoutMedium = client.ReadTimeoutMedium diff --git a/tikv/kv.go b/tikv/kv.go index 89a0ed5626..55a7e659fc 100644 --- a/tikv/kv.go +++ b/tikv/kv.go @@ -202,7 +202,7 @@ func NewKVStore(uuid string, pdClient pd.Client, spkv SafePointKV, tikvclient Cl return store, nil } -// NewPDClient creates pd.Client with pdAddrs. +// NewPDClient returns an unwrapped pd client. func NewPDClient(pdAddrs []string) (pd.Client, error) { cfg := config.GetGlobalConfig() // init pd-client @@ -222,8 +222,7 @@ func NewPDClient(pdAddrs []string) (pd.Client, error) { if err != nil { return nil, errors.WithStack(err) } - pdClient := &CodecPDClient{Client: util.InterceptedPDClient{Client: pdCli}} - return pdClient, nil + return pdCli, nil } // EnableTxnLocalLatches enables txn latch. It should be called before using @@ -592,6 +591,7 @@ var _ = NewLockResolver // NewLockResolver creates a LockResolver. // It is exported for other pkg to use. For instance, binlog service needs // to determine a transaction's commit state. +// TODO(iosmanthus): support api v2 func NewLockResolver(etcdAddrs []string, security config.Security, opts ...pd.ClientOption) (*txnlock.LockResolver, error) { pdCli, err := pd.NewClient(etcdAddrs, pd.SecurityOption{ CAPath: security.ClusterSSLCA, @@ -614,7 +614,7 @@ func NewLockResolver(etcdAddrs []string, security config.Security, opts ...pd.Cl return nil, err } - s, err := NewKVStore(uuid, locate.NewCodeCPDClient(pdCli), spkv, client.NewRPCClient(WithSecurity(security))) + s, err := NewKVStore(uuid, locate.NewCodecPDClient(ModeTxn, pdCli), spkv, client.NewRPCClient(WithSecurity(security))) if err != nil { return nil, err } diff --git a/tikv/region.go b/tikv/region.go index 4fd0a25642..177fc8e610 100644 --- a/tikv/region.go +++ b/tikv/region.go @@ -38,6 +38,7 @@ import ( "time" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/locate" "github.com/tikv/client-go/v2/tikvrpc" @@ -91,21 +92,38 @@ type RPCRuntimeStats = locate.RPCRuntimeStats // CodecPDClient wraps a PD Client to decode the encoded keys in region meta. type CodecPDClient = locate.CodecPDClient -// CodecPDClientV2 wraps a PD Client to decode the region meta in API v2 manner. -type CodecPDClientV2 = locate.CodecPDClientV2 +// NewCodecPDClient is a constructor for CodecPDClient +var NewCodecPDClient = locate.NewCodecPDClient -// NewCodecPDClientV2 is a constructor for CodecPDClientV2 -var NewCodecPDClientV2 = locate.NewCodecPDClientV2 +// NewCodecPDClientWithKeyspace creates a CodecPDClient in API v2 with keyspace name. +var NewCodecPDClientWithKeyspace = locate.NewCodecPDClientWithKeyspace + +// NewCodecV1 is a constructor for v1 Codec. +var NewCodecV1 = apicodec.NewCodecV1 + +// NewCodecV2 is a constructor for v2 Codec. +var NewCodecV2 = apicodec.NewCodecV2 + +// Codec is responsible for encode/decode requests. +type Codec = apicodec.Codec + +var DecodeKey = apicodec.DecodeKey + +// DefaultKeyspaceID is the keyspaceID of the default keyspace. +var DefaultKeyspaceID = apicodec.DefaultKeyspaceID + +// DefaultKeyspaceName is the name of the default keyspace. +var DefaultKeyspaceName = apicodec.DefaultKeyspaceName // Mode represents the operation mode of a request, export client.Mode -type Mode = client.Mode +type Mode = apicodec.Mode var ( // ModeRaw represent a raw operation in TiKV, export client.ModeRaw - ModeRaw Mode = client.ModeRaw + ModeRaw Mode = apicodec.ModeRaw // ModeTxn represent a transaction operation in TiKV, export client.ModeTxn - ModeTxn Mode = client.ModeTxn + ModeTxn Mode = apicodec.ModeTxn ) // RecordRegionRequestRuntimeStats records request runtime stats. diff --git a/tikv/test_util.go b/tikv/test_util.go index f283b9511c..805c2a7aad 100644 --- a/tikv/test_util.go +++ b/tikv/test_util.go @@ -35,18 +35,48 @@ package tikv import ( + "context" + "time" + "github.com/google/uuid" + "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/locate" + "github.com/tikv/client-go/v2/tikvrpc" pd "github.com/tikv/pd/client" ) +// CodecClient warps Client to provide codec encode and decode. +type CodecClient struct { + Client + codec apicodec.Codec +} + +// SendRequest uses codec to encode request before send, and decode response before return. +func (c *CodecClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { + req, err := c.codec.EncodeRequest(req) + if err != nil { + return nil, err + } + resp, err := c.Client.SendRequest(ctx, addr, req, timeout) + if err != nil { + return nil, err + } + return c.codec.DecodeResponse(req, resp) +} + // NewTestTiKVStore creates a test store with Option func NewTestTiKVStore(client Client, pdClient pd.Client, clientHijack func(Client) Client, pdClientHijack func(pd.Client) pd.Client, txnLocalLatches uint) (*KVStore, error) { + codec := apicodec.NewCodecV1(apicodec.ModeTxn) + client = &CodecClient{ + Client: client, + codec: codec, + } + pdCli := pd.Client(locate.NewCodecPDClient(ModeTxn, pdClient)) + if clientHijack != nil { client = clientHijack(client) } - pdCli := pd.Client(locate.NewCodeCPDClient(pdClient)) if pdClientHijack != nil { pdCli = pdClientHijack(pdCli) } diff --git a/txnkv/client.go b/txnkv/client.go index 3b3e734881..cfb40f57ea 100644 --- a/txnkv/client.go +++ b/txnkv/client.go @@ -18,11 +18,14 @@ import ( "context" "fmt" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pkg/errors" "github.com/tikv/client-go/v2/config" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/txnkv/transaction" + "github.com/tikv/client-go/v2/util" ) // Client is a txn client. @@ -30,13 +33,60 @@ type Client struct { *tikv.KVStore } +type option struct { + apiVersion kvrpcpb.APIVersion + keyspaceName string +} + +// ClientOpt is factory to set the client options. +type ClientOpt func(*option) + +// WithKeyspace is used to set client's keyspace. +func WithKeyspace(keyspaceName string) ClientOpt { + return func(opt *option) { + opt.keyspaceName = keyspaceName + } +} + +// WithAPIVersion is used to set client's apiVersion. +func WithAPIVersion(apiVersion kvrpcpb.APIVersion) ClientOpt { + return func(opt *option) { + opt.apiVersion = apiVersion + } +} + // NewClient creates a txn client with pdAddrs. -func NewClient(pdAddrs []string) (*Client, error) { - cfg := config.GetGlobalConfig() +func NewClient(pdAddrs []string, opts ...ClientOpt) (*Client, error) { + // Apply options. + opt := &option{} + for _, o := range opts { + o(opt) + } + // Use an unwrapped PDClient to obtain keyspace meta. pdClient, err := tikv.NewPDClient(pdAddrs) if err != nil { - return nil, err + return nil, errors.WithStack(err) + } + + pdClient = util.InterceptedPDClient{Client: pdClient} + + // Construct codec from options. + var codecCli *tikv.CodecPDClient + switch opt.apiVersion { + case kvrpcpb.APIVersion_V1: + codecCli = tikv.NewCodecPDClient(tikv.ModeTxn, pdClient) + case kvrpcpb.APIVersion_V2: + codecCli, err = tikv.NewCodecPDClientWithKeyspace(tikv.ModeTxn, pdClient, opt.keyspaceName) + if err != nil { + return nil, err + } + default: + return nil, errors.Errorf("unknown api version: %d", opt.apiVersion) } + + pdClient = codecCli + + cfg := config.GetGlobalConfig() // init uuid uuid := fmt.Sprintf("tikv-%v", pdClient.GetClusterID(context.TODO())) tlsConfig, err := cfg.Security.ToTLSConfig() @@ -49,7 +99,9 @@ func NewClient(pdAddrs []string) (*Client, error) { return nil, err } - s, err := tikv.NewKVStore(uuid, pdClient, spkv, tikv.NewRPCClient(tikv.WithSecurity(cfg.Security))) + rpcClient := tikv.NewRPCClient(tikv.WithSecurity(cfg.Security), tikv.WithCodec(codecCli.GetCodec())) + + s, err := tikv.NewKVStore(uuid, pdClient, spkv, rpcClient) if err != nil { return nil, err } diff --git a/txnkv/txnsnapshot/scan.go b/txnkv/txnsnapshot/scan.go index 2e24d52fce..eee3c0e1e6 100644 --- a/txnkv/txnsnapshot/scan.go +++ b/txnkv/txnsnapshot/scan.go @@ -214,7 +214,8 @@ func (s *Scanner) getData(bo *retry.Backoffer) error { if !s.reverse { reqEndKey = s.endKey - if len(reqEndKey) > 0 && len(loc.EndKey) > 0 && bytes.Compare(loc.EndKey, reqEndKey) < 0 { + if len(reqEndKey) == 0 || + (len(loc.EndKey) > 0 && bytes.Compare(loc.EndKey, reqEndKey) < 0) { reqEndKey = loc.EndKey } } else { From 1c928737661ffd97cbc5193b8cb557dc8b9cc7d6 Mon Sep 17 00:00:00 2001 From: David <8039876+AmoebaProtozoa@users.noreply.github.com> Date: Tue, 27 Dec 2022 11:42:04 +0800 Subject: [PATCH 2/8] fix lint Signed-off-by: David <8039876+AmoebaProtozoa@users.noreply.github.com> --- internal/apicodec/codec.go | 1 + internal/apicodec/codec_v2.go | 2 +- internal/locate/pd_codec.go | 1 + rawkv/rawkv.go | 2 -- tikv/region.go | 1 + 5 files changed, 4 insertions(+), 3 deletions(-) diff --git a/internal/apicodec/codec.go b/internal/apicodec/codec.go index 78c1fc35b8..0267467ac9 100644 --- a/internal/apicodec/codec.go +++ b/internal/apicodec/codec.go @@ -45,6 +45,7 @@ type Codec interface { DecodeKey(encoded []byte) ([]byte, error) } +// DecodeKey split a key to it's keyspace prefix and actual key. func DecodeKey(encoded []byte, version kvrpcpb.APIVersion) ([]byte, []byte, error) { switch version { case kvrpcpb.APIVersion_V1: diff --git a/internal/apicodec/codec_v2.go b/internal/apicodec/codec_v2.go index 71fbd4984f..0ae4f8c4ff 100644 --- a/internal/apicodec/codec_v2.go +++ b/internal/apicodec/codec_v2.go @@ -26,7 +26,7 @@ var ( keyspacePrefixLen int = 4 // errKeyOutOfBound happens when key to be decoded lies outside the keyspace's range. - errKeyOutOfBound = errors.New("given key does not belong to the keyspace.") + errKeyOutOfBound = errors.New("given key does not belong to the keyspace") ) // BuildKeyspaceName builds a keyspace name diff --git a/internal/locate/pd_codec.go b/internal/locate/pd_codec.go index 848659fcfa..de995b6b6c 100644 --- a/internal/locate/pd_codec.go +++ b/internal/locate/pd_codec.go @@ -72,6 +72,7 @@ func NewCodecPDClientWithKeyspace(mode apicodec.Mode, client pd.Client, keyspace return &CodecPDClient{client, codec}, nil } +// GetKeyspaceID attempts to retrieve keyspace ID corresponding to the given keyspace name from PD. func GetKeyspaceID(client pd.Client, name string) (uint32, error) { meta, err := client.LoadKeyspace(context.Background(), apicodec.BuildKeyspaceName(name)) if err != nil { diff --git a/rawkv/rawkv.go b/rawkv/rawkv.go index 4be7a3fc8d..e9e8bd546a 100644 --- a/rawkv/rawkv.go +++ b/rawkv/rawkv.go @@ -43,7 +43,6 @@ import ( "github.com/pkg/errors" "github.com/tikv/client-go/v2/config" tikverr "github.com/tikv/client-go/v2/error" - "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/kvrpc" "github.com/tikv/client-go/v2/internal/locate" @@ -125,7 +124,6 @@ type Client struct { apiVersion kvrpcpb.APIVersion clusterID uint64 regionCache *locate.RegionCache - codec apicodec.Codec pdClient pd.Client rpcClient client.Client cf string diff --git a/tikv/region.go b/tikv/region.go index 177fc8e610..188096e6b2 100644 --- a/tikv/region.go +++ b/tikv/region.go @@ -107,6 +107,7 @@ var NewCodecV2 = apicodec.NewCodecV2 // Codec is responsible for encode/decode requests. type Codec = apicodec.Codec +// DecodeKey is used to split a given key to it's APIv2 prefix and actual key. var DecodeKey = apicodec.DecodeKey // DefaultKeyspaceID is the keyspaceID of the default keyspace. From 9d9b98f98d9d4ecefeb3a57b4b6f1a23bf667b2c Mon Sep 17 00:00:00 2001 From: iosmanthus Date: Wed, 4 Jan 2023 18:03:55 +0800 Subject: [PATCH 3/8] add KeyspaceID parsing utils for codec_v2 Signed-off-by: iosmanthus --- internal/apicodec/codec.go | 31 +++++++++++++++++++++++++++++-- internal/apicodec/codec_v1.go | 4 ++++ internal/apicodec/codec_v2.go | 8 +++++++- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/internal/apicodec/codec.go b/internal/apicodec/codec.go index 0267467ac9..d77ca0ec4e 100644 --- a/internal/apicodec/codec.go +++ b/internal/apicodec/codec.go @@ -1,13 +1,18 @@ package apicodec import ( + "encoding/binary" + "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/tikv/client-go/v2/tikvrpc" ) // Mode represents the operation mode of a request. -type Mode int +type ( + Mode int + KeyspaceID uint32 +) const ( // ModeRaw represent a raw operation in TiKV @@ -16,12 +21,34 @@ const ( ModeTxn ) +const ( + // NulSpaceID is a special keyspace id that represents no keyspace exist. + NulSpaceID KeyspaceID = 0xffffffff +) + +func ParseKeyspaceID(b []byte) (KeyspaceID, error) { + if len(b) < keyspacePrefixLen { + return NulSpaceID, nil + } + + if b[0] != rawModePrefix && b[0] != txnModePrefix { + return 0, errors.Errorf("unsupported key %s", b) + } + + buf := append([]byte{}, b[:keyspacePrefixLen]...) + buf[0] = 0 + + return KeyspaceID(binary.BigEndian.Uint32(buf)), nil +} + // Codec is responsible for encode/decode requests. type Codec interface { // GetAPIVersion returns the api version of the codec. GetAPIVersion() kvrpcpb.APIVersion - // GetKeyspace return the keyspace id of the codec. + // GetKeyspace return the keyspace id of the codec in bytes. GetKeyspace() []byte + // GetKeyspaceID return the keyspace id of the codec. + GetKeyspaceID() KeyspaceID // EncodeRequest encodes with the given Codec. // NOTE: req is reused on retry. MUST encode on cloned request, other than overwrite the original. EncodeRequest(req *tikvrpc.Request) (*tikvrpc.Request, error) diff --git a/internal/apicodec/codec_v1.go b/internal/apicodec/codec_v1.go index b21ef8802a..18fd5b6c62 100644 --- a/internal/apicodec/codec_v1.go +++ b/internal/apicodec/codec_v1.go @@ -30,6 +30,10 @@ func (c *codecV1) GetKeyspace() []byte { return nil } +func (c *codecV1) GetKeyspaceID() KeyspaceID { + return NulSpaceID +} + func (c *codecV1) EncodeRequest(req *tikvrpc.Request) (*tikvrpc.Request, error) { return req, nil } diff --git a/internal/apicodec/codec_v2.go b/internal/apicodec/codec_v2.go index 0ae4f8c4ff..b02b067845 100644 --- a/internal/apicodec/codec_v2.go +++ b/internal/apicodec/codec_v2.go @@ -23,7 +23,7 @@ var ( rawModePrefix byte = 'r' txnModePrefix byte = 'x' - keyspacePrefixLen int = 4 + keyspacePrefixLen = 4 // errKeyOutOfBound happens when key to be decoded lies outside the keyspace's range. errKeyOutOfBound = errors.New("given key does not belong to the keyspace") @@ -89,6 +89,12 @@ func (c *codecV2) GetKeyspace() []byte { return c.prefix } +func (c *codecV2) GetKeyspaceID() KeyspaceID { + prefix := append([]byte{}, c.prefix...) + prefix[0] = 0 + return KeyspaceID(binary.BigEndian.Uint32(prefix)) +} + func (c *codecV2) GetAPIVersion() kvrpcpb.APIVersion { return kvrpcpb.APIVersion_V2 } From 3facdf26e1174656464d8fb7ea10616ecad352f5 Mon Sep 17 00:00:00 2001 From: iosmanthus Date: Wed, 4 Jan 2023 18:15:18 +0800 Subject: [PATCH 4/8] add unit tests Signed-off-by: iosmanthus --- internal/apicodec/codec.go | 6 +----- internal/apicodec/codec_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 internal/apicodec/codec_test.go diff --git a/internal/apicodec/codec.go b/internal/apicodec/codec.go index d77ca0ec4e..a7e8b7a732 100644 --- a/internal/apicodec/codec.go +++ b/internal/apicodec/codec.go @@ -27,11 +27,7 @@ const ( ) func ParseKeyspaceID(b []byte) (KeyspaceID, error) { - if len(b) < keyspacePrefixLen { - return NulSpaceID, nil - } - - if b[0] != rawModePrefix && b[0] != txnModePrefix { + if len(b) < keyspacePrefixLen || (b[0] != rawModePrefix && b[0] != txnModePrefix) { return 0, errors.Errorf("unsupported key %s", b) } diff --git a/internal/apicodec/codec_test.go b/internal/apicodec/codec_test.go new file mode 100644 index 0000000000..6222d99738 --- /dev/null +++ b/internal/apicodec/codec_test.go @@ -0,0 +1,25 @@ +package apicodec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseKeyspaceID(t *testing.T) { + id, err := ParseKeyspaceID([]byte{'x', 1, 2, 3, 1, 2, 3}) + assert.Nil(t, err) + assert.Equal(t, KeyspaceID(0x010203), id) + + id, err = ParseKeyspaceID([]byte{'r', 1, 2, 3, 1, 2, 3, 4}) + assert.Nil(t, err) + assert.Equal(t, KeyspaceID(0x010203), id) + + id, err = ParseKeyspaceID([]byte{'t', 0, 0}) + assert.NotNil(t, err) + assert.Equal(t, KeyspaceID(0), id) + + id, err = ParseKeyspaceID([]byte{'t', 0, 0, 1, 1, 2, 3}) + assert.NotNil(t, err) + assert.Equal(t, KeyspaceID(0), id) +} From b2d910e4eb1640eace91d6f702b78c2abb97876e Mon Sep 17 00:00:00 2001 From: iosmanthus Date: Wed, 4 Jan 2023 18:17:56 +0800 Subject: [PATCH 5/8] add unit tests for GetKeyspaceID Signed-off-by: iosmanthus --- internal/apicodec/codec_v2_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/apicodec/codec_v2_test.go b/internal/apicodec/codec_v2_test.go index fe58cecec4..104bc00950 100644 --- a/internal/apicodec/codec_v2_test.go +++ b/internal/apicodec/codec_v2_test.go @@ -266,3 +266,7 @@ func (suite *testCodecV2Suite) TestDecodeEpochNotMatch() { re.Equal(expected.EpochNotMatch.CurrentRegions[i], result.EpochNotMatch.CurrentRegions[i], "index: %d", i) } } + +func (suite *testCodecV2Suite) TestGetKeyspaceID() { + suite.Equal(KeyspaceID(testKeyspaceID), suite.codec.GetKeyspaceID()) +} From a0fb5bb39f24513c24b81cf38661dcc5d5bb131c Mon Sep 17 00:00:00 2001 From: David <8039876+AmoebaProtozoa@users.noreply.github.com> Date: Fri, 6 Jan 2023 12:16:43 +0800 Subject: [PATCH 6/8] fix lint Signed-off-by: David <8039876+AmoebaProtozoa@users.noreply.github.com> --- internal/apicodec/codec.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/internal/apicodec/codec.go b/internal/apicodec/codec.go index a7e8b7a732..63ef6c7b97 100644 --- a/internal/apicodec/codec.go +++ b/internal/apicodec/codec.go @@ -8,16 +8,17 @@ import ( "github.com/tikv/client-go/v2/tikvrpc" ) -// Mode represents the operation mode of a request. type ( - Mode int + // Mode represents the operation mode of a request. + Mode int + // KeyspaceID denotes the target keyspace of the request. KeyspaceID uint32 ) const ( - // ModeRaw represent a raw operation in TiKV + // ModeRaw represent a raw operation in TiKV. ModeRaw = iota - // ModeTxn represent a transaction operation in TiKV + // ModeTxn represent a transaction operation in TiKV. ModeTxn ) @@ -26,6 +27,8 @@ const ( NulSpaceID KeyspaceID = 0xffffffff ) +// ParseKeyspaceID retrieves the keyspaceID from the given keyspace-encoded key. +// It returns error if the given key is not in proper api-v2 format. func ParseKeyspaceID(b []byte) (KeyspaceID, error) { if len(b) < keyspacePrefixLen || (b[0] != rawModePrefix && b[0] != txnModePrefix) { return 0, errors.Errorf("unsupported key %s", b) From cc14806a861ca4141151c0ceb2e27f0ad490fae4 Mon Sep 17 00:00:00 2001 From: David <8039876+AmoebaProtozoa@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:55:25 +0800 Subject: [PATCH 7/8] parseKeyspaceID returns nulSpaceID when failed to decode Signed-off-by: David <8039876+AmoebaProtozoa@users.noreply.github.com> --- internal/apicodec/codec.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/apicodec/codec.go b/internal/apicodec/codec.go index 63ef6c7b97..e43c842449 100644 --- a/internal/apicodec/codec.go +++ b/internal/apicodec/codec.go @@ -31,7 +31,7 @@ const ( // It returns error if the given key is not in proper api-v2 format. func ParseKeyspaceID(b []byte) (KeyspaceID, error) { if len(b) < keyspacePrefixLen || (b[0] != rawModePrefix && b[0] != txnModePrefix) { - return 0, errors.Errorf("unsupported key %s", b) + return NulSpaceID, errors.Errorf("unsupported key %s", b) } buf := append([]byte{}, b[:keyspacePrefixLen]...) From 5c53ec2b712a20a969844a7cc9e80a4acde71953 Mon Sep 17 00:00:00 2001 From: David <8039876+AmoebaProtozoa@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:58:30 +0800 Subject: [PATCH 8/8] fix unit test Signed-off-by: David <8039876+AmoebaProtozoa@users.noreply.github.com> --- internal/apicodec/codec_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/apicodec/codec_test.go b/internal/apicodec/codec_test.go index 6222d99738..640014f5c3 100644 --- a/internal/apicodec/codec_test.go +++ b/internal/apicodec/codec_test.go @@ -17,9 +17,9 @@ func TestParseKeyspaceID(t *testing.T) { id, err = ParseKeyspaceID([]byte{'t', 0, 0}) assert.NotNil(t, err) - assert.Equal(t, KeyspaceID(0), id) + assert.Equal(t, NulSpaceID, id) id, err = ParseKeyspaceID([]byte{'t', 0, 0, 1, 1, 2, 3}) assert.NotNil(t, err) - assert.Equal(t, KeyspaceID(0), id) + assert.Equal(t, NulSpaceID, id) }