Skip to content

Commit

Permalink
domain: ensure every Matcher implements are fqdn-and-case-insensitive
Browse files Browse the repository at this point in the history
except for RegexMatcher
  • Loading branch information
urlesistiana committed Oct 26, 2022
1 parent 068a688 commit f420fb6
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 227 deletions.
5 changes: 5 additions & 0 deletions pkg/matcher/domain/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

package domain

// "fqdn-insensitive" means the domain in Add() and Match() call
// is fqdn-insensitive. "google.com" and "google.com." will get
// the same outcome.
// The logic for case-insensitive is the same as above.

type Matcher[T any] interface {
// Match matches the domain s.
// s could be a fqdn or not, and should be case-insensitive.
Expand Down
194 changes: 33 additions & 161 deletions pkg/matcher/domain/matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/IrineSistiana/mosdns/v4/pkg/utils"
"regexp"
"strings"
"sync"
)

var _ WriteableMatcher[any] = (*MixMatcher[any])(nil)
Expand All @@ -34,23 +33,24 @@ var _ WriteableMatcher[any] = (*KeywordMatcher[any])(nil)
var _ WriteableMatcher[any] = (*RegexMatcher[any])(nil)

type SubDomainMatcher[T any] struct {
root *LabelNode[T]
root *labelNode[T]
}

func NewSubDomainMatcher[T any]() *SubDomainMatcher[T] {
return &SubDomainMatcher[T]{root: new(LabelNode[T])}
return &SubDomainMatcher[T]{root: new(labelNode[T])}
}

func (m *SubDomainMatcher[T]) Match(s string) (T, bool) {
s = NormalizeDomain(s)
ds := NewReverseDomainScanner(s)
currentNode := m.root
ds := NewUnifiedDomainScanner(s)
var v T
var ok bool
for ds.Scan() {
label, _ := ds.PrevLabel()
if nextNode := currentNode.GetChild(label); nextNode != nil {
if nextNode.HasValue() {
v, ok = nextNode.GetValue()
label := ds.NextLabel()
if nextNode := currentNode.getChild(label); nextNode != nil {
if nextNode.hasValue() {
v, ok = nextNode.getValue()
}
currentNode = nextNode
} else {
Expand All @@ -61,71 +61,27 @@ func (m *SubDomainMatcher[T]) Match(s string) (T, bool) {
}

func (m *SubDomainMatcher[T]) Len() int {
return m.root.Len()
return m.root.len()
}

func (m *SubDomainMatcher[T]) Add(s string, v T) error {
s = NormalizeDomain(s)
ds := NewReverseDomainScanner(s)
currentNode := m.root
ds := NewUnifiedDomainScanner(s)
for ds.Scan() {
label, _ := ds.PrevLabel()
if child := currentNode.GetChild(label); child != nil {
label := ds.NextLabel()
if child := currentNode.getChild(label); child != nil {
currentNode = child
} else {
currentNode = currentNode.NewChild(label)
currentNode = currentNode.newChild(label)
}
}
currentNode.StoreValue(v)
currentNode.storeValue(v)
return nil
}

// LabelNode can store dns labels.
type LabelNode[T any] struct {
children map[string]*LabelNode[T] // lazy init

v T
hasV bool
}

func (n *LabelNode[T]) StoreValue(v T) {
n.v = v
n.hasV = true
}

func (n *LabelNode[T]) GetValue() (T, bool) {
return n.v, n.hasV
}

func (n *LabelNode[T]) HasValue() bool {
return n.hasV
}

func (n *LabelNode[T]) NewChild(key string) *LabelNode[T] {
if n.children == nil {
n.children = make(map[string]*LabelNode[T])
}
node := new(LabelNode[T])
n.children[key] = node
return node
}

func (n *LabelNode[T]) GetChild(key string) *LabelNode[T] {
return n.children[key]
}

func (n *LabelNode[T]) Len() int {
l := 0
for _, node := range n.children {
l += node.Len()
if node.HasValue() {
l++
}
}
return l
}

type FullMatcher[T any] struct {
m map[string]T // string must be a fqdn.
m map[string]T // string in is map must be a normalized domain (See NormalizeDomain).
}

func NewFullMatcher[T any]() *FullMatcher[T] {
Expand All @@ -134,13 +90,16 @@ func NewFullMatcher[T any]() *FullMatcher[T] {
}
}

// Add adds domain s to this matcher, s can be a fqdn or not.
func (m *FullMatcher[T]) Add(s string, v T) error {
m.m[UnifyDomain(s)] = v
s = NormalizeDomain(s)
m.m[s] = v
return nil
}

func (m *FullMatcher[T]) Match(s string) (v T, ok bool) {
v, ok = m.m[UnifyDomain(s)]
s = NormalizeDomain(s)
v, ok = m.m[s]
return
}

Expand All @@ -159,14 +118,15 @@ func NewKeywordMatcher[T any]() *KeywordMatcher[T] {
}

func (m *KeywordMatcher[T]) Add(keyword string, v T) error {
keyword = NormalizeDomain(keyword) // fqdn-insensitive and case-insensitive
m.kws[keyword] = v
return nil
}

func (m *KeywordMatcher[T]) Match(s string) (v T, ok bool) {
domain := UnifyDomain(s)
s = NormalizeDomain(s)
for k, v := range m.kws {
if strings.Contains(domain, k) {
if strings.Contains(s, k) {
return v, true
}
}
Expand All @@ -177,9 +137,10 @@ func (m *KeywordMatcher[T]) Len() int {
return len(m.kws)
}

// RegexMatcher contains regexp rules.
// Note: the regexp rule is expect to match a lower-case non fqdn.
type RegexMatcher[T any] struct {
regs map[string]*regElem[T]
cache *regCache[T]
regs map[string]*regElem[T]
}

type regElem[T any] struct {
Expand All @@ -191,10 +152,6 @@ func NewRegexMatcher[T any]() *RegexMatcher[T] {
return &RegexMatcher[T]{regs: make(map[string]*regElem[T])}
}

func NewRegexMatcherWithCache[T any](cap int) *RegexMatcher[T] {
return &RegexMatcher[T]{regs: make(map[string]*regElem[T]), cache: newRegCache[T](cap)}
}

func (m *RegexMatcher[T]) Add(expr string, v T) error {
e := m.regs[expr]
if e == nil {
Expand All @@ -209,37 +166,16 @@ func (m *RegexMatcher[T]) Add(expr string, v T) error {
} else {
e.v = v
}

return nil
}

func (m *RegexMatcher[T]) Match(s string) (v T, ok bool) {
return m.match(TrimDot(s))
}

func (m *RegexMatcher[T]) match(domain string) (v T, ok bool) {
if m.cache != nil {
if e, ok := m.cache.lookup(domain); ok { // cache hit
if e != nil {
return e.v, true // matched
}
var zeroT T
return zeroT, false // not matched
}
}

s = NormalizeDomain(s)
for _, e := range m.regs {
if e.reg.MatchString(domain) {
if m.cache != nil {
m.cache.cache(domain, e)
}
if e.reg.MatchString(s) {
return e.v, true
}
}

if m.cache != nil { // cache the string
m.cache.cache(domain, nil)
}
var zeroT T
return zeroT, false
}
Expand All @@ -248,59 +184,6 @@ func (m *RegexMatcher[T]) Len() int {
return len(m.regs)
}

func (m *RegexMatcher[T]) ResetCache() {
if m.cache != nil {
m.cache.reset()
}
}

type regCache[T any] struct {
cap int
sync.RWMutex
m map[string]*regElem[T]
}

func newRegCache[T any](cap int) *regCache[T] {
return &regCache[T]{
cap: cap,
m: make(map[string]*regElem[T], cap),
}
}

func (c *regCache[T]) cache(s string, res *regElem[T]) {
c.Lock()
defer c.Unlock()

c.tryEvictUnderLock()
c.m[s] = res
}

func (c *regCache[T]) lookup(s string) (res *regElem[T], ok bool) {
c.RLock()
defer c.RUnlock()
res, ok = c.m[s]
return
}

func (c *regCache[T]) reset() {
c.Lock()
defer c.Unlock()
c.m = make(map[string]*regElem[T], c.cap)
}

func (c *regCache[T]) tryEvictUnderLock() {
if len(c.m) >= c.cap {
i := c.cap / 8
for key := range c.m { // evict 1/8 cache
delete(c.m, key)
i--
if i < 0 {
return
}
}
}
}

const (
MatcherFull = "full"
MatcherDomain = "domain"
Expand All @@ -311,10 +194,10 @@ const (
type MixMatcher[T any] struct {
defaultMatcher string

full WriteableMatcher[T]
domain WriteableMatcher[T]
regex WriteableMatcher[T]
keyword WriteableMatcher[T]
full *FullMatcher[T]
domain *SubDomainMatcher[T]
regex *RegexMatcher[T]
keyword *KeywordMatcher[T]
}

func NewMixMatcher[T any]() *MixMatcher[T] {
Expand Down Expand Up @@ -388,14 +271,3 @@ func (m *MixMatcher[T]) splitTypeAndPattern(s string) (string, string) {
}
return typ, pattern
}

// TrimDot trims the suffix '.'.
func TrimDot(s string) string {
return strings.TrimSuffix(s, ".")
}

// UnifyDomain unifies domain strings.
// It removes the suffix "." and make sure the domain is in lower case.
func UnifyDomain(s string) string {
return strings.ToLower(TrimDot(s))
}
32 changes: 12 additions & 20 deletions pkg/matcher/domain/matcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ package domain

import (
"reflect"
"strconv"
"testing"
)

Expand Down Expand Up @@ -87,6 +86,10 @@ func TestDomainMatcher(t *testing.T) {
assert("b.sub", true, 1)
assert("a.sub", true, 2)
assert("a.a.sub", true, 2)

// test case-insensitive
add("UPpER", 1)
assert("LowER.Upper", true, 1)
}

func assertInt(t testing.TB, want, got int) {
Expand Down Expand Up @@ -119,6 +122,10 @@ func Test_FullMatcher(t *testing.T) {
assert("append", true, nil)

assertInt(t, m.Len(), 3)

// test case-insensitive
add("UPpER", 1)
assert("Upper", true, 1)
}

func Test_KeywordMatcher(t *testing.T) {
Expand Down Expand Up @@ -146,6 +153,10 @@ func Test_KeywordMatcher(t *testing.T) {
assert("append", true, nil)

assertInt(t, m.Len(), 3)

// test case-insensitive
add("UPpER", 1)
assert("L.Upper.U", true, 1)
}

func Test_RegexMatcher(t *testing.T) {
Expand Down Expand Up @@ -183,22 +194,3 @@ func Test_RegexMatcher(t *testing.T) {
expr = "*"
add(expr, nil, true)
}

func Test_regCache(t *testing.T) {
c := newRegCache[any](128)
for i := 0; i < 1024; i++ {
s := strconv.Itoa(i)
res := new(regElem[any])
c.cache(s, res)
if len(c.m) > 128 {
t.Fatal("cache overflowed")
}
got, ok := c.lookup(s)
if !ok {
t.Fatal("cache lookup failed")
}
if got != res {
t.Fatal("cache item mismatched")
}
}
}
Loading

0 comments on commit f420fb6

Please sign in to comment.