Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(x/meg): Support capturing components #269

Merged
merged 6 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions meg_capturers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package multiaddr

import (
"encoding/binary"
"fmt"
"net/netip"

"github.com/multiformats/go-multiaddr/x/meg"
)

func CaptureAddrPort(network *string, ipPort *netip.AddrPort) (capturePattern meg.Pattern) {
var ipOnly netip.Addr
capturePort := func(s meg.Matchable) error {
switch s.Code() {
case P_UDP:
*network = "udp"
case P_TCP:
*network = "tcp"
default:
return fmt.Errorf("invalid network: %s", s.Value())
}

port := binary.BigEndian.Uint16(s.RawValue())
*ipPort = netip.AddrPortFrom(ipOnly, port)
return nil
}

pattern := meg.Cat(
meg.Or(
meg.CaptureWithF(P_IP4, func(s meg.Matchable) error {
var ok bool
ipOnly, ok = netip.AddrFromSlice(s.RawValue())
if !ok {
return fmt.Errorf("invalid ip4 address: %s", s.Value())
}
return nil
}),
meg.CaptureWithF(P_IP6, func(s meg.Matchable) error {
var ok bool
ipOnly, ok = netip.AddrFromSlice(s.RawValue())
if !ok {
return fmt.Errorf("invalid ip6 address: %s", s.Value())
}
return nil
}),
),
meg.Or(
meg.CaptureWithF(P_UDP, capturePort),
meg.CaptureWithF(P_TCP, capturePort),
),
)

return pattern
}
31 changes: 29 additions & 2 deletions meg_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package multiaddr

import (
"net/netip"
"testing"

"github.com/multiformats/go-multiaddr/x/meg"
Expand All @@ -16,10 +17,10 @@ func TestMatchAndCaptureMultiaddr(t *testing.T) {
meg.Val(P_IP4),
meg.Val(P_IP6),
),
meg.CaptureVal(P_UDP, &udpPort),
meg.CaptureStringVal(P_UDP, &udpPort),
meg.Val(P_QUIC_V1),
meg.Val(P_WEBTRANSPORT),
meg.CaptureZeroOrMore(P_CERTHASH, &certhashes),
meg.CaptureZeroOrMoreStringVals(P_CERTHASH, &certhashes),
)
if !found {
t.Fatal("failed to match")
Expand All @@ -43,3 +44,29 @@ func TestMatchAndCaptureMultiaddr(t *testing.T) {
}
}
}

func TestCaptureAddrPort(t *testing.T) {
m := StringCast("/ip4/1.2.3.4/udp/8231/quic-v1/webtransport")
var addrPort netip.AddrPort
var network string

found, err := m.Match(
CaptureAddrPort(&network, &addrPort),
meg.ZeroOrMore(meg.Any),
)
if err != nil {
t.Fatal("error", err)
}
if !found {
t.Fatal("failed to match")
}
if !addrPort.IsValid() {
t.Fatal("failed to capture addrPort")
}
if network != "udp" {
t.Fatal("unexpected network", network)
}
if addrPort.String() != "1.2.3.4:8231" {
t.Fatal("unexpected ipPort", addrPort)
}
}
75 changes: 58 additions & 17 deletions x/meg/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func preallocateCapture() *preallocatedCapture {
),
meg.Val(multiaddr.P_UDP),
meg.Val(multiaddr.P_WEBRTC_DIRECT),
meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes),
meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &p.certHashes),
)
return p
}
Expand Down Expand Up @@ -87,19 +87,19 @@ func isWebTransportMultiaddrPrealloc() *preallocatedCapture {
var sni string
p.matcher = meg.PatternToMatcher(
meg.Or(
meg.CaptureVal(multiaddr.P_IP4, &ip4Addr),
meg.CaptureVal(multiaddr.P_IP6, &ip6Addr),
meg.CaptureVal(multiaddr.P_DNS4, &dnsName),
meg.CaptureVal(multiaddr.P_DNS6, &dnsName),
meg.CaptureVal(multiaddr.P_DNS, &dnsName),
meg.CaptureStringVal(multiaddr.P_IP4, &ip4Addr),
meg.CaptureStringVal(multiaddr.P_IP6, &ip6Addr),
meg.CaptureStringVal(multiaddr.P_DNS4, &dnsName),
meg.CaptureStringVal(multiaddr.P_DNS6, &dnsName),
meg.CaptureStringVal(multiaddr.P_DNS, &dnsName),
),
meg.CaptureVal(multiaddr.P_UDP, &udpPort),
meg.CaptureStringVal(multiaddr.P_UDP, &udpPort),
meg.Val(multiaddr.P_QUIC_V1),
meg.Optional(
meg.CaptureVal(multiaddr.P_SNI, &sni),
meg.CaptureStringVal(multiaddr.P_SNI, &sni),
),
meg.Val(multiaddr.P_WEBTRANSPORT),
meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes),
meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &p.certHashes),
)
wtPrealloc = p
return p
Expand All @@ -120,26 +120,55 @@ func IsWebTransportMultiaddr(m multiaddr.Multiaddr) (bool, int) {
var certHashesStr []string
matched, _ := m.Match(
meg.Or(
meg.CaptureVal(multiaddr.P_IP4, &ip4Addr),
meg.CaptureVal(multiaddr.P_IP6, &ip6Addr),
meg.CaptureVal(multiaddr.P_DNS4, &dnsName),
meg.CaptureVal(multiaddr.P_DNS6, &dnsName),
meg.CaptureVal(multiaddr.P_DNS, &dnsName),
meg.CaptureStringVal(multiaddr.P_IP4, &ip4Addr),
meg.CaptureStringVal(multiaddr.P_IP6, &ip6Addr),
meg.CaptureStringVal(multiaddr.P_DNS4, &dnsName),
meg.CaptureStringVal(multiaddr.P_DNS6, &dnsName),
meg.CaptureStringVal(multiaddr.P_DNS, &dnsName),
),
meg.CaptureVal(multiaddr.P_UDP, &udpPort),
meg.CaptureStringVal(multiaddr.P_UDP, &udpPort),
meg.Val(multiaddr.P_QUIC_V1),
meg.Optional(
meg.CaptureVal(multiaddr.P_SNI, &sni),
meg.CaptureStringVal(multiaddr.P_SNI, &sni),
),
meg.Val(multiaddr.P_WEBTRANSPORT),
meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &certHashesStr),
meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &certHashesStr),
)
if !matched {
return false, 0
}
return true, len(certHashesStr)
}

func IsWebTransportMultiaddrCaptureBytes(m multiaddr.Multiaddr) (bool, int) {
var dnsName []byte
var ip4Addr []byte
var ip6Addr []byte
var udpPort []byte
var sni []byte
var certHashes [][]byte
matched, _ := m.Match(
meg.Or(
meg.CaptureBytes(multiaddr.P_IP4, &ip4Addr),
meg.CaptureBytes(multiaddr.P_IP6, &ip6Addr),
meg.CaptureBytes(multiaddr.P_DNS4, &dnsName),
meg.CaptureBytes(multiaddr.P_DNS6, &dnsName),
meg.CaptureBytes(multiaddr.P_DNS, &dnsName),
),
meg.CaptureBytes(multiaddr.P_UDP, &udpPort),
meg.Val(multiaddr.P_QUIC_V1),
meg.Optional(
meg.CaptureBytes(multiaddr.P_SNI, &sni),
),
meg.Val(multiaddr.P_WEBTRANSPORT),
meg.CaptureZeroOrMoreBytes(multiaddr.P_CERTHASH, &certHashes),
)
if !matched {
return false, 0
}
return true, len(certHashes)
}

func IsWebTransportMultiaddrNoCapture(m multiaddr.Multiaddr) (bool, int) {
matched, _ := m.Match(
meg.Or(
Expand Down Expand Up @@ -355,6 +384,18 @@ func BenchmarkIsWebTransportMultiaddrNoCapture(b *testing.B) {
}
}

func BenchmarkIsWebTransportMultiaddrCaptureBytes(b *testing.B) {
addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport")

b.ResetTimer()
for i := 0; i < b.N; i++ {
isWT, count := IsWebTransportMultiaddrCaptureBytes(addr)
if !isWT || count != 0 {
b.Fatal("unexpected result")
}
}
}

func BenchmarkIsWebTransportMultiaddr(b *testing.B) {
addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport")

Expand Down
40 changes: 26 additions & 14 deletions x/meg/meg.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,30 @@ import (
type stateKind = int

const (
done stateKind = (iota * -1) - 1
// split anything else that is negative
matchAny stateKind = (iota * -1) - 1
// done MUST be the last stateKind in this list. We use it to determine if a
// state is a split index.
done
// Anything that is less than done is a split index
)

// MatchState is the Thompson NFA for a regular expression.
type MatchState struct {
capture captureFunc
capture CaptureFunc
// next is is the index of the next state. in the MatchState array.
next int
// If codeOrKind is negative, it is a kind.
// If it is negative, but not a `done`, then it is the index to the next split.
// If it is negative, and less than `done`, then it is the index to the next split.
// This is done to keep the `MatchState` struct small and cache friendly.
codeOrKind int
}

type captureFunc func(string) error
type CaptureFunc func(Matchable) error

// capture is a linked list of capture funcs with values.
type capture struct {
f captureFunc
v string
f CaptureFunc
v Matchable
prev *capture
}

Expand All @@ -53,7 +56,14 @@ func (s MatchState) String() string {

type Matchable interface {
Code() int
Value() string // Used when capturing the value
// Value() returns the string representation of the matchable.
Value() string
// RawValue() returns the byte representation of the Value
RawValue() []byte
// Bytes() returns the underlying bytes of the matchable. For multiaddr
// Components, this includes the protocol code and possibly the varint
// encoded size.
Bytes() []byte
}

// Match returns whether the given Components match the Pattern defined in MatchState.
Expand Down Expand Up @@ -89,12 +99,12 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) {
}
for i, stateIndex := range currentStates.states {
s := states[stateIndex]
if s.codeOrKind >= 0 && s.codeOrKind == c.Code() {
if s.codeOrKind == matchAny || (s.codeOrKind >= 0 && s.codeOrKind == c.Code()) {
cm := currentStates.captures[i]
if s.capture != nil {
next := &capture{
f: s.capture,
v: c.Value(),
v: c,
}
if cm == nil {
cm = next
Expand Down Expand Up @@ -122,8 +132,8 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) {
// Flip the order of the captures because we see captures from right
// to left, but users expect them left to right.
type captureWithVal struct {
f captureFunc
v string
f CaptureFunc
v Matchable
}
reversedCaptures := make([]captureWithVal, 0, 16)
for c != nil {
Expand Down Expand Up @@ -190,10 +200,12 @@ func appendState(arr statesAndCaptures, states []MatchState, stateIndex int, c *
return arr
}

const splitIdxOffset = (-1 * (done - 1))

func storeSplitIdx(codeOrKind int) int {
return (codeOrKind + 2) * -1
return (codeOrKind + splitIdxOffset) * -1
}

func restoreSplitIdx(splitIdx int) int {
return (splitIdx * -1) - 2
return (splitIdx * -1) - splitIdxOffset
}
30 changes: 28 additions & 2 deletions x/meg/meg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ func (c codeAndValue) Value() string {
return c.val
}

// Bytes implements Matchable.
func (c codeAndValue) Bytes() []byte {
return []byte(c.val)
}

// RawValue implements Matchable.
func (c codeAndValue) RawValue() []byte {
return []byte(c.val)
}

var _ Matchable = codeAndValue{}

func TestSimple(t *testing.T) {
Expand All @@ -33,6 +43,22 @@ func TestSimple(t *testing.T) {
}
testCases :=
[]testCase{
{
pattern: PatternToMatcher(Val(Any), Val(1)),
shouldMatch: [][]int{
{0, 1},
{1, 1},
{2, 1},
{3, 1},
{4, 1},
},
shouldNotMatch: [][]int{
{0},
{0, 0},
{0, 1, 0},
},
skipQuickCheck: true,
},
{
pattern: PatternToMatcher(Val(0), Val(1)),
shouldMatch: [][]int{{0, 1}},
Expand Down Expand Up @@ -119,7 +145,7 @@ func TestCapture(t *testing.T) {
{
setup: func() (Matcher, func()) {
var code0str string
return PatternToMatcher(CaptureVal(0, &code0str), Val(1)), func() {
return PatternToMatcher(CaptureStringVal(0, &code0str), Val(1)), func() {
if code0str != "hello" {
panic("unexpected value")
}
Expand All @@ -130,7 +156,7 @@ func TestCapture(t *testing.T) {
{
setup: func() (Matcher, func()) {
var code0strs []string
return PatternToMatcher(CaptureOneOrMore(0, &code0strs), Val(1)), func() {
return PatternToMatcher(CaptureOneOrMoreStringVals(0, &code0strs), Val(1)), func() {
if code0strs[0] != "hello" {
panic("unexpected value")
}
Expand Down
Loading