diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 6c89440394..e6be5c5c2f 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -662,7 +662,7 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { // header with given protocol.ID. If there is no connection to p, attempts // to create one. If ProtocolID is "", writes no header. // (Thread-safe) -func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { +func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (str network.Stream, strErr error) { // If the caller wants to prevent the host from dialing, it should use the NoDial option. if nodial, _ := network.GetNoDial(ctx); !nodial { err := h.Connect(ctx, peer.AddrInfo{ID: p}) @@ -680,6 +680,11 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I } return nil, fmt.Errorf("failed to open stream: %w", err) } + defer func() { + if strErr != nil && s != nil { + s.Reset() + } + }() // Wait for any in-progress identifies on the connection to finish. This // is faster than negotiating. @@ -689,13 +694,11 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I select { case <-h.ids.IdentifyWait(s.Conn()): case <-ctx.Done(): - _ = s.Reset() return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err()) } pref, err := h.preferredProtocol(p, pids) if err != nil { - _ = s.Reset() return nil, err } @@ -720,7 +723,6 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I select { case err = <-errCh: if err != nil { - s.Reset() return nil, fmt.Errorf("failed to negotiate protocol: %w", err) } case <-ctx.Done():