From 4eb13a73bf71caafcdd309c61659cad63fbdb240 Mon Sep 17 00:00:00 2001
From: wwqgtxx <wwqgtxx@gmail.com>
Date: Mon, 22 Jul 2024 09:57:57 +0800
Subject: [PATCH] fix: wrong usage of RLock

---
 adapter/outboundgroup/loadbalance.go |  2 --
 common/lru/lrucache.go               | 32 ++++++++++++++++++++++++++++
 common/queue/queue.go                |  4 ++--
 common/utils/callback.go             |  4 ++--
 component/sniffer/dispatcher.go      | 20 ++++++-----------
 5 files changed, 42 insertions(+), 20 deletions(-)

diff --git a/adapter/outboundgroup/loadbalance.go b/adapter/outboundgroup/loadbalance.go
index 4cb0db004f..738ed15479 100644
--- a/adapter/outboundgroup/loadbalance.go
+++ b/adapter/outboundgroup/loadbalance.go
@@ -205,7 +205,6 @@ func strategyStickySessions(url string) strategyFn {
 			proxy := proxies[nowIdx]
 			if proxy.AliveForTestUrl(url) {
 				if nowIdx != idx {
-					lruCache.Delete(key)
 					lruCache.Set(key, nowIdx)
 				}
 
@@ -215,7 +214,6 @@ func strategyStickySessions(url string) strategyFn {
 			}
 		}
 
-		lruCache.Delete(key)
 		lruCache.Set(key, 0)
 		return proxies[0]
 	}
diff --git a/common/lru/lrucache.go b/common/lru/lrucache.go
index 6f32ed18b1..35f605b10c 100644
--- a/common/lru/lrucache.go
+++ b/common/lru/lrucache.go
@@ -223,6 +223,10 @@ func (c *LruCache[K, V]) Delete(key K) {
 	c.mu.Lock()
 	defer c.mu.Unlock()
 
+	c.delete(key)
+}
+
+func (c *LruCache[K, V]) delete(key K) {
 	if le, ok := c.cache[key]; ok {
 		c.deleteElement(le)
 	}
@@ -255,6 +259,34 @@ func (c *LruCache[K, V]) Clear() error {
 	return nil
 }
 
+// Compute either sets the computed new value for the key or deletes
+// the value for the key. When the delete result of the valueFn function
+// is set to true, the value will be deleted, if it exists. When delete
+// is set to false, the value is updated to the newValue.
+// The ok result indicates whether value was computed and stored, thus, is
+// present in the map. The actual result contains the new value in cases where
+// the value was computed and stored.
+func (c *LruCache[K, V]) Compute(
+	key K,
+	valueFn func(oldValue V, loaded bool) (newValue V, delete bool),
+) (actual V, ok bool) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	if el := c.get(key); el != nil {
+		actual, ok = el.value, true
+	}
+	if newValue, del := valueFn(actual, ok); del {
+		if ok { // data not in cache, so needn't delete
+			c.delete(key)
+		}
+		return lo.Empty[V](), false
+	} else {
+		c.set(key, newValue)
+		return newValue, true
+	}
+}
+
 type entry[K comparable, V any] struct {
 	key     K
 	value   V
diff --git a/common/queue/queue.go b/common/queue/queue.go
index cb58e2f5a2..d1b6beebe5 100644
--- a/common/queue/queue.go
+++ b/common/queue/queue.go
@@ -59,8 +59,8 @@ func (q *Queue[T]) Copy() []T {
 
 // Len returns the number of items in this queue.
 func (q *Queue[T]) Len() int64 {
-	q.lock.Lock()
-	defer q.lock.Unlock()
+	q.lock.RLock()
+	defer q.lock.RUnlock()
 
 	return int64(len(q.items))
 }
diff --git a/common/utils/callback.go b/common/utils/callback.go
index df950d3a81..ad734c0fd6 100644
--- a/common/utils/callback.go
+++ b/common/utils/callback.go
@@ -17,8 +17,8 @@ func NewCallback[T any]() *Callback[T] {
 }
 
 func (c *Callback[T]) Register(item func(T)) io.Closer {
-	c.mutex.RLock()
-	defer c.mutex.RUnlock()
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
 	element := c.list.PushBack(item)
 	return &callbackCloser[T]{
 		element:  element,
diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go
index 97bf162969..4438638dad 100644
--- a/component/sniffer/dispatcher.go
+++ b/component/sniffer/dispatcher.go
@@ -5,7 +5,6 @@ import (
 	"fmt"
 	"net"
 	"net/netip"
-	"sync"
 	"time"
 
 	"github.com/metacubex/mihomo/common/lru"
@@ -30,7 +29,6 @@ type SnifferDispatcher struct {
 	forceDomain     *trie.DomainSet
 	skipSNI         *trie.DomainSet
 	skipList        *lru.LruCache[string, uint8]
-	rwMux           sync.RWMutex
 	forceDnsMapping bool
 	parsePureIp     bool
 }
@@ -85,14 +83,11 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
 			return false
 		}
 
-		sd.rwMux.RLock()
 		dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
 		if count, ok := sd.skipList.Get(dst); ok && count > 5 {
 			log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst)
-			defer sd.rwMux.RUnlock()
 			return false
 		}
-		sd.rwMux.RUnlock()
 
 		if host, err := sd.sniffDomain(conn, metadata); err != nil {
 			sd.cacheSniffFailed(metadata)
@@ -104,9 +99,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
 				return false
 			}
 
-			sd.rwMux.RLock()
 			sd.skipList.Delete(dst)
-			sd.rwMux.RUnlock()
 
 			sd.replaceDomain(metadata, host, overrideDest)
 			return true
@@ -176,14 +169,13 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad
 }
 
 func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
-	sd.rwMux.Lock()
 	dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
-	count, _ := sd.skipList.Get(dst)
-	if count <= 5 {
-		count++
-	}
-	sd.skipList.Set(dst, count)
-	sd.rwMux.Unlock()
+	sd.skipList.Compute(dst, func(oldValue uint8, loaded bool) (newValue uint8, delete bool) {
+		if oldValue <= 5 {
+			oldValue++
+		}
+		return oldValue, false
+	})
 }
 
 func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {