From 3c454602211faf7dcf29bb0002addfdd1f48455d Mon Sep 17 00:00:00 2001 From: Mzack9999 Date: Sun, 8 Dec 2024 23:47:36 +0100 Subject: [PATCH] Patching stuck-go routines causing deadline errors (#381) * feat: added timeout to dns + singleflight for caching initial bulk resolutions * fixing singleflight * . * atomic * removing log + finalize * fix routine leak * fix race --------- Co-authored-by: Ice3man --- .github/workflows/build-test.yml | 1 + example/concurrent/concurrent.go | 88 ++++++++++++++++++++++ fastdialer/dialer.go | 34 ++++++--- fastdialer/dialer_private.go | 14 ++-- fastdialer/perf_test | 0 fastdialer/utils/dialwrap.go | 124 ++++++++++++++++++++----------- go.mod | 4 +- go.sum | 4 +- tests/fastdialer_test.go | 2 +- 9 files changed, 203 insertions(+), 68 deletions(-) create mode 100644 example/concurrent/concurrent.go create mode 100644 fastdialer/perf_test diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 48bb9e1..5f599d0 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -26,3 +26,4 @@ jobs: run: | go run -race example/simple/main.go go run -race example/impersonate/main.go + go run -race example/concurrent/concurrent.go diff --git a/example/concurrent/concurrent.go b/example/concurrent/concurrent.go new file mode 100644 index 0000000..1311c1a --- /dev/null +++ b/example/concurrent/concurrent.go @@ -0,0 +1,88 @@ +package main + +// this example is to test the concurrency of the dialer along +// with ensuring that maximum connection time doesn't exceed 3 seconds + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/projectdiscovery/fastdialer/fastdialer" +) + +func main() { + err := BenchmarkDial("scanme.sh", 1000) + if err != nil { + panic(err) + } +} + +type connResult struct { + target string + elapsed time.Duration + err error +} + +func BenchmarkDial(target string, iterations int) error { + options := fastdialer.DefaultOptions + fd, err := fastdialer.NewDialer(options) + if err != nil { + return errors.Join(err, errors.New("failed to create dialer")) + } + + ctx := context.Background() + + tasks := make(chan string, iterations) + results := make(chan connResult, iterations) + + var wg sync.WaitGroup + for w := 0; w < 10; w++ { + wg.Add(1) + go worker(ctx, fd, tasks, results, &wg) + } + + go func() { + for i := 0; i < iterations; i++ { + tasks <- target + } + close(tasks) + }() + + go func() { + wg.Wait() + close(results) + }() + + for result := range results { + if result.err != nil { + return result.err + } + if result.elapsed.Seconds() > 3 { + return errors.New("connection took too long") + } + } + + return nil +} + +func worker(ctx context.Context, fd *fastdialer.Dialer, tasks <-chan string, results chan<- connResult, wg *sync.WaitGroup) { + defer wg.Done() + + for task := range tasks { + start := time.Now() + conn, err := fd.Dial(ctx, "tcp", task+":443") + elapsed := time.Since(start) + + if err == nil && conn != nil { + conn.Close() + } + + results <- connResult{ + target: task, + elapsed: elapsed, + err: err, + } + } +} diff --git a/fastdialer/dialer.go b/fastdialer/dialer.go index cde5473..6c2bbbb 100644 --- a/fastdialer/dialer.go +++ b/fastdialer/dialer.go @@ -8,6 +8,9 @@ import ( "net" "strings" "sync/atomic" + "time" + + "golang.org/x/sync/singleflight" "github.com/Mzack9999/gcache" gounit "github.com/docker/go-units" @@ -64,6 +67,8 @@ type Dialer struct { networkpolicy *networkpolicy.NetworkPolicy dialCache gcache.Cache[string, *utils.DialWrap] dialTimeoutErrors gcache.Cache[string, *atomic.Uint32] + + resolutionsGroup *singleflight.Group } // NewDialer instance @@ -136,7 +141,11 @@ func NewDialer(options Options) (*Dialer, error) { options.Logger.Printf("could not load hosts file: %s\n", err) } } - dnsclient, err := retryabledns.New(resolvers, options.MaxRetries) + dnsclient, err := retryabledns.NewWithOptions(retryabledns.Options{ + BaseResolvers: resolvers, + MaxRetries: options.MaxRetries, + Timeout: 1 * time.Second, + }) if err != nil { return nil, err } @@ -152,17 +161,18 @@ func NewDialer(options Options) (*Dialer, error) { } d := &Dialer{ - dnsclient: dnsclient, - mDnsCache: dnsCache, - hmDnsCache: hmDnsCache, - hostsFileData: hostsFileData, - dialerHistory: dialerHistory, - dialerTLSData: dialerTLSData, - dialer: dialer, - proxyDialer: options.ProxyDialer, - options: &options, - networkpolicy: np, - dialCache: gcache.New[string, *utils.DialWrap](MaxDialCacheSize).Build(), + dnsclient: dnsclient, + mDnsCache: dnsCache, + hmDnsCache: hmDnsCache, + hostsFileData: hostsFileData, + dialerHistory: dialerHistory, + dialerTLSData: dialerTLSData, + dialer: dialer, + proxyDialer: options.ProxyDialer, + options: &options, + networkpolicy: np, + dialCache: gcache.New[string, *utils.DialWrap](MaxDialCacheSize).Build(), + resolutionsGroup: &singleflight.Group{}, } if options.MaxTemporaryErrors > 0 && options.MaxTemporaryToPermanentDuration > 0 { diff --git a/fastdialer/dialer_private.go b/fastdialer/dialer_private.go index eaa00cf..98afa2e 100644 --- a/fastdialer/dialer_private.go +++ b/fastdialer/dialer_private.go @@ -15,6 +15,7 @@ import ( "github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate" "github.com/projectdiscovery/fastdialer/fastdialer/utils" + retryabledns "github.com/projectdiscovery/retryabledns" ctxutil "github.com/projectdiscovery/utils/context" cryptoutil "github.com/projectdiscovery/utils/crypto" "github.com/projectdiscovery/utils/errkit" @@ -110,14 +111,14 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er if fixedIP != "" { IPS = append(IPS, fixedIP) } else { - data, err := d.GetDNSData(hostname) - if err != nil { - // otherwise attempt to retrieve it - data, err = d.dnsclient.Resolve(hostname) - } - if data == nil { + cacheData, err, _ := d.resolutionsGroup.Do(hostname, func() (interface{}, error) { + return d.GetDNSData(hostname) + }) + + if cacheData == nil { return nil, ResolveHostError } + data := cacheData.(*retryabledns.DNSData) if err != nil || len(data.A)+len(data.AAAA) == 0 { return nil, NoAddressFoundError } @@ -161,7 +162,6 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er // 2. it is a domain and not ip // 3. it has at least 1 valid ip // 4. proxy dialer is not set - dw, err = utils.NewDialWrap(d.dialer, IPS, opts.network, opts.address, opts.port) if err != nil { return nil, errkit.Wrap(err, "could not create dialwrap") diff --git a/fastdialer/perf_test b/fastdialer/perf_test new file mode 100644 index 0000000..e69de29 diff --git a/fastdialer/utils/dialwrap.go b/fastdialer/utils/dialwrap.go index f7e850e..497c04f 100644 --- a/fastdialer/utils/dialwrap.go +++ b/fastdialer/utils/dialwrap.go @@ -61,14 +61,15 @@ type DialWrap struct { network string address string port string - // below fields implement a singleflight like pattern - // where first connection is established and subsequent calls receive - // a shared result - wg sync.WaitGroup - mu sync.Mutex - completedFirstFlight *atomic.Bool - dups uint8 - err error // error returned by first flight + + // all connections blocks until a first connection is established + // subsequent calls will behave upon first result + busyFirstConnection *atomic.Bool + completedFirstConnection *atomic.Bool + firstConnectionDuration time.Duration + mu sync.RWMutex + // error returned by first connection + err error } // NewDialWrap creates a new dial wrap instance and returns it. @@ -88,31 +89,22 @@ func NewDialWrap(dialer *net.Dialer, ips []string, network, address, port string return nil, ErrNoIPs } return &DialWrap{ - dialer: dialer, - ipv4: ipv4, - ipv6: ipv6, - ips: valid, - completedFirstFlight: &atomic.Bool{}, - network: network, - address: address, - port: port, + dialer: dialer, + ipv4: ipv4, + ipv6: ipv6, + ips: valid, + completedFirstConnection: &atomic.Bool{}, + busyFirstConnection: &atomic.Bool{}, + network: network, + address: address, + port: port, }, nil } // DialContext is the main entry point for dialing func (d *DialWrap) DialContext(ctx context.Context, _ string, _ string) (net.Conn, error) { - if d.completedFirstFlight.Load() { - // if first flight completed and it failed due to other reasons - // and not due to context cancellation - if d.err != nil && !errkit.Is(d.err, ErrInflightCancel) && !errkit.Is(d.err, context.Canceled) { - return nil, d.err - } - return d.dial(ctx) - } select { - case <-ctx.Done(): - return nil, errkit.Append(ErrInflightCancel, ctx.Err()) - case res, ok := <-d.firstFlight(ctx): + case res, ok := <-d.doFirstConnection(ctx): if !ok { // closed channel so depending on the error // either dial new or return the error @@ -133,27 +125,35 @@ func (d *DialWrap) DialContext(ctx context.Context, _ string, _ string) (net.Con return nil, d.err } return nil, res.error + case <-d.hasCompletedFirstConnection(ctx): + // if first connection completed and it failed due to other reasons + // and not due to context cancellation + if d.err != nil && !errkit.Is(d.err, ErrInflightCancel) && !errkit.Is(d.err, context.Canceled) { + return nil, d.err + } + return d.dial(ctx) + case <-ctx.Done(): + return nil, errkit.Append(ErrInflightCancel, ctx.Err()) } } -// firstFlight is a singleflight pattern implementation -func (d *DialWrap) firstFlight(ctx context.Context) chan *dialResult { +func (d *DialWrap) doFirstConnection(ctx context.Context) chan *dialResult { + if d.busyFirstConnection.Load() { + return nil + } + d.busyFirstConnection.Store(true) + now := time.Now() + defer func() { + d.SetFirstConnectionDuration(time.Since(now)) + }() + size := len(d.ipv4) + len(d.ipv6) ch := make(chan *dialResult, size) - d.mu.Lock() - if d.dups > 0 { - d.mu.Unlock() - d.wg.Wait() - return ch - } - d.dups++ - d.wg.Add(1) - d.mu.Unlock() - defer d.wg.Done() + // dial parallel conns, err := d.dialAllParallel(ctx) defer func() { - d.completedFirstFlight.Store(true) + d.completedFirstConnection.Store(true) close(ch) }() if err != nil { @@ -167,6 +167,27 @@ func (d *DialWrap) firstFlight(ctx context.Context) chan *dialResult { return ch } +func (d *DialWrap) hasCompletedFirstConnection(ctx context.Context) chan struct{} { + ch := make(chan struct{}, 1) + + go func() { + defer close(ch) + for { + if d.completedFirstConnection.Load() { + ch <- struct{}{} + return + } + select { + case <-ctx.Done(): + return + default: + } + } + }() + + return ch +} + // dialAllParallel connects to all the given addresses in parallel, returning // the first successful connection, or the first error. func (d *DialWrap) dialAllParallel(ctx context.Context) ([]*dialResult, error) { @@ -261,11 +282,12 @@ func (d *DialWrap) dial(ctx context.Context) (net.Conn, error) { // // Or zero, if none of Timeout, Deadline, or context's deadline is set. func (d *DialWrap) deadline(ctx context.Context, now time.Time) (earliest time.Time) { - if d.dialer.Timeout != 0 { // including negative, for historical reasons - earliest = now.Add(d.dialer.Timeout) + // including negative, for historical reasons + if d.dialer.Timeout != 0 { + earliest = now.Add(d.dialer.Timeout + d.FirstConnectionTook()) } - if d, ok := ctx.Deadline(); ok { - earliest = minNonzeroTime(earliest, d) + if de, ok := ctx.Deadline(); ok { + earliest = minNonzeroTime(earliest, de.Add(d.FirstConnectionTook())) } return earliest } @@ -408,3 +430,17 @@ func minNonzeroTime(a, b time.Time) time.Time { } return b } + +func (d *DialWrap) FirstConnectionTook() time.Duration { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.firstConnectionDuration +} + +func (d *DialWrap) SetFirstConnectionDuration(dur time.Duration) { + d.mu.Lock() + defer d.mu.Unlock() + + d.firstConnectionDuration = dur +} diff --git a/go.mod b/go.mod index d3577ac..22e9f60 100644 --- a/go.mod +++ b/go.mod @@ -7,16 +7,17 @@ require ( github.com/dimchansky/utfbom v1.1.1 github.com/docker/go-units v0.5.0 github.com/pkg/errors v0.9.1 + github.com/projectdiscovery/goleak v0.0.0-20240729222606-a7d18edc33f8 github.com/projectdiscovery/hmap v0.0.69 github.com/projectdiscovery/networkpolicy v0.0.9 github.com/projectdiscovery/retryabledns v1.0.87 github.com/projectdiscovery/utils v0.3.0 github.com/refraction-networking/utls v1.6.7 github.com/stretchr/testify v1.9.0 - github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9 github.com/zmap/zcrypto v0.0.0-20230422215203-9a665e1e9968 golang.org/x/exp v0.0.0-20221205204356-47842c84f3db golang.org/x/net v0.29.0 + golang.org/x/sync v0.8.0 ) require ( @@ -55,7 +56,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.27.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.25.0 // indirect golang.org/x/text v0.18.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect diff --git a/go.sum b/go.sum index e16e1d7..4ae1b01 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/projectdiscovery/blackrock v0.0.1 h1:lHQqhaaEFjgf5WkuItbpeCZv2DUIE45k0VbGJyft6LQ= github.com/projectdiscovery/blackrock v0.0.1/go.mod h1:ANUtjDfaVrqB453bzToU+YB4cUbvBRpLvEwoWIwlTss= +github.com/projectdiscovery/goleak v0.0.0-20240729222606-a7d18edc33f8 h1:M86+KhVmrurDS2ry8kwI0Z8LosZUwKW1K08vEDHlJ4M= +github.com/projectdiscovery/goleak v0.0.0-20240729222606-a7d18edc33f8/go.mod h1:ZkbDKjIe4ojX5CyEk8dYe8odTs8bnPB5s0nzIm4bnMY= github.com/projectdiscovery/hmap v0.0.69 h1:e30pCr6JShf/UyJmKQpx++Yceiijw4GWj3lFHGZ1yko= github.com/projectdiscovery/hmap v0.0.69/go.mod h1:LgZHcgcxOvA3X8tuFtfu4dofJjAHAfpMno27Jx0J34w= github.com/projectdiscovery/networkpolicy v0.0.9 h1:IrlDoYZagNNO8y+7iZeHT8k5izE+nek7TdtvEBwCxqk= @@ -109,8 +111,6 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= -github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9 h1:GXIyLuIJ5Qk46lI8WJ83qHBZKUI3zhmMmuoY9HICUIQ= -github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9/go.mod h1:uQdBQGrE1fZ2EyOs0pLcCDd1bBV4rSThieuIIGhXZ50= github.com/tidwall/assert v0.1.0 h1:aWcKyRBUAdLoVebxo95N7+YZVTFF/ASTr7BN4sLP6XI= github.com/tidwall/assert v0.1.0/go.mod h1:QLYtGyeqse53vuELQheYl9dngGCJQ+mTtlxcktb+Kj8= github.com/tidwall/btree v1.4.3 h1:Lf5U/66bk0ftNppOBjVoy/AIPBrLMkheBp4NnSNiYOo= diff --git a/tests/fastdialer_test.go b/tests/fastdialer_test.go index e4dbb31..5459f8f 100644 --- a/tests/fastdialer_test.go +++ b/tests/fastdialer_test.go @@ -10,8 +10,8 @@ import ( "time" "github.com/projectdiscovery/fastdialer/fastdialer" + "github.com/projectdiscovery/goleak" "github.com/stretchr/testify/require" - "github.com/tarunKoyalwar/goleak" ) func TestFastDialerIP(t *testing.T) {