diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index f4cf10358b..3e0af7268a 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -121,11 +121,7 @@ func (c *conn) NewStream() (inet.Stream, error) { } func (c *conn) GetStreams() []inet.Stream { - var out []inet.Stream - for e := c.streams.Front(); e != nil; e = e.Next() { - out = append(out, e.Value.(*stream)) - } - return out + return c.allStreams() } // LocalMultiaddr is the Multiaddr on this side diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index 8d01505049..331b8960b8 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net" + "sync/atomic" "time" inet "github.com/libp2p/go-libp2p-net" @@ -24,7 +25,7 @@ type stream struct { writeErr error - protocol protocol.ID + protocol atomic.Value stat inet.Stat } @@ -70,7 +71,9 @@ func (s *stream) Write(p []byte) (n int, err error) { } func (s *stream) Protocol() protocol.ID { - return s.protocol + // Ignore type error. It means that the protocol is unset. + p, _ := s.protocol.Load().(protocol.ID) + return p } func (s *stream) Stat() inet.Stat { @@ -78,7 +81,7 @@ func (s *stream) Stat() inet.Stat { } func (s *stream) SetProtocol(proto protocol.ID) { - s.protocol = proto + s.protocol.Store(proto) } func (s *stream) Close() error {