Skip to content

Commit

Permalink
br: use atomic.Pointer instead of atomic.Value (pingcap#49359)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshikipom authored Dec 13, 2023
1 parent 161107a commit ced8af2
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 28 deletions.
8 changes: 4 additions & 4 deletions br/pkg/restore/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ func (importer *FileImporter) downloadSST(
logutil.Leader(regionInfo.Leader),
)

var atomicResp atomic.Value
var atomicResp atomic.Pointer[import_sstpb.DownloadResponse]
eg, ectx := errgroup.WithContext(ctx)
for _, p := range regionInfo.Region.GetPeers() {
peer := p
Expand Down Expand Up @@ -746,7 +746,7 @@ func (importer *FileImporter) downloadSST(
return nil, err
}

downloadResp := atomicResp.Load().(*import_sstpb.DownloadResponse)
downloadResp := atomicResp.Load()
sstMeta.Range.Start = TruncateTS(downloadResp.Range.GetStart())
sstMeta.Range.End = TruncateTS(downloadResp.Range.GetEnd())
sstMeta.ApiVersion = apiVersion
Expand Down Expand Up @@ -799,7 +799,7 @@ func (importer *FileImporter) downloadRawKVSST(
}
log.Debug("download SST", logutil.SSTMeta(sstMeta), logutil.Region(regionInfo.Region))

var atomicResp atomic.Value
var atomicResp atomic.Pointer[import_sstpb.DownloadResponse]
eg, ectx := errgroup.WithContext(ctx)
for _, p := range regionInfo.Region.GetPeers() {
peer := p
Expand All @@ -824,7 +824,7 @@ func (importer *FileImporter) downloadRawKVSST(
return nil, err
}

downloadResp := atomicResp.Load().(*import_sstpb.DownloadResponse)
downloadResp := atomicResp.Load()
sstMeta.Range.Start = downloadResp.Range.GetStart()
sstMeta.Range.End = downloadResp.Range.GetEnd()
sstMeta.ApiVersion = apiVersion
Expand Down
15 changes: 7 additions & 8 deletions br/pkg/storage/memstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ import (
)

type memFile struct {
Data atomic.Value // the atomic value is a byte slice, which can only be get/set atomically
Data atomic.Pointer[[]byte]
}

// GetData gets the underlying byte slice of the atomic value
// GetData gets the underlying byte slice of the atomic pointer
func (f *memFile) GetData() []byte {
var fileData []byte
fileDataVal := f.Data.Load()
if fileDataVal != nil {
fileData = fileDataVal.([]byte)
if p := f.Data.Load(); p != nil {
fileData = *p
}
return fileData
}
Expand Down Expand Up @@ -110,10 +109,10 @@ func (s *MemStorage) WriteFile(ctx context.Context, name string, data []byte) er
defer s.rwm.Unlock()
theFile, ok := s.dataStore[name]
if ok {
theFile.Data.Store(fileData)
theFile.Data.Store(&fileData)
} else {
theFile := new(memFile)
theFile.Data.Store(fileData)
theFile.Data.Store(&fileData)
s.dataStore[name] = theFile
}
return nil
Expand Down Expand Up @@ -352,7 +351,7 @@ func (w *memFileWriter) Close(ctx context.Context) error {
// continue on
}
fileData := append([]byte{}, w.buf.Bytes()...)
w.file.Data.Store(fileData)
w.file.Data.Store(&fileData)
w.isClosed.Store(true)
return nil
}
4 changes: 2 additions & 2 deletions br/pkg/storage/memstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ func TestMemStoreManipulateBytes(t *testing.T) {
testBytes := []byte(testStr)
require.Nil(t, store.WriteFile(ctx, "/aaa.txt", testBytes))
testBytes[3] = '2'
require.Equal(t, testStr, string(store.dataStore["/aaa.txt"].Data.Load().([]byte)))
require.Equal(t, testStr, string(*store.dataStore["/aaa.txt"].Data.Load()))

readBytes, err := store.ReadFile(ctx, "/aaa.txt")
require.Nil(t, err)
require.Equal(t, testStr, string(readBytes))
readBytes[3] = '2'
require.Equal(t, testStr, string(store.dataStore["/aaa.txt"].Data.Load().([]byte)))
require.Equal(t, testStr, string(*store.dataStore["/aaa.txt"].Data.Load()))
}

func TestMemStoreWriteDuringWalkDir(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions br/pkg/streamhelper/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ go_library(
"@org_golang_google_grpc//keepalive",
"@org_golang_google_grpc//status",
"@org_golang_x_sync//errgroup",
"@org_uber_go_atomic//:atomic",
"@org_uber_go_multierr//:multierr",
"@org_uber_go_zap//:zap",
],
Expand Down
20 changes: 6 additions & 14 deletions br/pkg/streamhelper/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"

"github.com/pingcap/errors"
logbackup "github.com/pingcap/kvproto/pkg/logbackuppb"
Expand All @@ -17,6 +16,7 @@ import (
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/metrics"
"go.uber.org/atomic"
"go.uber.org/zap"
)

Expand All @@ -38,7 +38,7 @@ type storeCollector struct {

input chan RegionWithLeader
// the oneshot error reporter.
err *atomic.Value
err *atomic.Error
// whether the recv and send loop has exited.
doneMessenger chan struct{}
onSuccess onSuccessHook
Expand All @@ -58,28 +58,20 @@ func newStoreCollector(storeID uint64, srv LogBackupService) *storeCollector {
batchSize: defaultBatchSize,
service: srv,
input: make(chan RegionWithLeader, defaultBatchSize),
err: new(atomic.Value),
err: new(atomic.Error),
doneMessenger: make(chan struct{}),
regionMap: make(map[uint64]kv.KeyRange),
}
}

func (c *storeCollector) reportErr(err error) {
if oldErr := c.Err(); oldErr != nil {
if oldErr := c.err.Load(); oldErr != nil {
log.Warn("reporting error twice, ignoring", logutil.AShortError("old", err), logutil.AShortError("new", oldErr))
return
}
c.err.Store(err)
}

func (c *storeCollector) Err() error {
err, ok := c.err.Load().(error)
if !ok {
return nil
}
return err
}

func (c *storeCollector) setOnSuccessHook(hook onSuccessHook) {
c.onSuccess = hook
}
Expand Down Expand Up @@ -166,7 +158,7 @@ func (c *storeCollector) spawn(ctx context.Context) func(context.Context) (Store
return StoreCheckpoints{}, cx.Err()
case <-c.doneMessenger:
}
if err := c.Err(); err != nil {
if err := c.err.Load(); err != nil {
return StoreCheckpoints{}, err
}
sc := StoreCheckpoints{
Expand Down Expand Up @@ -302,7 +294,7 @@ func (c *clusterCollector) CollectRegion(r RegionWithLeader) error {
case sc.input <- r:
return nil
case <-sc.doneMessenger:
err := sc.Err()
err := sc.err.Load()
if err != nil {
c.cancel()
}
Expand Down

0 comments on commit ced8af2

Please sign in to comment.