Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: race condition when adding new channel to NodeInfo #735

Merged
merged 4 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix: json marshal nodeInfo channels fails
  • Loading branch information
lklimek committed Feb 8, 2024
commit 5b45bf91bdc92606c229983bbe71c67d59fcbea3
100 changes: 61 additions & 39 deletions internal/libs/sync/concurrent_slice.go
Original file line number Diff line number Diff line change
@@ -1,94 +1,116 @@
package sync

import "sync"
import (
"encoding/json"
"sync"
)

type concurrentSlice[T any] struct {
mtx sync.RWMutex `json:"-"`
Items []T `json:"items"`
}

// Slice is a thread-safe slice interface
type Slice[T any] interface {
Append(val ...T)
Reset()
Get(index int) T
Set(index int, val T)
ToSlice() []T
Len() int
Copy() Slice[T]
// ConcurrentSlice is a thread-safe slice.
//
// It is safe to use from multiple goroutines without additional locking.
// It should be referenced by pointer.
//
// Initialize using NewConcurrentSlice().
type ConcurrentSlice[T any] struct {
mtx sync.RWMutex
items []T
}

// NewConcurrentSlice creates a new thread-safe slice.
// It is safe to use from multiple goroutines without additional locking.
// It can be referenced by value, and will behave similarly to a regular slice (which is a reference type).
func NewConcurrentSlice[T any](initial ...T) Slice[T] {
return &concurrentSlice[T]{
Items: initial,
func NewConcurrentSlice[T any](initial ...T) *ConcurrentSlice[T] {
return &ConcurrentSlice[T]{
items: initial,
}
}

// Append adds an element to the slice
func (s *concurrentSlice[T]) Append(val ...T) {
func (s *ConcurrentSlice[T]) Append(val ...T) {
s.mtx.Lock()
defer s.mtx.Unlock()

s.Items = append(s.Items, val...)
s.items = append(s.items, val...)
}

// Reset removes all elements from the slice
func (s *concurrentSlice[T]) Reset() {
func (s *ConcurrentSlice[T]) Reset() {
s.mtx.Lock()
defer s.mtx.Unlock()

s.Items = []T{}
s.items = []T{}
}

// Get returns the value at the given index
func (s *concurrentSlice[T]) Get(index int) T {
func (s *ConcurrentSlice[T]) Get(index int) T {
s.mtx.RLock()
defer s.mtx.RUnlock()

return s.Items[index]
return s.items[index]
}

func (s *concurrentSlice[T]) Set(index int, val T) {
// Set updates the value at the given index.
// If the index is greater than the length of the slice, it panics.
// If the index is equal to the length of the slice, the value is appended.
// Otherwise, the value at the index is updated.
func (s *ConcurrentSlice[T]) Set(index int, val T) {
s.mtx.Lock()
defer s.mtx.Unlock()

if index > len(s.Items) {
if index > len(s.items) {
panic("index out of range")
} else if index == len(s.Items) {
s.Items = append(s.Items, val)
} else if index == len(s.items) {
s.items = append(s.items, val)
return
}

s.Items[index] = val
s.items[index] = val
}

// ToSlice returns a copy of the underlying slice
func (s *concurrentSlice[T]) ToSlice() []T {
func (s *ConcurrentSlice[T]) ToSlice() []T {
s.mtx.RLock()
defer s.mtx.RUnlock()

slice := make([]T, len(s.Items))
copy(slice, s.Items)
slice := make([]T, len(s.items))
copy(slice, s.items)
return slice
}

// Len returns the length of the slice
func (s *concurrentSlice[T]) Len() int {
func (s *ConcurrentSlice[T]) Len() int {
s.mtx.RLock()
defer s.mtx.RUnlock()

return len(s.Items)
return len(s.items)
}

// Copy returns a new deep copy of concurrentSlice with the same elements
func (s *concurrentSlice[T]) Copy() Slice[T] {
func (s *ConcurrentSlice[T]) Copy() ConcurrentSlice[T] {
s.mtx.RLock()
defer s.mtx.RUnlock()

return &concurrentSlice[T]{
Items: s.ToSlice(),
return ConcurrentSlice[T]{
items: s.ToSlice(),
}
}

// MarshalJSON implements the json.Marshaler interface.
func (cs *ConcurrentSlice[T]) MarshalJSON() ([]byte, error) {
cs.mtx.RLock()
defer cs.mtx.RUnlock()

return json.Marshal(cs.items)
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (cs *ConcurrentSlice[T]) UnmarshalJSON(data []byte) error {
var items []T
if err := json.Unmarshal(data, &items); err != nil {
return err
}

cs.mtx.Lock()
defer cs.mtx.Unlock()

cs.items = items
return nil
}
19 changes: 14 additions & 5 deletions internal/libs/sync/concurrent_slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,26 @@ func TestConcurrentSlice_Concurrency(t *testing.T) {
}

func TestConcurrentSlice_MarshalUnmarshalJSON(t *testing.T) {
// Create a concurrentSlice
type node struct {
Channels *ConcurrentSlice[uint16]
}
cs := NewConcurrentSlice[uint16](1, 2, 3)

node1 := node{
Channels: cs,
}

// Marshal to JSON
data, err := json.Marshal(cs)
data, err := json.Marshal(node1)
assert.NoError(t, err, "Failed to marshal concurrentSlice")

// Unmarshal from JSON
var cs2 concurrentSlice[uint16]
err = json.Unmarshal(data, &cs2)
node2 := node{
// Channels: NewConcurrentSlice[uint16](),
}

err = json.Unmarshal(data, &node2)
assert.NoError(t, err, "Failed to unmarshal concurrentSlice")

assert.EqualValues(t, cs.ToSlice(), cs2.ToSlice())
assert.EqualValues(t, node1.Channels.ToSlice(), node2.Channels.ToSlice())
}
5 changes: 3 additions & 2 deletions types/node_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type NodeInfo struct {
Network string `json:"network"` // network/chain ID
Version string `json:"version"` // major.minor.revision
// Channels supported by this node. Use GetChannels() as a getter.
Channels tmsync.Slice[uint16] `json:"channels"` // channels this node knows about
Channels *tmsync.ConcurrentSlice[uint16] `json:"channels"` // channels this node knows about

// ASCIIText fields
Moniker string `json:"moniker"` // arbitrary moniker
Expand Down Expand Up @@ -185,13 +185,14 @@ func (info *NodeInfo) AddChannel(channel uint16) {
}

func (info NodeInfo) Copy() NodeInfo {
chans := info.Channels.Copy()
return NodeInfo{
ProtocolVersion: info.ProtocolVersion,
NodeID: info.NodeID,
ListenAddr: info.ListenAddr,
Network: info.Network,
Version: info.Version,
Channels: info.Channels.Copy(),
Channels: &chans,
Moniker: info.Moniker,
Other: info.Other,
ProTxHash: info.ProTxHash.Copy(),
Expand Down
6 changes: 5 additions & 1 deletion types/node_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestNodeInfoValidate(t *testing.T) {
{
"Too Many Channels",
func(ni *NodeInfo) {
ni.Channels = channels.Copy()
ni.Channels = ref(channels.Copy())
ni.Channels.Append(maxNumChannels)
},
true,
Expand Down Expand Up @@ -101,6 +101,10 @@ func TestNodeInfoValidate(t *testing.T) {

}

func ref[T any](t T) *T {
return &t
}

func testNodeID() NodeID {
return NodeIDFromPubKey(ed25519.GenPrivKey().PubKey())
}
Expand Down
Loading