Skip to content

Commit

Permalink
Merge pull request #2185 from dolthub/kom0055/main
Browse files Browse the repository at this point in the history
fix panic of concurrent map writes, when using in memory mode
  • Loading branch information
zachmu authored Dec 5, 2023
2 parents 6c70cd8 + 754cfdc commit 4987a03
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 30 deletions.
61 changes: 61 additions & 0 deletions internal/cmap/cmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package cmap

import "sync"

func NewMap[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
m: make(map[K]V),
mu: sync.RWMutex{},
}
}

type Map[K comparable, V any] struct {
m map[K]V
mu sync.RWMutex
}

func (m *Map[K, V]) Get(key K) (V, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
v, exists := m.m[key]
return v, exists
}

func (m *Map[K, V]) Set(key K, v V) {
m.mu.Lock()
defer m.mu.Unlock()
m.m[key] = v

}

func (m *Map[K, V]) Del(key K) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.m, key)
}

func (m *Map[K, V]) Foreach(f func(key K, v V) error) error {
m.mu.RLock()
defer m.mu.RUnlock()
for k, v := range m.m {
if err := f(k, v); err != nil {
return err
}
}
return nil
}

func (m *Map[K, V]) FindForeach(f func(key K, v V) bool) (K, V, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
for k, v := range m.m {
if f(k, v) {
return k, v, true
}
}
var (
k K
v V
)
return k, v, false
}
61 changes: 31 additions & 30 deletions memory/table_editor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"reflect"
"strings"

"github.com/dolthub/go-mysql-server/internal/cmap"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)
Expand Down Expand Up @@ -380,33 +381,32 @@ func newTableEditAccumulator(t *TableData) tableEditAccumulator {

return &pkTableEditAccumulator{
tableData: t,
adds: make(map[string]sql.Row),
deletes: make(map[string]sql.Row),
adds: cmap.NewMap[string, sql.Row](),
deletes: cmap.NewMap[string, sql.Row](),
}
}

// pkTableEditAccumulator manages the updates of keyed tables. It uses a map to efficiently toggle edits.
type pkTableEditAccumulator struct {
tableData *TableData
adds map[string]sql.Row
deletes map[string]sql.Row
adds *cmap.Map[string, sql.Row]
deletes *cmap.Map[string, sql.Row]
}

var _ tableEditAccumulator = (*pkTableEditAccumulator)(nil)

// Insert implements the tableEditAccumulator interface.
func (pke *pkTableEditAccumulator) Insert(value sql.Row) error {
rowKey := pke.getRowKey(value)
pke.adds[rowKey] = value
pke.adds.Set(rowKey, value)
return nil
}

// Delete implements the tableEditAccumulator interface.
func (pke *pkTableEditAccumulator) Delete(value sql.Row) error {
rowKey := pke.getRowKey(value)

delete(pke.adds, rowKey)
pke.deletes[rowKey] = value
pke.adds.Del(rowKey)
pke.deletes.Set(rowKey, value)

return nil
}
Expand All @@ -415,12 +415,12 @@ func (pke *pkTableEditAccumulator) Delete(value sql.Row) error {
func (pke *pkTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error) {
rowKey := pke.getRowKey(value)

r, exists := pke.adds[rowKey]
r, exists := pke.adds.Get(rowKey)
if exists {
return r, true, nil
}

r, exists = pke.deletes[rowKey]
r, exists = pke.deletes.Get(rowKey)
if exists {
return r, false, nil
}
Expand All @@ -440,16 +440,16 @@ func (pke *pkTableEditAccumulator) Get(value sql.Row) (sql.Row, bool, error) {
// GetByCols finds a row that has the same |cols| values as |value|.
func (pke *pkTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefixLengths []uint16) (sql.Row, bool, error) {
// If we have this row in any delete, bail.
for _, r := range pke.deletes {
if columnsMatch(cols, prefixLengths, r, value) {
return nil, false, nil
}
if _, _, exists := pke.deletes.FindForeach(func(key string, r sql.Row) bool {
return columnsMatch(cols, prefixLengths, r, value)
}); exists {
return nil, false, nil
}

for _, r := range pke.adds {
if columnsMatch(cols, prefixLengths, r, value) {
return r, true, nil
}
if _, r, exists := pke.adds.FindForeach(func(key string, r sql.Row) bool {
return columnsMatch(cols, prefixLengths, r, value)
}); exists {
return r, true, nil
}

for _, partition := range pke.tableData.partitions {
Expand All @@ -465,18 +465,19 @@ func (pke *pkTableEditAccumulator) GetByCols(value sql.Row, cols []int, prefixLe

// ApplyEdits implements the tableEditAccumulator interface.
func (pke *pkTableEditAccumulator) ApplyEdits(table *Table) error {
for _, val := range pke.deletes {
err := pke.deleteHelper(pke.tableData, val)
if err != nil {
return err
}

if err := pke.deletes.Foreach(func(key string, val sql.Row) error {
return pke.deleteHelper(pke.tableData, val)

}); err != nil {
return err
}

for _, val := range pke.adds {
err := pke.insertHelper(pke.tableData, val)
if err != nil {
return err
}
if err := pke.adds.Foreach(func(key string, val sql.Row) error {
return pke.insertHelper(pke.tableData, val)

}); err != nil {
return err
}

pke.tableData.sortRows()
Expand All @@ -487,8 +488,8 @@ func (pke *pkTableEditAccumulator) ApplyEdits(table *Table) error {

// Clear implements the tableEditAccumulator interface.
func (pke *pkTableEditAccumulator) Clear() {
pke.adds = make(map[string]sql.Row)
pke.deletes = make(map[string]sql.Row)
pke.adds = cmap.NewMap[string, sql.Row]()
pke.deletes = cmap.NewMap[string, sql.Row]()
}

// pkColumnIndexes returns the indexes of the primary partitionKeys in the initialized tableData.
Expand Down

0 comments on commit 4987a03

Please sign in to comment.