Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: add some memory tracker in HashJoin #33918

Merged
merged 8 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ type baseHashAggWorker struct {
BInMap int // indicate there are 2^BInMap buckets in Golang Map.
}

const (
// ref https://github.com/golang/go/blob/go1.15.6/src/reflect/type.go#L2162.
// defBucketMemoryUsage = bucketSize*(1+unsafe.Sizeof(string) + unsafe.Sizeof(slice))+2*ptrSize
// The bucket size may be changed by golang implement in the future.
defBucketMemoryUsage = 8*(1+16+24) + 16
)

func newBaseHashAggWorker(ctx sessionctx.Context, finishCh <-chan struct{}, aggFuncs []aggfuncs.AggFunc,
maxChunkSize int, memTrack *memory.Tracker) baseHashAggWorker {
baseWorker := baseHashAggWorker{
Expand Down Expand Up @@ -332,7 +325,7 @@ func (e *HashAggExec) initForUnparallelExec() {
e.partialResultMap = make(aggPartialResultMapper)
e.bInMap = 0
failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(defBucketMemoryUsage*(1<<e.bInMap) + setSize)
e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice*(1<<e.bInMap) + setSize)
e.groupKeyBuffer = make([][]byte, 0, 8)
e.childResult = newFirstChunk(e.children[0])
e.memTracker.Consume(e.childResult.MemoryUsage())
Expand Down Expand Up @@ -395,7 +388,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
}
// There is a bucket in the empty partialResultsMap.
failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(defBucketMemoryUsage * (1 << w.BInMap))
e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice * (1 << w.BInMap))
if e.stats != nil {
w.stats = &AggWorkerStat{}
e.stats.PartialStats = append(e.stats.PartialStats, w.stats)
Expand Down Expand Up @@ -425,7 +418,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
groupKeys: make([][]byte, 0, 8),
}
// There is a bucket in the empty partialResultsMap.
e.memTracker.Consume(defBucketMemoryUsage*(1<<w.BInMap) + setSize)
e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice*(1<<w.BInMap) + setSize)
if e.stats != nil {
w.stats = &AggWorkerStat{}
e.stats.FinalStats = append(e.stats.FinalStats, w.stats)
Expand Down Expand Up @@ -615,7 +608,7 @@ func (w *baseHashAggWorker) getPartialResult(sc *stmtctx.StatementContext, group
allMemDelta += int64(len(groupKey[i]))
// Map will expand when count > bucketNum * loadFactor. The memory usage will doubled.
if len(mapper) > (1<<w.BInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
w.memTracker.Consume(defBucketMemoryUsage * (1 << w.BInMap))
w.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice * (1 << w.BInMap))
w.BInMap++
}
}
Expand Down Expand Up @@ -1084,7 +1077,7 @@ func (e *HashAggExec) getPartialResults(groupKey string) []aggfuncs.PartialResul
allMemDelta += int64(len(groupKey))
// Map will expand when count > bucketNum * loadFactor. The memory usage will doubled.
if len(e.partialResultMap) > (1<<e.bInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
e.memTracker.Consume(defBucketMemoryUsage * (1 << e.bInMap))
e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice * (1 << e.bInMap))
e.bInMap++
}
}
Expand Down
22 changes: 13 additions & 9 deletions executor/concurrent_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package executor

import (
"sync"

"github.com/pingcap/tidb/util/hack"
)

// ShardCount controls the shard maps within the concurrent map
Expand All @@ -28,14 +30,15 @@ type concurrentMap []*concurrentMapShared
// A "thread" safe string to anything map.
type concurrentMapShared struct {
items map[uint64]*entry
sync.RWMutex // Read Write mutex, guards access to internal map.
sync.RWMutex // Read Write mutex, guards access to internal map.
bInMap int64 // indicate there are 2^bInMap buckets in items
}

// newConcurrentMap creates a new concurrent map.
func newConcurrentMap() concurrentMap {
m := make(concurrentMap, ShardCount)
for i := 0; i < ShardCount; i++ {
m[i] = &concurrentMapShared{items: make(map[uint64]*entry)}
m[i] = &concurrentMapShared{items: make(map[uint64]*entry), bInMap: 0}
}
return m
}
Expand All @@ -46,17 +49,18 @@ func (m concurrentMap) getShard(hashKey uint64) *concurrentMapShared {
}

// Insert inserts a value in a shard safely
func (m concurrentMap) Insert(key uint64, value *entry) {
func (m concurrentMap) Insert(key uint64, value *entry) (memDelta int64) {
shard := m.getShard(key)
shard.Lock()
v, ok := shard.items[key]
if !ok {
shard.items[key] = value
} else {
value.next = v
shard.items[key] = value
oldValue := shard.items[key]
value.next = oldValue
shard.items[key] = value
if len(shard.items) > (1<<shard.bInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
memDelta = hack.DefBucketMemoryUsageForMapIntToPtr * (1 << shard.bInMap)
shard.bInMap++
}
shard.Unlock()
return memDelta
}

// UpsertCb : Callback to return new element to be inserted into the map
Expand Down
35 changes: 35 additions & 0 deletions executor/concurrent_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ package executor

import (
"sync"
"sync/atomic"
"testing"

"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -65,3 +67,36 @@ func TestConcurrentMap(t *testing.T) {
_, ok = m.Get(uint64(mod + 1))
require.False(t, ok)
}

func TestConcurrentMapMemoryUsage(t *testing.T) {
m := newConcurrentMap()
const iterations = 1024 * hack.LoadFactorNum / hack.LoadFactorDen
var memUsage int64
wg := &sync.WaitGroup{}
wg.Add(2)
// Using go routines insert 1000 entries into the map.
go func() {
defer wg.Done()
var memDelta int64
for i := 0; i < iterations/2; i++ {
// Add entry to map.
memDelta += m.Insert(uint64(i*ShardCount), &entry{chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(i)}, nil})
}
atomic.AddInt64(&memUsage, memDelta)
}()

go func() {
defer wg.Done()
var memDelta int64
for i := iterations / 2; i < iterations; i++ {
// Add entry to map.
memDelta += m.Insert(uint64(i*ShardCount), &entry{chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(i)}, nil})
}
atomic.AddInt64(&memUsage, memDelta)
}()
wg.Wait()

// The first bucket memory usage will be recorded in concurrentMapHashTable, here only test the memory delta.
require.Equal(t, int64(1023)*hack.DefBucketMemoryUsageForMapIntToPtr, memUsage)
require.Equal(t, int64(10), m.getShard(0).bInMap)
}
45 changes: 41 additions & 4 deletions executor/hash_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"hash/fnv"
"sync/atomic"
"time"
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/disk"
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/memory"
)

Expand Down Expand Up @@ -186,6 +188,7 @@ func (c *hashRowContainer) PutChunkSelected(chk *chunk.Chunk, selected, ignoreNu
rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(i)}
c.hashTable.Put(key, rowPtr)
}
c.GetMemTracker().Consume(c.hashTable.GetMemoryDelta())
return nil
}

Expand Down Expand Up @@ -251,7 +254,7 @@ func newEntryStore() *entryStore {
return es
}

func (es *entryStore) GetStore() (e *entry) {
func (es *entryStore) GetStore() (e *entry, memDelta int64) {
sliceIdx := uint32(len(es.slices) - 1)
slice := es.slices[sliceIdx]
if es.cursor >= cap(slice) {
Expand All @@ -263,6 +266,7 @@ func (es *entryStore) GetStore() (e *entry) {
es.slices = append(es.slices, slice)
sliceIdx++
es.cursor = 0
memDelta = int64(unsafe.Sizeof(entry{})) * int64(size)
}
e = &es.slices[sliceIdx][es.cursor]
es.cursor++
Expand All @@ -273,6 +277,7 @@ type baseHashTable interface {
Put(hashKey uint64, rowPtr chunk.RowPtr)
Get(hashKey uint64) (rowPtrs []chunk.RowPtr)
Len() uint64
GetMemoryDelta() int64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for this func

}

// TODO (fangzhuhe) remove unsafeHashTable later if it not used anymore
Expand All @@ -283,6 +288,9 @@ type unsafeHashTable struct {
hashMap map[uint64]*entry
entryStore *entryStore
length uint64

bInMap int64 // indicate there are 2^bInMap buckets in hashMap
memDelta int64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment for this

}

// newUnsafeHashTable creates a new unsafeHashTable. estCount means the estimated size of the hashMap.
Expand All @@ -297,11 +305,16 @@ func newUnsafeHashTable(estCount int) *unsafeHashTable {
// Put puts the key/rowPtr pairs to the unsafeHashTable, multiple rowPtrs are stored in a list.
func (ht *unsafeHashTable) Put(hashKey uint64, rowPtr chunk.RowPtr) {
oldEntry := ht.hashMap[hashKey]
newEntry := ht.entryStore.GetStore()
newEntry, memDelta := ht.entryStore.GetStore()
newEntry.ptr = rowPtr
newEntry.next = oldEntry
ht.hashMap[hashKey] = newEntry
if len(ht.hashMap) > (1<<ht.bInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
memDelta += hack.DefBucketMemoryUsageForMapIntToPtr * (1 << ht.bInMap)
ht.bInMap++
}
ht.length++
ht.memDelta += memDelta
}

// Get gets the values of the "key" and appends them to "values".
Expand All @@ -318,11 +331,19 @@ func (ht *unsafeHashTable) Get(hashKey uint64) (rowPtrs []chunk.RowPtr) {
// if the same key is put more than once.
func (ht *unsafeHashTable) Len() uint64 { return ht.length }

// GetMemoryDelta gets the memDelta of the concurrentMapHashTable.
func (ht *unsafeHashTable) GetMemoryDelta() int64 {
memDelta := ht.memDelta
ht.memDelta = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why set this to 0?

return memDelta
}

// concurrentMapHashTable is a concurrent hash table built on concurrentMap
type concurrentMapHashTable struct {
hashMap concurrentMap
entryStore *entryStore
length uint64
memDelta int64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for this

}

// newConcurrentMapHashTable creates a concurrentMapHashTable
Expand All @@ -331,6 +352,7 @@ func newConcurrentMapHashTable() *concurrentMapHashTable {
ht.hashMap = newConcurrentMap()
ht.entryStore = newEntryStore()
ht.length = 0
ht.memDelta = hack.DefBucketMemoryUsageForMapIntToPtr + int64(unsafe.Sizeof(entry{}))*initialEntrySliceLen
return ht
}

Expand All @@ -341,10 +363,13 @@ func (ht *concurrentMapHashTable) Len() uint64 {

// Put puts the key/rowPtr pairs to the concurrentMapHashTable, multiple rowPtrs are stored in a list.
func (ht *concurrentMapHashTable) Put(hashKey uint64, rowPtr chunk.RowPtr) {
newEntry := ht.entryStore.GetStore()
newEntry, memDelta := ht.entryStore.GetStore()
newEntry.ptr = rowPtr
newEntry.next = nil
ht.hashMap.Insert(hashKey, newEntry)
memDelta += ht.hashMap.Insert(hashKey, newEntry)
if memDelta != 0 {
atomic.AddInt64(&ht.memDelta, memDelta)
}
atomic.AddUint64(&ht.length, 1)
}

Expand All @@ -357,3 +382,15 @@ func (ht *concurrentMapHashTable) Get(hashKey uint64) (rowPtrs []chunk.RowPtr) {
}
return
}

// GetMemoryDelta gets the memDelta of the concurrentMapHashTable. Memory delta will be cleared after each fetch.
func (ht *concurrentMapHashTable) GetMemoryDelta() int64 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetAndCleanMemoryDelta

var memDelta int64
for {
memDelta = atomic.LoadInt64(&ht.memDelta)
if atomic.CompareAndSwapInt64(&ht.memDelta, memDelta, 0) {
break
}
}
return memDelta
}
18 changes: 18 additions & 0 deletions executor/hash_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ import (
"fmt"
"hash"
"hash/fnv"
"sync"
"testing"

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -162,3 +164,19 @@ func testHashRowContainer(t *testing.T, hashFunc func() hash.Hash64, spill bool)
require.Equal(t, chk1.GetRow(1).GetDatumRow(colTypes), matched[1].GetDatumRow(colTypes))
return rowContainer, copiedRC
}

func TestConcurrentMapHashTableMemoryUsage(t *testing.T) {
m := newConcurrentMapHashTable()
const iterations = 1024 * hack.LoadFactorNum / hack.LoadFactorDen // 6656
wg := &sync.WaitGroup{}
wg.Add(2)
// Note: Now concurrentMapHashTable doesn't support inserting in parallel.
for i := 0; i < iterations; i++ {
// Add entry to map.
m.Put(uint64(i*ShardCount), chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(i)})
}
mapMemoryExpected := int64(1024) * hack.DefBucketMemoryUsageForMapIntToPtr
entryMemoryExpected := 16 * int64(64+128+256+512+1024+2048+4096)
require.Equal(t, mapMemoryExpected+entryMemoryExpected, m.GetMemoryDelta())
require.Equal(t, int64(0), m.GetMemoryDelta())
}
9 changes: 9 additions & 0 deletions util/hack/hack.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ const (
// LoadFactorDen is the denominator of load factor
LoadFactorDen = 2
)

const (
// DefBucketMemoryUsageForMapStrToSlice = bucketSize*(1+unsafe.Sizeof(string) + unsafe.Sizeof(slice))+2*ptrSize
// ref https://github.com/golang/go/blob/go1.15.6/src/reflect/type.go#L2162.
// The bucket size may be changed by golang implement in the future.
DefBucketMemoryUsageForMapStrToSlice = 8*(1+16+24) + 16
// DefBucketMemoryUsageForMapIntToPtr = bucketSize*(1+unsafe.Sizeof(uint64) + unsafe.Sizeof(pointer))+2*ptrSize
DefBucketMemoryUsageForMapIntToPtr = 8*(1+8+8) + 16
)