Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: data races in some situations #476

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions browser_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"
"sync"

"github.com/playwright-community/playwright-go/internal/safe"
"golang.org/x/exp/slices"
)

Expand All @@ -23,7 +24,7 @@ type browserContextImpl struct {
browser *browserImpl
serviceWorkers []Worker
backgroundPages []Page
bindings map[string]BindingCallFunction
bindings *safe.SyncMap[string, BindingCallFunction]
tracing *tracingImpl
request *apiRequestContextImpl
harRecorders map[string]harRecordingMetadata
Expand Down Expand Up @@ -240,18 +241,21 @@ func (b *browserContextImpl) ExposeBinding(name string, binding BindingCallFunct
needsHandle = handle[0]
}
for _, page := range b.Pages() {
if _, ok := page.(*pageImpl).bindings[name]; ok {
if _, ok := page.(*pageImpl).bindings.Load(name); ok {
return fmt.Errorf("Function '%s' has been already registered in one of the pages", name)
}
}
if _, ok := b.bindings[name]; ok {
if _, ok := b.bindings.Load(name); ok {
return fmt.Errorf("Function '%s' has been already registered", name)
}
b.bindings[name] = binding
_, err := b.channel.Send("exposeBinding", map[string]interface{}{
"name": name,
"needsHandle": needsHandle,
})
if err != nil {
return err
}
b.bindings.Store(name, binding)
return err
}

Expand Down Expand Up @@ -533,11 +537,11 @@ func (b *browserContextImpl) StorageState(paths ...string) (*StorageState, error
}

func (b *browserContextImpl) onBinding(binding *bindingCallImpl) {
function := b.bindings[binding.initializer["name"].(string)]
if function == nil {
function, ok := b.bindings.Load(binding.initializer["name"].(string))
if !ok || function == nil {
return
}
go binding.Call(function)
binding.Call(function)
}

func (b *browserContextImpl) onClose() {
Expand Down Expand Up @@ -740,7 +744,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini
pages: make([]Page, 0),
backgroundPages: make([]Page, 0),
routes: make([]*routeHandlerEntry, 0),
bindings: make(map[string]BindingCallFunction),
bindings: safe.NewSyncMap[string, BindingCallFunction](),
harRecorders: make(map[string]harRecordingMetadata),
closed: make(chan struct{}, 1),
harRouters: make([]*harRouter, 0),
Expand All @@ -754,7 +758,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini
bt.request = fromChannel(initializer["requestContext"]).(*apiRequestContextImpl)
bt.clock = newClock(bt)
bt.channel.On("bindingCall", func(params map[string]interface{}) {
bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
go bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
})

bt.channel.On("close", bt.onClose)
Expand Down
4 changes: 2 additions & 2 deletions channel_owner.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (c *channelOwner) dispose(reason ...string) {
if c.parent != nil {
delete(c.parent.objects, c.guid)
}
delete(c.connection.objects, c.guid)
c.connection.objects.Delete(c.guid)
if len(reason) > 0 {
c.wasCollected = reason[0] == "gc"
}
Expand Down Expand Up @@ -89,7 +89,7 @@ func (c *channelOwner) createChannelOwner(self interface{}, parent *channelOwner
c.parent.objects[guid] = c
}
if c.connection != nil {
c.connection.objects[guid] = c
c.connection.objects.Store(guid, c)
}
c.channel = newChannel(c, self)
c.eventToSubscriptionMapping = map[string]string{}
Expand Down
22 changes: 12 additions & 10 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/go-stack/stack"
"github.com/playwright-community/playwright-go/internal/safe"
)

var (
Expand All @@ -27,10 +28,10 @@ type result struct {
type connection struct {
transport transport
apiZone sync.Map
objects map[string]*channelOwner
objects *safe.SyncMap[string, *channelOwner]
lastID atomic.Uint32
rootObject *rootChannelOwner
callbacks sync.Map
callbacks *safe.SyncMap[uint32, *protocolCallback]
afterClose func()
onClose func() error
isRemote bool
Expand Down Expand Up @@ -97,21 +98,21 @@ func (c *connection) Dispatch(msg *message) {
method := msg.Method
if msg.ID != 0 {
cb, _ := c.callbacks.LoadAndDelete(uint32(msg.ID))
if cb.(*protocolCallback).noReply {
if cb.noReply {
return
}
if msg.Error != nil {
cb.(*protocolCallback).SetResult(result{
cb.SetResult(result{
Error: parseError(msg.Error.Error),
})
} else {
cb.(*protocolCallback).SetResult(result{
cb.SetResult(result{
Data: c.replaceGuidsWithChannels(msg.Result),
})
}
return
}
object := c.objects[msg.GUID]
object, _ := c.objects.Load(msg.GUID)
if method == "__create__" {
c.createRemoteObject(
object, msg.Params["type"].(string), msg.Params["guid"].(string), msg.Params["initializer"],
Expand All @@ -122,7 +123,7 @@ func (c *connection) Dispatch(msg *message) {
return
}
if method == "__adopt__" {
child, ok := c.objects[msg.Params["guid"].(string)]
child, ok := c.objects.Load(msg.Params["guid"].(string))
if !ok {
return
}
Expand Down Expand Up @@ -205,7 +206,7 @@ func (c *connection) replaceGuidsWithChannels(payload interface{}) interface{} {
if v.Kind() == reflect.Map {
mapV := payload.(map[string]interface{})
if guid, hasGUID := mapV["guid"]; hasGUID {
if channelOwner, ok := c.objects[guid.(string)]; ok {
if channelOwner, ok := c.objects.Load(guid.(string)); ok {
return channelOwner.channel
}
}
Expand Down Expand Up @@ -254,7 +255,7 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa
return nil, fmt.Errorf("could not send message: %w", err)
}

return cb.(*protocolCallback), nil
return cb, nil
}

func (c *connection) setInTracing(isTracing bool) {
Expand Down Expand Up @@ -327,7 +328,8 @@ func serializeCallLocation(caller stack.Call) map[string]interface{} {
func newConnection(transport transport, localUtils ...*localUtilsImpl) *connection {
connection := &connection{
abort: make(chan struct{}, 1),
objects: make(map[string]*channelOwner),
callbacks: safe.NewSyncMap[uint32, *protocolCallback](),
objects: safe.NewSyncMap[string, *channelOwner](),
transport: transport,
isRemote: false,
closedError: &safeValue[error]{},
Expand Down
43 changes: 25 additions & 18 deletions event_emitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type (
hasInit bool
}
eventRegister struct {
sync.Mutex
listeners []listener
}
listener struct {
Expand All @@ -33,18 +34,15 @@ type (

func (e *eventEmitter) Emit(name string, payload ...interface{}) (hasListener bool) {
e.eventsMutex.Lock()
defer e.eventsMutex.Unlock()
e.init()

evt, ok := e.events[name]
if !ok {
e.eventsMutex.Unlock()
return
}

hasListener = evt.count() > 0

evt.callHandlers(payload...)
return
e.eventsMutex.Unlock()
return evt.callHandlers(payload...) > 0
}

func (e *eventEmitter) Once(name string, handler interface{}) {
Expand All @@ -60,10 +58,11 @@ func (e *eventEmitter) RemoveListener(name string, handler interface{}) {
defer e.eventsMutex.Unlock()
e.init()

if _, ok := e.events[name]; !ok {
return
if evt, ok := e.events[name]; ok {
evt.Lock()
defer evt.Unlock()
evt.removeHandler(handler)
}
e.events[name].removeHandler(handler)
}

// ListenerCount count the listeners by name, count all if name is empty
Expand All @@ -90,6 +89,7 @@ func (e *eventEmitter) ListenerCount(name string) int {

func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) {
e.eventsMutex.Lock()
defer e.eventsMutex.Unlock()
e.init()

if _, ok := e.events[name]; !ok {
Expand All @@ -98,7 +98,6 @@ func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) {
}
}
e.events[name].addHandler(handler, once)
e.eventsMutex.Unlock()
}

func (e *eventEmitter) init() {
Expand All @@ -108,23 +107,27 @@ func (e *eventEmitter) init() {
}
}

func (e *eventRegister) addHandler(handler interface{}, once bool) {
e.listeners = append(e.listeners, listener{handler: handler, once: once})
func (er *eventRegister) addHandler(handler interface{}, once bool) {
er.Lock()
defer er.Unlock()
er.listeners = append(er.listeners, listener{handler: handler, once: once})
}

func (e *eventRegister) count() int {
return len(e.listeners)
func (er *eventRegister) count() int {
er.Lock()
defer er.Unlock()
return len(er.listeners)
}

func (e *eventRegister) removeHandler(handler interface{}) {
handlerPtr := reflect.ValueOf(handler).Pointer()

e.listeners = slices.DeleteFunc[[]listener](e.listeners, func(l listener) bool {
e.listeners = slices.DeleteFunc(e.listeners, func(l listener) bool {
return reflect.ValueOf(l.handler).Pointer() == handlerPtr
})
}

func (e *eventRegister) callHandlers(payloads ...interface{}) {
func (er *eventRegister) callHandlers(payloads ...interface{}) int {
payloadV := make([]reflect.Value, 0)

for _, p := range payloads {
Expand All @@ -136,10 +139,14 @@ func (e *eventRegister) callHandlers(payloads ...interface{}) {
handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))])
}

for _, l := range e.listeners {
er.Lock()
defer er.Unlock()
count := len(er.listeners)
for _, l := range er.listeners {
if l.once {
defer e.removeHandler(l.handler)
defer er.removeHandler(l.handler)
}
handle(l)
}
return count
}
91 changes: 91 additions & 0 deletions internal/safe/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package safe

import (
"sync"

"golang.org/x/exp/maps"
)

// SyncMap is a thread-safe map
type SyncMap[K comparable, V any] struct {
sync.RWMutex
m map[K]V
}

// NewSyncMap creates a new thread-safe map
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
return &SyncMap[K, V]{
m: make(map[K]V),
}
}

func (m *SyncMap[K, V]) Store(k K, v V) {
m.Lock()
defer m.Unlock()
m.m[k] = v
}

func (m *SyncMap[K, V]) Load(k K) (v V, ok bool) {
m.RLock()
defer m.RUnlock()
v, ok = m.m[k]
return
}

// LoadOrStore returns the existing value for the key if present. Otherwise, it stores and returns the given value.
func (m *SyncMap[K, V]) LoadOrStore(k K, v V) (actual V, loaded bool) {
m.Lock()
defer m.Unlock()
actual, loaded = m.m[k]
if loaded {
return
}
m.m[k] = v
return v, false
}

// LoadAndDelete deletes the value for a key, and returns the previous value if any.
func (m *SyncMap[K, V]) LoadAndDelete(k K) (v V, loaded bool) {
m.Lock()
defer m.Unlock()
v, loaded = m.m[k]
if loaded {
delete(m.m, k)
}
return
}

func (m *SyncMap[K, V]) Delete(k K) {
m.Lock()
defer m.Unlock()
delete(m.m, k)
}

func (m *SyncMap[K, V]) Clear() {
m.Lock()
defer m.Unlock()
maps.Clear(m.m)
}

func (m *SyncMap[K, V]) Len() int {
m.RLock()
defer m.RUnlock()
return len(m.m)
}

func (m *SyncMap[K, V]) Clone() map[K]V {
m.RLock()
defer m.RUnlock()
return maps.Clone(m.m)
}

func (m *SyncMap[K, V]) Range(f func(k K, v V) bool) {
m.RLock()
defer m.RUnlock()

for k, v := range m.m {
if !f(k, v) {
break
}
}
}
Loading
Loading