diff --git a/cluster.go b/cluster.go index 42b61036..8f991323 100644 --- a/cluster.go +++ b/cluster.go @@ -133,12 +133,14 @@ func (s clusterslots) parse(tls bool) map[string]group { return parseShards(s.reply.val, s.addr, tls) } -func getClusterSlots(c conn) clusterslots { +func getClusterSlots(c conn, timeout time.Duration) clusterslots { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() v := c.Version() if v < 7 { - return clusterslots{reply: c.Do(context.Background(), cmds.SlotCmd), addr: c.Addr(), ver: v} + return clusterslots{reply: c.Do(ctx, cmds.SlotCmd), addr: c.Addr(), ver: v} } - return clusterslots{reply: c.Do(context.Background(), cmds.ShardsCmd), addr: c.Addr(), ver: v} + return clusterslots{reply: c.Do(ctx, cmds.ShardsCmd), addr: c.Addr(), ver: v} } func (c *clusterClient) _refresh() (err error) { @@ -154,9 +156,9 @@ func (c *clusterClient) _refresh() (err error) { for i := 0; i < cap(results); i++ { if i&3 == 0 { // batch CLUSTER SLOTS/CLUSTER SHARDS for every 4 connections for j := i; j < i+4 && j < len(pending); j++ { - go func(c conn) { - results <- getClusterSlots(c) - }(pending[j]) + go func(c conn, timeout time.Duration) { + results <- getClusterSlots(c, timeout) + }(pending[j], c.opt.ConnWriteTimeout) } } result = <-results