Skip to content

Commit

Permalink
When a stream is closed, cancel pending writes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ichbinjoe committed Dec 14, 2018
1 parent f6e0e0f commit f0d2c32
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 35 deletions.
56 changes: 36 additions & 20 deletions multiplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package multiplex

import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
Expand Down Expand Up @@ -52,7 +53,7 @@ type Multiplex struct {
shutdownErr error
shutdownLock sync.Mutex

wrLock sync.Mutex
wrTkn chan struct{}

nstreams chan *Stream

Expand All @@ -68,22 +69,28 @@ func NewMultiplex(con net.Conn, initiator bool) *Multiplex {
channels: make(map[streamID]*Stream),
closed: make(chan struct{}),
shutdown: make(chan struct{}),
wrTkn: make(chan struct{}, 1),
nstreams: make(chan *Stream, 16),
}

go mp.handleIncoming()

mp.wrTkn <- struct{}{}

return mp
}

func (mp *Multiplex) newStream(id streamID, name string) *Stream {
return &Stream{
id: id,
name: name,
dataIn: make(chan []byte, 8),
reset: make(chan struct{}),
mp: mp,
func (mp *Multiplex) newStream(id streamID, name string) (s *Stream) {
s = &Stream{
id: id,
name: name,
dataIn: make(chan []byte, 8),
reset: make(chan struct{}),
mp: mp,
}

s.closedLocal, s.doCloseLocal = context.WithCancel(context.Background())
return
}

func (m *Multiplex) Accept() (*Stream, error) {
Expand Down Expand Up @@ -127,10 +134,16 @@ func (mp *Multiplex) IsClosed() bool {
}
}

func (mp *Multiplex) sendMsg(header uint64, data []byte, dl time.Time) error {
mp.wrLock.Lock()
defer mp.wrLock.Unlock()
if !dl.IsZero() {
func (mp *Multiplex) sendMsg(ctx context.Context, header uint64, data []byte) error {
select {
case tkn := <-mp.wrTkn:
defer func() { mp.wrTkn <- tkn }()
case <-ctx.Done():
return ctx.Err()
}

dl, hasDl := ctx.Deadline()
if hasDl {
if err := mp.con.SetWriteDeadline(dl); err != nil {
return err
}
Expand All @@ -151,7 +164,7 @@ func (mp *Multiplex) sendMsg(header uint64, data []byte, dl time.Time) error {
return err
}

if !dl.IsZero() {
if hasDl {
// only return this error if we don't *already* have an error from the write.
if err2 := mp.con.SetWriteDeadline(time.Time{}); err == nil && err2 != nil {
return err2
Expand Down Expand Up @@ -193,7 +206,7 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
mp.channels[s.id] = s
mp.chLock.Unlock()

err := mp.sendMsg(header, []byte(name), time.Time{})
err := mp.sendMsg(context.Background(), header, []byte(name))
if err != nil {
return nil, err
}
Expand All @@ -212,7 +225,8 @@ func (mp *Multiplex) cleanup() {
// Cancel readers
close(msch.reset)
}
msch.closedLocal = true

msch.doCloseLocal();
msch.clLock.Unlock()
}
// Don't remove this nil assignment. We check if this is nil to check if
Expand Down Expand Up @@ -296,7 +310,8 @@ func (mp *Multiplex) handleIncoming() {

// Honestly, this check should never be true... It means we've leaked.
// However, this is an error on *our* side so we shouldn't just bail.
if msch.closedLocal && msch.closedRemote {
isClosed := msch.isClosed()
if isClosed && msch.closedRemote {
msch.clLock.Unlock()
log.Errorf("leaked a completely closed stream")
continue
Expand All @@ -306,7 +321,8 @@ func (mp *Multiplex) handleIncoming() {
close(msch.reset)
}
msch.closedRemote = true
msch.closedLocal = true
msch.doCloseLocal()

msch.clLock.Unlock()

mp.chLock.Lock()
Expand All @@ -329,7 +345,7 @@ func (mp *Multiplex) handleIncoming() {
close(msch.dataIn)
msch.closedRemote = true

cleanup := msch.closedLocal
cleanup := msch.isClosed()

msch.clLock.Unlock()

Expand All @@ -346,7 +362,7 @@ func (mp *Multiplex) handleIncoming() {
// This is a perfectly valid case when we reset
// and forget about the stream.
log.Debugf("message for non-existant stream, dropping data: %d", ch)
go mp.sendMsg(ch.header(resetTag), nil, time.Time{})
go mp.sendMsg(context.Background(), ch.header(resetTag), nil)
continue
}

Expand All @@ -358,7 +374,7 @@ func (mp *Multiplex) handleIncoming() {
pool.Put(b)

log.Errorf("Received data from remote after stream was closed by them. (len = %d)", len(b))
go mp.sendMsg(msch.id.header(resetTag), nil, time.Time{})
go mp.sendMsg(context.Background(), msch.id.header(resetTag), nil)
continue
}

Expand Down
52 changes: 37 additions & 15 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package multiplex

import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"

pool "github.com/libp2p/go-buffer-pool"
"github.com/libp2p/go-buffer-pool"
)

// streamID is a convenience type for operating on stream IDs
Expand Down Expand Up @@ -41,11 +42,14 @@ type Stream struct {
rDeadline time.Time

clLock sync.Mutex
closedLocal bool
closedRemote bool

// Closed when the connection is reset.
reset chan struct{}

// Closed when the writer is closed (reset will also be closed)
closedLocal context.Context
doCloseLocal context.CancelFunc
}

func (s *Stream) Name() string {
Expand Down Expand Up @@ -139,36 +143,52 @@ func (s *Stream) Write(b []byte) (int, error) {

func (s *Stream) write(b []byte) (int, error) {
if s.isClosed() {
return 0, fmt.Errorf("cannot write to closed stream")
return 0, errors.New("cannot write to closed stream")
}



wDeadlineCtx, cleanup := func(s *Stream) (context.Context, context.CancelFunc) {
if s.wDeadline.IsZero() {
return s.closedLocal, nil
} else {
return context.WithDeadline(s.closedLocal, s.wDeadline)
}
}(s)

err := s.mp.sendMsg(wDeadlineCtx, s.id.header(messageTag), b)

if cleanup != nil {
cleanup();
}

err := s.mp.sendMsg(s.id.header(messageTag), b, s.wDeadline)
if err != nil {
if err == context.Canceled {
err = errors.New("cannot write to closed stream")
}
return 0, err
}

return len(b), nil
}

func (s *Stream) isClosed() bool {
s.clLock.Lock()
defer s.clLock.Unlock()
return s.closedLocal
return s.closedLocal.Err() != nil
}

func (s *Stream) Close() error {
err := s.mp.sendMsg(s.id.header(closeTag), nil, time.Time{})
err := s.mp.sendMsg(context.Background(), s.id.header(closeTag), nil)

s.clLock.Lock()
if s.closedLocal {
s.clLock.Unlock()
if s.isClosed() {
return nil
}

s.clLock.Lock()
remote := s.closedRemote
s.closedLocal = true
s.clLock.Unlock()

s.doCloseLocal()

if remote {
s.mp.chLock.Lock()
delete(s.mp.channels, s.id)
Expand All @@ -180,18 +200,20 @@ func (s *Stream) Close() error {

func (s *Stream) Reset() error {
s.clLock.Lock()
if s.closedRemote && s.closedLocal {
isClosed := s.isClosed()
if s.closedRemote && isClosed {
s.clLock.Unlock()
return nil
}

if !s.closedRemote {
close(s.reset)
// We generally call this to tell the other side to go away. No point in waiting around.
go s.mp.sendMsg(s.id.header(resetTag), nil, time.Time{})
go s.mp.sendMsg(context.Background(), s.id.header(resetTag), nil)
}

s.closedLocal = true
s.doCloseLocal()

s.closedRemote = true

s.clLock.Unlock()
Expand Down

0 comments on commit f0d2c32

Please sign in to comment.