-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
- Use io.ReadFull instead of similar function in package. - Return from Read with partial data. Don't attempt to fill buffer. - Do not return net.Error with Temporary() == true
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,6 +95,13 @@ const ( | |
writeWait = time.Second | ||
) | ||
|
||
func hideTempErr(err error) error { | ||
if e, ok := err.(net.Error); ok && e.Temporary() { | ||
err = struct{ error }{err} | ||
} | ||
return err | ||
} | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
garyburd
Author
Contributor
|
||
|
||
func isControl(frameType int) bool { | ||
return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage | ||
} | ||
|
@@ -501,7 +508,7 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { | |
// SetWriteDeadline sets the write deadline on the underlying network | ||
// connection. After a write has timed out, the websocket state is corrupt and | ||
// all future writes will return an error. A zero value for t means writes will | ||
// not time out | ||
// not time out | ||
func (c *Conn) SetWriteDeadline(t time.Time) error { | ||
c.writeDeadline = t | ||
return nil | ||
|
@@ -522,7 +529,7 @@ func (c *Conn) advanceFrame() (int, error) { | |
// 2. Read and parse first two bytes of frame header. | ||
|
||
var b [8]byte | ||
if err := c.read(b[:2]); err != nil { | ||
if _, err := io.ReadFull(c.br, b[:2]); err != nil { | ||
return noFrame, err | ||
} | ||
|
||
|
@@ -562,12 +569,12 @@ func (c *Conn) advanceFrame() (int, error) { | |
|
||
switch c.readRemaining { | ||
case 126: | ||
if err := c.read(b[:2]); err != nil { | ||
if _, err := io.ReadFull(c.br, b[:2]); err != nil { | ||
return noFrame, err | ||
} | ||
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) | ||
case 127: | ||
if err := c.read(b[:8]); err != nil { | ||
if _, err := io.ReadFull(c.br, b[:8]); err != nil { | ||
return noFrame, err | ||
} | ||
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) | ||
|
@@ -581,7 +588,7 @@ func (c *Conn) advanceFrame() (int, error) { | |
|
||
if mask { | ||
c.readMaskPos = 0 | ||
if err := c.read(c.readMaskKey[:]); err != nil { | ||
if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil { | ||
return noFrame, err | ||
} | ||
} | ||
|
@@ -601,12 +608,15 @@ func (c *Conn) advanceFrame() (int, error) { | |
|
||
// 6. Read control frame payload. | ||
|
||
payload := make([]byte, c.readRemaining) | ||
c.readRemaining = 0 | ||
if err := c.read(payload); err != nil { | ||
return noFrame, err | ||
var payload []byte | ||
if c.readRemaining > 0 { | ||
payload = make([]byte, c.readRemaining) | ||
c.readRemaining = 0 | ||
if _, err := io.ReadFull(c.br, payload); err != nil { | ||
return noFrame, err | ||
} | ||
maskBytes(c.readMaskKey, 0, payload) | ||
} | ||
maskBytes(c.readMaskKey, 0, payload) | ||
|
||
// 7. Process control frame payload. | ||
|
||
|
@@ -643,23 +653,6 @@ func (c *Conn) handleProtocolError(message string) error { | |
return errors.New("websocket: " + message) | ||
} | ||
|
||
func (c *Conn) read(buf []byte) error { | ||
var err error | ||
for len(buf) > 0 && err == nil { | ||
var nn int | ||
nn, err = c.br.Read(buf) | ||
buf = buf[nn:] | ||
} | ||
if err == io.EOF { | ||
if len(buf) == 0 { | ||
err = nil | ||
} else { | ||
err = io.ErrUnexpectedEOF | ||
} | ||
} | ||
return err | ||
} | ||
|
||
// NextReader returns the next data message received from the peer. The | ||
// returned messageType is either TextMessage or BinaryMessage. | ||
// | ||
|
@@ -674,8 +667,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { | |
c.readLength = 0 | ||
|
||
for c.readErr == nil { | ||
var frameType int | ||
frameType, c.readErr = c.advanceFrame() | ||
frameType, err := c.advanceFrame() | ||
if err != nil { | ||
c.readErr = hideTempErr(err) | ||
break | ||
} | ||
if frameType == TextMessage || frameType == BinaryMessage { | ||
return frameType, messageReader{c, c.readSeq}, nil | ||
} | ||
|
@@ -700,21 +696,22 @@ func (r messageReader) Read(b []byte) (n int, err error) { | |
if int64(len(b)) > r.c.readRemaining { | ||
b = b[:r.c.readRemaining] | ||
} | ||
r.c.readErr = r.c.read(b) | ||
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b) | ||
r.c.readRemaining -= int64(len(b)) | ||
return len(b), r.c.readErr | ||
n, err := r.c.br.Read(b) | ||
r.c.readErr = hideTempErr(err) | ||
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) | ||
r.c.readRemaining -= int64(n) | ||
return n, r.c.readErr | ||
} | ||
|
||
if r.c.readFinal { | ||
r.c.readSeq++ | ||
return 0, io.EOF | ||
} | ||
|
||
var frameType int | ||
frameType, r.c.readErr = r.c.advanceFrame() | ||
|
||
if frameType == TextMessage || frameType == BinaryMessage { | ||
frameType, err := r.c.advanceFrame() | ||
if err != nil { | ||
r.c.readErr = hideTempErr(err) | ||
} else if frameType == TextMessage || frameType == BinaryMessage { | ||
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") | ||
} | ||
} | ||
|
Hey this is messing with my unit tests in a weird way. Used to be:
But that stopped working:
It seems that the errors returned after timeout have the Temporary flag set. Do you really want to hide that from users of the lib? Is my unit test now obsolete, do I have to just test for "err != nil"?