Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

go: remotestorage: Rework how dictionary fetching and dictionary cache is populated. #8859

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 115 additions & 150 deletions go/libraries/doltcore/remotestorage/chunk_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ import (
"sync/atomic"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/dolthub/gozstd"
lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"

Expand Down Expand Up @@ -72,14 +70,7 @@ const (
reliableCallDeliverRespTimeout = 15 * time.Second
)

var globalDictCache *dictionaryCache
var once sync.Once

func NewChunkFetcher(ctx context.Context, dcs *DoltChunkStore) *ChunkFetcher {
once.Do(func() {
globalDictCache = NewDictionaryCache(dcs.csClient)
})

eg, ctx := errgroup.WithContext(ctx)
ret := &ChunkFetcher{
eg: eg,
Expand Down Expand Up @@ -345,8 +336,11 @@ func getMissingChunks(req *remotesapi.GetDownloadLocsRequest, resp *remotesapi.G
}

type fetchResp struct {
get *GetRange
refresh func(ctx context.Context, err error, client remotesapi.ChunkStoreServiceClient) (string, error)
get *GetRange
refresh func(ctx context.Context, err error, client remotesapi.ChunkStoreServiceClient) (string, error)
rangeType rangeType
dictCache *dictionaryCache
path string
}

type fetchReq struct {
Expand All @@ -357,52 +351,72 @@ type fetchReq struct {
// A simple structure to keep track of *GetRange requests along with
// |locationRefreshes| for the URL paths we have seen.
type downloads struct {
ranges *ranges.Tree
refreshes map[string]*locationRefresh
chunkRanges *ranges.Tree
dictRanges *ranges.Tree
dictCache *dictionaryCache
refreshes map[string]*locationRefresh
}

func newDownloads() downloads {
return downloads{
ranges: ranges.NewTree(chunkAggDistance),
refreshes: make(map[string]*locationRefresh),
chunkRanges: ranges.NewTree(chunkAggDistance),
dictRanges: ranges.NewTree(chunkAggDistance),
dictCache: &dictionaryCache{},
refreshes: make(map[string]*locationRefresh),
}
}

func (d downloads) Add(resp *remotesapi.DownloadLoc) {
gr := (*GetRange)(resp.Location.(*remotesapi.DownloadLoc_HttpGetRange).HttpGetRange)
path := gr.ResourcePath()
hgr := resp.Location.(*remotesapi.DownloadLoc_HttpGetRange).HttpGetRange
path := ResourcePath(hgr.Url)
if v, ok := d.refreshes[path]; ok {
v.Add(resp)
} else {
refresh := new(locationRefresh)
refresh.Add(resp)
d.refreshes[path] = refresh
}
for _, r := range gr.Ranges {
d.ranges.Insert(gr.Url, r.Hash[:], r.Offset, r.Length, r.DictionaryOffset, r.DictionaryLength)
for _, r := range hgr.Ranges {
var getDict func() (any, error)
if r.DictionaryLength != 0 {
var first bool
getDict, first = d.dictCache.get(path, r.DictionaryOffset, r.DictionaryLength)
if first {
d.dictRanges.Insert(hgr.Url, nil, r.DictionaryOffset, r.DictionaryLength, nil)
}
}
d.chunkRanges.Insert(hgr.Url, r.Hash[:], r.Offset, r.Length, getDict)
}
}

func toGetRange(rs []*ranges.GetRange) *GetRange {
ret := new(GetRange)
for _, r := range rs {
ret.Url = r.Url
ret.Ranges = append(ret.Ranges, &remotesapi.RangeChunk{
Hash: r.Hash,
Offset: r.Offset,
Length: r.Length,
DictionaryOffset: r.DictionaryOffset,
DictionaryLength: r.DictionaryLength,
ret.Ranges = append(ret.Ranges, &Range{
Hash: r.Hash,
Offset: r.Offset,
Length: r.Length,
GetDict: r.GetDict,
})
}
return ret
}

type rangeType int

const (
rangeType_Chunk rangeType = iota
rangeType_Dictionary
)

// Reads off |locCh| and assembles DownloadLocs into download ranges.
func fetcherDownloadRangesThread(ctx context.Context, locCh chan []*remotesapi.DownloadLoc, fetchReqCh chan fetchReq, doneCh chan struct{}) error {
downloads := newDownloads()
pending := make([]fetchReq, 0)
var toSend *GetRange
var toSendType rangeType

for {
// pending is our slice of request threads that showed up
// asking for a download. We range through it and try to send
Expand All @@ -413,11 +427,16 @@ func fetcherDownloadRangesThread(ctx context.Context, locCh chan []*remotesapi.D
// can get the next range to download from
// |downloads.ranges|.
if toSend == nil {
max := downloads.ranges.DeleteMaxRegion()
max := downloads.dictRanges.DeleteMaxRegion()
if len(max) == 0 {
break
max = downloads.chunkRanges.DeleteMaxRegion()
if len(max) == 0 {
break
}
toSend, toSendType = toGetRange(max), rangeType_Chunk
} else {
toSend, toSendType = toGetRange(max), rangeType_Dictionary
}
toSend = toGetRange(max)
}
path := toSend.ResourcePath()
refresh := downloads.refreshes[path]
Expand All @@ -427,6 +446,9 @@ func fetcherDownloadRangesThread(ctx context.Context, locCh chan []*remotesapi.D
refresh: func(ctx context.Context, err error, client remotesapi.ChunkStoreServiceClient) (string, error) {
return refresh.GetURL(ctx, err, client)
},
rangeType: toSendType,
path: path,
dictCache: downloads.dictCache,
}

select {
Expand Down Expand Up @@ -462,7 +484,7 @@ func fetcherDownloadRangesThread(ctx context.Context, locCh chan []*remotesapi.D
// nil and our ranges Tree is empty, then we have delivered
// every download we will ever see to a download thread. We can
// close |doneCh| and return nil.
if locCh == nil && downloads.ranges.Len() == 0 {
if locCh == nil && downloads.chunkRanges.Len() == 0 && downloads.dictRanges.Len() == 0 {
close(doneCh)
return nil
}
Expand Down Expand Up @@ -595,7 +617,50 @@ func fetcherDownloadURLThread(ctx context.Context, fetchReqCh chan fetchReq, don
case <-ctx.Done():
return context.Cause(ctx)
case fetchResp := <-respCh:
f := fetchResp.get.GetDownloadFunc(ctx, stats, health, fetcher, params, chunkCh, func(ctx context.Context, lastError error, resourcePath string) (string, error) {
var i int
var cb func(context.Context, []byte) error
if fetchResp.rangeType == rangeType_Chunk {
cb = func(ctx context.Context, bs []byte) error {
rng := fetchResp.get.Ranges[i]
i += 1
h := hash.New(rng.Hash[:])
var cc nbs.ToChunker
if rng.GetDict != nil {
dictRes, err := rng.GetDict()
if err != nil {
return err
}
cc = nbs.NewArchiveToChunker(h, dictRes.(*gozstd.DDict), bs)
} else {
var err error
cc, err = nbs.NewCompressedChunk(h, bs)
if err != nil {
return err
}
}
select {
case chunkCh <- cc:
case <-ctx.Done():
return context.Cause(ctx)
}
return nil
}
} else {
cb = func(ctx context.Context, bs []byte) error {
rng := fetchResp.get.Ranges[i]
i += 1
var ddict *gozstd.DDict
decompressed, err := gozstd.Decompress(nil, bs)
if err == nil {
ddict, err = gozstd.NewDDict(decompressed)
}
fetchResp.dictCache.set(fetchResp.path, rng.Offset, rng.Length, ddict, err)
// XXX: For now, we fail here on any error, instead of when we try to use the dictionary...
// For now, the record in the cache will be terminally failed and is never removed.
return err
}
}
f := fetchResp.get.GetDownloadFunc(ctx, stats, health, fetcher, params, cb, func(ctx context.Context, lastError error, resourcePath string) (string, error) {
return fetchResp.refresh(ctx, lastError, client)
})
err := f()
Expand All @@ -607,17 +672,8 @@ func fetcherDownloadURLThread(ctx context.Context, fetchReqCh chan fetchReq, don
}
}

// dictionaryCache caches dictionaries for the chunks in an archive store. When we fetch from a database with an archive,
// we get back the path/offset/length of the dictionary for each chunk. These, by definition, are repeatedly used
// and we don't want to request the same dictionary multiple times.
//
// Currently (feb '25), archives generally have only a default dictionary, so this is kind of overkill. Mainly planning
// for the future when chunk grouping is the default and we could have thousands of dictionaries.
type dictionaryCache struct {
cache *lru.TwoQueueCache[DictionaryKey, *gozstd.DDict]
pending sync.Map
client remotesapi.ChunkStoreServiceClient
dlds downloads
dictionaries sync.Map
}

// DictionaryKey is the a globaly unique identifier for an archive dictionary.
Expand All @@ -629,118 +685,27 @@ type DictionaryKey struct {
len uint32
}

func NewDictionaryCache(client remotesapi.ChunkStoreServiceClient) *dictionaryCache {
c, err := lru.New2Q[DictionaryKey, *gozstd.DDict](1024)
if err != nil {
panic(err)
}

return &dictionaryCache{
cache: c,
client: client,
dlds: newDownloads(),
}
type DictionaryPayload struct {
done chan struct{}
res any
err error
}

func (dc *dictionaryCache) get(rang *GetRange, idx int, stats StatsRecorder, recorder reliable.HealthRecorder) (*gozstd.DDict, error) {
path := rang.ResourcePath()
off := rang.Ranges[idx].DictionaryOffset
ln := rang.Ranges[idx].DictionaryLength

key := DictionaryKey{path, off, ln}
if dict, ok := dc.cache.Get(key); ok {
return dict, nil
}

// Check for an in-flight request. Default dictionary will be requested many times, so we want to avoid
// making multiple requests for the same resource.
if ch, loaded := dc.pending.LoadOrStore(key, make(chan struct{})); loaded {
// There's an ongoing fetch, wait for its completion
<-ch.(chan struct{})
if dict, ok := dc.cache.Get(key); ok {
return dict, nil
}
return nil, errors.New("failed to fetch dictionary due to in-flight request")
}
// When update is done, regardless of success or failure, we need to unblock anyone waiting.
defer func() {
if ch, found := dc.pending.LoadAndDelete(key); found {
close(ch.(chan struct{}))
}
}()

// Fetch the dictionary
ddict, err := dc.fetchDictionary(path, rang.Url, off, ln, stats, recorder)
if err != nil {
return nil, err
}

// Store the dictionary in the cache
dc.cache.Add(key, ddict)

return ddict, nil
func (dc *dictionaryCache) get(path string, offset uint64, length uint32) (func() (any, error), bool) {
key := DictionaryKey{path, offset, length}
entry, loaded := dc.dictionaries.LoadOrStore(key, &DictionaryPayload{done: make(chan struct{})})
payload := entry.(*DictionaryPayload)
return func() (any, error) {
<-payload.done
return payload.res, payload.err
}, !loaded
}

// fetchDictionary performs an GET request for a single span which is for a zstd dictionary.
func (dc *dictionaryCache) fetchDictionary(path, url string, off uint64, ln uint32, stats StatsRecorder, recorder reliable.HealthRecorder) (*gozstd.DDict, error) {
ctx := context.Background()
pathToUrl := dc.dlds.refreshes[path]
if pathToUrl == nil {
// We manually construct the RangeChunk and DownloadLoc in this case because we are retrieving the dictionary span.
// We'll make a single span request, and consume the entire response to create the dictionary.
sRang := &remotesapi.HttpGetRange{}
sRang.Url = url
sRang.Ranges = append(sRang.Ranges, &remotesapi.RangeChunk{Offset: off, Length: ln})
rang := &remotesapi.DownloadLoc_HttpGetRange{HttpGetRange: sRang}
dl := &remotesapi.DownloadLoc{Location: rang}

refresh := new(locationRefresh)
refresh.Add(dl)
dc.dlds.refreshes[path] = refresh
pathToUrl = refresh
}

urlF := func(lastError error) (string, error) {
earl, err := pathToUrl.GetURL(ctx, lastError, dc.client)
if err != nil {
return "", err
}
if earl == "" {
earl = path
}
return earl, nil
}

resp := reliable.StreamingRangeDownload(ctx, reliable.StreamingRangeRequest{
Fetcher: globalHttpFetcher,
Offset: off,
Length: uint64(ln),
UrlFact: urlF,
Stats: stats,
Health: recorder,
BackOffFact: func(ctx context.Context) backoff.BackOff {
return downloadBackOff(ctx, defaultRequestParams.DownloadRetryCount)
},
Throughput: reliable.MinimumThroughputCheck{
CheckInterval: defaultRequestParams.ThroughputMinimumCheckInterval,
BytesPerCheck: defaultRequestParams.ThroughputMinimumBytesPerCheck,
NumIntervals: defaultRequestParams.ThroughputMinimumNumIntervals,
},
RespHeadersTimeout: defaultRequestParams.RespHeadersTimeout,
})
defer resp.Close()

buf := make([]byte, ln)
_, err := io.ReadFull(resp.Body, buf)
if err != nil {
return nil, err
}

// Dictionaries are compressed, but with vanilla zstd, so there is no dictionary.
rawDict, err := gozstd.Decompress(nil, buf)
if err != nil {
return nil, err
}

return gozstd.NewDDict(rawDict)
func (dc *dictionaryCache) set(path string, offset uint64, length uint32, res any, err error) {
key := DictionaryKey{path, offset, length}
entry, _ := dc.dictionaries.LoadOrStore(key, &DictionaryPayload{done: make(chan struct{})})
payload := entry.(*DictionaryPayload)
payload.res = res
payload.err = err
close(payload.done)
}
Loading
Loading