diff --git a/providerquerymanager/providerquerymanager.go b/providerquerymanager/providerquerymanager.go index 38471479..29065228 100644 --- a/providerquerymanager/providerquerymanager.go +++ b/providerquerymanager/providerquerymanager.go @@ -124,17 +124,24 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, inProgressRequestChan: inProgressRequestChan, }: case <-pqm.ctx.Done(): - return nil + ch := make(chan peer.ID) + close(ch) + return ch case <-sessionCtx.Done(): - return nil + ch := make(chan peer.ID) + close(ch) + return ch } + // DO NOT select on sessionCtx. We only want to abort here if we're + // shutting down because we can't actually _cancel_ the request till we + // get to receiveProviders. var receivedInProgressRequest inProgressRequest select { case <-pqm.ctx.Done(): - return nil - case <-sessionCtx.Done(): - return nil + ch := make(chan peer.ID) + close(ch) + return ch case receivedInProgressRequest = <-inProgressRequestChan: } @@ -170,7 +177,9 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k case <-pqm.ctx.Done(): return case <-sessionCtx.Done(): - pqm.cancelProviderRequest(k, incomingProviders) + if incomingProviders != nil { + pqm.cancelProviderRequest(k, incomingProviders) + } return case provider, ok := <-incomingProviders: if !ok { @@ -228,7 +237,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() { wg.Add(1) go func(p peer.ID) { defer wg.Done() - err := pqm.network.ConnectTo(pqm.ctx, p) + err := pqm.network.ConnectTo(findProviderCtx, p) if err != nil { log.Debugf("failed to connect to provider %s: %s", p, err) return @@ -397,12 +406,12 @@ func (crm *cancelRequestMessage) debugMessage() string { func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[crm.k] if !ok { - log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String()) + // Request finished while queued. return } _, ok = requestStatus.listeners[crm.incomingProviders] if !ok { - log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String()) + // Request finished and _restarted_ while queued. return } delete(requestStatus.listeners, crm.incomingProviders) diff --git a/providerquerymanager/providerquerymanager_test.go b/providerquerymanager/providerquerymanager_test.go index 3abe6b0e..efdfd14f 100644 --- a/providerquerymanager/providerquerymanager_test.go +++ b/providerquerymanager/providerquerymanager_test.go @@ -304,3 +304,60 @@ func TestFindProviderTimeout(t *testing.T) { t.Fatal("Find provider request should have timed out, did not") } } + +func TestFindProviderPreCanceled(t *testing.T) { + peers := testutil.GeneratePeers(10) + fpn := &fakeProviderNetwork{ + peersFound: peers, + delay: 1 * time.Millisecond, + } + ctx := context.Background() + providerQueryManager := New(ctx, fpn) + providerQueryManager.Startup() + providerQueryManager.SetFindProviderTimeout(100 * time.Millisecond) + keys := testutil.GenerateCids(1) + + sessionCtx, cancel := context.WithCancel(ctx) + cancel() + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0]) + if firstRequestChan == nil { + t.Fatal("expected non-nil channel") + } + select { + case <-firstRequestChan: + case <-time.After(10 * time.Millisecond): + t.Fatal("shouldn't have blocked waiting on a closed context") + } +} + +func TestCancelFindProvidersAfterCompletion(t *testing.T) { + peers := testutil.GeneratePeers(2) + fpn := &fakeProviderNetwork{ + peersFound: peers, + delay: 1 * time.Millisecond, + } + ctx := context.Background() + providerQueryManager := New(ctx, fpn) + providerQueryManager.Startup() + providerQueryManager.SetFindProviderTimeout(100 * time.Millisecond) + keys := testutil.GenerateCids(1) + + sessionCtx, cancel := context.WithCancel(ctx) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0]) + <-firstRequestChan // wait for everything to start. + time.Sleep(10 * time.Millisecond) // wait for the incoming providres to stop. + cancel() // cancel the context. + + timer := time.NewTimer(10 * time.Millisecond) + defer timer.Stop() + for { + select { + case _, ok := <-firstRequestChan: + if !ok { + return + } + case <-timer.C: + t.Fatal("should have finished receiving responses within timeout") + } + } +}