diff --git a/go.mod b/go.mod index 4ca2a330a3..2c5014bab4 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,7 @@ require ( go.uber.org/fx v1.19.2 go.uber.org/goleak v1.1.12 golang.org/x/crypto v0.7.0 + golang.org/x/exp v0.0.0-20230321023759-10a507213a29 golang.org/x/sync v0.1.0 golang.org/x/sys v0.7.0 golang.org/x/tools v0.7.0 @@ -109,7 +110,6 @@ require ( go.uber.org/dig v1.16.1 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.24.0 // indirect - golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect golang.org/x/mod v0.10.0 // indirect golang.org/x/net v0.8.0 // indirect golang.org/x/text v0.8.0 // indirect diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 40887f2c3a..4173f2bb79 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -1,12 +1,16 @@ package identify import ( + "bytes" "context" "fmt" "io" + "sort" "sync" "time" + "golang.org/x/exp/slices" + "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" @@ -62,6 +66,31 @@ type identifySnapshot struct { record *record.Envelope } +// Equal says if two snapshots are identical. +// It does NOT compare the sequence number. +func (s identifySnapshot) Equal(other *identifySnapshot) bool { + hasRecord := s.record != nil + otherHasRecord := other.record != nil + if hasRecord != otherHasRecord { + return false + } + if hasRecord && !s.record.Equal(other.record) { + return false + } + if !slices.Equal(s.protocols, other.protocols) { + return false + } + if len(s.addrs) != len(other.addrs) { + return false + } + for i, a := range s.addrs { + if !a.Equal(other.addrs[i]) { + return false + } + } + return true +} + type IDService interface { // IdentifyConn synchronously triggers an identify request on the connection and // waits for it to complete. If the connection is being identified by another @@ -249,10 +278,12 @@ func (ids *idService) loop(ctx context.Context) { if !ok { return } + if updated := ids.updateSnapshot(); !updated { + continue + } if ids.metricsTracer != nil { ids.metricsTracer.TriggeredPushes(e) } - ids.updateSnapshot() select { case triggerPush <- struct{}{}: default: // we already have one more push queued, no need to queue another one @@ -529,11 +560,16 @@ func readAllIDMessages(r pbio.Reader, finalMsg proto.Message) error { return fmt.Errorf("too many parts") } -func (ids *idService) updateSnapshot() { +func (ids *idService) updateSnapshot() (updated bool) { + addrs := ids.Host.Addrs() + sort.Slice(addrs, func(i, j int) bool { return bytes.Compare(addrs[i].Bytes(), addrs[j].Bytes()) == -1 }) + protos := ids.Host.Mux().Protocols() + sort.Slice(protos, func(i, j int) bool { return protos[i] < protos[j] }) snapshot := identifySnapshot{ - addrs: ids.Host.Addrs(), - protocols: ids.Host.Mux().Protocols(), + addrs: addrs, + protocols: protos, } + if !ids.disableSignedPeerRecord { if cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore()); ok { snapshot.record = cab.GetPeerRecord(ids.Host.ID()) @@ -541,11 +577,17 @@ func (ids *idService) updateSnapshot() { } ids.currentSnapshot.Lock() + defer ids.currentSnapshot.Unlock() + + if ids.currentSnapshot.snapshot.Equal(&snapshot) { + return false + } + snapshot.seq = ids.currentSnapshot.snapshot.seq + 1 ids.currentSnapshot.snapshot = snapshot - ids.currentSnapshot.Unlock() log.Debugw("updating snapshot", "seq", snapshot.seq, "addrs", snapshot.addrs) + return true } func (ids *idService) writeChunkedIdentifyMsg(s network.Stream, mes *pb.Identify) error { diff --git a/p2p/protocol/identify/snapshot_test.go b/p2p/protocol/identify/snapshot_test.go new file mode 100644 index 0000000000..55354a49a6 --- /dev/null +++ b/p2p/protocol/identify/snapshot_test.go @@ -0,0 +1,47 @@ +package identify + +import ( + "crypto/rand" + "testing" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/core/record" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func TestSnapshotEquality(t *testing.T) { + addr1 := ma.StringCast("/ip4/127.0.0.1/tcp/1234") + addr2 := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1") + + _, pubKey1, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + _, pubKey2, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + record1 := &record.Envelope{PublicKey: pubKey1} + record2 := &record.Envelope{PublicKey: pubKey2} + + for _, tc := range []struct { + s1, s2 *identifySnapshot + result bool + }{ + {s1: &identifySnapshot{record: record1}, s2: &identifySnapshot{record: record1}, result: true}, + {s1: &identifySnapshot{record: record1}, s2: &identifySnapshot{record: record2}, result: false}, + {s1: &identifySnapshot{addrs: []ma.Multiaddr{addr1}}, s2: &identifySnapshot{addrs: []ma.Multiaddr{addr1}}, result: true}, + {s1: &identifySnapshot{addrs: []ma.Multiaddr{addr1}}, s2: &identifySnapshot{addrs: []ma.Multiaddr{addr2}}, result: false}, + {s1: &identifySnapshot{addrs: []ma.Multiaddr{addr1, addr2}}, s2: &identifySnapshot{addrs: []ma.Multiaddr{addr2}}, result: false}, + {s1: &identifySnapshot{addrs: []ma.Multiaddr{addr1}}, s2: &identifySnapshot{addrs: []ma.Multiaddr{addr1, addr2}}, result: false}, + {s1: &identifySnapshot{protocols: []protocol.ID{"/foo"}}, s2: &identifySnapshot{protocols: []protocol.ID{"/foo"}}, result: true}, + {s1: &identifySnapshot{protocols: []protocol.ID{"/foo"}}, s2: &identifySnapshot{protocols: []protocol.ID{"/bar"}}, result: false}, + {s1: &identifySnapshot{protocols: []protocol.ID{"/foo", "/bar"}}, s2: &identifySnapshot{protocols: []protocol.ID{"/bar"}}, result: false}, + {s1: &identifySnapshot{protocols: []protocol.ID{"/foo"}}, s2: &identifySnapshot{protocols: []protocol.ID{"/foo", "/bar"}}, result: false}, + } { + if tc.result { + require.Truef(t, tc.s1.Equal(tc.s2), "expected equal: %+v and %+v", tc.s1, tc.s2) + } else { + require.Falsef(t, tc.s1.Equal(tc.s2), "expected unequal: %+v and %+v", tc.s1, tc.s2) + } + } +}