Skip to content

Commit

Permalink
Fix data race while accessing connection in partitionConsumer (#535)
Browse files Browse the repository at this point in the history
The partitionConsumer maintains a few internal go-routines, two of which
access the underlying internal.Connection.  The main runEvenstLoop()
go-routine reads the connection field while a separate go-routine is used
to detect connnection loss, initiate reconnection, and sets the connection.

Previously, access to the conn field was not synchronized.

Now, the conn field is read and written atomically; resolving the data race.

Signed-off-by: Daniel Ferstay <[email protected]>

Co-authored-by: Daniel Ferstay <[email protected]>
  • Loading branch information
dferstay and Daniel Ferstay authored Jun 24, 2021
1 parent f33b594 commit 96ce2de
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ type partitionConsumer struct {
state atomic.Int32
options *partitionConsumerOpts

conn internal.Connection
conn atomic.Value

topic string
name string
Expand Down Expand Up @@ -238,7 +238,7 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) {
RequestId: proto.Uint64(requestID),
ConsumerId: proto.Uint64(pc.consumerID),
}
_, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_UNSUBSCRIBE, cmdUnsubscribe)
_, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_UNSUBSCRIBE, cmdUnsubscribe)
if err != nil {
pc.log.WithError(err).Error("Failed to unsubscribe consumer")
unsub.err = err
Expand All @@ -248,7 +248,7 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) {
return
}

pc.conn.DeleteConsumeHandler(pc.consumerID)
pc._getConn().DeleteConsumeHandler(pc.consumerID)
if pc.nackTracker != nil {
pc.nackTracker.Close()
}
Expand Down Expand Up @@ -276,7 +276,7 @@ func (pc *partitionConsumer) requestGetLastMessageID() (trackingMessageID, error
RequestId: proto.Uint64(requestID),
ConsumerId: proto.Uint64(pc.consumerID),
}
res, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID,
res, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID,
pb.BaseCommand_GET_LAST_MESSAGE_ID, cmdGetLastMessageID)
if err != nil {
pc.log.WithError(err).Error("Failed to get last message id")
Expand Down Expand Up @@ -326,7 +326,7 @@ func (pc *partitionConsumer) internalRedeliver(req *redeliveryRequest) {
}
}

pc.client.rpcClient.RequestOnCnxNoWait(pc.conn,
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(),
pb.BaseCommand_REDELIVER_UNACKNOWLEDGED_MESSAGES, &pb.CommandRedeliverUnacknowledgedMessages{
ConsumerId: proto.Uint64(pc.consumerID),
MessageIds: msgIDDataList,
Expand Down Expand Up @@ -399,7 +399,7 @@ func (pc *partitionConsumer) requestSeekWithoutClear(msgID messageID) error {
MessageId: id,
}

_, err = pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_SEEK, cmdSeek)
_, err = pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_SEEK, cmdSeek)
if err != nil {
pc.log.WithError(err).Error("Failed to reset to message id")
return err
Expand Down Expand Up @@ -435,7 +435,7 @@ func (pc *partitionConsumer) internalSeekByTime(seek *seekByTimeRequest) {
MessagePublishTime: proto.Uint64(uint64(seek.publishTime.UnixNano() / int64(time.Millisecond))),
}

_, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_SEEK, cmdSeek)
_, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_SEEK, cmdSeek)
if err != nil {
pc.log.WithError(err).Error("Failed to reset to message publish time")
seek.err = err
Expand Down Expand Up @@ -465,7 +465,7 @@ func (pc *partitionConsumer) internalAck(req *ackRequest) {
AckType: pb.CommandAck_Individual.Enum(),
}

pc.client.rpcClient.RequestOnCnxNoWait(pc.conn, pb.BaseCommand_ACK, cmdAck)
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(), pb.BaseCommand_ACK, cmdAck)
}

func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, headersAndPayload internal.Buffer) error {
Expand Down Expand Up @@ -607,7 +607,7 @@ func (pc *partitionConsumer) internalFlow(permits uint32) error {
ConsumerId: proto.Uint64(pc.consumerID),
MessagePermits: proto.Uint32(permits),
}
pc.client.rpcClient.RequestOnCnxNoWait(pc.conn, pb.BaseCommand_FLOW, cmdFlow)
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(), pb.BaseCommand_FLOW, cmdFlow)

return nil
}
Expand Down Expand Up @@ -843,7 +843,7 @@ func (pc *partitionConsumer) internalClose(req *closeRequest) {
ConsumerId: proto.Uint64(pc.consumerID),
RequestId: proto.Uint64(requestID),
}
_, err := pc.client.rpcClient.RequestOnCnx(pc.conn, requestID, pb.BaseCommand_CLOSE_CONSUMER, cmdClose)
_, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), requestID, pb.BaseCommand_CLOSE_CONSUMER, cmdClose)
if err != nil {
pc.log.WithError(err).Warn("Failed to close consumer")
} else {
Expand All @@ -855,7 +855,7 @@ func (pc *partitionConsumer) internalClose(req *closeRequest) {
}

pc.setConsumerState(consumerClosed)
pc.conn.DeleteConsumeHandler(pc.consumerID)
pc._getConn().DeleteConsumeHandler(pc.consumerID)
if pc.nackTracker != nil {
pc.nackTracker.Close()
}
Expand Down Expand Up @@ -971,9 +971,9 @@ func (pc *partitionConsumer) grabConn() error {
pc.name = res.Response.ConsumerStatsResponse.GetConsumerName()
}

pc.conn = res.Cnx
pc._setConn(res.Cnx)
pc.log.Info("Connected consumer")
pc.conn.AddConsumeHandler(pc.consumerID, pc)
pc._getConn().AddConsumeHandler(pc.consumerID, pc)

msgType := res.Response.GetType()

Expand Down Expand Up @@ -1104,7 +1104,7 @@ func (pc *partitionConsumer) discardCorruptedMessage(msgID *pb.MessageIdData,
"validationError": validationError,
}).Error("Discarding corrupted message")

pc.client.rpcClient.RequestOnCnxNoWait(pc.conn,
pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(),
pb.BaseCommand_ACK, &pb.CommandAck{
ConsumerId: proto.Uint64(pc.consumerID),
MessageId: []*pb.MessageIdData{msgID},
Expand All @@ -1113,6 +1113,21 @@ func (pc *partitionConsumer) discardCorruptedMessage(msgID *pb.MessageIdData,
})
}

// _setConn sets the internal connection field of this partition consumer atomically.
// Note: should only be called by this partition consumer when a new connection is available.
func (pc *partitionConsumer) _setConn(conn internal.Connection) {
pc.conn.Store(conn)
}

// _getConn returns internal connection field of this partition consumer atomically.
// Note: should only be called by this partition consumer before attempting to use the connection
func (pc *partitionConsumer) _getConn() internal.Connection {
// Invariant: The conn must be non-nill for the lifetime of the partitionConsumer.
// For this reason we leave this cast unchecked and panic() if the
// invariant is broken
return pc.conn.Load().(internal.Connection)
}

func convertToMessageIDData(msgID trackingMessageID) *pb.MessageIdData {
if msgID.Undefined() {
return nil
Expand Down

0 comments on commit 96ce2de

Please sign in to comment.