Skip to content

Commit

Permalink
Add a retry when getting ts from PD for validating read ts (#1600)
Browse files Browse the repository at this point in the history
 

Signed-off-by: MyonKeminta <[email protected]>
  • Loading branch information
MyonKeminta authored Mar 3, 2025
1 parent 5ac118b commit 10a84d0
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 35 deletions.
87 changes: 61 additions & 26 deletions oracle/oracles/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/metrics"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/util"
pd "github.com/tikv/pd/client"
"github.com/tikv/pd/client/clients/tso"
"go.uber.org/zap"
Expand Down Expand Up @@ -647,6 +648,7 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op
// waiting for reusing the same result should not be canceled. So pass context.Background() instead of the
// current ctx.
res, err := o.GetTimestamp(context.Background(), opt)
_, _ = util.EvalFailpoint("getCurrentTSForValidationBeforeReturn")
return res, err
})
select {
Expand All @@ -660,42 +662,75 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op
}
}

func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) (errRet error) {
func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
if readTS == math.MaxUint64 {
if isStaleRead {
return oracle.ErrLatestStaleRead{}
}
return nil
}

latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope)
// If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check.
// But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function
// loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls.
if !exists || readTS > latestTSInfo.tso {
currentTS, err := o.getCurrentTSForValidation(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if isStaleRead {
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now())
}
if readTS > currentTS {
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
retrying := false
for {
latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope)
// If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check.
// But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function
// loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls.
if !exists || readTS > latestTSInfo.tso {
currentTS, err := o.getCurrentTSForValidation(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if isStaleRead && !retrying {
// Trigger the adjustment at most once in a single invocation.
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now())
}
if readTS > currentTS {
// It's possible that the caller is checking a ts that's legal but not fetched from the current oracle
// object. In this case, it's possible that:
// * The ts is not be cached by the low resolution ts (so that readTS > latestTSInfo.TSO);
// * ... and then the getCurrentTSForValidation (which uses a singleflight internally) reuse a
// previously-started call and returns an older ts
// so that it may cause the check false-positive.
// To handle this case, we do not fail immediately when the check doesn't at once; instead, retry one
// more time. In the retry:
// * Considering that there can already be some other concurrent GetTimestamp operation that may have updated
// the low resolution ts, so check it again. If it passes, then no need to get the next ts from PD,
// which is slow.
// * Then, call getCurrentTSForValidation and check again. As the current GetTimestamp operation
// inside getCurrentTSForValidation must be started after finishing the previous one (while the
// latter is finished after starting this invocation to ValidateReadTS), then we can conclude that
// the next ts returned by getCurrentTSForValidation must be greater than any ts allocated by PD
// before the current invocation to ValidateReadTS.
skipRetry := false
if val, err1 := util.EvalFailpoint("validateReadTSRetryGetTS"); err1 == nil {
if str, ok := val.(string); ok {
if str == "skip" {
skipRetry = true
}
}
}
if !retrying && !skipRetry {
retrying = true
continue
}
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
}
}
} else if !retrying && isStaleRead {
// Trigger the adjustment at most once in a single invocation.
estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0)
if err != nil {
logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval",
zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope))
} else {
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now())
}
}
} else if isStaleRead {
estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0)
if err != nil {
logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval",
zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope))
} else {
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now())
}
return nil
}
return nil
}

// adjustUpdateLowResolutionTSIntervalWithRequestedStaleness triggers adjustments the update interval of low resolution
Expand Down
104 changes: 95 additions & 9 deletions oracle/oracles/pd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ package oracles

import (
"context"
"fmt"
"math"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/pingcap/failpoint"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/util"
pd "github.com/tikv/pd/client"
)

Expand Down Expand Up @@ -374,10 +378,15 @@ func TestValidateReadTS(t *testing.T) {
// the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass.
err = o.ValidateReadTS(ctx, ts+1, staleRead, opt)
assert.NoError(t, err)
// It can't pass if the readTS is newer than previous ts + 2.
// It can also pass if the readTS is previous ts + 2, as it can perform a retry.
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
err = o.ValidateReadTS(ctx, ts+2, staleRead, opt)
assert.NoError(t, err)
// As it retries at most once, it can't pass the check if the readTS is newer than previous ts + 3
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
err = o.ValidateReadTS(ctx, ts+3, staleRead, opt)
assert.Error(t, err)

// Simulate other PD clients requests a timestamp.
Expand Down Expand Up @@ -412,6 +421,8 @@ func (c *MockPDClientWithPause) Resume() {
}

func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
util.EnableFailpoints()

pdClient := &MockPDClientWithPause{}
o, err := NewPdOracle(pdClient, &PDOracleOptions{
UpdateInterval: time.Second * 2,
Expand All @@ -420,6 +431,11 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
assert.NoError(t, err)
defer o.Close()

assert.NoError(t, failpoint.Enable("tikvclient/validateReadTSRetryGetTS", `return("skip")`))
defer func() {
assert.NoError(t, failpoint.Disable("tikvclient/validateReadTSRetryGetTS"))
}()

asyncValidate := func(ctx context.Context, readTS uint64) chan error {
ch := make(chan error, 1)
go func() {
Expand All @@ -429,21 +445,21 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
return ch
}

noResult := func(ch chan error) {
noResult := func(ch chan error, additionalMsg ...interface{}) {
select {
case <-ch:
assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked")
assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked", additionalMsg...)
default:
}
}

cancelIndices := []int{-1, -1, 0, 1}
for i, ts := range []uint64{100, 200, 300, 400} {
for caseIndex, ts := range []uint64{100, 200, 300, 400} {
// Note: the ts is the result that the next GetTS will return. Any validation with readTS <= ts should pass, otherwise fail.

// We will cancel the cancelIndex-th validation call. This is for testing that canceling some of the calls
// doesn't affect other calls that are waiting
cancelIndex := cancelIndices[i]
// doesn't affect other calls that are waiting.
cancelIndex := cancelIndices[caseIndex]

pdClient.Pause()

Expand Down Expand Up @@ -541,13 +557,18 @@ func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) {
assert.NoError(t, err)
mustNoNotify()

// It loads `ts + 1` from the mock PD, and the check cannot pass.
// It loads `ts + 1` from the mock PD, and then retries `ts + 2` and passes.
err = o.ValidateReadTS(ctx, ts+2, false, opt)
assert.NoError(t, err)
mustNoNotify()

// It loads `ts + 3` and `ts + 4` from the mock PD, and the check cannot pass.
err = o.ValidateReadTS(ctx, ts+5, false, opt)
assert.Error(t, err)
mustNoNotify()

// Do the check again. It loads `ts + 2` from the mock PD, and the check passes.
err = o.ValidateReadTS(ctx, ts+2, false, opt)
// Do the check again. It loads `ts + 5` from the mock PD, and the check passes.
err = o.ValidateReadTS(ctx, ts+5, false, opt)
assert.NoError(t, err)
mustNoNotify()
}
Expand Down Expand Up @@ -586,3 +607,68 @@ func TestSetLastTSAlwaysPushTS(t *testing.T) {
close(cancel)
wg.Wait()
}

func TestValidateReadTSFromDifferentSource(t *testing.T) {
// If a ts is fetched from a different client to the same cluster, the ts might not be cached by the low resolution
// ts. In this case, the validation should not be false positive.
util.EnableFailpoints()
pdClient := MockPdClient{}
o, err := NewPdOracle(&pdClient, &PDOracleOptions{
UpdateInterval: time.Second * 2,
NoUpdateTS: true,
})
assert.NoError(t, err)
defer o.Close()

// Construct the situation that the low resolution ts is lower than the ts fetched from another client.
ts, err := o.GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
assert.NoError(t, err)
lowResolutionTS, err := o.GetLowResolutionTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
assert.NoError(t, err)
assert.Equal(t, ts, lowResolutionTS)

assert.NoError(t, failpoint.Enable("tikvclient/getCurrentTSForValidationBeforeReturn", "pause"))
defer func() {
assert.NoError(t, failpoint.Disable("tikvclient/getCurrentTSForValidationBeforeReturn"))
}()

// Trigger getting ts from PD for validation, which causes a previously-started concurrent call. We block it during
// getting the ts by the failpoint. So that when the second call starts, it will reuse the same singleflight
// for getting the ts, which return a older ts to it.
firstResCh := make(chan error)
go func() {
firstResCh <- o.ValidateReadTS(context.Background(), ts+1, false, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
}()

select {
case err = <-firstResCh:
assert.FailNow(t, fmt.Sprintf("expected to be blocked, but got result: %v", err))
case <-time.After(time.Millisecond * 50):
}

pdClient.logicalTimestamp.Add(10)
physical, logical, err := pdClient.GetTS(context.Background())
assert.NoError(t, err)
// The next ts should be the previous `ts + 1 (fetched by the ValidateReadTS call) + 10 (advanced manually) + 1`.
nextTS := oracle.ComposeTS(physical, logical)
// The low resolution ts is not updated since the validation.
nextLowResolutionTS, err := o.GetLowResolutionTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
assert.NoError(t, err)
assert.Equal(t, ts+1, nextLowResolutionTS)
assert.Equal(t, nextTS-11, nextLowResolutionTS)

// The second check reuses the singleflight to get the ts and the result can be older than `nextTS`.
secondResCh := make(chan error)
go func() {
secondResCh <- o.ValidateReadTS(context.Background(), nextTS, false, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
}()
select {
case err = <-firstResCh:
assert.FailNow(t, fmt.Sprintf("expected to be blocked, but got result: %v", err))
case <-time.After(time.Millisecond * 50):
}

assert.NoError(t, failpoint.Disable("tikvclient/getCurrentTSForValidationBeforeReturn"))
require.NoError(t, <-firstResCh)
require.NoError(t, <-secondResCh)
}

0 comments on commit 10a84d0

Please sign in to comment.