Skip to content
This repository has been archived by the owner on Feb 1, 2023. It is now read-only.

If peer is first to send a block to session, protect connection #406

Merged
merged 4 commits into from
Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions internal/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ type SessionPeerManager interface {
Peers() []peer.ID
// Whether there are any peers in the session
HasPeers() bool
// Protect connection from being pruned by the connection manager
ProtectConnection(peer.ID)
}

// ProviderFinder is used to find providers for a given key
Expand Down
41 changes: 37 additions & 4 deletions internal/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,49 @@ func newFakeSessionPeerManager() *bsspm.SessionPeerManager {
return bsspm.New(1, newFakePeerTagger())
}

func newFakePeerTagger() *fakePeerTagger {
return &fakePeerTagger{
protectedPeers: make(map[peer.ID]map[string]struct{}),
}
}

type fakePeerTagger struct {
lk sync.Mutex
protectedPeers map[peer.ID]map[string]struct{}
}

func newFakePeerTagger() *fakePeerTagger {
return &fakePeerTagger{}
func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) {}
func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {}

func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) {
fpt.lk.Lock()
defer fpt.lk.Unlock()

tags, ok := fpt.protectedPeers[p]
if !ok {
tags = make(map[string]struct{})
fpt.protectedPeers[p] = tags
}
tags[tag] = struct{}{}
}

func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) {
func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()

if tags, ok := fpt.protectedPeers[p]; ok {
delete(tags, tag)
return len(tags) > 0
}

return false
}
func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {

func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()

return len(fpt.protectedPeers[p]) > 0
}

type fakeProviderFinder struct {
Expand Down
5 changes: 5 additions & 0 deletions internal/session/sessionwantsender.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,11 @@ func (sws *sessionWantSender) processUpdates(updates []update) []cid.Cid {
// Inform the peer tracker that this peer was the first to send
// us the block
sws.peerRspTrkr.receivedBlockFrom(upd.from)

// Protect the connection to this peer so that we can ensure
// that the connection doesn't get pruned by the connection
// manager
sws.spm.ProtectConnection(upd.from)
}
delete(sws.peerConsecutiveDontHaves, upd.from)
}
Expand Down
57 changes: 57 additions & 0 deletions internal/session/sessionwantsender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

bsbpm "github.com/ipfs/go-bitswap/internal/blockpresencemanager"
bspm "github.com/ipfs/go-bitswap/internal/peermanager"
bsspm "github.com/ipfs/go-bitswap/internal/sessionpeermanager"
"github.com/ipfs/go-bitswap/internal/testutil"
cid "github.com/ipfs/go-cid"
peer "github.com/libp2p/go-libp2p-core/peer"
Expand Down Expand Up @@ -374,6 +375,62 @@ func TestRegisterSessionWithPeerManager(t *testing.T) {
}
}

func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
cids := testutil.GenerateCids(2)
peers := testutil.GeneratePeers(3)
peerA := peers[0]
peerB := peers[1]
peerC := peers[2]
sid := uint64(1)
pm := newMockPeerManager()
fpt := newFakePeerTagger()
fpm := bsspm.New(1, fpt)
swc := newMockSessionMgr()
bpm := bsbpm.New()
onSend := func(peer.ID, []cid.Cid, []cid.Cid) {}
onPeersExhausted := func([]cid.Cid) {}
spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted)
defer spm.Shutdown()

go spm.Run()

// add cid0
spm.Add(cids[:1])

// peerA: block cid0
spm.Update(peerA, cids[:1], nil, nil)

// Wait for processing to complete
time.Sleep(10 * time.Millisecond)

// Expect peer A to be protected as it was first to send the block
if !fpt.isProtected(peerA) {
t.Fatal("Expected first peer to send block to have protected connection")
}

// peerB: block cid0
spm.Update(peerB, cids[:1], nil, nil)

// Wait for processing to complete
time.Sleep(10 * time.Millisecond)

// Expect peer B not to be protected as it was not first to send the block
if fpt.isProtected(peerB) {
t.Fatal("Expected peer not to be protected")
}

// peerC: block cid1
spm.Update(peerC, cids[1:], nil, nil)

// Wait for processing to complete
time.Sleep(10 * time.Millisecond)

// Expect peer C not to be protected as we didn't want the block it sent
if fpt.isProtected(peerC) {
t.Fatal("Expected peer not to be protected")
}
}

func TestPeerUnavailable(t *testing.T) {
cids := testutil.GenerateCids(2)
peers := testutil.GeneratePeers(2)
Expand Down
13 changes: 7 additions & 6 deletions internal/sessionmanager/sessionmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ func (fs *fakeSession) Shutdown() {
type fakeSesPeerManager struct {
}

func (*fakeSesPeerManager) Peers() []peer.ID { return nil }
func (*fakeSesPeerManager) PeersDiscovered() bool { return false }
func (*fakeSesPeerManager) Shutdown() {}
func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) HasPeers() bool { return false }
func (*fakeSesPeerManager) Peers() []peer.ID { return nil }
func (*fakeSesPeerManager) PeersDiscovered() bool { return false }
func (*fakeSesPeerManager) Shutdown() {}
func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) HasPeers() bool { return false }
func (*fakeSesPeerManager) ProtectConnection(peer.ID) {}

type fakePeerManager struct {
lk sync.Mutex
Expand Down
16 changes: 16 additions & 0 deletions internal/sessionpeermanager/sessionpeermanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const (
type PeerTagger interface {
TagPeer(peer.ID, string, int)
UntagPeer(p peer.ID, tag string)
Protect(peer.ID, string)
Unprotect(peer.ID, string) bool
}

// SessionPeerManager keeps track of peers for a session, and takes care of
Expand Down Expand Up @@ -67,6 +69,18 @@ func (spm *SessionPeerManager) AddPeer(p peer.ID) bool {
return true
}

// Protect connection to this peer from being pruned by the connection manager
func (spm *SessionPeerManager) ProtectConnection(p peer.ID) {
spm.plk.Lock()
defer spm.plk.Unlock()

if _, ok := spm.peers[p]; !ok {
return
}

spm.tagger.Protect(p, spm.tag)
}

// RemovePeer removes the peer from the SessionPeerManager.
// Returns true if the peer was removed, false if it did not exist.
func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {
Expand All @@ -79,6 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {

delete(spm.peers, p)
spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.tag)

log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers))
return true
Expand Down Expand Up @@ -130,5 +145,6 @@ func (spm *SessionPeerManager) Shutdown() {
// connections to those peers
for p := range spm.peers {
spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.tag)
}
}
83 changes: 79 additions & 4 deletions internal/sessionpeermanager/sessionpeermanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@ import (
)

type fakePeerTagger struct {
lk sync.Mutex
taggedPeers []peer.ID
wait sync.WaitGroup
lk sync.Mutex
taggedPeers []peer.ID
protectedPeers map[peer.ID]map[string]struct{}
wait sync.WaitGroup
}

func newFakePeerTagger() *fakePeerTagger {
return &fakePeerTagger{
protectedPeers: make(map[peer.ID]map[string]struct{}),
}
}

func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) {
Expand All @@ -36,6 +43,40 @@ func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
}
}

func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) {
fpt.lk.Lock()
defer fpt.lk.Unlock()

tags, ok := fpt.protectedPeers[p]
if !ok {
tags = make(map[string]struct{})
fpt.protectedPeers[p] = tags
}
tags[tag] = struct{}{}
}

func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()

if tags, ok := fpt.protectedPeers[p]; ok {
delete(tags, tag)
if len(tags) == 0 {
delete(fpt.protectedPeers, p)
}
return len(tags) > 0
}

return false
}

func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()

return len(fpt.protectedPeers[p]) > 0
}

func TestAddPeers(t *testing.T) {
peers := testutil.GeneratePeers(2)
spm := New(1, &fakePeerTagger{})
Expand Down Expand Up @@ -208,9 +249,35 @@ func TestPeerTagging(t *testing.T) {
}
}

func TestProtectConnection(t *testing.T) {
peers := testutil.GeneratePeers(1)
peerA := peers[0]
fpt := newFakePeerTagger()
spm := New(1, fpt)

// Should not protect connection if peer hasn't been added yet
spm.ProtectConnection(peerA)
if fpt.isProtected(peerA) {
t.Fatal("Expected peer not to be protected")
}

// Once peer is added, should be able to protect connection
spm.AddPeer(peerA)
spm.ProtectConnection(peerA)
if !fpt.isProtected(peerA) {
t.Fatal("Expected peer to be protected")
}

// Removing peer should unprotect connection
spm.RemovePeer(peerA)
if fpt.isProtected(peerA) {
t.Fatal("Expected peer to be unprotected")
}
}

func TestShutdown(t *testing.T) {
peers := testutil.GeneratePeers(2)
fpt := &fakePeerTagger{}
fpt := newFakePeerTagger()
spm := New(1, fpt)

spm.AddPeer(peers[0])
Expand All @@ -219,9 +286,17 @@ func TestShutdown(t *testing.T) {
t.Fatal("Expected to have tagged two peers")
}

spm.ProtectConnection(peers[0])
if !fpt.isProtected(peers[0]) {
t.Fatal("Expected peer to be protected")
}

spm.Shutdown()

if len(fpt.taggedPeers) != 0 {
t.Fatal("Expected to have untagged all peers")
}
if len(fpt.protectedPeers) != 0 {
t.Fatal("Expected to have unprotected all peers")
}
}