Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(x/meg): Support capturing components #269

Merged
merged 6 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions meg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ func TestMatchAndCaptureMultiaddr(t *testing.T) {
meg.Val(P_IP4),
meg.Val(P_IP6),
),
meg.CaptureVal(P_UDP, &udpPort),
meg.CaptureStringVal(P_UDP, &udpPort),
meg.Val(P_QUIC_V1),
meg.Val(P_WEBTRANSPORT),
meg.CaptureZeroOrMore(P_CERTHASH, &certhashes),
meg.CaptureZeroOrMoreStringVals(P_CERTHASH, &certhashes),
)
if !found {
t.Fatal("failed to match")
Expand Down
75 changes: 58 additions & 17 deletions x/meg/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func preallocateCapture() *preallocatedCapture {
),
meg.Val(multiaddr.P_UDP),
meg.Val(multiaddr.P_WEBRTC_DIRECT),
meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes),
meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &p.certHashes),
)
return p
}
Expand Down Expand Up @@ -87,19 +87,19 @@ func isWebTransportMultiaddrPrealloc() *preallocatedCapture {
var sni string
p.matcher = meg.PatternToMatcher(
meg.Or(
meg.CaptureVal(multiaddr.P_IP4, &ip4Addr),
meg.CaptureVal(multiaddr.P_IP6, &ip6Addr),
meg.CaptureVal(multiaddr.P_DNS4, &dnsName),
meg.CaptureVal(multiaddr.P_DNS6, &dnsName),
meg.CaptureVal(multiaddr.P_DNS, &dnsName),
meg.CaptureStringVal(multiaddr.P_IP4, &ip4Addr),
meg.CaptureStringVal(multiaddr.P_IP6, &ip6Addr),
meg.CaptureStringVal(multiaddr.P_DNS4, &dnsName),
meg.CaptureStringVal(multiaddr.P_DNS6, &dnsName),
meg.CaptureStringVal(multiaddr.P_DNS, &dnsName),
),
meg.CaptureVal(multiaddr.P_UDP, &udpPort),
meg.CaptureStringVal(multiaddr.P_UDP, &udpPort),
meg.Val(multiaddr.P_QUIC_V1),
meg.Optional(
meg.CaptureVal(multiaddr.P_SNI, &sni),
meg.CaptureStringVal(multiaddr.P_SNI, &sni),
),
meg.Val(multiaddr.P_WEBTRANSPORT),
meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes),
meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &p.certHashes),
)
wtPrealloc = p
return p
Expand All @@ -120,26 +120,55 @@ func IsWebTransportMultiaddr(m multiaddr.Multiaddr) (bool, int) {
var certHashesStr []string
matched, _ := m.Match(
meg.Or(
meg.CaptureVal(multiaddr.P_IP4, &ip4Addr),
meg.CaptureVal(multiaddr.P_IP6, &ip6Addr),
meg.CaptureVal(multiaddr.P_DNS4, &dnsName),
meg.CaptureVal(multiaddr.P_DNS6, &dnsName),
meg.CaptureVal(multiaddr.P_DNS, &dnsName),
meg.CaptureStringVal(multiaddr.P_IP4, &ip4Addr),
meg.CaptureStringVal(multiaddr.P_IP6, &ip6Addr),
meg.CaptureStringVal(multiaddr.P_DNS4, &dnsName),
meg.CaptureStringVal(multiaddr.P_DNS6, &dnsName),
meg.CaptureStringVal(multiaddr.P_DNS, &dnsName),
),
meg.CaptureVal(multiaddr.P_UDP, &udpPort),
meg.CaptureStringVal(multiaddr.P_UDP, &udpPort),
meg.Val(multiaddr.P_QUIC_V1),
meg.Optional(
meg.CaptureVal(multiaddr.P_SNI, &sni),
meg.CaptureStringVal(multiaddr.P_SNI, &sni),
),
meg.Val(multiaddr.P_WEBTRANSPORT),
meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &certHashesStr),
meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &certHashesStr),
)
if !matched {
return false, 0
}
return true, len(certHashesStr)
}

func IsWebTransportMultiaddrCaptureBytes(m multiaddr.Multiaddr) (bool, int) {
var dnsName []byte
var ip4Addr []byte
var ip6Addr []byte
var udpPort []byte
var sni []byte
var certHashes [][]byte
matched, _ := m.Match(
meg.Or(
meg.CaptureBytes(multiaddr.P_IP4, &ip4Addr),
meg.CaptureBytes(multiaddr.P_IP6, &ip6Addr),
meg.CaptureBytes(multiaddr.P_DNS4, &dnsName),
meg.CaptureBytes(multiaddr.P_DNS6, &dnsName),
meg.CaptureBytes(multiaddr.P_DNS, &dnsName),
),
meg.CaptureBytes(multiaddr.P_UDP, &udpPort),
meg.Val(multiaddr.P_QUIC_V1),
meg.Optional(
meg.CaptureBytes(multiaddr.P_SNI, &sni),
),
meg.Val(multiaddr.P_WEBTRANSPORT),
meg.CaptureZeroOrMoreBytes(multiaddr.P_CERTHASH, &certHashes),
)
if !matched {
return false, 0
}
return true, len(certHashes)
}

func IsWebTransportMultiaddrNoCapture(m multiaddr.Multiaddr) (bool, int) {
matched, _ := m.Match(
meg.Or(
Expand Down Expand Up @@ -355,6 +384,18 @@ func BenchmarkIsWebTransportMultiaddrNoCapture(b *testing.B) {
}
}

func BenchmarkIsWebTransportMultiaddrCaptureBytes(b *testing.B) {
addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport")

b.ResetTimer()
for i := 0; i < b.N; i++ {
isWT, count := IsWebTransportMultiaddrCaptureBytes(addr)
if !isWT || count != 0 {
b.Fatal("unexpected result")
}
}
}

func BenchmarkIsWebTransportMultiaddr(b *testing.B) {
addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport")

Expand Down
9 changes: 5 additions & 4 deletions x/meg/meg.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ type MatchState struct {
codeOrKind int
}

type captureFunc func(string) error
type captureFunc func(Matchable) error

// capture is a linked list of capture funcs with values.
type capture struct {
f captureFunc
v string
v Matchable
prev *capture
}

Expand All @@ -54,6 +54,7 @@ func (s MatchState) String() string {
type Matchable interface {
Code() int
Value() string // Used when capturing the value
Bytes() []byte
}

// Match returns whether the given Components match the Pattern defined in MatchState.
Expand Down Expand Up @@ -94,7 +95,7 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) {
if s.capture != nil {
next := &capture{
f: s.capture,
v: c.Value(),
v: c,
}
if cm == nil {
cm = next
Expand Down Expand Up @@ -123,7 +124,7 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) {
// to left, but users expect them left to right.
type captureWithVal struct {
f captureFunc
v string
v Matchable
}
reversedCaptures := make([]captureWithVal, 0, 16)
for c != nil {
Expand Down
9 changes: 7 additions & 2 deletions x/meg/meg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ func (c codeAndValue) Value() string {
return c.val
}

// Bytes implements Matchable.
func (c codeAndValue) Bytes() []byte {
return []byte(c.val)
}

var _ Matchable = codeAndValue{}

func TestSimple(t *testing.T) {
Expand Down Expand Up @@ -119,7 +124,7 @@ func TestCapture(t *testing.T) {
{
setup: func() (Matcher, func()) {
var code0str string
return PatternToMatcher(CaptureVal(0, &code0str), Val(1)), func() {
return PatternToMatcher(CaptureStringVal(0, &code0str), Val(1)), func() {
if code0str != "hello" {
panic("unexpected value")
}
Expand All @@ -130,7 +135,7 @@ func TestCapture(t *testing.T) {
{
setup: func() (Matcher, func()) {
var code0strs []string
return PatternToMatcher(CaptureOneOrMore(0, &code0strs), Val(1)), func() {
return PatternToMatcher(CaptureOneOrMoreStringVals(0, &code0strs), Val(1)), func() {
if code0strs[0] != "hello" {
panic("unexpected value")
}
Expand Down
75 changes: 60 additions & 15 deletions x/meg/sugar.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,55 @@ func Or(p ...Pattern) Pattern {

var errAlreadyCapture = errors.New("already captured")

func captureOneValueOrErr(val *string) captureFunc {
func captureOneBytesOrErr(val *[]byte) captureFunc {
if val == nil {
return nil
}
var set bool
f := func(s string) error {
f := func(s Matchable) error {
if set {
*val = nil
return errAlreadyCapture
}
*val = s.Bytes()
return nil
}
return f
}

func captureOneStringValueOrErr(val *string) captureFunc {
if val == nil {
return nil
}
var set bool
f := func(s Matchable) error {
if set {
*val = ""
return errAlreadyCapture
}
*val = s
*val = s.Value()
return nil
}
return f
}

func captureManyBytes(vals *[][]byte) captureFunc {
if vals == nil {
return nil
}
f := func(s Matchable) error {
*vals = append(*vals, s.Bytes())
return nil
}
return f
}

func captureMany(vals *[]string) captureFunc {
func captureManyStrings(vals *[]string) captureFunc {
if vals == nil {
return nil
}
f := func(s string) error {
*vals = append(*vals, s)
f := func(s Matchable) error {
*vals = append(*vals, s.Value())
return nil
}
return f
Expand All @@ -110,15 +137,19 @@ func captureValWithF(code int, f captureFunc) Pattern {
}

func Val(code int) Pattern {
return CaptureVal(code, nil)
return CaptureStringVal(code, nil)
}

func CaptureVal(code int, val *string) Pattern {
return captureValWithF(code, captureOneValueOrErr(val))
func CaptureStringVal(code int, val *string) Pattern {
return captureValWithF(code, captureOneStringValueOrErr(val))
}

func CaptureBytes(code int, val *[]byte) Pattern {
return captureValWithF(code, captureOneBytesOrErr(val))
}

func ZeroOrMore(code int) Pattern {
return CaptureZeroOrMore(code, nil)
return CaptureZeroOrMoreStringVals(code, nil)
}

func captureZeroOrMoreWithF(code int, f captureFunc) Pattern {
Expand Down Expand Up @@ -146,16 +177,30 @@ func captureZeroOrMoreWithF(code int, f captureFunc) Pattern {
}
}

func CaptureZeroOrMore(code int, vals *[]string) Pattern {
return captureZeroOrMoreWithF(code, captureMany(vals))
func CaptureZeroOrMoreBytes(code int, vals *[][]byte) Pattern {
return captureZeroOrMoreWithF(code, captureManyBytes(vals))
}

func CaptureZeroOrMoreStringVals(code int, vals *[]string) Pattern {
return captureZeroOrMoreWithF(code, captureManyStrings(vals))
}

func OneOrMore(code int) Pattern {
return CaptureOneOrMore(code, nil)
return CaptureOneOrMoreStringVals(code, nil)
}

func CaptureOneOrMoreStringVals(code int, vals *[]string) Pattern {
f := captureManyStrings(vals)
return func(states *[]MatchState, nextIdx int) int {
// First attach the zero-or-more loop.
zeroOrMoreIdx := captureZeroOrMoreWithF(code, f)(states, nextIdx)
// Then put the capture state before the loop.
return captureValWithF(code, f)(states, zeroOrMoreIdx)
}
}

func CaptureOneOrMore(code int, vals *[]string) Pattern {
f := captureMany(vals)
func CaptureOneOrMoreBytes(code int, vals *[][]byte) Pattern {
f := captureManyBytes(vals)
return func(states *[]MatchState, nextIdx int) int {
// First attach the zero-or-more loop.
zeroOrMoreIdx := captureZeroOrMoreWithF(code, f)(states, nextIdx)
Expand Down