Skip to content

Commit

Permalink
Merge pull request ipfs/go-bitswap#345 from ipfs/fix/cancel-leak
Browse files Browse the repository at this point in the history
fix: in message queue only send cancel if want was sent

This commit was moved from ipfs/go-bitswap@6728add
  • Loading branch information
Stebalien authored Apr 11, 2020
2 parents cf0893f + 378f7df commit f610dcf
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 45 deletions.
100 changes: 67 additions & 33 deletions bitswap/internal/messagequeue/messagequeue.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
bsmsg "github.com/ipfs/go-bitswap/message"
pb "github.com/ipfs/go-bitswap/message/pb"
bsnet "github.com/ipfs/go-bitswap/network"
"github.com/ipfs/go-bitswap/wantlist"
bswl "github.com/ipfs/go-bitswap/wantlist"
cid "github.com/ipfs/go-cid"
logging "github.com/ipfs/go-log"
Expand Down Expand Up @@ -80,41 +81,44 @@ type MessageQueue struct {
msg bsmsg.BitSwapMessage
}

// recallWantlist keeps a list of pending wants, and a list of all wants that
// have ever been requested
// recallWantlist keeps a list of pending wants and a list of sent wants
type recallWantlist struct {
// The list of all wants that have been requested, including wants that
// have been sent and wants that have not yet been sent
allWants *bswl.Wantlist
// The list of wants that have not yet been sent
pending *bswl.Wantlist
// The list of wants that have been sent
sent *bswl.Wantlist
}

func newRecallWantList() recallWantlist {
return recallWantlist{
allWants: bswl.New(),
pending: bswl.New(),
pending: bswl.New(),
sent: bswl.New(),
}
}

// Add want to both the pending list and the list of all wants
// Add want to the pending list
func (r *recallWantlist) Add(c cid.Cid, priority int32, wtype pb.Message_Wantlist_WantType) {
r.allWants.Add(c, priority, wtype)
r.pending.Add(c, priority, wtype)
}

// Remove wants from both the pending list and the list of all wants
// Remove wants from both the pending list and the list of sent wants
func (r *recallWantlist) Remove(c cid.Cid) {
r.allWants.Remove(c)
r.sent.Remove(c)
r.pending.Remove(c)
}

// Remove wants by type from both the pending list and the list of all wants
// Remove wants by type from both the pending list and the list of sent wants
func (r *recallWantlist) RemoveType(c cid.Cid, wtype pb.Message_Wantlist_WantType) {
r.allWants.RemoveType(c, wtype)
r.sent.RemoveType(c, wtype)
r.pending.RemoveType(c, wtype)
}

// Sent moves the want from the pending to the sent list
func (r *recallWantlist) Sent(e bsmsg.Entry) {
r.pending.RemoveType(e.Cid, e.WantType)
r.sent.Add(e.Cid, e.Priority, e.WantType)
}

type peerConn struct {
p peer.ID
network MessageNetwork
Expand Down Expand Up @@ -251,15 +255,29 @@ func (mq *MessageQueue) AddCancels(cancelKs []cid.Cid) {
mq.wllock.Lock()
defer mq.wllock.Unlock()

workReady := false

// Remove keys from broadcast and peer wants, and add to cancels
for _, c := range cancelKs {
// Check if a want for the key was sent
_, wasSentBcst := mq.bcstWants.sent.Contains(c)
_, wasSentPeer := mq.peerWants.sent.Contains(c)

// Remove the want from tracking wantlists
mq.bcstWants.Remove(c)
mq.peerWants.Remove(c)
mq.cancels.Add(c)

// Only send a cancel if a want was sent
if wasSentBcst || wasSentPeer {
mq.cancels.Add(c)
workReady = true
}
}

// Schedule a message send
mq.signalWorkReady()
if workReady {
mq.signalWorkReady()
}
}

// SetRebroadcastInterval sets a new interval on which to rebroadcast the full wantlist
Expand Down Expand Up @@ -366,13 +384,13 @@ func (mq *MessageQueue) transferRebroadcastWants() bool {
defer mq.wllock.Unlock()

// Check if there are any wants to rebroadcast
if mq.bcstWants.allWants.Len() == 0 && mq.peerWants.allWants.Len() == 0 {
if mq.bcstWants.sent.Len() == 0 && mq.peerWants.sent.Len() == 0 {
return false
}

// Copy all wants into pending wants lists
mq.bcstWants.pending.Absorb(mq.bcstWants.allWants)
mq.peerWants.pending.Absorb(mq.peerWants.allWants)
// Copy sent wants into pending wants lists
mq.bcstWants.pending.Absorb(mq.bcstWants.sent)
mq.peerWants.pending.Absorb(mq.peerWants.sent)

return true
}
Expand Down Expand Up @@ -405,7 +423,7 @@ func (mq *MessageQueue) sendMessage() {
mq.dhTimeoutMgr.Start()

// Convert want lists to a Bitswap Message
message := mq.extractOutgoingMessage(mq.sender.SupportsHave())
message, onSent := mq.extractOutgoingMessage(mq.sender.SupportsHave())

// After processing the message, clear out its fields to save memory
defer mq.msg.Reset(false)
Expand All @@ -421,7 +439,7 @@ func (mq *MessageQueue) sendMessage() {
for i := 0; i < maxRetries; i++ {
if mq.attemptSendAndRecovery(message) {
// We were able to send successfully.
mq.onMessageSent(wantlist)
onSent(wantlist)

mq.simulateDontHaveWithTimeout(wantlist)

Expand Down Expand Up @@ -452,7 +470,7 @@ func (mq *MessageQueue) simulateDontHaveWithTimeout(wantlist []bsmsg.Entry) {
// Unlikely, but just in case check that the block hasn't been
// received in the interim
c := entry.Cid
if _, ok := mq.peerWants.allWants.Contains(c); ok {
if _, ok := mq.peerWants.sent.Contains(c); ok {
wants = append(wants, c)
}
}
Expand Down Expand Up @@ -522,7 +540,7 @@ func (mq *MessageQueue) pendingWorkCount() int {
}

// Convert the lists of wants into a Bitswap message
func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapMessage {
func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) (bsmsg.BitSwapMessage, func([]bsmsg.Entry)) {
mq.wllock.Lock()
defer mq.wllock.Unlock()

Expand Down Expand Up @@ -572,19 +590,35 @@ func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapM
mq.cancels.Remove(c)
}

return mq.msg
}
// Called when the message has been successfully sent.
onMessageSent := func(wantlist []bsmsg.Entry) {
bcst := keysToSet(bcstEntries)
prws := keysToSet(peerEntries)

// Called when the message has been successfully sent.
func (mq *MessageQueue) onMessageSent(wantlist []bsmsg.Entry) {
// Remove the sent keys from the broadcast and regular wantlists.
mq.wllock.Lock()
defer mq.wllock.Unlock()
mq.wllock.Lock()
defer mq.wllock.Unlock()

for _, e := range wantlist {
mq.bcstWants.pending.Remove(e.Cid)
mq.peerWants.pending.RemoveType(e.Cid, e.WantType)
// Move the keys from pending to sent
for _, e := range wantlist {
if _, ok := bcst[e.Cid]; ok {
mq.bcstWants.Sent(e)
}
if _, ok := prws[e.Cid]; ok {
mq.peerWants.Sent(e)
}
}
}

return mq.msg, onMessageSent
}

// Convert wantlist entries into a set of cids
func keysToSet(wl []wantlist.Entry) map[cid.Cid]struct{} {
set := make(map[cid.Cid]struct{}, len(wl))
for _, e := range wl {
set[e.Cid] = struct{}{}
}
return set
}

func (mq *MessageQueue) initializeSender() error {
Expand Down
57 changes: 45 additions & 12 deletions bitswap/internal/messagequeue/messagequeue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,25 +319,43 @@ func TestCancelOverridesPendingWants(t *testing.T) {
fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)

wantHaves := testutil.GenerateCids(2)
wantBlocks := testutil.GenerateCids(2)
cancels := []cid.Cid{wantBlocks[0], wantHaves[0]}

messageQueue.Startup()
messageQueue.AddWants(wantBlocks, wantHaves)
messageQueue.AddCancels([]cid.Cid{wantBlocks[0], wantHaves[0]})
messageQueue.AddCancels(cancels)
messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond)

if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks) {
if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks)-len(cancels) {
t.Fatal("Wrong message count")
}

// Cancelled 1 want-block and 1 want-have before they were sent
// so that leaves 1 want-block and 1 want-have
wb, wh, cl := filterWantTypes(messages[0])
if len(wb) != 1 || !wb[0].Equals(wantBlocks[1]) {
t.Fatal("Expected 1 want-block")
}
if len(wh) != 1 || !wh[0].Equals(wantHaves[1]) {
t.Fatal("Expected 1 want-have")
}
// Cancelled wants before they were sent, so no cancel should be sent
// to the network
if len(cl) != 0 {
t.Fatal("Expected no cancels")
}

// Cancel the remaining want-blocks and want-haves
cancels = append(wantHaves, wantBlocks...)
messageQueue.AddCancels(cancels)
messages = collectMessages(ctx, t, messagesSent, 10*time.Millisecond)

// The remaining 2 cancels should be sent to the network as they are for
// wants that were sent to the network
_, _, cl = filterWantTypes(messages[0])
if len(cl) != 2 {
t.Fatal("Expected 2 cancels")
}
Expand All @@ -353,26 +371,41 @@ func TestWantOverridesPendingCancels(t *testing.T) {
fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
cancels := testutil.GenerateCids(3)

cids := testutil.GenerateCids(3)
wantBlocks := cids[:1]
wantHaves := cids[1:]

messageQueue.Startup()
messageQueue.AddCancels(cancels)
messageQueue.AddWants([]cid.Cid{cancels[0]}, []cid.Cid{cancels[1]})

// Add 1 want-block and 2 want-haves
messageQueue.AddWants(wantBlocks, wantHaves)

messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if totalEntriesLength(messages) != len(wantBlocks)+len(wantHaves) {
t.Fatal("Wrong message count", totalEntriesLength(messages))
}

if totalEntriesLength(messages) != len(cancels) {
t.Fatal("Wrong message count")
// Cancel existing wants
messageQueue.AddCancels(cids)
// Override one cancel with a want-block (before cancel is sent to network)
messageQueue.AddWants(cids[:1], []cid.Cid{})

messages = collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if totalEntriesLength(messages) != 3 {
t.Fatal("Wrong message count", totalEntriesLength(messages))
}

// Should send 1 want-block and 2 cancels
wb, wh, cl := filterWantTypes(messages[0])
if len(wb) != 1 || !wb[0].Equals(cancels[0]) {
if len(wb) != 1 {
t.Fatal("Expected 1 want-block")
}
if len(wh) != 1 || !wh[0].Equals(cancels[1]) {
t.Fatal("Expected 1 want-have")
if len(wh) != 0 {
t.Fatal("Expected 0 want-have")
}
if len(cl) != 1 || !cl[0].Equals(cancels[2]) {
t.Fatal("Expected 1 cancel")
if len(cl) != 2 {
t.Fatal("Expected 2 cancels")
}
}

Expand Down
11 changes: 11 additions & 0 deletions bitswap/internal/sessionwantlist/sessionwantlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
cid "github.com/ipfs/go-cid"
)

// The SessionWantList keeps track of which sessions want a CID
type SessionWantlist struct {
sync.RWMutex
wants map[cid.Cid]map[uint64]struct{}
Expand All @@ -17,6 +18,7 @@ func NewSessionWantlist() *SessionWantlist {
}
}

// The given session wants the keys
func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) {
swl.Lock()
defer swl.Unlock()
Expand All @@ -29,6 +31,8 @@ func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) {
}
}

// Remove the keys for all sessions.
// Called when blocks are received.
func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) {
swl.Lock()
defer swl.Unlock()
Expand All @@ -38,6 +42,8 @@ func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) {
}
}

// Remove the session's wants, and return wants that are no longer wanted by
// any session.
func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid {
swl.Lock()
defer swl.Unlock()
Expand All @@ -54,6 +60,7 @@ func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid {
return deletedKs
}

// Remove the session's wants
func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) {
swl.Lock()
defer swl.Unlock()
Expand All @@ -68,6 +75,7 @@ func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) {
}
}

// All keys wanted by all sessions
func (swl *SessionWantlist) Keys() []cid.Cid {
swl.RLock()
defer swl.RUnlock()
Expand All @@ -79,6 +87,7 @@ func (swl *SessionWantlist) Keys() []cid.Cid {
return ks
}

// All sessions that want the given keys
func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 {
swl.RLock()
defer swl.RUnlock()
Expand All @@ -97,6 +106,7 @@ func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 {
return ses
}

// Filter for keys that at least one session wants
func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set {
swl.RLock()
defer swl.RUnlock()
Expand All @@ -110,6 +120,7 @@ func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set {
return has
}

// Filter for keys that the given session wants
func (swl *SessionWantlist) SessionHas(ses uint64, ks []cid.Cid) *cid.Set {
swl.RLock()
defer swl.RUnlock()
Expand Down

0 comments on commit f610dcf

Please sign in to comment.