Skip to content

Commit

Permalink
Use *Protocol in Component
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Feb 6, 2025
1 parent 55da517 commit 493f175
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 6 deletions.
8 changes: 6 additions & 2 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,16 @@ func readComponent(b []byte) (int, Component, error) {
if p.Code == 0 {
return 0, Component{}, fmt.Errorf("no protocol with code %d", code)
}
pPtr := protocolPtrByCode[code]
if pPtr == nil {
return 0, Component{}, fmt.Errorf("no protocol with code %d", code)
}

if p.Size == 0 {
c, err := validateComponent(Component{
bytes: string(b[:offset]),
valueStartIdx: offset,
protocol: p,
protocol: pPtr,
})

return offset, c, err
Expand All @@ -110,7 +114,7 @@ func readComponent(b []byte) (int, Component, error) {

c, err := validateComponent(Component{
bytes: string(b[:offset+size]),
protocol: p,
protocol: pPtr,
valueStartIdx: offset,
})

Expand Down
31 changes: 27 additions & 4 deletions component.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type Component struct {
// bytes is the raw bytes of the component. It includes the protocol code as
// varint, possibly the size of the value, and the value.
bytes string // string for immutability.
protocol Protocol
protocol *Protocol
valueStartIdx int // Index of the first byte of the Component's value in the bytes array
}

Expand Down Expand Up @@ -110,18 +110,27 @@ func (c Component) Compare(o Component) int {
}

func (c Component) Protocols() []Protocol {
return []Protocol{c.protocol}
if c.protocol == nil {
return nil
}
return []Protocol{*c.protocol}
}

func (c Component) ValueForProtocol(code int) (string, error) {
if c.protocol == nil {
return "", fmt.Errorf("component has nil protocol")
}
if c.protocol.Code != code {
return "", ErrProtocolNotFound
}
return c.Value(), nil
}

func (c Component) Protocol() Protocol {
return c.protocol
if c.protocol == nil {
return Protocol{}
}
return *c.protocol
}

func (c Component) RawValue() []byte {
Expand All @@ -138,6 +147,9 @@ func (c Component) Value() string {
}

func (c Component) valueAndErr() (string, error) {
if c.protocol == nil {
return "", fmt.Errorf("component has nil protocol")
}
if c.protocol.Transcoder == nil {
return "", nil
}
Expand All @@ -157,6 +169,9 @@ func (c Component) String() string {
// writeTo is an efficient, private function for string-formatting a multiaddr.
// Trust me, we tend to allocate a lot when doing this.
func (c Component) writeTo(b *strings.Builder) {
if c.protocol == nil {
return
}
b.WriteByte('/')
b.WriteString(c.protocol.Name)
value := c.Value()
Expand Down Expand Up @@ -188,6 +203,11 @@ func NewComponent(protocol, value string) (Component, error) {
}

func newComponent(protocol Protocol, bvalue []byte) (Component, error) {
protocolPtr := protocolPtrByCode[protocol.Code]
if protocolPtr == nil {
protocolPtr = &protocol
}

size := len(bvalue)
size += len(protocol.VCode)
if protocol.Size < 0 {
Expand All @@ -209,7 +229,7 @@ func newComponent(protocol Protocol, bvalue []byte) (Component, error) {
return validateComponent(
Component{
bytes: string(maddr),
protocol: protocol,
protocol: protocolPtr,
valueStartIdx: offset,
})
}
Expand All @@ -218,6 +238,9 @@ func newComponent(protocol Protocol, bvalue []byte) (Component, error) {
// It ensures that we will be able to call all methods on Component without
// error.
func validateComponent(c Component) (Component, error) {
if c.protocol == nil {
return Component{}, fmt.Errorf("component is missing its protocol")
}
if c.valueStartIdx > len(c.bytes) {
return Component{}, fmt.Errorf("component valueStartIdx is greater than the length of the component's bytes")
}
Expand Down
26 changes: 26 additions & 0 deletions multiaddr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1146,3 +1146,29 @@ func BenchmarkComponentValidation(b *testing.B) {
}
}
}

func FuzzComponents(f *testing.F) {
for _, v := range good {
m := StringCast(v)
for _, c := range m {
f.Add(c.Bytes())
}
}
f.Fuzz(func(t *testing.T, compBytes []byte) {
n, c, err := readComponent(compBytes)
if err != nil {
t.Skip()
}
if c.protocol == nil {
t.Fatal("component has nil protocol")
}
if c.protocol.Code == 0 {
t.Fatal("component has nil protocol code")
}
if !bytes.Equal(c.Bytes(), compBytes[:n]) {
t.Logf("component bytes: %v", c.Bytes())
t.Logf("original bytes: %v", compBytes[:n])
t.Fatal("component bytes are not equal to the original bytes")
}
})
}
4 changes: 4 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ type Protocol struct {
var protocolsByName = map[string]Protocol{}
var protocolsByCode = map[int]Protocol{}

// Keep a map of pointers so that we can reuse the same pointer for the same protocol.
var protocolPtrByCode = map[int]*Protocol{}

// Protocols is the list of multiaddr protocols supported by this module.
var Protocols = []Protocol{}

Expand All @@ -72,6 +75,7 @@ func AddProtocol(p Protocol) error {
Protocols = append(Protocols, p)
protocolsByName[p.Name] = p
protocolsByCode[p.Code] = p
protocolPtrByCode[p.Code] = &p
return nil
}

Expand Down

0 comments on commit 493f175

Please sign in to comment.