diff --git a/component.go b/component.go index 39f5d5a..d25f06a 100644 --- a/component.go +++ b/component.go @@ -19,30 +19,39 @@ type Component struct { valueStartIdx int // Index of the first byte of the Component's value in the bytes array } -func (c Component) AsMultiaddr() Multiaddr { +func (c *Component) AsMultiaddr() Multiaddr { if c.Empty() { return nil } - return []Component{c} + return []Component{*c} } -func (c Component) Encapsulate(o Multiaddr) Multiaddr { +func (c *Component) Encapsulate(o Multiaddr) Multiaddr { return c.AsMultiaddr().Encapsulate(o) } -func (c Component) Decapsulate(o Multiaddr) Multiaddr { +func (c *Component) Decapsulate(o Multiaddr) Multiaddr { return c.AsMultiaddr().Decapsulate(o) } -func (c Component) Empty() bool { +func (c *Component) Empty() bool { + if c == nil { + return true + } return len(c.bytes) == 0 } -func (c Component) Bytes() []byte { +func (c *Component) Bytes() []byte { + if c == nil { + return nil + } return []byte(c.bytes) } -func (c Component) MarshalBinary() ([]byte, error) { +func (c *Component) MarshalBinary() ([]byte, error) { + if c == nil { + return nil, errNilPtr + } return c.Bytes(), nil } @@ -58,7 +67,10 @@ func (c *Component) UnmarshalBinary(data []byte) error { return nil } -func (c Component) MarshalText() ([]byte, error) { +func (c *Component) MarshalText() ([]byte, error) { + if c == nil { + return nil, errNilPtr + } return []byte(c.String()), nil } @@ -79,7 +91,10 @@ func (c *Component) UnmarshalText(data []byte) error { return nil } -func (c Component) MarshalJSON() ([]byte, error) { +func (c *Component) MarshalJSON() ([]byte, error) { + if c == nil { + return nil, errNilPtr + } txt, err := c.MarshalText() if err != nil { return nil, err @@ -101,22 +116,40 @@ func (c *Component) UnmarshalJSON(data []byte) error { return c.UnmarshalText([]byte(v)) } -func (c Component) Equal(o Component) bool { +func (c *Component) Equal(o *Component) bool { + if c == nil || o == nil { + return c == o + } return c.bytes == o.bytes } -func (c Component) Compare(o Component) int { +func (c *Component) Compare(o *Component) int { + if c == nil && o == nil { + return 0 + } + if c == nil { + return -1 + } + if o == nil { + return 1 + } return strings.Compare(c.bytes, o.bytes) } -func (c Component) Protocols() []Protocol { +func (c *Component) Protocols() []Protocol { + if c == nil { + return nil + } if c.protocol == nil { return nil } return []Protocol{*c.protocol} } -func (c Component) ValueForProtocol(code int) (string, error) { +func (c *Component) ValueForProtocol(code int) (string, error) { + if c == nil { + return "", fmt.Errorf("component is nil") + } if c.protocol == nil { return "", fmt.Errorf("component has nil protocol") } @@ -126,18 +159,27 @@ func (c Component) ValueForProtocol(code int) (string, error) { return c.Value(), nil } -func (c Component) Protocol() Protocol { +func (c *Component) Protocol() Protocol { + if c == nil { + return Protocol{} + } if c.protocol == nil { return Protocol{} } return *c.protocol } -func (c Component) RawValue() []byte { +func (c *Component) RawValue() []byte { + if c == nil { + return nil + } return []byte(c.bytes[c.valueStartIdx:]) } -func (c Component) Value() string { +func (c *Component) Value() string { + if c == nil { + return "" + } if c.Empty() { return "" } @@ -146,7 +188,10 @@ func (c Component) Value() string { return value } -func (c Component) valueAndErr() (string, error) { +func (c *Component) valueAndErr() (string, error) { + if c == nil { + return "", errNilPtr + } if c.protocol == nil { return "", fmt.Errorf("component has nil protocol") } @@ -160,7 +205,10 @@ func (c Component) valueAndErr() (string, error) { return value, nil } -func (c Component) String() string { +func (c *Component) String() string { + if c == nil { + return "" + } var b strings.Builder c.writeTo(&b) return b.String() @@ -168,7 +216,10 @@ func (c Component) String() string { // writeTo is an efficient, private function for string-formatting a multiaddr. // Trust me, we tend to allocate a lot when doing this. -func (c Component) writeTo(b *strings.Builder) { +func (c *Component) writeTo(b *strings.Builder) { + if c == nil { + return + } if c.protocol == nil { return } diff --git a/multiaddr.go b/multiaddr.go index 86092ae..3a8a03b 100644 --- a/multiaddr.go +++ b/multiaddr.go @@ -16,6 +16,15 @@ var errNilPtr = errors.New("nil ptr") // Multiaddr is the data structure representing a Multiaddr type Multiaddr []Component +func (m Multiaddr) copy() Multiaddr { + if m == nil { + return nil + } + out := make(Multiaddr, len(m)) + copy(out, m) + return out +} + func (m Multiaddr) Empty() bool { if len(m) == 0 { return true @@ -71,7 +80,7 @@ func (m Multiaddr) Equal(m2 Multiaddr) bool { return false } for i, c := range m { - if !c.Equal(m2[i]) { + if !c.Equal(&m2[i]) { return false } } @@ -80,7 +89,7 @@ func (m Multiaddr) Equal(m2 Multiaddr) bool { func (m Multiaddr) Compare(o Multiaddr) int { for i := 0; i < len(m) && i < len(o); i++ { - if cmp := m[i].Compare(o[i]); cmp != 0 { + if cmp := m[i].Compare(&o[i]); cmp != 0 { return cmp } } @@ -177,13 +186,13 @@ func (m Multiaddr) Encapsulate(o Multiaddr) Multiaddr { return Join(m, o) } -func (m Multiaddr) EncapsulateC(c Component) Multiaddr { +func (m Multiaddr) EncapsulateC(c *Component) Multiaddr { if c.Empty() { return m } out := make([]Component, 0, len(m)+1) out = append(out, m...) - out = append(out, c) + out = append(out, *c) return out } @@ -200,7 +209,7 @@ func (m Multiaddr) Decapsulate(rightParts Multiaddr) Multiaddr { break } - foundMatch = rightC.Equal(leftParts[i+j]) + foundMatch = rightC.Equal(&leftParts[i+j]) if !foundMatch { break } diff --git a/multiaddr_test.go b/multiaddr_test.go index 95b57f4..28b068f 100644 --- a/multiaddr_test.go +++ b/multiaddr_test.go @@ -28,6 +28,12 @@ func TestReturnsNilOnEmpty(t *testing.T) { a, _ = SplitLast(a) require.Nil(t, a) + a, c := SplitLast(nil) + require.Zero(t, len(a.Protocols())) + require.Nil(t, a) + require.Nil(t, c) + require.True(t, c.Empty()) + // Test that empty multiaddr from various operations returns nil a = StringCast("/ip4/1.2.3.4/tcp/1234") _, a = SplitFirst(a) @@ -36,6 +42,11 @@ func TestReturnsNilOnEmpty(t *testing.T) { _, a = SplitFirst(a) require.Nil(t, a) + c, a = SplitFirst(nil) + require.Nil(t, a) + require.Nil(t, c) + require.True(t, c.Empty()) + a = StringCast("/ip4/1.2.3.4/tcp/1234") a = a.Decapsulate(a) require.Nil(t, a) @@ -400,7 +411,7 @@ func TestBytesSplitAndJoin(t *testing.T) { for i, a := range split { if a.String() != res[i] { - t.Errorf("split component failed: %s != %s", a, res[i]) + t.Errorf("split component failed: %s != %s", &a, res[i]) } } @@ -411,7 +422,7 @@ func TestBytesSplitAndJoin(t *testing.T) { for i, a := range split { if a.String() != res[i] { - t.Errorf("split component failed: %s != %s", a, res[i]) + t.Errorf("split component failed: %s != %s", &a, res[i]) } } } @@ -863,7 +874,7 @@ func TestComponentBinaryMarshaler(t *testing.T) { if err = comp2.UnmarshalBinary(b); err != nil { t.Fatal(err) } - if !comp.Equal(comp2) { + if !comp.Equal(&comp2) { t.Error("expected equal components in circular marshaling test") } } @@ -882,7 +893,7 @@ func TestComponentTextMarshaler(t *testing.T) { if err = comp2.UnmarshalText(b); err != nil { t.Fatal(err) } - if !comp.Equal(comp2) { + if !comp.Equal(&comp2) { t.Error("expected equal components in circular marshaling test") } } @@ -901,7 +912,7 @@ func TestComponentJSONMarshaler(t *testing.T) { if err = comp2.UnmarshalJSON(b); err != nil { t.Fatal(err) } - if !comp.Equal(comp2) { + if !comp.Equal(&comp2) { t.Error("expected equal components in circular marshaling test") } } @@ -914,6 +925,9 @@ func TestUseNil(t *testing.T) { _ = f() var foo Multiaddr = nil + _, right := SplitFirst(foo) + right.Protocols() + foo.Protocols() foo.Bytes() foo.Compare(nil) foo.Decapsulate(nil) @@ -930,6 +944,32 @@ func TestUseNil(t *testing.T) { _, _ = foo.ValueForProtocol(0) } +func TestUseNilComponent(t *testing.T) { + var foo *Component + foo.AsMultiaddr() + foo.Encapsulate(nil) + foo.Decapsulate(nil) + foo.Empty() + foo.Bytes() + foo.MarshalBinary() + foo.MarshalJSON() + foo.MarshalText() + foo.UnmarshalBinary(nil) + foo.UnmarshalJSON(nil) + foo.UnmarshalText(nil) + foo.Equal(nil) + foo.Compare(nil) + foo.Protocols() + foo.ValueForProtocol(0) + foo.Protocol() + foo.RawValue() + foo.Value() + _ = foo.String() + + var m Multiaddr = nil + m.EncapsulateC(foo) +} + func TestFilterAddrs(t *testing.T) { bad := []Multiaddr{ newMultiaddr(t, "/ip6/fe80::1/tcp/1234"), diff --git a/util.go b/util.go index 71fb9bf..038c143 100644 --- a/util.go +++ b/util.go @@ -64,26 +64,30 @@ func StringCast(s string) Multiaddr { } // SplitFirst returns the first component and the rest of the multiaddr. -func SplitFirst(m Multiaddr) (Component, Multiaddr) { +func SplitFirst(m Multiaddr) (*Component, Multiaddr) { if m.Empty() { - return Component{}, nil + return nil, nil } if len(m) == 1 { - return m[0], nil + return &m[0], nil } - return m[0], m[1:] + // defensive copy. Users can avoid by doing the split themselves. + copyC := m[0] + return ©C, m[1:].copy() } // SplitLast returns the rest of the multiaddr and the last component. -func SplitLast(m Multiaddr) (Multiaddr, Component) { +func SplitLast(m Multiaddr) (Multiaddr, *Component) { if m.Empty() { - return nil, Component{} + return nil, nil } if len(m) == 1 { // We want to explicitly return a nil slice if the prefix is now empty. - return nil, m[0] + return nil, &m[0] } - return m[:len(m)-1], m[len(m)-1] + // defensive copy. Users can avoid by doing the split themselves. + copyC := m[len(m)-1] + return m[:len(m)-1].copy(), ©C } // SplitFunc splits the multiaddr when the callback first returns true. The @@ -108,7 +112,8 @@ func SplitFunc(m Multiaddr, cb func(Component) bool) (Multiaddr, Multiaddr) { if post.Empty() { post = nil } - return pre, post + // defensive copy. Users can avoid by doing the split themselves. + return pre.copy(), post.copy() } // ForEach walks over the multiaddr, component by component. diff --git a/util_test.go b/util_test.go index 3494486..1409122 100644 --- a/util_test.go +++ b/util_test.go @@ -66,15 +66,15 @@ func TestSplitFirstLast(t *testing.T) { } ci, m := SplitFirst(c.AsMultiaddr()) - if !ci.Equal(c) || m != nil { + if !ci.Equal(&c) || m != nil { t.Error("split first on component failed") } m, ci = SplitLast(c.AsMultiaddr()) - if !ci.Equal(c) || m != nil { + if !ci.Equal(&c) || m != nil { t.Error("split last on component failed") } cis := Split(c.AsMultiaddr()) - if len(cis) != 1 || !cis[0].Equal(c) { + if len(cis) != 1 || !cis[0].Equal(&c) { t.Error("split on component failed") } m1, m2 := SplitFunc(c.AsMultiaddr(), func(c Component) bool { @@ -96,7 +96,7 @@ func TestSplitFirstLast(t *testing.T) { t.Error("expected exactly one component") } i++ - if !ci.Equal(c) { + if !ci.Equal(&c) { t.Error("foreach on component failed") } return true