package rueidis import ( "sync" "sync/atomic" ) // PubSubMessage represent a pubsub message from redis type PubSubMessage struct { // Pattern is only available with pmessage. Pattern string // Channel is the channel the message belongs to Channel string // Message is the message content Message string } // PubSubSubscription represent a pubsub "subscribe", "unsubscribe", "psubscribe" or "punsubscribe" event. type PubSubSubscription struct { // Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe" Kind string // Channel is the event subject. Channel string // Count is the current number of subscriptions for connection. Count int64 } // PubSubHooks can be registered into DedicatedClient to process pubsub messages without using Client.Receive type PubSubHooks struct { // OnMessage will be called when receiving "message" and "pmessage" event. OnMessage func(m PubSubMessage) // OnSubscription will be called when receiving "subscribe", "unsubscribe", "psubscribe" and "punsubscribe" event. OnSubscription func(s PubSubSubscription) } func (h *PubSubHooks) isZero() bool { return h.OnMessage == nil && h.OnSubscription == nil } func newSubs() *subs { return &subs{chs: make(map[string]chs), sub: make(map[uint64]*sub)} } type subs struct { chs map[string]chs sub map[uint64]*sub cnt uint64 mu sync.RWMutex } type chs struct { sub map[uint64]*sub } type sub struct { ch chan PubSubMessage cs []string } func (s *subs) Publish(channel string, msg PubSubMessage) { if atomic.LoadUint64(&s.cnt) != 0 { s.mu.RLock() for _, sb := range s.chs[channel].sub { sb.ch <- msg } s.mu.RUnlock() } } func (s *subs) Subscribe(channels []string) (ch chan PubSubMessage, cancel func()) { id := atomic.AddUint64(&s.cnt, 1) s.mu.Lock() if s.chs != nil { ch = make(chan PubSubMessage, 16) sb := &sub{cs: channels, ch: ch} s.sub[id] = sb for _, channel := range channels { c := s.chs[channel].sub if c == nil { c = make(map[uint64]*sub, 1) s.chs[channel] = chs{sub: c} } c[id] = sb } cancel = func() { go func() { for range ch { } }() s.mu.Lock() if s.chs != nil { s.remove(id) } s.mu.Unlock() } } s.mu.Unlock() return ch, cancel } func (s *subs) remove(id uint64) { if sb := s.sub[id]; sb != nil { for _, channel := range sb.cs { if c := s.chs[channel].sub; c != nil { delete(c, id) } } close(sb.ch) delete(s.sub, id) } } func (s *subs) Unsubscribe(channel string) { if atomic.LoadUint64(&s.cnt) != 0 { s.mu.Lock() for id := range s.chs[channel].sub { s.remove(id) } delete(s.chs, channel) s.mu.Unlock() } } func (s *subs) Close() { var sbs map[uint64]*sub s.mu.Lock() sbs = s.sub s.chs = nil s.sub = nil s.mu.Unlock() for _, sb := range sbs { close(sb.ch) } }