diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index 60f0ee7cb3..96fa2fca46 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -47,6 +47,7 @@ import ( "github.com/tikv/client-go/v2/internal/logutil" "github.com/tikv/client-go/v2/metrics" "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/util" pd "github.com/tikv/pd/client" "github.com/tikv/pd/client/clients/tso" "go.uber.org/zap" @@ -647,6 +648,7 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op // waiting for reusing the same result should not be canceled. So pass context.Background() instead of the // current ctx. res, err := o.GetTimestamp(context.Background(), opt) + _, _ = util.EvalFailpoint("getCurrentTSForValidationBeforeReturn") return res, err }) select { @@ -660,7 +662,7 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op } } -func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) (errRet error) { +func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error { if readTS == math.MaxUint64 { if isStaleRead { return oracle.ErrLatestStaleRead{} @@ -668,34 +670,67 @@ func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRea return nil } - latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope) - // If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check. - // But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function - // loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls. - if !exists || readTS > latestTSInfo.tso { - currentTS, err := o.getCurrentTSForValidation(ctx, opt) - if err != nil { - return errors.Errorf("fail to validate read timestamp: %v", err) - } - if isStaleRead { - o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) - } - if readTS > currentTS { - return oracle.ErrFutureTSRead{ - ReadTS: readTS, - CurrentTS: currentTS, + retrying := false + for { + latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope) + // If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check. + // But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function + // loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls. + if !exists || readTS > latestTSInfo.tso { + currentTS, err := o.getCurrentTSForValidation(ctx, opt) + if err != nil { + return errors.Errorf("fail to validate read timestamp: %v", err) + } + if isStaleRead && !retrying { + // Trigger the adjustment at most once in a single invocation. + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) + } + if readTS > currentTS { + // It's possible that the caller is checking a ts that's legal but not fetched from the current oracle + // object. In this case, it's possible that: + // * The ts is not be cached by the low resolution ts (so that readTS > latestTSInfo.TSO); + // * ... and then the getCurrentTSForValidation (which uses a singleflight internally) reuse a + // previously-started call and returns an older ts + // so that it may cause the check false-positive. + // To handle this case, we do not fail immediately when the check doesn't at once; instead, retry one + // more time. In the retry: + // * Considering that there can already be some other concurrent GetTimestamp operation that may have updated + // the low resolution ts, so check it again. If it passes, then no need to get the next ts from PD, + // which is slow. + // * Then, call getCurrentTSForValidation and check again. As the current GetTimestamp operation + // inside getCurrentTSForValidation must be started after finishing the previous one (while the + // latter is finished after starting this invocation to ValidateReadTS), then we can conclude that + // the next ts returned by getCurrentTSForValidation must be greater than any ts allocated by PD + // before the current invocation to ValidateReadTS. + skipRetry := false + if val, err1 := util.EvalFailpoint("validateReadTSRetryGetTS"); err1 == nil { + if str, ok := val.(string); ok { + if str == "skip" { + skipRetry = true + } + } + } + if !retrying && !skipRetry { + retrying = true + continue + } + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } + } + } else if !retrying && isStaleRead { + // Trigger the adjustment at most once in a single invocation. + estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0) + if err != nil { + logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval", + zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope)) + } else { + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now()) } } - } else if isStaleRead { - estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0) - if err != nil { - logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval", - zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope)) - } else { - o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now()) - } + return nil } - return nil } // adjustUpdateLowResolutionTSIntervalWithRequestedStaleness triggers adjustments the update interval of low resolution diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index 01c67c5e46..b79bfdeb82 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -36,14 +36,18 @@ package oracles import ( "context" + "fmt" "math" "sync" "sync/atomic" "testing" "time" + "github.com/pingcap/failpoint" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/util" pd "github.com/tikv/pd/client" ) @@ -374,10 +378,15 @@ func TestValidateReadTS(t *testing.T) { // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. err = o.ValidateReadTS(ctx, ts+1, staleRead, opt) assert.NoError(t, err) - // It can't pass if the readTS is newer than previous ts + 2. + // It can also pass if the readTS is previous ts + 2, as it can perform a retry. ts, err = o.GetTimestamp(ctx, opt) assert.NoError(t, err) err = o.ValidateReadTS(ctx, ts+2, staleRead, opt) + assert.NoError(t, err) + // As it retries at most once, it can't pass the check if the readTS is newer than previous ts + 3 + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + err = o.ValidateReadTS(ctx, ts+3, staleRead, opt) assert.Error(t, err) // Simulate other PD clients requests a timestamp. @@ -412,6 +421,8 @@ func (c *MockPDClientWithPause) Resume() { } func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) { + util.EnableFailpoints() + pdClient := &MockPDClientWithPause{} o, err := NewPdOracle(pdClient, &PDOracleOptions{ UpdateInterval: time.Second * 2, @@ -420,6 +431,11 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) { assert.NoError(t, err) defer o.Close() + assert.NoError(t, failpoint.Enable("tikvclient/validateReadTSRetryGetTS", `return("skip")`)) + defer func() { + assert.NoError(t, failpoint.Disable("tikvclient/validateReadTSRetryGetTS")) + }() + asyncValidate := func(ctx context.Context, readTS uint64) chan error { ch := make(chan error, 1) go func() { @@ -429,21 +445,21 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) { return ch } - noResult := func(ch chan error) { + noResult := func(ch chan error, additionalMsg ...interface{}) { select { case <-ch: - assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked") + assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked", additionalMsg...) default: } } cancelIndices := []int{-1, -1, 0, 1} - for i, ts := range []uint64{100, 200, 300, 400} { + for caseIndex, ts := range []uint64{100, 200, 300, 400} { // Note: the ts is the result that the next GetTS will return. Any validation with readTS <= ts should pass, otherwise fail. // We will cancel the cancelIndex-th validation call. This is for testing that canceling some of the calls - // doesn't affect other calls that are waiting - cancelIndex := cancelIndices[i] + // doesn't affect other calls that are waiting. + cancelIndex := cancelIndices[caseIndex] pdClient.Pause() @@ -541,13 +557,18 @@ func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) { assert.NoError(t, err) mustNoNotify() - // It loads `ts + 1` from the mock PD, and the check cannot pass. + // It loads `ts + 1` from the mock PD, and then retries `ts + 2` and passes. err = o.ValidateReadTS(ctx, ts+2, false, opt) + assert.NoError(t, err) + mustNoNotify() + + // It loads `ts + 3` and `ts + 4` from the mock PD, and the check cannot pass. + err = o.ValidateReadTS(ctx, ts+5, false, opt) assert.Error(t, err) mustNoNotify() - // Do the check again. It loads `ts + 2` from the mock PD, and the check passes. - err = o.ValidateReadTS(ctx, ts+2, false, opt) + // Do the check again. It loads `ts + 5` from the mock PD, and the check passes. + err = o.ValidateReadTS(ctx, ts+5, false, opt) assert.NoError(t, err) mustNoNotify() } @@ -586,3 +607,68 @@ func TestSetLastTSAlwaysPushTS(t *testing.T) { close(cancel) wg.Wait() } + +func TestValidateReadTSFromDifferentSource(t *testing.T) { + // If a ts is fetched from a different client to the same cluster, the ts might not be cached by the low resolution + // ts. In this case, the validation should not be false positive. + util.EnableFailpoints() + pdClient := MockPdClient{} + o, err := NewPdOracle(&pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + defer o.Close() + + // Construct the situation that the low resolution ts is lower than the ts fetched from another client. + ts, err := o.GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + assert.NoError(t, err) + lowResolutionTS, err := o.GetLowResolutionTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + assert.NoError(t, err) + assert.Equal(t, ts, lowResolutionTS) + + assert.NoError(t, failpoint.Enable("tikvclient/getCurrentTSForValidationBeforeReturn", "pause")) + defer func() { + assert.NoError(t, failpoint.Disable("tikvclient/getCurrentTSForValidationBeforeReturn")) + }() + + // Trigger getting ts from PD for validation, which causes a previously-started concurrent call. We block it during + // getting the ts by the failpoint. So that when the second call starts, it will reuse the same singleflight + // for getting the ts, which return a older ts to it. + firstResCh := make(chan error) + go func() { + firstResCh <- o.ValidateReadTS(context.Background(), ts+1, false, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + }() + + select { + case err = <-firstResCh: + assert.FailNow(t, fmt.Sprintf("expected to be blocked, but got result: %v", err)) + case <-time.After(time.Millisecond * 50): + } + + pdClient.logicalTimestamp.Add(10) + physical, logical, err := pdClient.GetTS(context.Background()) + assert.NoError(t, err) + // The next ts should be the previous `ts + 1 (fetched by the ValidateReadTS call) + 10 (advanced manually) + 1`. + nextTS := oracle.ComposeTS(physical, logical) + // The low resolution ts is not updated since the validation. + nextLowResolutionTS, err := o.GetLowResolutionTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + assert.NoError(t, err) + assert.Equal(t, ts+1, nextLowResolutionTS) + assert.Equal(t, nextTS-11, nextLowResolutionTS) + + // The second check reuses the singleflight to get the ts and the result can be older than `nextTS`. + secondResCh := make(chan error) + go func() { + secondResCh <- o.ValidateReadTS(context.Background(), nextTS, false, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + }() + select { + case err = <-firstResCh: + assert.FailNow(t, fmt.Sprintf("expected to be blocked, but got result: %v", err)) + case <-time.After(time.Millisecond * 50): + } + + assert.NoError(t, failpoint.Disable("tikvclient/getCurrentTSForValidationBeforeReturn")) + require.NoError(t, <-firstResCh) + require.NoError(t, <-secondResCh) +}