From aa976062fe599a7e343c2c9d0045a71b991c8e56 Mon Sep 17 00:00:00 2001 From: Konstantin Burkalev Date: Tue, 7 Nov 2023 15:35:30 +0200 Subject: [PATCH] Fixes subprotocol selection (aling with rfc6455) (#823) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #822 **Summary of Changes** 1. Changed the order of subprotocol selection to prefer client one --------- Co-authored-by: Corey Daley --- server.go | 4 ++-- server_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index 1e720e1d..29d0deb6 100644 --- a/server.go +++ b/server.go @@ -102,8 +102,8 @@ func checkSameOrigin(r *http.Request) bool { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { if u.Subprotocols != nil { clientProtocols := Subprotocols(r) - for _, serverProtocol := range u.Subprotocols { - for _, clientProtocol := range clientProtocols { + for _, clientProtocol := range clientProtocols { + for _, serverProtocol := range u.Subprotocols { if clientProtocol == serverProtocol { return clientProtocol } diff --git a/server_test.go b/server_test.go index 5804be13..0c341ac7 100644 --- a/server_test.go +++ b/server_test.go @@ -54,6 +54,36 @@ func TestIsWebSocketUpgrade(t *testing.T) { } } +func TestSubProtocolSelection(t *testing.T) { + upgrader := Upgrader{ + Subprotocols: []string{"foo", "bar", "baz"}, + } + + r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}} + s := upgrader.selectSubprotocol(&r, nil) + if s != "foo" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "bar" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "baz" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string") + } +} + var checkSameOriginTests = []struct { ok bool r *http.Request