Skip to content

Commit

Permalink
Fix data race while accessing connection in partitionConsumer
Browse files Browse the repository at this point in the history
The partitionConsumer maintains a few internal go-routines, two of which
access the underlying internal.Connection.  The main runEvenstLoop()
go-routine reads the connection field while a separate go-routine is used
to detect connnection loss, initiate reconnection, and sets the connection.

Previously, access to the conn field was not synchronized.

Now, the conn field is read and written atomically; avoiding race
conditions.

Signed-off-by: Daniel Ferstay <[email protected]>
  • Loading branch information
Daniel Ferstay committed Jun 7, 2021
1 parent 8a78d2c commit d3c9ab3
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"math"
"sync"
syncAtomic "sync/atomic"
"time"

"github.com/gogo/protobuf/proto"
Expand Down Expand Up @@ -107,7 +108,7 @@ type partitionConsumer struct {
state atomic.Int32
options *partitionConsumerOpts

conn internal.Connection
conn syncAtomic.Value

topic string
name string
Expand Down Expand Up @@ -238,7 +239,7 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) {
RequestId: proto.Uint64(requestID),
ConsumerId: proto.Uint64(pc.consumerID),
}
_, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_UNSUBSCRIBE, cmdUnsubscribe)
_, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_UNSUBSCRIBE, cmdUnsubscribe)
if err != nil {
pc.log.WithError(err).Error("Failed to unsubscribe consumer")
unsub.err = err
Expand All @@ -248,7 +249,7 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) {
return
}

pc.conn.DeleteConsumeHandler(pc.consumerID)
pc._getConn().DeleteConsumeHandler(pc.consumerID)
if pc.nackTracker != nil {
pc.nackTracker.Close()
}
Expand Down Expand Up @@ -276,7 +277,7 @@ func (pc *partitionConsumer) requestGetLastMessageID() (trackingMessageID, error
RequestId: proto.Uint64(requestID),
ConsumerId: proto.Uint64(pc.consumerID),
}
res, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID,
res, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID,
pb.BaseCommand_GET_LAST_MESSAGE_ID, cmdGetLastMessageID)
if err != nil {
pc.log.WithError(err).Error("Failed to get last message id")
Expand Down Expand Up @@ -326,7 +327,7 @@ func (pc *partitionConsumer) internalRedeliver(req *redeliveryRequest) {
}
}

pc.client.rpcClient.RequestOnCnxNoWait(pc.conn,
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(),
pb.BaseCommand_REDELIVER_UNACKNOWLEDGED_MESSAGES, &pb.CommandRedeliverUnacknowledgedMessages{
ConsumerId: proto.Uint64(pc.consumerID),
MessageIds: msgIDDataList,
Expand Down Expand Up @@ -399,7 +400,7 @@ func (pc *partitionConsumer) requestSeekWithoutClear(msgID messageID) error {
MessageId: id,
}

_, err = pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_SEEK, cmdSeek)
_, err = pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_SEEK, cmdSeek)
if err != nil {
pc.log.WithError(err).Error("Failed to reset to message id")
return err
Expand Down Expand Up @@ -435,7 +436,7 @@ func (pc *partitionConsumer) internalSeekByTime(seek *seekByTimeRequest) {
MessagePublishTime: proto.Uint64(uint64(seek.publishTime.UnixNano() / int64(time.Millisecond))),
}

_, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_SEEK, cmdSeek)
_, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_SEEK, cmdSeek)
if err != nil {
pc.log.WithError(err).Error("Failed to reset to message publish time")
seek.err = err
Expand Down Expand Up @@ -465,7 +466,7 @@ func (pc *partitionConsumer) internalAck(req *ackRequest) {
AckType: pb.CommandAck_Individual.Enum(),
}

pc.client.rpcClient.RequestOnCnxNoWait(pc.conn, pb.BaseCommand_ACK, cmdAck)
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(), pb.BaseCommand_ACK, cmdAck)
}

func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, headersAndPayload internal.Buffer) error {
Expand Down Expand Up @@ -607,7 +608,7 @@ func (pc *partitionConsumer) internalFlow(permits uint32) error {
ConsumerId: proto.Uint64(pc.consumerID),
MessagePermits: proto.Uint32(permits),
}
pc.client.rpcClient.RequestOnCnxNoWait(pc.conn, pb.BaseCommand_FLOW, cmdFlow)
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(), pb.BaseCommand_FLOW, cmdFlow)

return nil
}
Expand Down Expand Up @@ -843,7 +844,7 @@ func (pc *partitionConsumer) internalClose(req *closeRequest) {
ConsumerId: proto.Uint64(pc.consumerID),
RequestId: proto.Uint64(requestID),
}
_, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_CLOSE_CONSUMER, cmdClose)
_, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_CLOSE_CONSUMER, cmdClose)
if err != nil {
pc.log.WithError(err).Warn("Failed to close consumer")
} else {
Expand All @@ -855,7 +856,7 @@ func (pc *partitionConsumer) internalClose(req *closeRequest) {
}

pc.setConsumerState(consumerClosed)
pc.conn.DeleteConsumeHandler(pc.consumerID)
pc._getConn().DeleteConsumeHandler(pc.consumerID)
if pc.nackTracker != nil {
pc.nackTracker.Close()
}
Expand Down Expand Up @@ -971,9 +972,9 @@ func (pc *partitionConsumer) grabConn() error {
pc.name = res.Response.ConsumerStatsResponse.GetConsumerName()
}

pc.conn = res.Cnx
pc._setConn(res.Cnx)
pc.log.Info("Connected consumer")
pc.conn.AddConsumeHandler(pc.consumerID, pc)
pc._getConn().AddConsumeHandler(pc.consumerID, pc)

msgType := res.Response.GetType()

Expand Down Expand Up @@ -1104,7 +1105,7 @@ func (pc *partitionConsumer) discardCorruptedMessage(msgID *pb.MessageIdData,
"validationError": validationError,
}).Error("Discarding corrupted message")

pc.client.rpcClient.RequestOnCnxNoWait(pc.conn,
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(),
pb.BaseCommand_ACK, &pb.CommandAck{
ConsumerId: proto.Uint64(pc.consumerID),
MessageId: []*pb.MessageIdData{msgID},
Expand All @@ -1113,6 +1114,18 @@ func (pc *partitionConsumer) discardCorruptedMessage(msgID *pb.MessageIdData,
})
}

func (pc *partitionConsumer) _setConn(conn internal.Connection) {
pc.conn.Store(conn)
}

func (pc *partitionConsumer) _getConn() internal.Connection {
conn, ok := pc.conn.Load().(internal.Connection)
if !ok {
return nil
}
return conn
}

func convertToMessageIDData(msgID trackingMessageID) *pb.MessageIdData {
if msgID.Undefined() {
return nil
Expand Down

0 comments on commit d3c9ab3

Please sign in to comment.