From 60b62d88f3b42f86ea0dc6df0496f9c47a74e931 Mon Sep 17 00:00:00 2001 From: canstand Date: Sun, 14 Jul 2024 11:04:06 +0800 Subject: [PATCH] fix: data races in some situations --- browser_context.go | 22 ++++++----- channel_owner.go | 4 +- connection.go | 22 ++++++----- event_emitter.go | 43 +++++++++++--------- internal/safe/map.go | 91 +++++++++++++++++++++++++++++++++++++++++++ page.go | 24 +++++++----- tests/binding_test.go | 29 ++++++++++++++ 7 files changed, 186 insertions(+), 49 deletions(-) create mode 100644 internal/safe/map.go diff --git a/browser_context.go b/browser_context.go index 5a1da4fa..2733bad5 100644 --- a/browser_context.go +++ b/browser_context.go @@ -9,6 +9,7 @@ import ( "strings" "sync" + "github.com/playwright-community/playwright-go/internal/safe" "golang.org/x/exp/slices" ) @@ -23,7 +24,7 @@ type browserContextImpl struct { browser *browserImpl serviceWorkers []Worker backgroundPages []Page - bindings map[string]BindingCallFunction + bindings *safe.SyncMap[string, BindingCallFunction] tracing *tracingImpl request *apiRequestContextImpl harRecorders map[string]harRecordingMetadata @@ -240,18 +241,21 @@ func (b *browserContextImpl) ExposeBinding(name string, binding BindingCallFunct needsHandle = handle[0] } for _, page := range b.Pages() { - if _, ok := page.(*pageImpl).bindings[name]; ok { + if _, ok := page.(*pageImpl).bindings.Load(name); ok { return fmt.Errorf("Function '%s' has been already registered in one of the pages", name) } } - if _, ok := b.bindings[name]; ok { + if _, ok := b.bindings.Load(name); ok { return fmt.Errorf("Function '%s' has been already registered", name) } - b.bindings[name] = binding _, err := b.channel.Send("exposeBinding", map[string]interface{}{ "name": name, "needsHandle": needsHandle, }) + if err != nil { + return err + } + b.bindings.Store(name, binding) return err } @@ -533,11 +537,11 @@ func (b *browserContextImpl) StorageState(paths ...string) (*StorageState, error } func (b *browserContextImpl) onBinding(binding *bindingCallImpl) { - function := b.bindings[binding.initializer["name"].(string)] - if function == nil { + function, ok := b.bindings.Load(binding.initializer["name"].(string)) + if !ok || function == nil { return } - go binding.Call(function) + binding.Call(function) } func (b *browserContextImpl) onClose() { @@ -740,7 +744,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini pages: make([]Page, 0), backgroundPages: make([]Page, 0), routes: make([]*routeHandlerEntry, 0), - bindings: make(map[string]BindingCallFunction), + bindings: safe.NewSyncMap[string, BindingCallFunction](), harRecorders: make(map[string]harRecordingMetadata), closed: make(chan struct{}, 1), harRouters: make([]*harRouter, 0), @@ -754,7 +758,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini bt.request = fromChannel(initializer["requestContext"]).(*apiRequestContextImpl) bt.clock = newClock(bt) bt.channel.On("bindingCall", func(params map[string]interface{}) { - bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl)) + go bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl)) }) bt.channel.On("close", bt.onClose) diff --git a/channel_owner.go b/channel_owner.go index d858ac87..007be70e 100644 --- a/channel_owner.go +++ b/channel_owner.go @@ -23,7 +23,7 @@ func (c *channelOwner) dispose(reason ...string) { if c.parent != nil { delete(c.parent.objects, c.guid) } - delete(c.connection.objects, c.guid) + c.connection.objects.Delete(c.guid) if len(reason) > 0 { c.wasCollected = reason[0] == "gc" } @@ -89,7 +89,7 @@ func (c *channelOwner) createChannelOwner(self interface{}, parent *channelOwner c.parent.objects[guid] = c } if c.connection != nil { - c.connection.objects[guid] = c + c.connection.objects.Store(guid, c) } c.channel = newChannel(c, self) c.eventToSubscriptionMapping = map[string]string{} diff --git a/connection.go b/connection.go index c2ca1e1b..65409673 100644 --- a/connection.go +++ b/connection.go @@ -12,6 +12,7 @@ import ( "time" "github.com/go-stack/stack" + "github.com/playwright-community/playwright-go/internal/safe" ) var ( @@ -27,10 +28,10 @@ type result struct { type connection struct { transport transport apiZone sync.Map - objects map[string]*channelOwner + objects *safe.SyncMap[string, *channelOwner] lastID atomic.Uint32 rootObject *rootChannelOwner - callbacks sync.Map + callbacks *safe.SyncMap[uint32, *protocolCallback] afterClose func() onClose func() error isRemote bool @@ -97,21 +98,21 @@ func (c *connection) Dispatch(msg *message) { method := msg.Method if msg.ID != 0 { cb, _ := c.callbacks.LoadAndDelete(uint32(msg.ID)) - if cb.(*protocolCallback).noReply { + if cb.noReply { return } if msg.Error != nil { - cb.(*protocolCallback).SetResult(result{ + cb.SetResult(result{ Error: parseError(msg.Error.Error), }) } else { - cb.(*protocolCallback).SetResult(result{ + cb.SetResult(result{ Data: c.replaceGuidsWithChannels(msg.Result), }) } return } - object := c.objects[msg.GUID] + object, _ := c.objects.Load(msg.GUID) if method == "__create__" { c.createRemoteObject( object, msg.Params["type"].(string), msg.Params["guid"].(string), msg.Params["initializer"], @@ -122,7 +123,7 @@ func (c *connection) Dispatch(msg *message) { return } if method == "__adopt__" { - child, ok := c.objects[msg.Params["guid"].(string)] + child, ok := c.objects.Load(msg.Params["guid"].(string)) if !ok { return } @@ -205,7 +206,7 @@ func (c *connection) replaceGuidsWithChannels(payload interface{}) interface{} { if v.Kind() == reflect.Map { mapV := payload.(map[string]interface{}) if guid, hasGUID := mapV["guid"]; hasGUID { - if channelOwner, ok := c.objects[guid.(string)]; ok { + if channelOwner, ok := c.objects.Load(guid.(string)); ok { return channelOwner.channel } } @@ -254,7 +255,7 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa return nil, fmt.Errorf("could not send message: %w", err) } - return cb.(*protocolCallback), nil + return cb, nil } func (c *connection) setInTracing(isTracing bool) { @@ -327,7 +328,8 @@ func serializeCallLocation(caller stack.Call) map[string]interface{} { func newConnection(transport transport, localUtils ...*localUtilsImpl) *connection { connection := &connection{ abort: make(chan struct{}, 1), - objects: make(map[string]*channelOwner), + callbacks: safe.NewSyncMap[uint32, *protocolCallback](), + objects: safe.NewSyncMap[string, *channelOwner](), transport: transport, isRemote: false, closedError: &safeValue[error]{}, diff --git a/event_emitter.go b/event_emitter.go index 7e534b93..3bfd9c2f 100644 --- a/event_emitter.go +++ b/event_emitter.go @@ -23,6 +23,7 @@ type ( hasInit bool } eventRegister struct { + sync.Mutex listeners []listener } listener struct { @@ -33,18 +34,15 @@ type ( func (e *eventEmitter) Emit(name string, payload ...interface{}) (hasListener bool) { e.eventsMutex.Lock() - defer e.eventsMutex.Unlock() e.init() evt, ok := e.events[name] if !ok { + e.eventsMutex.Unlock() return } - - hasListener = evt.count() > 0 - - evt.callHandlers(payload...) - return + e.eventsMutex.Unlock() + return evt.callHandlers(payload...) > 0 } func (e *eventEmitter) Once(name string, handler interface{}) { @@ -60,10 +58,11 @@ func (e *eventEmitter) RemoveListener(name string, handler interface{}) { defer e.eventsMutex.Unlock() e.init() - if _, ok := e.events[name]; !ok { - return + if evt, ok := e.events[name]; ok { + evt.Lock() + defer evt.Unlock() + evt.removeHandler(handler) } - e.events[name].removeHandler(handler) } // ListenerCount count the listeners by name, count all if name is empty @@ -90,6 +89,7 @@ func (e *eventEmitter) ListenerCount(name string) int { func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) { e.eventsMutex.Lock() + defer e.eventsMutex.Unlock() e.init() if _, ok := e.events[name]; !ok { @@ -98,7 +98,6 @@ func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) { } } e.events[name].addHandler(handler, once) - e.eventsMutex.Unlock() } func (e *eventEmitter) init() { @@ -108,23 +107,27 @@ func (e *eventEmitter) init() { } } -func (e *eventRegister) addHandler(handler interface{}, once bool) { - e.listeners = append(e.listeners, listener{handler: handler, once: once}) +func (er *eventRegister) addHandler(handler interface{}, once bool) { + er.Lock() + defer er.Unlock() + er.listeners = append(er.listeners, listener{handler: handler, once: once}) } -func (e *eventRegister) count() int { - return len(e.listeners) +func (er *eventRegister) count() int { + er.Lock() + defer er.Unlock() + return len(er.listeners) } func (e *eventRegister) removeHandler(handler interface{}) { handlerPtr := reflect.ValueOf(handler).Pointer() - e.listeners = slices.DeleteFunc[[]listener](e.listeners, func(l listener) bool { + e.listeners = slices.DeleteFunc(e.listeners, func(l listener) bool { return reflect.ValueOf(l.handler).Pointer() == handlerPtr }) } -func (e *eventRegister) callHandlers(payloads ...interface{}) { +func (er *eventRegister) callHandlers(payloads ...interface{}) int { payloadV := make([]reflect.Value, 0) for _, p := range payloads { @@ -136,10 +139,14 @@ func (e *eventRegister) callHandlers(payloads ...interface{}) { handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))]) } - for _, l := range e.listeners { + er.Lock() + defer er.Unlock() + count := len(er.listeners) + for _, l := range er.listeners { if l.once { - defer e.removeHandler(l.handler) + defer er.removeHandler(l.handler) } handle(l) } + return count } diff --git a/internal/safe/map.go b/internal/safe/map.go new file mode 100644 index 00000000..439af2c2 --- /dev/null +++ b/internal/safe/map.go @@ -0,0 +1,91 @@ +package safe + +import ( + "sync" + + "golang.org/x/exp/maps" +) + +// SyncMap is a thread-safe map +type SyncMap[K comparable, V any] struct { + sync.RWMutex + m map[K]V +} + +// NewSyncMap creates a new thread-safe map +func NewSyncMap[K comparable, V any]() *SyncMap[K, V] { + return &SyncMap[K, V]{ + m: make(map[K]V), + } +} + +func (m *SyncMap[K, V]) Store(k K, v V) { + m.Lock() + defer m.Unlock() + m.m[k] = v +} + +func (m *SyncMap[K, V]) Load(k K) (v V, ok bool) { + m.RLock() + defer m.RUnlock() + v, ok = m.m[k] + return +} + +// LoadOrStore returns the existing value for the key if present. Otherwise, it stores and returns the given value. +func (m *SyncMap[K, V]) LoadOrStore(k K, v V) (actual V, loaded bool) { + m.Lock() + defer m.Unlock() + actual, loaded = m.m[k] + if loaded { + return + } + m.m[k] = v + return v, false +} + +// LoadAndDelete deletes the value for a key, and returns the previous value if any. +func (m *SyncMap[K, V]) LoadAndDelete(k K) (v V, loaded bool) { + m.Lock() + defer m.Unlock() + v, loaded = m.m[k] + if loaded { + delete(m.m, k) + } + return +} + +func (m *SyncMap[K, V]) Delete(k K) { + m.Lock() + defer m.Unlock() + delete(m.m, k) +} + +func (m *SyncMap[K, V]) Clear() { + m.Lock() + defer m.Unlock() + maps.Clear(m.m) +} + +func (m *SyncMap[K, V]) Len() int { + m.RLock() + defer m.RUnlock() + return len(m.m) +} + +func (m *SyncMap[K, V]) Clone() map[K]V { + m.RLock() + defer m.RUnlock() + return maps.Clone(m.m) +} + +func (m *SyncMap[K, V]) Range(f func(k K, v V) bool) { + m.RLock() + defer m.RUnlock() + + for k, v := range m.m { + if !f(k, v) { + break + } + } +} diff --git a/page.go b/page.go index 27741955..72c420e3 100644 --- a/page.go +++ b/page.go @@ -7,6 +7,7 @@ import ( "os" "sync" + "github.com/playwright-community/playwright-go/internal/safe" "golang.org/x/exp/slices" ) @@ -26,7 +27,7 @@ type pageImpl struct { routes []*routeHandlerEntry viewportSize *Size ownedContext BrowserContext - bindings map[string]BindingCallFunction + bindings *safe.SyncMap[string, BindingCallFunction] closeReason *string closeWasCalled bool harRouters []*harRouter @@ -783,7 +784,7 @@ func newPage(parent *channelOwner, objectType string, guid string, initializer m bt := &pageImpl{ workers: make([]Worker, 0), routes: make([]*routeHandlerEntry, 0), - bindings: make(map[string]BindingCallFunction), + bindings: safe.NewSyncMap[string, BindingCallFunction](), viewportSize: viewportSize, harRouters: make([]*harRouter, 0), locatorHandlers: make(map[float64]*locatorHandlerEntry, 0), @@ -799,7 +800,7 @@ func newPage(parent *channelOwner, objectType string, guid string, initializer m bt.keyboard = newKeyboard(bt.channel) bt.touchscreen = newTouchscreen(bt.channel) bt.channel.On("bindingCall", func(params map[string]interface{}) { - bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl)) + go bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl)) }) bt.channel.On("close", bt.onClose) bt.channel.On("crash", func() { @@ -882,11 +883,11 @@ func (p *pageImpl) closeErrorWithReason() error { } func (p *pageImpl) onBinding(binding *bindingCallImpl) { - function := p.bindings[binding.initializer["name"].(string)] - if function == nil { + function, ok := p.bindings.Load(binding.initializer["name"].(string)) + if !ok || function == nil { return } - go binding.Call(function) + binding.Call(function) } func (p *pageImpl) onFrameAttached(frame *frameImpl) { @@ -1080,18 +1081,21 @@ func (p *pageImpl) ExposeBinding(name string, binding BindingCallFunction, handl if len(handle) == 1 { needsHandle = handle[0] } - if _, ok := p.bindings[name]; ok { + if _, ok := p.bindings.Load(name); ok { return fmt.Errorf("Function '%s' has been already registered", name) } - if _, ok := p.browserContext.bindings[name]; ok { + if _, ok := p.browserContext.bindings.Load(name); ok { return fmt.Errorf("Function '%s' has been already registered in the browser context", name) } - p.bindings[name] = binding _, err := p.channel.Send("exposeBinding", map[string]interface{}{ "name": name, "needsHandle": needsHandle, }) - return err + if err != nil { + return err + } + p.bindings.Store(name, binding) + return nil } func (p *pageImpl) SelectOption(selector string, values SelectOptionValues, options ...PageSelectOptionOptions) ([]string, error) { diff --git a/tests/binding_test.go b/tests/binding_test.go index eebfeb0e..cb65c3b3 100644 --- a/tests/binding_test.go +++ b/tests/binding_test.go @@ -2,7 +2,9 @@ package playwright_test import ( "errors" + "fmt" "strings" + "sync" "testing" "github.com/playwright-community/playwright-go" @@ -123,3 +125,30 @@ func TestPageExposeBindingPanic(t *testing.T) { stack := strings.Split(innerError["stack"].(string), "\n") require.Contains(t, stack[3], "binding_test.go") } + +func TestPageBindingsNoRace(t *testing.T) { + BeforeEach(t) + + wg := &sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + i := i + go func() { + defer wg.Done() + err := page.ExposeBinding(fmt.Sprintf("foo%d", i), func(source *playwright.BindingSource, args ...interface{}) interface{} { + return 42 + }) + require.NoError(t, err) + }() + } + wg.Wait() + ret, err := page.Evaluate(`async () => { + try { + return await window['foo9'](); + } catch (e) { + return {message: e.message, stack: e.stack}; + } + }`) + require.NoError(t, err) + require.Equal(t, 42, ret) +}