Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix panic of concurrent map writes, when using in memory mode #2179

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions internal/cmap/cmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package cmap

import "sync"

func NewMap[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
m: make(map[K]V),
mu: sync.RWMutex{},
}
}

type Map[K comparable, V any] struct {
m map[K]V
mu sync.RWMutex
}

func (m *Map[K, V]) Get(key K) (V, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
v, exists := m.m[key]
return v, exists
}

func (m *Map[K, V]) Set(key K, v V) {
m.mu.Lock()
defer m.mu.Unlock()
m.m[key] = v

}

func (m *Map[K, V]) Del(key K) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.m, key)
}

func (m *Map[K, V]) Foreach(f func(key K, v V) error) error {
m.mu.RLock()
defer m.mu.RUnlock()
for k, v := range m.m {
if err := f(k, v); err != nil {
return err
}
}
return nil
}

func (m *Map[K, V]) FindForeach(f func(key K, v V) bool) (K, V, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
for k, v := range m.m {
if f(k, v) {
return k, v, true
}
}
var (
k K
v V
)
return k, v, false
}
61 changes: 31 additions & 30 deletions memory/table_editor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"reflect"
"strings"

"github.com/dolthub/go-mysql-server/internal/cmap"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)
Expand Down Expand Up @@ -380,33 +381,32 @@ func newTableEditAccumulator(t *TableData) tableEditAccumulator {

return &pkTableEditAccumulator{
tableData: t,
adds: make(map[string]sql.Row),
deletes: make(map[string]sql.Row),
adds: cmap.NewMap[string, sql.Row](),
deletes: cmap.NewMap[string, sql.Row](),
}
}

// pkTableEditAccumulator manages the updates of keyed tables. It uses a map to efficiently toggle edits.
type pkTableEditAccumulator struct {
tableData *TableData
adds map[string]sql.Row
deletes map[string]sql.Row
adds *cmap.Map[string, sql.Row]
deletes *cmap.Map[string, sql.Row]
}

var _ tableEditAccumulator = (*pkTableEditAccumulator)(nil)

// Insert implements the tableEditAccumulator interface.
func (pke *pkTableEditAccumulator) Insert(value sql.Row) error {
rowKey := pke.getRowKey(value)
pke.adds[rowKey] = value
pke.adds.Set(rowKey, value)
return nil
}

// Delete implements the tableEditAccumulator interface.
func (pke *pkTableEditAccumulator) Delete(value sql.Row) error {
rowKey := pke.getRowKey(value)

delete(pke.adds, rowKey)
pke.deletes[rowKey] = value
pke.adds.Del(rowKey)
pke.deletes.Set(rowKey, value)

return nil
}
Expand All @@ -415,12 +415,12 @@ func (pke *pkTableEditAccumulator) Delete(value sql.Row) error {
func (pke *pkTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error) {
rowKey := pke.getRowKey(value)

r, exists := pke.adds[rowKey]
r, exists := pke.adds.Get(rowKey)
if exists {
return r, true, nil
}

r, exists = pke.deletes[rowKey]
r, exists = pke.deletes.Get(rowKey)
if exists {
return r, false, nil
}
Expand All @@ -440,16 +440,16 @@ func (pke *pkTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error) {
// GetByCols finds a row that has the same |cols| values as |value|.
func (pke *pkTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefixLengths []uint16) (sql.Row, bool, error) {
// If we have this row in any delete, bail.
for _, r := range pke.deletes {
if columnsMatch(cols, prefixLengths, r, value) {
return nil, false, nil
}
if _, _, exists := pke.deletes.FindForeach(func(key string, r sql.Row) bool {
return columnsMatch(cols, prefixLengths, r, value)
}); exists {
return nil, false, nil
}

for _, r := range pke.adds {
if columnsMatch(cols, prefixLengths, r, value) {
return r, true, nil
}
if _, r, exists := pke.adds.FindForeach(func(key string, r sql.Row) bool {
return columnsMatch(cols, prefixLengths, r, value)
}); exists {
return r, true, nil
}

for _, partition := range pke.tableData.partitions {
Expand All @@ -465,18 +465,19 @@ func (pke *pkTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefixLe

// ApplyEdits implements the tableEditAccumulator interface.
func (pke *pkTableEditAccumulator) ApplyEdits(table *Table) error {
for _, val := range pke.deletes {
err := pke.deleteHelper(pke.tableData, val)
if err != nil {
return err
}

if err := pke.deletes.Foreach(func(key string, val sql.Row) error {
return pke.deleteHelper(pke.tableData, val)

}); err != nil {
return err
}

for _, val := range pke.adds {
err := pke.insertHelper(pke.tableData, val)
if err != nil {
return err
}
if err := pke.adds.Foreach(func(key string, val sql.Row) error {
return pke.insertHelper(pke.tableData, val)

}); err != nil {
return err
}

pke.tableData.sortRows()
Expand All @@ -487,8 +488,8 @@ func (pke *pkTableEditAccumulator) ApplyEdits(table *Table) error {

// Clear implements the tableEditAccumulator interface.
func (pke *pkTableEditAccumulator) Clear() {
pke.adds = make(map[string]sql.Row)
pke.deletes = make(map[string]sql.Row)
pke.adds = cmap.NewMap[string, sql.Row]()
pke.deletes = cmap.NewMap[string, sql.Row]()
}

// pkColumnIndexes returns the indexes of the primary partitionKeys in the initialized tableData.
Expand Down