Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Commit

Permalink
delete reuse-connection if they aren't used for more than 10 seconds
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 7, 2019
1 parent 54733f5 commit 5fde971
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 85 deletions.
86 changes: 53 additions & 33 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,12 @@ var _ = Describe("Connection", func() {
return id, priv
}

runServer := func(tr tpt.Transport, multiaddr string) (ma.Multiaddr, <-chan tpt.CapableConn) {
addrChan := make(chan ma.Multiaddr)
connChan := make(chan tpt.CapableConn)
go func() {
defer GinkgoRecover()
addr, err := ma.NewMultiaddr(multiaddr)
Expect(err).ToNot(HaveOccurred())
ln, err := tr.Listen(addr)
Expect(err).ToNot(HaveOccurred())
addrChan <- ln.Multiaddr()
conn, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
connChan <- conn
}()
return <-addrChan, connChan
runServer := func(tr tpt.Transport, multiaddr string) tpt.Listener {
addr, err := ma.NewMultiaddr(multiaddr)
Expect(err).ToNot(HaveOccurred())
ln, err := tr.Listen(addr)
Expect(err).ToNot(HaveOccurred())
return ln
}

BeforeEach(func() {
Expand All @@ -73,13 +64,17 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv4", func() {
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
serverConn, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
serverConn := <-serverConnChan
defer serverConn.Close()
Expect(conn.LocalPeer()).To(Equal(clientID))
Expect(conn.LocalPrivateKey()).To(Equal(clientKey))
Expect(conn.RemotePeer()).To(Equal(serverID))
Expand All @@ -93,13 +88,17 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv6", func() {
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport, "/ip6/::1/udp/0/quic")
ln := runServer(serverTransport, "/ip6/::1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
serverConn := <-serverConnChan
defer conn.Close()
serverConn, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
defer serverConn.Close()
Expect(conn.LocalPeer()).To(Equal(clientID))
Expect(conn.LocalPrivateKey()).To(Equal(clientKey))
Expect(conn.RemotePeer()).To(Equal(serverID))
Expand All @@ -113,13 +112,17 @@ var _ = Describe("Connection", func() {
It("opens and accepts streams", func() {
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close()

clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
serverConn, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
serverConn := <-serverConnChan
defer serverConn.Close()

str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expand All @@ -138,32 +141,48 @@ var _ = Describe("Connection", func() {

serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")

clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
// dial, but expect the wrong peer ID
_, err = clientTransport.Dial(context.Background(), serverAddr, thirdPartyID)
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR"))
Consistently(serverConnChan).ShouldNot(Receive())

done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
ln.Accept()
}()
Consistently(done).ShouldNot(BeClosed())
ln.Close()
Eventually(done).Should(BeClosed())
})

It("dials to two servers at the same time", func() {
serverID2, serverKey2 := createPeer()

serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
serverTransport2, err := NewTransport(serverKey2)
defer ln1.Close()
Expect(err).ToNot(HaveOccurred())
serverAddr2, serverConnChan2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic")
ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic")
defer ln2.Close()

data := bytes.Repeat([]byte{'a'}, 5*1<<20) // 5 MB
// wait for both servers to accept a connection
// then send some data
go func() {
for _, c := range []tpt.CapableConn{<-serverConnChan, <-serverConnChan2} {
serverConn1, err := ln1.Accept()
Expect(err).ToNot(HaveOccurred())
serverConn2, err := ln2.Accept()
Expect(err).ToNot(HaveOccurred())

for _, c := range []tpt.CapableConn{serverConn1, serverConn2} {
go func(conn tpt.CapableConn) {
defer GinkgoRecover()
str, err := conn.OpenStream()
Expand All @@ -177,10 +196,12 @@ var _ = Describe("Connection", func() {

clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
c1, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
c2, err := clientTransport.Dial(context.Background(), serverAddr2, serverID2)
defer c1.Close()
c2, err := clientTransport.Dial(context.Background(), ln2.Multiaddr(), serverID2)
Expect(err).ToNot(HaveOccurred())
defer c2.Close()

done := make(chan struct{}, 2)
// receive the data on both connections at the same time
Expand All @@ -193,7 +214,6 @@ var _ = Describe("Connection", func() {
d, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal(data))
conn.Close()
done <- struct{}{}
}(c)
}
Expand Down
27 changes: 27 additions & 0 deletions libp2pquic_suite_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package libp2pquic

import (
"bytes"
mrand "math/rand"
"runtime/pprof"
"strings"
"time"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
Expand All @@ -17,3 +21,26 @@ func TestLibp2pQuicTransport(t *testing.T) {
var _ = BeforeSuite(func() {
mrand.Seed(GinkgoRandomSeed())
})

var garbageCollectIntervalOrig time.Duration
var maxUnusedDurationOrig time.Duration

func isGarbageCollectorRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector")
}

var _ = BeforeEach(func() {
Expect(isGarbageCollectorRunning()).To(BeFalse())
garbageCollectIntervalOrig = garbageCollectInterval
maxUnusedDurationOrig = maxUnusedDuration
garbageCollectInterval = 50 * time.Millisecond
maxUnusedDuration = 0
})

var _ = AfterEach(func() {
Eventually(isGarbageCollectorRunning).Should(BeFalse())
garbageCollectInterval = garbageCollectIntervalOrig
maxUnusedDuration = maxUnusedDurationOrig
})
2 changes: 1 addition & 1 deletion listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (l *listener) setupConn(sess quic.Session) (tpt.CapableConn, error) {

// Close closes the listener.
func (l *listener) Close() error {
l.conn.DecreaseCount()
defer l.conn.DecreaseCount()
return l.quicListener.Close()
}

Expand Down
3 changes: 3 additions & 0 deletions listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var _ = Describe("Listener", func() {
Expect(err).ToNot(HaveOccurred())
ln, err := t.Listen(localAddr)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
netAddr := ln.Addr()
Expect(netAddr).To(BeAssignableToTypeOf(&net.UDPAddr{}))
port := netAddr.(*net.UDPAddr).Port
Expand All @@ -45,6 +46,7 @@ var _ = Describe("Listener", func() {
Expect(err).ToNot(HaveOccurred())
ln, err := t.Listen(localAddr)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
netAddr := ln.Addr()
Expect(netAddr).To(BeAssignableToTypeOf(&net.UDPAddr{}))
port := netAddr.(*net.UDPAddr).Port
Expand All @@ -57,6 +59,7 @@ var _ = Describe("Listener", func() {
Expect(err).ToNot(HaveOccurred())
ln, err := t.Listen(localAddr)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
netAddr := ln.Addr()
Expect(netAddr).To(BeAssignableToTypeOf(&net.UDPAddr{}))
port := netAddr.(*net.UDPAddr).Port
Expand Down
87 changes: 82 additions & 5 deletions reuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,56 @@ package libp2pquic
import (
"net"
"sync"
"sync/atomic"
"time"

"github.com/vishvananda/netlink"
)

// Constants. Defined as variables to simplify testing.
var (
garbageCollectInterval = 30 * time.Second
maxUnusedDuration = 10 * time.Second
)

type reuseConn struct {
net.PacketConn
refCount int32 // to be used as an atomic

mutex sync.Mutex
refCount int
unusedSince time.Time
}

func newReuseConn(conn net.PacketConn) *reuseConn {
return &reuseConn{PacketConn: conn}
}

func (c *reuseConn) IncreaseCount() { atomic.AddInt32(&c.refCount, 1) }
func (c *reuseConn) DecreaseCount() { atomic.AddInt32(&c.refCount, -1) }
func (c *reuseConn) GetCount() int { return int(atomic.LoadInt32(&c.refCount)) }
func (c *reuseConn) IncreaseCount() {
c.mutex.Lock()
c.refCount++
c.unusedSince = time.Time{}
c.mutex.Unlock()
}

func (c *reuseConn) DecreaseCount() {
c.mutex.Lock()
c.refCount--
if c.refCount == 0 {
c.unusedSince = time.Now()
}
c.mutex.Unlock()
}

func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool {
c.mutex.Lock()
defer c.mutex.Unlock()
return !c.unusedSince.IsZero() && c.unusedSince.Add(maxUnusedDuration).Before(now)
}

type reuse struct {
mutex sync.Mutex

garbageCollectorRunning bool

handle *netlink.Handle // Only set on Linux. nil on other systems.

unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn
Expand All @@ -46,6 +75,51 @@ func newReuse() (*reuse, error) {
}, nil
}

func (r *reuse) runGarbageCollector() {
ticker := time.NewTicker(garbageCollectInterval)
for {
<-ticker.C

now := time.Now()
var shouldExit bool
r.mutex.Lock()
for key, conn := range r.global {
if conn.ShouldGarbageCollect(now) {
delete(r.global, key)
}
}
for ukey, conns := range r.unicast {
for key, conn := range conns {
if conn.ShouldGarbageCollect(now) {
delete(conns, key)
}
}
if len(conns) == 0 {
delete(r.unicast, ukey)
}
}

// stop the garbage collector if we're not tracking any connections
if len(r.global) == 0 && len(r.unicast) == 0 {
r.garbageCollectorRunning = false
shouldExit = true
}
r.mutex.Unlock()

if shouldExit {
return
}
}
}

// must be called while holding the mutex
func (r *reuse) maybeStartGarbageCollector() {
if !r.garbageCollectorRunning {
r.garbageCollectorRunning = true
go r.runGarbageCollector()
}
}

// Get the source IP that the kernel would use for dialing.
// This only works on Linux.
// On other systems, this returns an empty slice of IP addresses.
Expand Down Expand Up @@ -80,6 +154,7 @@ func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
return nil, err
}
conn.IncreaseCount()
r.maybeStartGarbageCollector()
return conn, nil
}

Expand Down Expand Up @@ -131,6 +206,8 @@ func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
r.mutex.Lock()
defer r.mutex.Unlock()

r.maybeStartGarbageCollector()

// Deal with listen on a global address
if localAddr.IP.IsUnspecified() {
// The kernel already checked that the laddr is not already listen
Expand Down
Loading

0 comments on commit 5fde971

Please sign in to comment.