diff --git a/rules/logic/logic.go b/rules/logic/logic.go index 8c79cab537..6e67285268 100644 --- a/rules/logic/logic.go +++ b/rules/logic/logic.go @@ -4,6 +4,7 @@ import ( "fmt" "regexp" "strings" + "sync" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/rules/common" @@ -13,19 +14,19 @@ import ( type Logic struct { *common.Base - payload string - adapter string - ruleType C.RuleType - rules []C.Rule - subRules map[string][]C.Rule - needIP bool - needProcess bool + payload string + adapter string + ruleType C.RuleType + rules []C.Rule + subRules map[string][]C.Rule + + payloadOnce sync.Once } type ParseRuleFunc func(tp, payload, target string, params []string, subRules map[string][]C.Rule) (C.Rule, error) func NewSubRule(payload, adapter string, subRules map[string][]C.Rule, parseRule ParseRuleFunc) (*Logic, error) { - logic := &Logic{Base: &common.Base{}, payload: payload, adapter: adapter, ruleType: C.SubRules} + logic := &Logic{Base: &common.Base{}, payload: payload, adapter: adapter, ruleType: C.SubRules, subRules: subRules} err := logic.parsePayload(fmt.Sprintf("(%s)", payload), parseRule) if err != nil { return nil, err @@ -34,15 +35,6 @@ func NewSubRule(payload, adapter string, subRules map[string][]C.Rule, parseRule if len(logic.rules) != 1 { return nil, fmt.Errorf("Sub-Rule rule must contain one rule") } - for _, rule := range subRules[adapter] { - if rule.ShouldResolveIP() { - logic.needIP = true - } - if rule.ShouldFindProcess() { - logic.needProcess = true - } - } - logic.subRules = subRules return logic, nil } @@ -56,9 +48,6 @@ func NewNOT(payload string, adapter string, parseRule ParseRuleFunc) (*Logic, er if len(logic.rules) != 1 { return nil, fmt.Errorf("not rule must contain one rule") } - logic.needIP = logic.rules[0].ShouldResolveIP() - logic.needProcess = logic.rules[0].ShouldFindProcess() - logic.payload = fmt.Sprintf("(!(%s,%s))", logic.rules[0].RuleType(), logic.rules[0].Payload()) return logic, nil } @@ -68,40 +57,15 @@ func NewOR(payload string, adapter string, parseRule ParseRuleFunc) (*Logic, err if err != nil { return nil, err } - - payloads := make([]string, 0, len(logic.rules)) - for _, rule := range logic.rules { - payloads = append(payloads, fmt.Sprintf("(%s,%s)", rule.RuleType().String(), rule.Payload())) - if rule.ShouldResolveIP() { - logic.needIP = true - } - if rule.ShouldFindProcess() { - logic.needProcess = true - } - } - logic.payload = fmt.Sprintf("(%s)", strings.Join(payloads, " || ")) - return logic, nil } + func NewAND(payload string, adapter string, parseRule ParseRuleFunc) (*Logic, error) { logic := &Logic{Base: &common.Base{}, payload: payload, adapter: adapter, ruleType: C.AND} err := logic.parsePayload(payload, parseRule) if err != nil { return nil, err } - - payloads := make([]string, 0, len(logic.rules)) - for _, rule := range logic.rules { - payloads = append(payloads, fmt.Sprintf("(%s,%s)", rule.RuleType().String(), rule.Payload())) - if rule.ShouldResolveIP() { - logic.needIP = true - } - if rule.ShouldFindProcess() { - logic.needProcess = true - } - } - logic.payload = fmt.Sprintf("(%s)", strings.Join(payloads, " && ")) - return logic, nil } @@ -218,13 +182,6 @@ func (logic *Logic) parsePayload(payload string, parseRule ParseRuleFunc) error return err } - if rule.ShouldResolveIP() { - logic.needIP = true - } - if rule.ShouldFindProcess() { - logic.needProcess = true - } - rules = append(rules, rule) } @@ -279,9 +236,9 @@ func (logic *Logic) Match(metadata *C.Metadata) (bool, string) { } } return true, logic.adapter + default: + return false, "" } - - return false, "" } func (logic *Logic) Adapter() string { @@ -289,15 +246,58 @@ func (logic *Logic) Adapter() string { } func (logic *Logic) Payload() string { + logic.payloadOnce.Do(func() { // a little bit expensive, so only computed once + switch logic.ruleType { + case C.NOT: + logic.payload = fmt.Sprintf("(!(%s,%s))", logic.rules[0].RuleType(), logic.rules[0].Payload()) + case C.OR: + payloads := make([]string, 0, len(logic.rules)) + for _, rule := range logic.rules { + payloads = append(payloads, fmt.Sprintf("(%s,%s)", rule.RuleType().String(), rule.Payload())) + } + logic.payload = fmt.Sprintf("(%s)", strings.Join(payloads, " || ")) + case C.AND: + payloads := make([]string, 0, len(logic.rules)) + for _, rule := range logic.rules { + payloads = append(payloads, fmt.Sprintf("(%s,%s)", rule.RuleType().String(), rule.Payload())) + } + logic.payload = fmt.Sprintf("(%s)", strings.Join(payloads, " && ")) + default: + } + }) return logic.payload } func (logic *Logic) ShouldResolveIP() bool { - return logic.needIP + if logic.ruleType == C.SubRules { + for _, rule := range logic.subRules[logic.adapter] { + if rule.ShouldResolveIP() { + return true + } + } + } + for _, rule := range logic.rules { + if rule.ShouldResolveIP() { + return true + } + } + return false } func (logic *Logic) ShouldFindProcess() bool { - return logic.needProcess + if logic.ruleType == C.SubRules { + for _, rule := range logic.subRules[logic.adapter] { + if rule.ShouldFindProcess() { + return true + } + } + } + for _, rule := range logic.rules { + if rule.ShouldFindProcess() { + return true + } + } + return false } func (logic *Logic) ProviderNames() (names []string) {