Skip to content

Commit

Permalink
When a stream is closed, cancel pending writes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ichbinjoe committed Dec 14, 2018
1 parent f6e0e0f commit 4b1acc2
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 33 deletions.
53 changes: 35 additions & 18 deletions multiplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package multiplex

import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
Expand Down Expand Up @@ -52,7 +53,7 @@ type Multiplex struct {
shutdownErr error
shutdownLock sync.Mutex

wrLock sync.Mutex
wrTkn chan struct{}

nstreams chan *Stream

Expand All @@ -68,21 +69,25 @@ func NewMultiplex(con net.Conn, initiator bool) *Multiplex {
channels: make(map[streamID]*Stream),
closed: make(chan struct{}),
shutdown: make(chan struct{}),
wrTkn: make(chan struct{}, 1),
nstreams: make(chan *Stream, 16),
}

go mp.handleIncoming()

mp.wrTkn <- struct{}{}

return mp
}

func (mp *Multiplex) newStream(id streamID, name string) *Stream {
return &Stream{
id: id,
name: name,
dataIn: make(chan []byte, 8),
reset: make(chan struct{}),
mp: mp,
id: id,
name: name,
dataIn: make(chan []byte, 8),
reset: make(chan struct{}),
closedLocal: make(chan struct{}),
mp: mp,
}
}

Expand Down Expand Up @@ -127,10 +132,16 @@ func (mp *Multiplex) IsClosed() bool {
}
}

func (mp *Multiplex) sendMsg(header uint64, data []byte, dl time.Time) error {
mp.wrLock.Lock()
defer mp.wrLock.Unlock()
if !dl.IsZero() {
func (mp *Multiplex) sendMsg(ctx context.Context, header uint64, data []byte) error {
select {
case tkn := <-mp.wrTkn:
defer func() { mp.wrTkn <- tkn }()
case <-ctx.Done():
return ctx.Err()
}

dl, hasDl := ctx.Deadline()
if hasDl {
if err := mp.con.SetWriteDeadline(dl); err != nil {
return err
}
Expand All @@ -151,7 +162,7 @@ func (mp *Multiplex) sendMsg(header uint64, data []byte, dl time.Time) error {
return err
}

if !dl.IsZero() {
if hasDl {
// only return this error if we don't *already* have an error from the write.
if err2 := mp.con.SetWriteDeadline(time.Time{}); err == nil && err2 != nil {
return err2
Expand Down Expand Up @@ -193,7 +204,7 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
mp.channels[s.id] = s
mp.chLock.Unlock()

err := mp.sendMsg(header, []byte(name), time.Time{})
err := mp.sendMsg(context.Background(), header, []byte(name))
if err != nil {
return nil, err
}
Expand All @@ -212,7 +223,9 @@ func (mp *Multiplex) cleanup() {
// Cancel readers
close(msch.reset)
}
msch.closedLocal = true
if !msch.isClosed() {
close(msch.closedLocal)
}
msch.clLock.Unlock()
}
// Don't remove this nil assignment. We check if this is nil to check if
Expand Down Expand Up @@ -296,7 +309,8 @@ func (mp *Multiplex) handleIncoming() {

// Honestly, this check should never be true... It means we've leaked.
// However, this is an error on *our* side so we shouldn't just bail.
if msch.closedLocal && msch.closedRemote {
isClosed := msch.isClosed()
if isClosed && msch.closedRemote {
msch.clLock.Unlock()
log.Errorf("leaked a completely closed stream")
continue
Expand All @@ -306,7 +320,10 @@ func (mp *Multiplex) handleIncoming() {
close(msch.reset)
}
msch.closedRemote = true
msch.closedLocal = true
if !isClosed {
close(msch.closedLocal)
}

msch.clLock.Unlock()

mp.chLock.Lock()
Expand All @@ -329,7 +346,7 @@ func (mp *Multiplex) handleIncoming() {
close(msch.dataIn)
msch.closedRemote = true

cleanup := msch.closedLocal
cleanup := msch.isClosed()

msch.clLock.Unlock()

Expand All @@ -346,7 +363,7 @@ func (mp *Multiplex) handleIncoming() {
// This is a perfectly valid case when we reset
// and forget about the stream.
log.Debugf("message for non-existant stream, dropping data: %d", ch)
go mp.sendMsg(ch.header(resetTag), nil, time.Time{})
go mp.sendMsg(context.Background(), ch.header(resetTag), nil)
continue
}

Expand All @@ -358,7 +375,7 @@ func (mp *Multiplex) handleIncoming() {
pool.Put(b)

log.Errorf("Received data from remote after stream was closed by them. (len = %d)", len(b))
go mp.sendMsg(msch.id.header(resetTag), nil, time.Time{})
go mp.sendMsg(context.Background(), msch.id.header(resetTag), nil)
continue
}

Expand Down
64 changes: 49 additions & 15 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package multiplex

import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"

pool "github.com/libp2p/go-buffer-pool"
"github.com/libp2p/go-buffer-pool"
)

// streamID is a convenience type for operating on stream IDs
Expand Down Expand Up @@ -41,11 +42,13 @@ type Stream struct {
rDeadline time.Time

clLock sync.Mutex
closedLocal bool
closedRemote bool

// Closed when the connection is reset.
reset chan struct{}

// Closed when the writer is closed (reset will also be closed)
closedLocal chan struct{}
}

func (s *Stream) Name() string {
Expand Down Expand Up @@ -139,35 +142,62 @@ func (s *Stream) Write(b []byte) (int, error) {

func (s *Stream) write(b []byte) (int, error) {
if s.isClosed() {
return 0, fmt.Errorf("cannot write to closed stream")
return 0, errors.New("cannot write to closed stream")
}

err := s.mp.sendMsg(s.id.header(messageTag), b, s.wDeadline)
wDeadlineCtx, cancelFn := func(deadline time.Time) (context.Context, context.CancelFunc) {
if deadline.IsZero() {
return context.WithCancel(context.Background())
} else {
return context.WithDeadline(context.Background(), deadline)
}
}(s.wDeadline)

w := make(chan struct{})

go func() {
select {
case <-s.closedLocal:
case <-w:
}
cancelFn()
}()

err := s.mp.sendMsg(wDeadlineCtx, s.id.header(messageTag), b)
close(w)

if err != nil {
if err == context.Canceled {
err = errors.New("cannot write to closed stream")
}
return 0, err
}

return len(b), nil
}

func (s *Stream) isClosed() bool {
s.clLock.Lock()
defer s.clLock.Unlock()
return s.closedLocal
select {
case <-s.closedLocal:
return true
default:
return false
}
}

func (s *Stream) Close() error {
err := s.mp.sendMsg(s.id.header(closeTag), nil, time.Time{})
err := s.mp.sendMsg(context.Background(), s.id.header(closeTag), nil)

s.clLock.Lock()
if s.closedLocal {
s.clLock.Unlock()
if s.isClosed() {
return nil
}

s.clLock.Lock()
remote := s.closedRemote
s.closedLocal = true
s.clLock.Unlock()
if !s.isClosed() {
close(s.closedLocal)
}

if remote {
s.mp.chLock.Lock()
Expand All @@ -180,18 +210,22 @@ func (s *Stream) Close() error {

func (s *Stream) Reset() error {
s.clLock.Lock()
if s.closedRemote && s.closedLocal {
isClosed := s.isClosed()
if s.closedRemote && isClosed {
s.clLock.Unlock()
return nil
}

if !s.closedRemote {
close(s.reset)
// We generally call this to tell the other side to go away. No point in waiting around.
go s.mp.sendMsg(s.id.header(resetTag), nil, time.Time{})
go s.mp.sendMsg(context.Background(), s.id.header(resetTag), nil)
}

if !isClosed {
close(s.closedLocal)
}

s.closedLocal = true
s.closedRemote = true

s.clLock.Unlock()
Expand Down

0 comments on commit 4b1acc2

Please sign in to comment.