Skip to content

Commit

Permalink
better apply tcp keepalive to listeners
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Sep 25, 2024
1 parent b14c527 commit e347645
Show file tree
Hide file tree
Showing 31 changed files with 180 additions and 99 deletions.
22 changes: 22 additions & 0 deletions adapter/inbound/listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package inbound
import (
"context"
"net"
"sync"

"github.com/metacubex/mihomo/component/keepalive"

"github.com/metacubex/tfo-go"
)
Expand All @@ -11,28 +14,47 @@ var (
lc = tfo.ListenConfig{
DisableTFO: true,
}
mutex sync.RWMutex
)

func SetTfo(open bool) {
mutex.Lock()
defer mutex.Unlock()
lc.DisableTFO = !open
}

func Tfo() bool {
mutex.RLock()
defer mutex.RUnlock()
return !lc.DisableTFO
}

func SetMPTCP(open bool) {
mutex.Lock()
defer mutex.Unlock()
setMultiPathTCP(&lc.ListenConfig, open)
}

func MPTCP() bool {
mutex.RLock()
defer mutex.RUnlock()
return getMultiPathTCP(&lc.ListenConfig)
}

func ListenContext(ctx context.Context, network, address string) (net.Listener, error) {
mutex.RLock()
defer mutex.RUnlock()
return lc.Listen(ctx, network, address)
}

func Listen(network, address string) (net.Listener, error) {
return ListenContext(context.Background(), network, address)
}

func init() {
keepalive.SetDisableKeepAliveCallback.Register(func(b bool) {
mutex.Lock()
defer mutex.Unlock()
keepalive.SetNetListenConfig(&lc.ListenConfig)
})
}
2 changes: 0 additions & 2 deletions adapter/outbound/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"

N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
Expand All @@ -26,7 +25,6 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...
if err != nil {
return nil, err
}
N.TCPKeepAlive(c)
return NewConn(c, d), nil
}

Expand Down
2 changes: 0 additions & 2 deletions adapter/outbound/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"net/url"
"strconv"

N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
Expand Down Expand Up @@ -77,7 +76,6 @@ func (h *Http) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metad
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", h.addr, err)
}
N.TCPKeepAlive(c)

defer func(c net.Conn) {
safeConnClose(c, err)
Expand Down
1 change: 0 additions & 1 deletion adapter/outbound/shadowsocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
}
N.TCPKeepAlive(c)

defer func(c net.Conn) {
safeConnClose(c, err)
Expand Down
1 change: 0 additions & 1 deletion adapter/outbound/shadowsocksr.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ func (ssr *ShadowSocksR) DialContextWithDialer(ctx context.Context, dialer C.Dia
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ssr.addr, err)
}
N.TCPKeepAlive(c)

defer func(c net.Conn) {
safeConnClose(c, err)
Expand Down
6 changes: 1 addition & 5 deletions adapter/outbound/snell.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net"
"strconv"

N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/structure"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
Expand Down Expand Up @@ -94,7 +93,6 @@ func (s *Snell) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
}
N.TCPKeepAlive(c)

defer func(c net.Conn) {
safeConnClose(c, err)
Expand Down Expand Up @@ -122,7 +120,6 @@ func (s *Snell) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
if err != nil {
return nil, err
}
N.TCPKeepAlive(c)
c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption})

err = snell.WriteUDPHeader(c, s.version)
Expand Down Expand Up @@ -207,8 +204,7 @@ func NewSnell(option SnellOption) (*Snell, error) {
if err != nil {
return nil, err
}

N.TCPKeepAlive(c)

return streamConn(c, streamOption{psk, option.Version, addr, obfsOption}), nil
})
}
Expand Down
3 changes: 0 additions & 3 deletions adapter/outbound/socks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"net/netip"
"strconv"

N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
Expand Down Expand Up @@ -82,7 +81,6 @@ func (ss *Socks5) DialContextWithDialer(ctx context.Context, dialer C.Dialer, me
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
}
N.TCPKeepAlive(c)

defer func(c net.Conn) {
safeConnClose(c, err)
Expand Down Expand Up @@ -128,7 +126,6 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
safeConnClose(c, err)
}(c)

N.TCPKeepAlive(c)
var user *socks5.User
if ss.user != "" {
user = &socks5.User{
Expand Down
1 change: 0 additions & 1 deletion adapter/outbound/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ func (s *sshClient) connect(ctx context.Context, cDialer C.Dialer, addr string)
if err != nil {
return nil, err
}
N.TCPKeepAlive(c)

defer func(c net.Conn) {
safeConnClose(c, err)
Expand Down
4 changes: 0 additions & 4 deletions adapter/outbound/trojan.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"strconv"

N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
Expand Down Expand Up @@ -148,7 +147,6 @@ func (t *Trojan) DialContextWithDialer(ctx context.Context, dialer C.Dialer, met
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
}
N.TCPKeepAlive(c)

defer func(c net.Conn) {
safeConnClose(c, err)
Expand Down Expand Up @@ -206,7 +204,6 @@ func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, me
defer func(c net.Conn) {
safeConnClose(c, err)
}(c)
N.TCPKeepAlive(c)
c, err = t.plainStream(ctx, c)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
Expand Down Expand Up @@ -300,7 +297,6 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error())
}
N.TCPKeepAlive(c)
return c, nil
}

Expand Down
3 changes: 0 additions & 3 deletions adapter/outbound/vmess.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ func (v *Vmess) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
}
N.TCPKeepAlive(c)
defer func(c net.Conn) {
safeConnClose(c, err)
}(c)
Expand Down Expand Up @@ -369,7 +368,6 @@ func (v *Vmess) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
}
N.TCPKeepAlive(c)
defer func(c net.Conn) {
safeConnClose(c, err)
}(c)
Expand Down Expand Up @@ -481,7 +479,6 @@ func NewVmess(option VmessOption) (*Vmess, error) {
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
}
N.TCPKeepAlive(c)
return c, nil
}

Expand Down
23 changes: 0 additions & 23 deletions common/net/tcp_keepalive.go

This file was deleted.

10 changes: 0 additions & 10 deletions common/net/tcp_keepalive_go122.go

This file was deleted.

19 changes: 0 additions & 19 deletions common/net/tcp_keepalive_go123.go

This file was deleted.

2 changes: 2 additions & 0 deletions component/dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"sync"
"time"

"github.com/metacubex/mihomo/component/keepalive"
"github.com/metacubex/mihomo/component/resolver"
)

Expand Down Expand Up @@ -138,6 +139,7 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
}

dialer := netDialer.(*net.Dialer)
keepalive.SetNetDialer(dialer)
if opt.mpTcp {
setMultiPathTCP(dialer)
}
Expand Down
65 changes: 65 additions & 0 deletions component/keepalive/tcp_keepalive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package keepalive

import (
"net"
"runtime"
"time"

"github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/utils"
)

var (
keepAliveIdle = atomic.NewTypedValue[time.Duration](0 * time.Second)
keepAliveInterval = atomic.NewTypedValue[time.Duration](0 * time.Second)
disableKeepAlive = atomic.NewBool(false)

SetDisableKeepAliveCallback = utils.NewCallback[bool]()
)

func SetKeepAliveIdle(t time.Duration) {
keepAliveIdle.Store(t)
}

func SetKeepAliveInterval(t time.Duration) {
keepAliveInterval.Store(t)
}

func KeepAliveIdle() time.Duration {
return keepAliveIdle.Load()
}

func KeepAliveInterval() time.Duration {
return keepAliveInterval.Load()
}

func SetDisableKeepAlive(disable bool) {
if runtime.GOOS == "android" {
setDisableKeepAlive(false)
} else {
setDisableKeepAlive(disable)
}
}

func setDisableKeepAlive(disable bool) {
disableKeepAlive.Store(disable)
SetDisableKeepAliveCallback.Emit(disable)
}

func DisableKeepAlive() bool {
return disableKeepAlive.Load()
}

func SetNetDialer(dialer *net.Dialer) {
setNetDialer(dialer)
}

func SetNetListenConfig(lc *net.ListenConfig) {
setNetListenConfig(lc)
}

func TCPKeepAlive(c net.Conn) {
if tcp, ok := c.(*net.TCPConn); ok && tcp != nil {
tcpKeepAlive(tcp)
}
}
30 changes: 30 additions & 0 deletions component/keepalive/tcp_keepalive_go122.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//go:build !go1.23

package keepalive

import "net"

func tcpKeepAlive(tcp *net.TCPConn) {
if DisableKeepAlive() {
_ = tcp.SetKeepAlive(false)
} else {
_ = tcp.SetKeepAlive(true)
_ = tcp.SetKeepAlivePeriod(KeepAliveInterval())
}
}

func setNetDialer(dialer *net.Dialer) {
if DisableKeepAlive() {
dialer.KeepAlive = -1 // If negative, keep-alive probes are disabled.
} else {
dialer.KeepAlive = KeepAliveInterval()
}
}

func setNetListenConfig(lc *net.ListenConfig) {
if DisableKeepAlive() {
lc.KeepAlive = -1 // If negative, keep-alive probes are disabled.
} else {
lc.KeepAlive = KeepAliveInterval()
}
}
Loading

0 comments on commit e347645

Please sign in to comment.