Skip to content

Commit

Permalink
Fix data race while accessing connection in partitionConsumer
Browse files Browse the repository at this point in the history
The partitionConsumer maintains a few internal go-routines, two of which
access the underlying internal.Connection.  The main runEvenstLoop()
go-routine reads the connection field while a separate go-routine is used
to detect connnection loss, initiate reconnection, and sets the connection.

Previously, access to the conn field was not synchronized.

Now, the conn field is read and written atomically; avoiding race
conditions.

Signed-off-by: Daniel Ferstay <[email protected]>
  • Loading branch information
Daniel Ferstay committed Jun 21, 2021
1 parent 8a78d2c commit 6339d74
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ type partitionConsumer struct {
state atomic.Int32
options *partitionConsumerOpts

conn internal.Connection
conn atomic.Value

topic string
name string
Expand Down Expand Up @@ -238,7 +238,7 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) {
RequestId: proto.Uint64(requestID),
ConsumerId: proto.Uint64(pc.consumerID),
}
_, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_UNSUBSCRIBE, cmdUnsubscribe)
_, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_UNSUBSCRIBE, cmdUnsubscribe)
if err != nil {
pc.log.WithError(err).Error("Failed to unsubscribe consumer")
unsub.err = err
Expand All @@ -248,7 +248,7 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) {
return
}

pc.conn.DeleteConsumeHandler(pc.consumerID)
pc._getConn().DeleteConsumeHandler(pc.consumerID)
if pc.nackTracker != nil {
pc.nackTracker.Close()
}
Expand Down Expand Up @@ -276,7 +276,7 @@ func (pc *partitionConsumer) requestGetLastMessageID() (trackingMessageID, error
RequestId: proto.Uint64(requestID),
ConsumerId: proto.Uint64(pc.consumerID),
}
res, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID,
res, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID,
pb.BaseCommand_GET_LAST_MESSAGE_ID, cmdGetLastMessageID)
if err != nil {
pc.log.WithError(err).Error("Failed to get last message id")
Expand Down Expand Up @@ -326,7 +326,7 @@ func (pc *partitionConsumer) internalRedeliver(req *redeliveryRequest) {
}
}

pc.client.rpcClient.RequestOnCnxNoWait(pc.conn,
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(),
pb.BaseCommand_REDELIVER_UNACKNOWLEDGED_MESSAGES, &pb.CommandRedeliverUnacknowledgedMessages{
ConsumerId: proto.Uint64(pc.consumerID),
MessageIds: msgIDDataList,
Expand Down Expand Up @@ -399,7 +399,7 @@ func (pc *partitionConsumer) requestSeekWithoutClear(msgID messageID) error {
MessageId: id,
}

_, err = pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_SEEK, cmdSeek)
_, err = pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_SEEK, cmdSeek)
if err != nil {
pc.log.WithError(err).Error("Failed to reset to message id")
return err
Expand Down Expand Up @@ -435,7 +435,7 @@ func (pc *partitionConsumer) internalSeekByTime(seek *seekByTimeRequest) {
MessagePublishTime: proto.Uint64(uint64(seek.publishTime.UnixNano() / int64(time.Millisecond))),
}

_, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_SEEK, cmdSeek)
_, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_SEEK, cmdSeek)
if err != nil {
pc.log.WithError(err).Error("Failed to reset to message publish time")
seek.err = err
Expand Down Expand Up @@ -465,7 +465,7 @@ func (pc *partitionConsumer) internalAck(req *ackRequest) {
AckType: pb.CommandAck_Individual.Enum(),
}

pc.client.rpcClient.RequestOnCnxNoWait(pc.conn, pb.BaseCommand_ACK, cmdAck)
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(), pb.BaseCommand_ACK, cmdAck)
}

func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, headersAndPayload internal.Buffer) error {
Expand Down Expand Up @@ -607,7 +607,7 @@ func (pc *partitionConsumer) internalFlow(permits uint32) error {
ConsumerId: proto.Uint64(pc.consumerID),
MessagePermits: proto.Uint32(permits),
}
pc.client.rpcClient.RequestOnCnxNoWait(pc.conn, pb.BaseCommand_FLOW, cmdFlow)
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(), pb.BaseCommand_FLOW, cmdFlow)

return nil
}
Expand Down Expand Up @@ -843,7 +843,7 @@ func (pc *partitionConsumer) internalClose(req *closeRequest) {
ConsumerId: proto.Uint64(pc.consumerID),
RequestId: proto.Uint64(requestID),
}
_, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_CLOSE_CONSUMER, cmdClose)
_, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_CLOSE_CONSUMER, cmdClose)
if err != nil {
pc.log.WithError(err).Warn("Failed to close consumer")
} else {
Expand All @@ -855,7 +855,7 @@ func (pc *partitionConsumer) internalClose(req *closeRequest) {
}

pc.setConsumerState(consumerClosed)
pc.conn.DeleteConsumeHandler(pc.consumerID)
pc._getConn().DeleteConsumeHandler(pc.consumerID)
if pc.nackTracker != nil {
pc.nackTracker.Close()
}
Expand Down Expand Up @@ -971,9 +971,9 @@ func (pc *partitionConsumer) grabConn() error {
pc.name = res.Response.ConsumerStatsResponse.GetConsumerName()
}

pc.conn = res.Cnx
pc._setConn(res.Cnx)
pc.log.Info("Connected consumer")
pc.conn.AddConsumeHandler(pc.consumerID, pc)
pc._getConn().AddConsumeHandler(pc.consumerID, pc)

msgType := res.Response.GetType()

Expand Down Expand Up @@ -1104,7 +1104,7 @@ func (pc *partitionConsumer) discardCorruptedMessage(msgID *pb.MessageIdData,
"validationError": validationError,
}).Error("Discarding corrupted message")

pc.client.rpcClient.RequestOnCnxNoWait(pc.conn,
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(),
pb.BaseCommand_ACK, &pb.CommandAck{
ConsumerId: proto.Uint64(pc.consumerID),
MessageId: []*pb.MessageIdData{msgID},
Expand All @@ -1113,6 +1113,17 @@ func (pc *partitionConsumer) discardCorruptedMessage(msgID *pb.MessageIdData,
})
}

func (pc *partitionConsumer) _setConn(conn internal.Connection) {
pc.conn.Store(conn)
}

func (pc *partitionConsumer) _getConn() internal.Connection {
// Invariant: The conn must be non-nill for the lifetime of the partitionConsumer.
// For this reason we leave this cast unchecked and panic() if the
// invariant is broken
return pc.conn.Load().(internal.Connection)
}

func convertToMessageIDData(msgID trackingMessageID) *pb.MessageIdData {
if msgID.Undefined() {
return nil
Expand Down

0 comments on commit 6339d74

Please sign in to comment.