Skip to content
This repository has been archived by the owner on Sep 10, 2022. It is now read-only.

improve correctness of closing connections on failure #19

Merged
merged 2 commits into from
Apr 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 56 additions & 18 deletions listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ var _ = Describe("Listener", func() {

It("accepts a single connection", func() {
ln := createListener(defaultUpgrader)
defer ln.Close()
cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
Expect(err).ToNot(HaveOccurred())
sconn, err := ln.Accept()
Expand All @@ -113,6 +114,7 @@ var _ = Describe("Listener", func() {

It("accepts multiple connections", func() {
ln := createListener(defaultUpgrader)
defer ln.Close()
const num = 10
for i := 0; i < 10; i++ {
cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
Expand All @@ -127,11 +129,15 @@ var _ = Describe("Listener", func() {
const timeout = 200 * time.Millisecond
tpt.AcceptTimeout = timeout
ln := createListener(defaultUpgrader)
defer ln.Close()
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred())
if !Expect(err).ToNot(HaveOccurred()) {
return
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer conn.Close()
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
// start a Read. It will block until the connection is closed
Expand All @@ -151,10 +157,16 @@ var _ = Describe("Listener", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, _ = ln.Accept()
conn, err := ln.Accept()
if !Expect(err).To(HaveOccurred()) {
conn.Close()
}
close(done)
}()
_, _ = dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
if !Expect(err).To(HaveOccurred()) {
conn.Close()
}
Consistently(done).ShouldNot(BeClosed())
// make the goroutine return
ln.Close()
Expand All @@ -178,6 +190,7 @@ var _ = Describe("Listener", func() {
if err != nil {
return
}
conn.Close()
accepted <- conn
}
}()
Expand All @@ -187,8 +200,14 @@ var _ = Describe("Listener", func() {
wg.Add(1)
go func() {
defer GinkgoRecover()
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred())
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
if Expect(err).ToNot(HaveOccurred()) {
stream, err := conn.AcceptStream() // wait for conn to be accepted.
if !Expect(err).To(HaveOccurred()) {
stream.Close()
}
conn.Close()
}
wg.Done()
}()
}
Expand All @@ -201,29 +220,40 @@ var _ = Describe("Listener", func() {

It("stops setting up when the more than AcceptQueueLength connections are waiting to get accepted", func() {
ln := createListener(defaultUpgrader)
defer ln.Close()

// setup AcceptQueueLength connections, but don't accept any of them
dialed := make(chan struct{}, 10*st.AcceptQueueLength) // used as a thread-safe counter
dialed := make(chan tpt.Conn, 10*st.AcceptQueueLength) // used as a thread-safe counter
for i := 0; i < st.AcceptQueueLength; i++ {
go func() {
defer GinkgoRecover()
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred())
dialed <- struct{}{}
dialed <- conn
}()
}
Eventually(dialed).Should(HaveLen(st.AcceptQueueLength))
// dial a new connection. This connection should not complete setup, since the queue is full
go func() {
defer GinkgoRecover()
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred())
dialed <- struct{}{}
dialed <- conn
}()
Consistently(dialed).Should(HaveLen(st.AcceptQueueLength))
// accept a single connection. Now the new connection should be set up, and fill the queue again
_, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
conn, err := ln.Accept()
if Expect(err).ToNot(HaveOccurred()) {
conn.Close()
}
Eventually(dialed).Should(HaveLen(st.AcceptQueueLength + 1))

// Cleanup
for i := 0; i < st.AcceptQueueLength+1; i++ {
if c := <-dialed; c != nil {
c.Close()
}
}
})
})

Expand All @@ -233,9 +263,12 @@ var _ = Describe("Listener", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln.Accept()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
conn, err := ln.Accept()
if Expect(err).To(HaveOccurred()) {
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
} else {
conn.Close()
}
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
Expand All @@ -246,15 +279,20 @@ var _ = Describe("Listener", func() {
It("doesn't accept new connections when it is closed", func() {
ln := createListener(defaultUpgrader)
Expect(ln.Close()).To(Succeed())
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
Expect(err).To(HaveOccurred())
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
if !Expect(err).To(HaveOccurred()) {
conn.Close()
}
})

It("closes incoming connections that have not yet been accepted", func() {
ln := createListener(defaultUpgrader)
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
if !Expect(err).ToNot(HaveOccurred()) {
ln.Close()
return
}
Expect(conn.IsClosed()).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(ln.Close()).To(Succeed())
Eventually(conn.IsClosed).Should(BeTrue())
})
Expand Down
6 changes: 5 additions & 1 deletion upgrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (u *Upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma
}
smconn, err := u.setupMuxer(ctx, sconn, p)
if err != nil {
conn.Close()
sconn.Close()
return nil, fmt.Errorf("failed to negotiate security stream multiplexer: %s", err)
}
return &transportConn{
Expand Down Expand Up @@ -122,6 +122,10 @@ func (u *Upgrader) setupMuxer(ctx context.Context, conn net.Conn, p peer.ID) (sm
case <-done:
return smconn, err
case <-ctx.Done():
// interrupt this process
conn.Close()
// wait to finish
<-done
return nil, ctx.Err()
}
}