diff --git a/rpcc/conn.go b/rpcc/conn.go index aa210df..4482582 100644 --- a/rpcc/conn.go +++ b/rpcc/conn.go @@ -240,18 +240,17 @@ type Conn struct { compressionLevel func(level int) error mu sync.Mutex // Protects following. - closed bool reqSeq uint64 pending map[uint64]*rpcCall + streams map[string]*streamClients + closed bool + err error // Protected by mu and closed until context is cancelled. reqMu sync.Mutex // Protects following. req Request // Encodes and decodes JSON onto conn. Encoding is // guarded by mutex and decoding is done by recv. codec Codec - - streamMu sync.Mutex // Protects following. - streams map[string]*streamClients } // Response represents an RPC response or notification sent by the server. @@ -378,7 +377,7 @@ func (c *Conn) send(ctx context.Context, call *rpcCall) (err error) { c.mu.Lock() if c.closed { c.mu.Unlock() - return ErrConnClosing + return c.err } c.reqSeq++ reqID := c.reqSeq @@ -421,26 +420,25 @@ func (c *Conn) send(ctx context.Context, call *rpcCall) (err error) { // notify handles RPC notifications and sends them // to the appropriate stream listeners. func (c *Conn) notify(method string, data []byte) { - c.streamMu.Lock() + c.mu.Lock() stream := c.streams[method] - c.streamMu.Unlock() - if stream != nil { // Stream writer must be able to handle incoming writes // even after it has been removed (unsubscribed). stream.write(method, data) } + c.mu.Unlock() } // listen registers a new stream listener (chan) for the RPC notification // method. Returns a function for removing the listener. Error if the // connection is closed. func (c *Conn) listen(method string, w streamWriter) (func(), error) { - c.streamMu.Lock() - defer c.streamMu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() if c.streams == nil { - return nil, ErrConnClosing + return nil, c.err } stream, ok := c.streams[method] @@ -454,35 +452,59 @@ func (c *Conn) listen(method string, w streamWriter) (func(), error) { return unsub, nil } -// Close closes the connection. -func (c *Conn) close(err error) error { - c.cancel() +type closeError struct { + msg string + err error +} + +func (e *closeError) Cause() error { + return e.err +} + +func (e *closeError) Error() string { + return fmt.Sprintf("%s: %v", e.msg, e.err) +} +// Close closes the connection. Subsequent calls to Close will return the error +// that closed the connection. +func (c *Conn) close(err error) error { c.mu.Lock() + defer c.mu.Unlock() + if c.closed { - c.mu.Unlock() - return ErrConnClosing + return c.err } c.closed = true if err == nil { err = ErrConnClosing + } else { + err = &closeError{msg: ErrConnClosing.Error(), err: err} } + c.err = err for id, call := range c.pending { delete(c.pending, id) call.done(err) } - c.mu.Unlock() - - // Stop sending on all streams. - c.streamMu.Lock() + // Stop sending on all streams by signaling + // that the connection is closed. c.streams = nil - c.streamMu.Unlock() // Conn can be nil if DialContext did not complete. if c.conn != nil { - err = c.conn.Close() + wserr := c.conn.Close() + if wserr != nil && err == ErrConnClosing { + err = wserr + c.err = &closeError{msg: ErrConnClosing.Error(), err: err} + } } + // Delay cancel until c.err has settled, at this point any active + // streams will be closed. + c.cancel() + + if err == ErrConnClosing { + return nil + } return err } diff --git a/rpcc/conn_test.go b/rpcc/conn_test.go index 0cc4b3b..1450830 100644 --- a/rpcc/conn_test.go +++ b/rpcc/conn_test.go @@ -339,6 +339,60 @@ func TestConn_StreamRecv(t *testing.T) { } } +func TestConn_PropagateError(t *testing.T) { + srv := newTestServer(t, nil) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s1, err := NewStream(ctx, "test.Stream1", srv.conn) + if err != nil { + t.Fatal(err) + } + defer s1.Close() + s2, err := NewStream(ctx, "test.Stream2", srv.conn) + if err != nil { + t.Fatal(err) + } + defer s2.Close() + + errC := make(chan error, 2) + go func() { + errC <- Invoke(ctx, "test.Invoke", nil, nil, srv.conn) + }() + go func() { + var reply string + errC <- s1.RecvMsg(&reply) + }() + + // Give a little time for both Invoke & Recv. + time.Sleep(5 * time.Millisecond) + + srv.wsConn.Close() + + // Give a little time for connection to close. + time.Sleep(5 * time.Millisecond) + + lastErr := Invoke(ctx, "test.Invoke", nil, nil, srv.conn) + if lastErr == nil { + t.Error("RecvMsg on closed connection: got nil, want an error") + } + + var reply string + err = s2.RecvMsg(&reply) + if err != lastErr { + t.Errorf("Error was not repeated, got %v, want %v", err, lastErr) + } + + for i := 0; i < 2; i++ { + err := <-errC + if err != lastErr { + t.Errorf("Error was not repeated, got %v, want %v", err, lastErr) + } + } +} + func TestConn_Context(t *testing.T) { srv := newTestServer(t, nil) defer srv.Close() diff --git a/rpcc/stream.go b/rpcc/stream.go index 507d2c4..367c147 100644 --- a/rpcc/stream.go +++ b/rpcc/stream.go @@ -172,7 +172,7 @@ func (s *streamClient) watch() { case <-s.ctx.Done(): s.close(s.ctx.Err()) case <-s.conn.ctx.Done(): - s.close(ErrConnClosing) + s.close(s.conn.err) case <-s.done: } } diff --git a/rpcc/stream_test.go b/rpcc/stream_test.go index 7d25007..2758e02 100644 --- a/rpcc/stream_test.go +++ b/rpcc/stream_test.go @@ -198,7 +198,7 @@ func TestStream_RecvAfterConnClose(t *testing.T) { conn.notify("test", []byte(`"message2"`)) conn.notify("test", []byte(`"message3"`)) - connCancel() + conn.Close() for i := 0; i < 3; i++ { var reply string @@ -210,7 +210,7 @@ func TestStream_RecvAfterConnClose(t *testing.T) { err = s.RecvMsg(nil) if err != ErrConnClosing { - t.Errorf("err got %v, want ErrConnClosing", err) + t.Errorf("err got %v, want %v", err, ErrConnClosing) } }