diff --git a/.travis.yml b/.travis.yml index 4cfe98c..923835b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,7 +10,6 @@ env: global: - GOTFLAGS="-race" matrix: - - BUILD_DEPTYPE=gx - BUILD_DEPTYPE=gomod diff --git a/conn.go b/conn.go index 168497b..b63a28a 100644 --- a/conn.go +++ b/conn.go @@ -19,7 +19,6 @@ var _ net.Conn = (*Conn)(nil) type Conn struct { *ws.Conn DefaultMessageType int - done func() reader io.Reader closeOnce sync.Once } @@ -85,14 +84,14 @@ func (c *Conn) Write(b []byte) (n int, err error) { func (c *Conn) Close() error { var err error c.closeOnce.Do(func() { - if c.done != nil { - c.done() - // Be nice to GC - c.done = nil + err1 := c.Conn.WriteControl(ws.CloseMessage, nil, time.Now().Add(GracefulCloseTimeout)) + err2 := c.Conn.Close() + switch { + case err1 != nil: + err = err1 + case err2 != nil: + err = err2 } - - c.Conn.WriteControl(ws.CloseMessage, nil, time.Now().Add(GracefulCloseTimeout)) - err = c.Conn.Close() }) return err } @@ -122,10 +121,9 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { } // NewConn creates a Conn given a regular gorilla/websocket Conn. -func NewConn(raw *ws.Conn, done func()) *Conn { +func NewConn(raw *ws.Conn) *Conn { return &Conn{ Conn: raw, DefaultMessageType: ws.BinaryMessage, - done: done, } } diff --git a/listener.go b/listener.go index 3d774f0..437be98 100644 --- a/listener.go +++ b/listener.go @@ -1,7 +1,6 @@ package websocket import ( - "context" "fmt" "net" "net/http" @@ -21,7 +20,7 @@ type listener struct { func (l *listener) serve() { defer close(l.closed) - http.Serve(l.Listener, l) + _ = http.Serve(l.Listener, l) } func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -31,35 +30,12 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - ctx, cancel := context.WithCancel(context.Background()) - - var cnCh <-chan bool - if cn, ok := w.(http.CloseNotifier); ok { - cnCh = cn.CloseNotify() - } - - wscon := NewConn(c, cancel) - // Just to make sure. - defer wscon.Close() - select { - case l.incoming <- wscon: + case l.incoming <- NewConn(c): case <-l.closed: c.Close() - return - case <-cnCh: - return - } - - // wait until conn gets closed, otherwise the handler closes it early - select { - case <-ctx.Done(): - case <-l.closed: - c.Close() - return - case <-cnCh: - return } + // The connection has been hijacked, it's safe to return. } func (l *listener) Accept() (manet.Conn, error) { diff --git a/websocket.go b/websocket.go index b9ec1d6..681d676 100644 --- a/websocket.go +++ b/websocket.go @@ -86,7 +86,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma return nil, err } - mnc, err := manet.WrapNetConn(NewConn(wscon, nil)) + mnc, err := manet.WrapNetConn(NewConn(wscon)) if err != nil { wscon.Close() return nil, err diff --git a/websocket_test.go b/websocket_test.go index 2c9ad32..b9c688d 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -75,8 +75,14 @@ func TestWebsocketListen(t *testing.T) { return } - c.Write(msg) - c.Close() + _, err = c.Write(msg) + if err != nil { + t.Error(err) + } + err = c.Close() + if err != nil { + t.Error(err) + } }() c, err := l.Accept() @@ -120,8 +126,12 @@ func TestConcurrentClose(t *testing.T) { return } - go c.Write(msg) - go c.Close() + go func() { + _, _ = c.Write(msg) + }() + go func() { + _ = c.Close() + }() } }()