Skip to content

Commit

Permalink
Merge pull request #2 from filecoin-project/fix/context-cancel
Browse files Browse the repository at this point in the history
Handle context cancellation properly (ipfs#428)
  • Loading branch information
dirkmc authored Aug 23, 2023
2 parents 226a4d3 + 51e2291 commit cb442d6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 41 deletions.
54 changes: 34 additions & 20 deletions requestmanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,14 @@ func (rm *RequestManager) NewRequest(ctx context.Context,

inProgressRequestChan := make(chan inProgressRequest)

rm.send(&newRequestMessage{requestID, span, p, root, selectorNode, extensions, inProgressRequestChan}, ctx.Done())
err := rm.send(&newRequestMessage{requestID, span, p, root, selectorNode, extensions, inProgressRequestChan}, ctx.Done())
if err != nil {
return rm.emptyResponse()
}
var receivedInProgressRequest inProgressRequest
select {
case <-rm.ctx.Done():
return rm.emptyResponse()
case <-ctx.Done():
return rm.emptyResponse()
case receivedInProgressRequest = <-inProgressRequestChan:
}

Expand Down Expand Up @@ -282,12 +283,13 @@ func (rm *RequestManager) cancelRequestAndClose(requestID graphsync.RequestID,
// CancelRequest cancels the given request ID and waits for the request to terminate
func (rm *RequestManager) CancelRequest(ctx context.Context, requestID graphsync.RequestID) error {
terminated := make(chan error, 1)
rm.send(&cancelRequestMessage{requestID, terminated, graphsync.RequestClientCancelledErr{}}, ctx.Done())
err := rm.send(&cancelRequestMessage{requestID, terminated, graphsync.RequestClientCancelledErr{}}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return errors.New("context cancelled")
case err := <-terminated:
return err
}
Expand All @@ -299,19 +301,20 @@ func (rm *RequestManager) ProcessResponses(p peer.ID,
responses []gsmsg.GraphSyncResponse,
blks []blocks.Block) {

rm.send(&processResponsesMessage{p, responses, blks}, nil)
_ = rm.send(&processResponsesMessage{p, responses, blks}, nil)
}

// UnpauseRequest unpauses a request that was paused in a block hook based request ID
// Can also send extensions with unpause
func (rm *RequestManager) UnpauseRequest(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&unpauseRequestMessage{requestID, extensions, response}, ctx.Done())
err := rm.send(&unpauseRequestMessage{requestID, extensions, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return errors.New("context cancelled")
case err := <-response:
return err
}
Expand All @@ -320,12 +323,13 @@ func (rm *RequestManager) UnpauseRequest(ctx context.Context, requestID graphsyn
// PauseRequest pauses an in progress request (may take 1 or more blocks to process)
func (rm *RequestManager) PauseRequest(ctx context.Context, requestID graphsync.RequestID) error {
response := make(chan error, 1)
rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
err := rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return errors.New("context cancelled")
case err := <-response:
return err
}
Expand All @@ -334,26 +338,27 @@ func (rm *RequestManager) PauseRequest(ctx context.Context, requestID graphsync.
// UpdateRequest updates an in progress request
func (rm *RequestManager) UpdateRequest(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
err := rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return errors.New("context cancelled")
case err := <-response:
return err
}
}

// GetRequestTask gets data for the given task in the request queue
func (rm *RequestManager) GetRequestTask(p peer.ID, task *peertask.Task, requestExecutionChan chan executor.RequestTask) {
rm.send(&getRequestTaskMessage{p, task, requestExecutionChan}, nil)
_ = rm.send(&getRequestTaskMessage{p, task, requestExecutionChan}, nil)
}

// ReleaseRequestTask releases a task request the requestQueue
func (rm *RequestManager) ReleaseRequestTask(p peer.ID, task *peertask.Task, err error) {
done := make(chan struct{}, 1)
rm.send(&releaseRequestTaskMessage{p, task, err, done}, nil)
_ = rm.send(&releaseRequestTaskMessage{p, task, err, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -363,7 +368,7 @@ func (rm *RequestManager) ReleaseRequestTask(p peer.ID, task *peertask.Task, err
// PeerState gets stats on all outgoing requests for a given peer
func (rm *RequestManager) PeerState(p peer.ID) peerstate.PeerState {
response := make(chan peerstate.PeerState)
rm.send(&peerStateMessage{p, response}, nil)
_ = rm.send(&peerStateMessage{p, response}, nil)
select {
case <-rm.ctx.Done():
return peerstate.PeerState{}
Expand Down Expand Up @@ -391,11 +396,20 @@ func (rm *RequestManager) Shutdown() {
rm.cancel()
}

func (rm *RequestManager) send(message requestManagerMessage, done <-chan struct{}) {
func (rm *RequestManager) send(message requestManagerMessage, done <-chan struct{}) error {
// prioritize cancelled context
select {
case <-done:
return errors.New("unable to send message before cancellation")
default:
}
select {
case <-rm.ctx.Done():
return rm.ctx.Err()
case <-done:
return errors.New("unable to send message before cancellation")
case rm.messages <- message:
return nil
}
}

Expand Down
55 changes: 34 additions & 21 deletions responsemanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,19 @@ func New(ctx context.Context,

// ProcessRequests processes incoming requests for the given peer
func (rm *ResponseManager) ProcessRequests(ctx context.Context, p peer.ID, requests []gsmsg.GraphSyncRequest) {
rm.send(&processRequestsMessage{p, requests}, ctx.Done())
_ = rm.send(&processRequestsMessage{p, requests}, ctx.Done())
}

// UnpauseResponse unpauses a response that was previously paused
func (rm *ResponseManager) UnpauseResponse(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&unpauseRequestMessage{requestID, response, extensions}, ctx.Done())
err := rm.send(&unpauseRequestMessage{requestID, response, extensions}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return errors.New("context cancelled")
case err := <-response:
return err
}
Expand All @@ -174,12 +175,13 @@ func (rm *ResponseManager) UnpauseResponse(ctx context.Context, requestID graphs
// PauseResponse pauses an in progress response (may take 1 or more blocks to process)
func (rm *ResponseManager) PauseResponse(ctx context.Context, requestID graphsync.RequestID) error {
response := make(chan error, 1)
rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
err := rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return errors.New("context cancelled")
case err := <-response:
return err
}
Expand All @@ -188,12 +190,13 @@ func (rm *ResponseManager) PauseResponse(ctx context.Context, requestID graphsyn
// CancelResponse cancels an in progress response
func (rm *ResponseManager) CancelResponse(ctx context.Context, requestID graphsync.RequestID) error {
response := make(chan error, 1)
rm.send(&errorRequestMessage{requestID, queryexecutor.ErrCancelledByCommand, response}, ctx.Done())
err := rm.send(&errorRequestMessage{requestID, queryexecutor.ErrCancelledByCommand, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return errors.New("context cancelled")
case err := <-response:
return err
}
Expand All @@ -202,12 +205,13 @@ func (rm *ResponseManager) CancelResponse(ctx context.Context, requestID graphsy
// UpdateRequest updates an in progress response
func (rm *ResponseManager) UpdateResponse(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
err := rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return errors.New("context cancelled")
case err := <-response:
return err
}
Expand All @@ -216,7 +220,7 @@ func (rm *ResponseManager) UpdateResponse(ctx context.Context, requestID graphsy
// Synchronize is a utility method that blocks until all current messages are processed
func (rm *ResponseManager) synchronize() {
sync := make(chan error)
rm.send(&synchronizeMessage{sync}, nil)
_ = rm.send(&synchronizeMessage{sync}, nil)
select {
case <-rm.ctx.Done():
case <-sync:
Expand All @@ -225,18 +229,18 @@ func (rm *ResponseManager) synchronize() {

// StartTask starts the given task from the peer task queue
func (rm *ResponseManager) StartTask(task *peertask.Task, p peer.ID, responseTaskChan chan<- queryexecutor.ResponseTask) {
rm.send(&startTaskRequest{task, p, responseTaskChan}, nil)
_ = rm.send(&startTaskRequest{task, p, responseTaskChan}, nil)
}

// GetUpdates is called to read pending updates for a task and clear them
func (rm *ResponseManager) GetUpdates(requestID graphsync.RequestID, updatesChan chan<- []gsmsg.GraphSyncRequest) {
rm.send(&responseUpdateRequest{requestID, updatesChan}, nil)
_ = rm.send(&responseUpdateRequest{requestID, updatesChan}, nil)
}

// FinishTask marks a task from the task queue as done
func (rm *ResponseManager) FinishTask(task *peertask.Task, p peer.ID, err error) {
done := make(chan struct{}, 1)
rm.send(&finishTaskRequest{task, p, err, done}, nil)
_ = rm.send(&finishTaskRequest{task, p, err, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -246,7 +250,7 @@ func (rm *ResponseManager) FinishTask(task *peertask.Task, p peer.ID, err error)
// CloseWithNetworkError closes a request due to a network error
func (rm *ResponseManager) CloseWithNetworkError(requestID graphsync.RequestID) {
done := make(chan error, 1)
rm.send(&errorRequestMessage{requestID, queryexecutor.ErrNetworkError, done}, nil)
_ = rm.send(&errorRequestMessage{requestID, queryexecutor.ErrNetworkError, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -256,7 +260,7 @@ func (rm *ResponseManager) CloseWithNetworkError(requestID graphsync.RequestID)
// TerminateRequest indicates a request has finished sending data and should no longer be tracked
func (rm *ResponseManager) TerminateRequest(requestID graphsync.RequestID) {
done := make(chan struct{}, 1)
rm.send(&terminateRequestMessage{requestID, done}, nil)
_ = rm.send(&terminateRequestMessage{requestID, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -266,7 +270,7 @@ func (rm *ResponseManager) TerminateRequest(requestID graphsync.RequestID) {
// PeerState gets current state of the outgoing responses for a given peer
func (rm *ResponseManager) PeerState(p peer.ID) peerstate.PeerState {
response := make(chan peerstate.PeerState)
rm.send(&peerStateMessage{p, response}, nil)
_ = rm.send(&peerStateMessage{p, response}, nil)
select {
case <-rm.ctx.Done():
return peerstate.PeerState{}
Expand All @@ -275,11 +279,20 @@ func (rm *ResponseManager) PeerState(p peer.ID) peerstate.PeerState {
}
}

func (rm *ResponseManager) send(message responseManagerMessage, done <-chan struct{}) {
func (rm *ResponseManager) send(message responseManagerMessage, done <-chan struct{}) error {
// prioritize cancelled context
select {
case <-done:
return errors.New("unable to send message before cancellation")
default:
}
select {
case <-rm.ctx.Done():
return rm.ctx.Err()
case <-done:
return errors.New("unable to send message before cancellation")
case rm.messages <- message:
return nil
}
}

Expand Down

0 comments on commit cb442d6

Please sign in to comment.