Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a retry when getting ts from PD for validating read ts #1600

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 61 additions & 26 deletions oracle/oracles/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/metrics"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/util"
pd "github.com/tikv/pd/client"
"github.com/tikv/pd/client/clients/tso"
"go.uber.org/zap"
Expand Down Expand Up @@ -647,6 +648,7 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op
// waiting for reusing the same result should not be canceled. So pass context.Background() instead of the
// current ctx.
res, err := o.GetTimestamp(context.Background(), opt)
_, _ = util.EvalFailpoint("getCurrentTSForValidationBeforeReturn")
return res, err
})
select {
Expand All @@ -660,42 +662,75 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op
}
}

func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) (errRet error) {
func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
if readTS == math.MaxUint64 {
if isStaleRead {
return oracle.ErrLatestStaleRead{}
}
return nil
}

latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope)
// If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check.
// But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function
// loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls.
if !exists || readTS > latestTSInfo.tso {
currentTS, err := o.getCurrentTSForValidation(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if isStaleRead {
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now())
}
if readTS > currentTS {
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
retrying := false
for {
latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope)
// If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check.
// But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function
// loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls.
if !exists || readTS > latestTSInfo.tso {
currentTS, err := o.getCurrentTSForValidation(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if isStaleRead && !retrying {
// Trigger the adjustment at most once in a single invocation.
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now())
}
if readTS > currentTS {
// It's possible that the caller is checking a ts that's legal but not fetched from the current oracle
// object. In this case, it's possible that:
// * The ts is not be cached by the low resolution ts (so that readTS > latestTSInfo.TSO);
// * ... and then the getCurrentTSForValidation (which uses a singleflight internally) reuse a
// previously-started call and returns an older ts
// so that it may cause the check false-positive.
// To handle this case, we do not fail immediately when the check doesn't at once; instead, retry one
// more time. In the retry:
// * Considering that there can already be some other concurrent GetTimestamp operation that may have updated
// the low resolution ts, so check it again. If it passes, then no need to get the next ts from PD,
// which is slow.
// * Then, call getCurrentTSForValidation and check again. As the current GetTimestamp operation
// inside getCurrentTSForValidation must be started after finishing the previous one (while the
// latter is finished after starting this invocation to ValidateReadTS), then we can conclude that
// the next ts returned by getCurrentTSForValidation must be greater than any ts allocated by PD
// before the current invocation to ValidateReadTS.
skipRetry := false
if val, err1 := util.EvalFailpoint("validateReadTSRetryGetTS"); err1 == nil {
if str, ok := val.(string); ok {
if str == "skip" {
skipRetry = true
}
}
}
if !retrying && !skipRetry {
retrying = true
continue
}
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
}
}
} else if !retrying && isStaleRead {
// Trigger the adjustment at most once in a single invocation.
estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0)
if err != nil {
logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval",
zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope))
} else {
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now())
}
}
} else if isStaleRead {
estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0)
if err != nil {
logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval",
zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope))
} else {
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now())
}
return nil
}
return nil
}

// adjustUpdateLowResolutionTSIntervalWithRequestedStaleness triggers adjustments the update interval of low resolution
Expand Down
104 changes: 95 additions & 9 deletions oracle/oracles/pd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ package oracles

import (
"context"
"fmt"
"math"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/pingcap/failpoint"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/util"
pd "github.com/tikv/pd/client"
)

Expand Down Expand Up @@ -374,10 +378,15 @@ func TestValidateReadTS(t *testing.T) {
// the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass.
err = o.ValidateReadTS(ctx, ts+1, staleRead, opt)
assert.NoError(t, err)
// It can't pass if the readTS is newer than previous ts + 2.
// It can also pass if the readTS is previous ts + 2, as it can perform a retry.
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
err = o.ValidateReadTS(ctx, ts+2, staleRead, opt)
assert.NoError(t, err)
// As it retries at most once, it can't pass the check if the readTS is newer than previous ts + 3
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
err = o.ValidateReadTS(ctx, ts+3, staleRead, opt)
assert.Error(t, err)

// Simulate other PD clients requests a timestamp.
Expand Down Expand Up @@ -412,6 +421,8 @@ func (c *MockPDClientWithPause) Resume() {
}

func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
util.EnableFailpoints()

pdClient := &MockPDClientWithPause{}
o, err := NewPdOracle(pdClient, &PDOracleOptions{
UpdateInterval: time.Second * 2,
Expand All @@ -420,6 +431,11 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
assert.NoError(t, err)
defer o.Close()

assert.NoError(t, failpoint.Enable("tikvclient/validateReadTSRetryGetTS", `return("skip")`))
defer func() {
assert.NoError(t, failpoint.Disable("tikvclient/validateReadTSRetryGetTS"))
}()

asyncValidate := func(ctx context.Context, readTS uint64) chan error {
ch := make(chan error, 1)
go func() {
Expand All @@ -429,21 +445,21 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
return ch
}

noResult := func(ch chan error) {
noResult := func(ch chan error, additionalMsg ...interface{}) {
select {
case <-ch:
assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked")
assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked", additionalMsg...)
default:
}
}

cancelIndices := []int{-1, -1, 0, 1}
for i, ts := range []uint64{100, 200, 300, 400} {
for caseIndex, ts := range []uint64{100, 200, 300, 400} {
// Note: the ts is the result that the next GetTS will return. Any validation with readTS <= ts should pass, otherwise fail.

// We will cancel the cancelIndex-th validation call. This is for testing that canceling some of the calls
// doesn't affect other calls that are waiting
cancelIndex := cancelIndices[i]
// doesn't affect other calls that are waiting.
cancelIndex := cancelIndices[caseIndex]

pdClient.Pause()

Expand Down Expand Up @@ -541,13 +557,18 @@ func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) {
assert.NoError(t, err)
mustNoNotify()

// It loads `ts + 1` from the mock PD, and the check cannot pass.
// It loads `ts + 1` from the mock PD, and then retries `ts + 2` and passes.
err = o.ValidateReadTS(ctx, ts+2, false, opt)
assert.NoError(t, err)
mustNoNotify()

// It loads `ts + 3` and `ts + 4` from the mock PD, and the check cannot pass.
err = o.ValidateReadTS(ctx, ts+5, false, opt)
assert.Error(t, err)
mustNoNotify()

// Do the check again. It loads `ts + 2` from the mock PD, and the check passes.
err = o.ValidateReadTS(ctx, ts+2, false, opt)
// Do the check again. It loads `ts + 5` from the mock PD, and the check passes.
err = o.ValidateReadTS(ctx, ts+5, false, opt)
assert.NoError(t, err)
mustNoNotify()
}
Expand Down Expand Up @@ -586,3 +607,68 @@ func TestSetLastTSAlwaysPushTS(t *testing.T) {
close(cancel)
wg.Wait()
}

func TestValidateReadTSFromDifferentSource(t *testing.T) {
// If a ts is fetched from a different client to the same cluster, the ts might not be cached by the low resolution
// ts. In this case, the validation should not be false positive.
util.EnableFailpoints()
pdClient := MockPdClient{}
o, err := NewPdOracle(&pdClient, &PDOracleOptions{
UpdateInterval: time.Second * 2,
NoUpdateTS: true,
})
assert.NoError(t, err)
defer o.Close()

// Construct the situation that the low resolution ts is lower than the ts fetched from another client.
ts, err := o.GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
assert.NoError(t, err)
lowResolutionTS, err := o.GetLowResolutionTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
assert.NoError(t, err)
assert.Equal(t, ts, lowResolutionTS)

assert.NoError(t, failpoint.Enable("tikvclient/getCurrentTSForValidationBeforeReturn", "pause"))
defer func() {
assert.NoError(t, failpoint.Disable("tikvclient/getCurrentTSForValidationBeforeReturn"))
}()

// Trigger getting ts from PD for validation, which causes a previously-started concurrent call. We block it during
// getting the ts by the failpoint. So that when the second call starts, it will reuse the same singleflight
// for getting the ts, which return a older ts to it.
firstResCh := make(chan error)
go func() {
firstResCh <- o.ValidateReadTS(context.Background(), ts+1, false, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
}()

select {
case err = <-firstResCh:
assert.FailNow(t, fmt.Sprintf("expected to be blocked, but got result: %v", err))
case <-time.After(time.Millisecond * 50):
}

pdClient.logicalTimestamp.Add(10)
physical, logical, err := pdClient.GetTS(context.Background())
assert.NoError(t, err)
// The next ts should be the previous `ts + 1 (fetched by the ValidateReadTS call) + 10 (advanced manually) + 1`.
nextTS := oracle.ComposeTS(physical, logical)
// The low resolution ts is not updated since the validation.
nextLowResolutionTS, err := o.GetLowResolutionTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
assert.NoError(t, err)
assert.Equal(t, ts+1, nextLowResolutionTS)
assert.Equal(t, nextTS-11, nextLowResolutionTS)

// The second check reuses the singleflight to get the ts and the result can be older than `nextTS`.
secondResCh := make(chan error)
go func() {
secondResCh <- o.ValidateReadTS(context.Background(), nextTS, false, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
}()
select {
case err = <-firstResCh:
assert.FailNow(t, fmt.Sprintf("expected to be blocked, but got result: %v", err))
case <-time.After(time.Millisecond * 50):
}

assert.NoError(t, failpoint.Disable("tikvclient/getCurrentTSForValidationBeforeReturn"))
require.NoError(t, <-firstResCh)
require.NoError(t, <-secondResCh)
}
Loading