diff --git a/go.mod b/go.mod index acbcfc0b..37a89a32 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,6 @@ require ( github.com/libp2p/go-libp2p v0.13.0 github.com/libp2p/go-libp2p-core v0.8.5 github.com/libp2p/go-libp2p-record v0.1.1 // indirect - github.com/multiformats/go-multiaddr v0.3.1 github.com/stretchr/testify v1.6.1 github.com/whyrusleeping/cbor-gen v0.0.0-20210219115102-f37d292932f2 go.uber.org/atomic v1.6.0 diff --git a/message.go b/message.go index dd23b17e..ac61e611 100644 --- a/message.go +++ b/message.go @@ -16,7 +16,7 @@ var ( // version of data-transfer (supports do-not-send-first-blocks extension) ProtocolDataTransfer1_2 protocol.ID = "/fil/datatransfer/1.2.0" - // ProtocolDataTransfer1_2 is the protocol identifier for the version + // ProtocolDataTransfer1_1 is the protocol identifier for the version // of data-transfer that supports the do-not-send-cids extension // (but not the do-not-send-first-blocks extension) ProtocolDataTransfer1_1 protocol.ID = "/fil/datatransfer/1.1.0" diff --git a/network/libp2p_impl.go b/network/libp2p_impl.go index 95766628..372e52c8 100644 --- a/network/libp2p_impl.go +++ b/network/libp2p_impl.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "sync" "time" logging "github.com/ipfs/go-log/v2" @@ -13,7 +12,6 @@ import ( "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" - ma "github.com/multiformats/go-multiaddr" "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer" @@ -53,8 +51,7 @@ type Option func(*libp2pDataTransferNetwork) // DataTransferProtocols OVERWRITES the default libp2p protocols we use for data transfer with the given protocols. func DataTransferProtocols(protocols []protocol.ID) Option { return func(impl *libp2pDataTransferNetwork) { - impl.dtProtocols = nil - impl.dtProtocols = append(impl.dtProtocols, protocols...) + impl.setDataTransferProtocols(protocols) } } @@ -87,17 +84,13 @@ func NewFromLibp2pHost(host host.Host, options ...Option) DataTransferNetwork { minAttemptDuration: defaultMinAttemptDuration, maxAttemptDuration: defaultMaxAttemptDuration, backoffFactor: defaultBackoffFactor, - dtProtocols: defaultDataTransferProtocols, - peerProtocols: make(map[peer.ID]protocol.ID), } + dataTransferNetwork.setDataTransferProtocols(defaultDataTransferProtocols) for _, option := range options { option(&dataTransferNetwork) } - // Listen to network notifications - host.Network().Notify(&dataTransferNetwork) - return &dataTransferNetwork } @@ -114,10 +107,8 @@ type libp2pDataTransferNetwork struct { minAttemptDuration time.Duration maxAttemptDuration time.Duration dtProtocols []protocol.ID + dtProtocolStrings []string backoffFactor float64 - - pplk sync.RWMutex - peerProtocols map[peer.ID]protocol.ID } func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.ID, protocols ...protocol.ID) (network.Stream, error) { @@ -143,9 +134,6 @@ func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.I id, nAttempts, impl.maxStreamOpenAttempts, time.Since(start)) } - // Cache the peer's protocol version - impl.setPeerProtocol(id, s.Protocol()) - return s, err } @@ -229,10 +217,7 @@ func (dtnet *libp2pDataTransferNetwork) handleNewStream(s network.Stream) { return } - // Cache the peer's protocol version p := s.Conn().RemotePeer() - dtnet.setPeerProtocol(p, s.Protocol()) - for { var received datatransfer.Message var err error @@ -321,12 +306,13 @@ func (dtnet *libp2pDataTransferNetwork) msgToStream(ctx context.Context, s netwo func (impl *libp2pDataTransferNetwork) Protocol(ctx context.Context, id peer.ID) (protocol.ID, error) { // Check the cache for the peer's protocol version - impl.pplk.RLock() - proto, ok := impl.peerProtocols[id] - impl.pplk.RUnlock() + firstProto, err := impl.host.Peerstore().FirstSupportedProtocol(id, impl.dtProtocolStrings...) + if err != nil { + return "", err + } - if ok { - return proto, nil + if firstProto != "" { + return protocol.ID(firstProto), nil } // The peer's protocol version is not in the cache, so connect to the peer. @@ -341,27 +327,12 @@ func (impl *libp2pDataTransferNetwork) Protocol(ctx context.Context, id peer.ID) return s.Protocol(), nil } -func (impl *libp2pDataTransferNetwork) setPeerProtocol(p peer.ID, proto protocol.ID) { - impl.pplk.Lock() - defer impl.pplk.Unlock() +func (impl *libp2pDataTransferNetwork) setDataTransferProtocols(protocols []protocol.ID) { + impl.dtProtocols = append([]protocol.ID{}, protocols...) - impl.peerProtocols[p] = proto -} - -func (impl *libp2pDataTransferNetwork) clearPeerProtocol(p peer.ID) { - impl.pplk.Lock() - defer impl.pplk.Unlock() - - delete(impl.peerProtocols, p) -} - -// Clear the peer protocol version cache for the peer when the peer disconnects -func (impl *libp2pDataTransferNetwork) Disconnected(n network.Network, conn network.Conn) { - impl.clearPeerProtocol(conn.RemotePeer()) + // Keep a string version of the protocols for performance reasons + impl.dtProtocolStrings = make([]string, 0, len(impl.dtProtocols)) + for _, proto := range impl.dtProtocols { + impl.dtProtocolStrings = append(impl.dtProtocolStrings, string(proto)) + } } - -func (impl *libp2pDataTransferNetwork) Listen(n network.Network, multiaddr ma.Multiaddr) {} -func (impl *libp2pDataTransferNetwork) ListenClose(n network.Network, multiaddr ma.Multiaddr) {} -func (impl *libp2pDataTransferNetwork) Connected(n network.Network, conn network.Conn) {} -func (impl *libp2pDataTransferNetwork) OpenedStream(n network.Network, stream network.Stream) {} -func (impl *libp2pDataTransferNetwork) ClosedStream(n network.Network, stream network.Stream) {}