diff --git a/fastdialer/utils/dialwrap.go b/fastdialer/utils/dialwrap.go index 497c04f..1c77f89 100644 --- a/fastdialer/utils/dialwrap.go +++ b/fastdialer/utils/dialwrap.go @@ -70,6 +70,8 @@ type DialWrap struct { mu sync.RWMutex // error returned by first connection err error + + firstConnCond *sync.Cond } // NewDialWrap creates a new dial wrap instance and returns it. @@ -98,6 +100,7 @@ func NewDialWrap(dialer *net.Dialer, ips []string, network, address, port string network: network, address: address, port: port, + firstConnCond: sync.NewCond(&sync.Mutex{}), }, nil } @@ -128,8 +131,12 @@ func (d *DialWrap) DialContext(ctx context.Context, _ string, _ string) (net.Con case <-d.hasCompletedFirstConnection(ctx): // if first connection completed and it failed due to other reasons // and not due to context cancellation - if d.err != nil && !errkit.Is(d.err, ErrInflightCancel) && !errkit.Is(d.err, context.Canceled) { - return nil, d.err + d.firstConnCond.L.Lock() + err := d.err + d.firstConnCond.L.Unlock() + + if err != nil && !errkit.Is(err, ErrInflightCancel) && !errkit.Is(err, context.Canceled) { + return nil, err } return d.dial(ctx) case <-ctx.Done(): @@ -143,27 +150,31 @@ func (d *DialWrap) doFirstConnection(ctx context.Context) chan *dialResult { } d.busyFirstConnection.Store(true) now := time.Now() - defer func() { - d.SetFirstConnectionDuration(time.Since(now)) - }() size := len(d.ipv4) + len(d.ipv6) ch := make(chan *dialResult, size) // dial parallel - conns, err := d.dialAllParallel(ctx) - defer func() { + go func() { + defer close(ch) + + conns, err := d.dialAllParallel(ctx) + + d.firstConnCond.L.Lock() + d.SetFirstConnectionDuration(time.Since(now)) d.completedFirstConnection.Store(true) - close(ch) - }() - if err != nil { + d.firstConnCond.Broadcast() d.err = err - ch <- &dialResult{error: err} - return ch - } - for _, conn := range conns { - ch <- conn - } + d.firstConnCond.L.Unlock() + + if err != nil { + ch <- &dialResult{error: err} + return + } + for _, conn := range conns { + ch <- conn + } + }() return ch } @@ -172,19 +183,22 @@ func (d *DialWrap) hasCompletedFirstConnection(ctx context.Context) chan struct{ go func() { defer close(ch) - for { - if d.completedFirstConnection.Load() { - ch <- struct{}{} - return - } - select { - case <-ctx.Done(): + + // Check immediately first + if d.completedFirstConnection.Load() { + return + } + + d.firstConnCond.L.Lock() + defer d.firstConnCond.L.Unlock() + + for !d.completedFirstConnection.Load() { + if ctx.Err() != nil { return - default: } + d.firstConnCond.Wait() } }() - return ch }