Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tighten lock around appending new chunks of read data in stream #28

Merged
merged 6 commits into from
May 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (s *Stream) LocalAddr() net.Addr {
return s.session.LocalAddr()
}

// LocalAddr returns the remote address
// RemoteAddr returns the remote address
func (s *Stream) RemoteAddr() net.Addr {
return s.session.RemoteAddr()
}
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ module github.com/libp2p/go-yamux

go 1.12

require github.com/libp2p/go-buffer-pool v0.0.2
require (
github.com/libp2p/go-buffer-pool v0.0.2
)
20 changes: 20 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,22 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/libp2p/go-buffer-pool v0.0.2 h1:QNK2iAFa8gjAe1SPz6mHSMuCcjs+X1wlHzeOSqcmlfs=
github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoRZd1Vi32+RXyFM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
2 changes: 0 additions & 2 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,6 @@ func TestSendData_Large(t *testing.T) {
t.Errorf("err: %v", err)
return
}

t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
}()

go func() {
Expand Down
43 changes: 10 additions & 33 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"sync"
"sync/atomic"
"time"

"github.com/libp2p/go-buffer-pool"
)

type streamState int
Expand All @@ -25,7 +23,6 @@ const (
// Stream is used to represent a logical stream
// within a session.
type Stream struct {
recvWindow uint32
sendWindow uint32

id uint32
Expand All @@ -35,7 +32,7 @@ type Stream struct {
stateLock sync.Mutex

recvLock sync.Mutex
recvBuf pool.Buffer
recvBuf segmentedBuffer

sendLock sync.Mutex

Expand All @@ -52,10 +49,10 @@ func newStream(session *Session, id uint32, state streamState) *Stream {
id: id,
session: session,
state: state,
recvWindow: initialStreamWindow,
sendWindow: initialStreamWindow,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
recvBuf: NewSegmentedBuffer(initialStreamWindow),
recvNotifyCh: make(chan struct{}, 1),
sendNotifyCh: make(chan struct{}, 1),
}
Expand Down Expand Up @@ -84,9 +81,7 @@ START:
case streamRemoteClose:
fallthrough
case streamClosed:
s.recvLock.Lock()
empty := s.recvBuf.Len() == 0
s.recvLock.Unlock()
if empty {
return 0, io.EOF
}
Expand Down Expand Up @@ -213,19 +208,13 @@ func (s *Stream) sendWindowUpdate() error {

// Determine the delta update
max := s.session.config.MaxStreamWindowSize
s.recvLock.Lock()
delta := (max - uint32(s.recvBuf.Len())) - s.recvWindow

// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
// Update our window
needed, delta := s.recvBuf.GrowTo(max, flags != 0)
if !needed {
return nil
}

// Update our window
s.recvWindow += delta
s.recvLock.Unlock()

// Send the header
hdr := encode(typeWindowUpdate, flags, s.id, delta)
if err := s.session.sendMsg(hdr, nil, nil); err != nil {
Expand Down Expand Up @@ -406,29 +395,17 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
return nil
}

// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}

// Copy into buffer
s.recvLock.Lock()

if length > s.recvWindow {
s.recvLock.Unlock()
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
// Validate it's okay to copy
if !s.recvBuf.TryReserve(length) {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvBuf.Cap(), length)
return ErrRecvWindowExceeded
}

s.recvBuf.Grow(int(length))
if _, err := io.Copy(&s.recvBuf, conn); err != nil {
s.recvLock.Unlock()
// Copy into buffer
if err := s.recvBuf.Append(conn, int(length)); err != nil {
s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
return err
}

// Decrement the receive window
s.recvWindow -= length
s.recvLock.Unlock()

// Unblock any readers
asyncNotify(s.recvNotifyCh)
return nil
Expand Down
116 changes: 116 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
package yamux

import (
"io"
"sync"
"sync/atomic"

pool "github.com/libp2p/go-buffer-pool"
)

// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {
Expand Down Expand Up @@ -29,3 +37,111 @@ func min(values ...uint32) uint32 {
}
return m
}

type segmentedBuffer struct {
cap uint32
pending uint32
len uint32
bm sync.Mutex
b [][]byte
}

// NewSegmentedBuffer allocates a ring buffer.
func NewSegmentedBuffer(initialCapacity uint32) segmentedBuffer {
return segmentedBuffer{cap: initialCapacity, b: make([][]byte, 0)}
}

func (s *segmentedBuffer) Len() int {
return int(atomic.LoadUint32(&s.len))
}

func (s *segmentedBuffer) Cap() uint32 {
return atomic.LoadUint32(&s.cap)
}

// If the space to write into + current buffer size has grown to half of the window size,
// grow up to that max size, and indicate how much additional space was reserved.
func (s *segmentedBuffer) GrowTo(max uint32, force bool) (bool, uint32) {
s.bm.Lock()
defer s.bm.Unlock()

currentWindow := atomic.LoadUint32(&s.len) + atomic.LoadUint32(&s.cap) + s.pending
if currentWindow > max {
// somewhat counter-intuitively not an error.
// note that len+cap is the 'window' that shouldn't exceed max or a reservation
// would fail, triggering an error.
// We pre-count 'pending' data where we've read a header and are working on
// reading it into available data here, so that we don't undercount the remaining
// window size, but that can mean this sum ends up larger than max.
return false, 0
}
delta := max - currentWindow

if delta < (max/2) && !force {
return false, 0
}

atomic.AddUint32(&s.cap, delta)
return true, delta
}

func (s *segmentedBuffer) TryReserve(space uint32) bool {
// It is noticable that the check-and-set of pending is not atomic,
// Due to this, accesses to pending are protected by bm.
s.bm.Lock()
defer s.bm.Unlock()
if atomic.LoadUint32(&s.cap) < s.pending+space {
return false
}
s.pending += space
return true
}

func (s *segmentedBuffer) Read(b []byte) (int, error) {
s.bm.Lock()
defer s.bm.Unlock()
if len(s.b) == 0 {
return 0, io.EOF
}
n := copy(b, s.b[0])
if n == len(s.b[0]) {
pool.Put(s.b[0])
s.b[0] = nil
s.b = s.b[1:]
} else {
s.b[0] = s.b[0][n:]
}
if n > 0 {
atomic.AddUint32(&s.len, ^uint32(n-1))
}
return n, nil
}

func (s *segmentedBuffer) Append(input io.Reader, length int) error {
dst := pool.Get(length)
n := 0
read := 0
var err error
for n < length && err == nil {
read, err = input.Read(dst[n:])
n += read
}
if err == io.EOF {
if length == n {
err = nil
willscott marked this conversation as resolved.
Show resolved Hide resolved
} else {
err = ErrStreamReset
}
}

s.bm.Lock()
defer s.bm.Unlock()
if n > 0 {
atomic.AddUint32(&s.len, uint32(n))
// cap -= n
atomic.AddUint32(&s.cap, ^uint32(n-1))
s.pending = s.pending - uint32(length)
s.b = append(s.b, dst[0:n])
}
return err
}