Skip to content

Commit

Permalink
webrtc: use a single stream mutex, prevent (0, nil) Read return values
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jul 31, 2023
1 parent 9fcfc62 commit ba129f2
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 135 deletions.
20 changes: 7 additions & 13 deletions p2p/transport/webrtc/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ const (
// Package pion detached data channel into a net.Conn
// and then a network.MuxedStream
type stream struct {
mx sync.Mutex
// pbio.Reader is not thread safe,
// and while our Read is not promised to be thread safe,
// we ourselves internally read from multiple routines...
readMu sync.Mutex
reader pbio.Reader
// this buffer is limited up to a single message. Reason we need it
// is because a reader might read a message midway, and so we need a
Expand All @@ -72,7 +72,6 @@ type stream struct {

// The public Write API is not promised to be thread safe,
// but we need to be able to write control messages.
writeMu sync.Mutex
writer pbio.Writer
sendStateChanged chan struct{}
sendState sendState
Expand Down Expand Up @@ -118,8 +117,8 @@ func newStream(

channel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
channel.OnBufferedAmountLow(func() {
s.writeMu.Lock()
defer s.writeMu.Unlock()
s.mx.Lock()
defer s.mx.Unlock()
// first send out queued control messages
for len(s.controlMsgQueue) > 0 {
msg := s.controlMsgQueue[0]
Expand Down Expand Up @@ -171,16 +170,12 @@ func (s *stream) SetDeadline(t time.Time) error {
// processIncomingFlag process the flag on an incoming message
// It needs to be called with msg.Flag, not msg.GetFlag(),
// otherwise we'd misinterpret the default value.
// It needs to be called while the mutex is locked.
func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
if flag == nil {
return
}

s.writeMu.Lock()
defer s.writeMu.Unlock()
s.readMu.Lock()
defer s.readMu.Unlock()

switch *flag {
case pb.Message_FIN:
if s.receiveState == receiveStateReceiving {
Expand Down Expand Up @@ -215,9 +210,8 @@ func (s *stream) maybeDeclareStreamDone() {
}

func (s *stream) setCloseError(e error) {
s.writeMu.Lock()
defer s.writeMu.Unlock()
s.readMu.Lock()
defer s.readMu.Unlock()
s.mx.Lock()
defer s.mx.Unlock()

s.closeErr = e
}
104 changes: 56 additions & 48 deletions p2p/transport/webrtc/stream_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ import (
"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb"
)

// Read from the underlying datachannel.
// This also process SCTP control messages such as DCEP, which is handled internally by pion,
// and stream closure which is signaled by `Read` on the datachannel returning io.EOF.
func (s *stream) Read(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}

s.mx.Lock()
defer s.mx.Unlock()

if s.closeErr != nil {
return 0, s.closeErr
}
Expand All @@ -23,67 +27,71 @@ func (s *stream) Read(b []byte) (int, error) {
return 0, network.ErrReset
}

if s.nextMessage == nil {
// load the next message
var msg pb.Message
if err := s.readMessageFromDataChannel(&msg); err != nil {
if err == io.EOF {
// if the channel was properly closed, return EOF
if s.receiveState == receiveStateDataRead {
return 0, io.EOF
var read int
for {
if s.nextMessage == nil {
// load the next message
s.mx.Unlock()
var msg pb.Message
if err := s.reader.ReadMsg(&msg); err != nil {
s.mx.Lock()
if err == io.EOF {
// if the channel was properly closed, return EOF
if s.receiveState == receiveStateDataRead {
return 0, io.EOF
}
// This case occurs when the remote node closes the stream without writing a FIN message
// There's little we can do here
return 0, errors.New("didn't receive final state for stream")
}
// This case occurs when the remote node closes the stream without writing a FIN message
// There's little we can do here
return 0, errors.New("didn't receive final state for stream")
return 0, err
}
return 0, err
s.mx.Lock()
s.nextMessage = &msg
}
s.nextMessage = &msg
}

n := copy(b, s.nextMessage.Message)
s.nextMessage.Message = s.nextMessage.Message[n:]
if len(s.nextMessage.Message) > 0 {
return n, nil
}
if len(s.nextMessage.Message) > 0 {
n := copy(b, s.nextMessage.Message)
read += n
s.nextMessage.Message = s.nextMessage.Message[n:]
return read, nil
}

// process flags on the message after reading all the data
s.processIncomingFlag(s.nextMessage.Flag)
s.nextMessage = nil
if s.closeErr != nil {
return n, s.closeErr
}
switch s.receiveState {
case receiveStateDataRead:
return n, io.EOF
case receiveStateReset:
return n, network.ErrReset
default:
return n, nil
// process flags on the message after reading all the data
s.processIncomingFlag(s.nextMessage.Flag)
s.nextMessage = nil
if s.closeErr != nil {
return read, s.closeErr
}
switch s.receiveState {
case receiveStateDataRead:
return read, io.EOF
case receiveStateReset:
s.dataChannel.SetReadDeadline(time.Time{})
return read, network.ErrReset
}
}
}

func (s *stream) readMessageFromDataChannel(msg *pb.Message) error {
s.readMu.Lock()
defer s.readMu.Unlock()
return s.reader.ReadMsg(msg)
}

func (s *stream) SetReadDeadline(t time.Time) error { return s.dataChannel.SetReadDeadline(t) }

func (s *stream) CloseRead() error {
s.readMu.Lock()
defer s.readMu.Unlock()
s.mx.Lock()
defer s.mx.Unlock()

s.receiveState = receiveStateReset
if s.nextMessage != nil {
s.processIncomingFlag(s.nextMessage.Flag)
s.nextMessage = nil
}
var err error
if s.receiveState == receiveStateReceiving && s.closeErr == nil {
err := s.sendControlMessage(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()})
s.maybeDeclareStreamDone()
return err
err = s.sendControlMessage(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()})
}
return nil
s.receiveState = receiveStateReset
s.maybeDeclareStreamDone()

// make any calls to Read blocking on ReadMsg return immediately
s.dataChannel.SetReadDeadline(time.Now())

return err
}
83 changes: 81 additions & 2 deletions p2p/transport/webrtc/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"testing"
"time"

"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb"

"github.com/libp2p/go-libp2p/core/network"

"github.com/pion/datachannel"
Expand Down Expand Up @@ -100,8 +102,9 @@ func TestStreamSimpleReadWriteClose(t *testing.T) {
serverStr := newStream(server.dc, server.rwc, nil, nil, func() { serverDone = true })

// send a foobar from the client
_, err := clientStr.Write([]byte("foobar"))
n, err := clientStr.Write([]byte("foobar"))
require.NoError(t, err)
require.Equal(t, 6, n)
require.NoError(t, clientStr.CloseWrite())
// writing after closing should error
_, err = clientStr.Write([]byte("foobar"))
Expand All @@ -113,7 +116,7 @@ func TestStreamSimpleReadWriteClose(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []byte("foobar"), b)
// reading again should give another io.EOF
n, err := serverStr.Read(make([]byte, 10))
n, err = serverStr.Read(make([]byte, 10))
require.Zero(t, n)
require.ErrorIs(t, err, io.EOF)
require.False(t, serverDone)
Expand All @@ -131,6 +134,82 @@ func TestStreamSimpleReadWriteClose(t *testing.T) {
require.True(t, clientDone)
}

func TestStreamPartialReads(t *testing.T) {
client, server := getDetachedDataChannels(t)

clientStr := newStream(client.dc, client.rwc, nil, nil, func() {})
serverStr := newStream(server.dc, server.rwc, nil, nil, func() {})

_, err := serverStr.Write([]byte("foobar"))
require.NoError(t, err)
require.NoError(t, serverStr.CloseWrite())

n, err := clientStr.Read([]byte{}) // empty read
require.NoError(t, err)
require.Zero(t, n)
b := make([]byte, 3)
n, err = clientStr.Read(b)
require.NoError(t, err)
require.Equal(t, []byte("foo"), b)
b, err = io.ReadAll(clientStr)
require.NoError(t, err)
require.Equal(t, []byte("bar"), b)
}

func TestStreamSkipEmptyFrames(t *testing.T) {
client, server := getDetachedDataChannels(t)

clientStr := newStream(client.dc, client.rwc, nil, nil, func() {})
serverStr := newStream(server.dc, server.rwc, nil, nil, func() {})

for i := 0; i < 10; i++ {
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{}))
}
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{Message: []byte("foo")}))
for i := 0; i < 10; i++ {
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{}))
}
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{Message: []byte("bar")}))
for i := 0; i < 10; i++ {
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{}))
}
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}))

var read []byte
var count int
for i := 0; i < 100; i++ {
b := make([]byte, 10)
count++
n, err := clientStr.Read(b)
read = append(read, b[:n]...)
if err == io.EOF {
break
}
require.NoError(t, err)
}
require.LessOrEqual(t, count, 3, "should've taken a maximum of 3 reads")
require.Equal(t, []byte("foobar"), read)
}

func TestStreamReadReturnsOnClose(t *testing.T) {
client, _ := getDetachedDataChannels(t)

clientStr := newStream(client.dc, client.rwc, nil, nil, func() {})
// serverStr := newStream(server.dc, server.rwc, nil, nil, func() {})
errChan := make(chan error, 1)
go func() {
_, err := clientStr.Read([]byte{0})
errChan <- err
}()
require.NoError(t, clientStr.Close())
select {
case err := <-errChan:
require.ErrorIs(t, err, network.ErrReset)
case <-time.After(500 * time.Millisecond):
t.Fatal("timeout")
}
}

func TestStreamResets(t *testing.T) {
client, server := getDetachedDataChannels(t)

Expand Down
28 changes: 17 additions & 11 deletions p2p/transport/webrtc/stream_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ var errWriteAfterClose = errors.New("write after close")
const minMessageSize = 1 << 10

func (s *stream) Write(b []byte) (int, error) {
s.writeMu.Lock()
defer s.writeMu.Unlock()
s.mx.Lock()
defer s.mx.Unlock()

if s.closeErr != nil {
return 0, s.closeErr
Expand Down Expand Up @@ -75,16 +75,16 @@ func (s *stream) Write(b []byte) (int, error) {

availableSpace := s.availableSendSpace()
if availableSpace < minMessageSize {
s.writeMu.Unlock()
s.mx.Unlock()
select {
case <-s.writeAvailable:
case <-writeDeadlineChan:
s.writeMu.Lock()
s.mx.Lock()
return n, os.ErrDeadlineExceeded
case <-s.sendStateChanged:
case <-s.writeDeadlineUpdated:
}
s.writeMu.Lock()
s.mx.Lock()
continue
}
end := maxMessageSize
Expand Down Expand Up @@ -113,22 +113,25 @@ func (s *stream) spawnControlMessageReader() {
s.processIncomingFlag(s.nextMessage.Flag)
s.nextMessage = nil
}

go func() {
// no deadline needed, Read will return once there's a new message, or an error occurred
_ = s.dataChannel.SetReadDeadline(time.Time{})
for {
var msg pb.Message
if err := s.readMessageFromDataChannel(&msg); err != nil {
if err := s.reader.ReadMsg(&msg); err != nil {
return
}
s.mx.Lock()
s.processIncomingFlag(msg.Flag)
s.mx.Unlock()
}
}()
}

func (s *stream) SetWriteDeadline(t time.Time) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()
s.mx.Lock()
defer s.mx.Unlock()
s.writeDeadline = t
select {
case s.writeDeadlineUpdated <- struct{}{}:
Expand All @@ -149,9 +152,6 @@ func (s *stream) availableSendSpace() int {
const controlMsgSize = 100 // TODO: use actual message size

func (s *stream) sendControlMessage(msg *pb.Message) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()

available := s.availableSendSpace()
if controlMsgSize < available {
return s.writer.WriteMsg(msg)
Expand All @@ -161,6 +161,9 @@ func (s *stream) sendControlMessage(msg *pb.Message) error {
}

func (s *stream) cancelWrite() error {
s.mx.Lock()
defer s.mx.Unlock()

if s.sendState != sendStateSending {
return nil
}
Expand All @@ -177,6 +180,9 @@ func (s *stream) cancelWrite() error {
}

func (s *stream) CloseWrite() error {
s.mx.Lock()
defer s.mx.Unlock()

if s.sendState != sendStateSending {
return nil
}
Expand Down
Loading

0 comments on commit ba129f2

Please sign in to comment.