diff --git a/pulsar/default_router.go b/pulsar/default_router.go index b5e24a6214..6945ff195a 100644 --- a/pulsar/default_router.go +++ b/pulsar/default_router.go @@ -18,7 +18,6 @@ package pulsar import ( - "math" "math/rand" "sync/atomic" "time" @@ -27,7 +26,7 @@ import ( type defaultRouter struct { currentPartitionCursor uint32 - lastChangeTimestamp int64 + lastBatchTimestamp int64 msgCounter uint32 cumulativeBatchSize uint32 } @@ -45,7 +44,7 @@ func NewDefaultRouter( disableBatching bool) func(*ProducerMessage, uint32) int { state := &defaultRouter{ currentPartitionCursor: rand.Uint32(), - lastChangeTimestamp: math.MinInt64, + lastBatchTimestamp: time.Now().UnixNano(), } readClockAfterNumMessages := uint32(maxBatchingMessages / 10) @@ -75,37 +74,38 @@ func NewDefaultRouter( // If there's no key, we do round-robin across partition, sticking with a given // partition for a certain amount of messages or volume buffered or the max delay to batch is reached so that // we ensure having a decent amount of batching of the messages. - // Note that it is possible that we skip more than one partition if multiple goroutines increment - // currentPartitionCursor at the same time. If that happens it shouldn't be a problem because we only want to - // spread the data on different partitions but not necessarily in a specific sequence. var now int64 size := uint32(len(message.Payload)) - previousMessageCount := atomic.LoadUint32(&state.msgCounter) - previousBatchingMaxSize := atomic.LoadUint32(&state.cumulativeBatchSize) - previousLastChange := atomic.LoadInt64(&state.lastChangeTimestamp) + partitionCursor := atomic.LoadUint32(&state.currentPartitionCursor) + messageCount := atomic.AddUint32(&state.msgCounter, 1) + batchSize := atomic.AddUint32(&state.cumulativeBatchSize, size) - messageCountReached := previousMessageCount >= uint32(maxBatchingMessages-1) - sizeReached := (size >= uint32(maxBatchingSize)-previousBatchingMaxSize) + // Note: use greater-than for the threshold check so that we don't route this message to a new partition + // before a batch is complete. + messageCountReached := messageCount > uint32(maxBatchingMessages) + sizeReached := batchSize > uint32(maxBatchingSize) durationReached := false - if readClockAfterNumMessages == 0 || previousMessageCount%readClockAfterNumMessages == 0 { + if readClockAfterNumMessages == 0 || messageCount%readClockAfterNumMessages == 0 { now = time.Now().UnixNano() - durationReached = now-previousLastChange >= maxBatchingDelay.Nanoseconds() + lastBatchTime := atomic.LoadInt64(&state.lastBatchTimestamp) + durationReached = now-lastBatchTime > maxBatchingDelay.Nanoseconds() } if messageCountReached || sizeReached || durationReached { - atomic.AddUint32(&state.currentPartitionCursor, 1) - atomic.StoreUint32(&state.msgCounter, 0) - atomic.StoreUint32(&state.cumulativeBatchSize, 0) - if now != 0 { - atomic.StoreInt64(&state.lastChangeTimestamp, now) + // Note: CAS to ensure that concurrent go-routines can only move the cursor forward by one so that + // partitions are not skipped. + newCursor := partitionCursor + 1 + if atomic.CompareAndSwapUint32(&state.currentPartitionCursor, partitionCursor, newCursor) { + atomic.StoreUint32(&state.msgCounter, 0) + atomic.StoreUint32(&state.cumulativeBatchSize, 0) + if now == 0 { + now = time.Now().UnixNano() + } + atomic.StoreInt64(&state.lastBatchTimestamp, now) } - return int(state.currentPartitionCursor % numPartitions) - } - atomic.AddUint32(&state.msgCounter, 1) - atomic.AddUint32(&state.cumulativeBatchSize, size) - if now != 0 { - atomic.StoreInt64(&state.lastChangeTimestamp, now) + return int(newCursor % numPartitions) } - return int(state.currentPartitionCursor % numPartitions) + + return int(partitionCursor % numPartitions) } } diff --git a/pulsar/default_router_test.go b/pulsar/default_router_test.go index 31b27aff99..3c42e66d88 100644 --- a/pulsar/default_router_test.go +++ b/pulsar/default_router_test.go @@ -71,16 +71,21 @@ func TestDefaultRouterRoutingBecauseMaxNumberOfMessagesReached(t *testing.T) { const numPartitions = uint32(3) p1 := router(&ProducerMessage{ Payload: []byte("message 1"), - }, 3) + }, numPartitions) assert.LessOrEqual(t, p1, int(numPartitions)) p2 := router(&ProducerMessage{ Payload: []byte("message 2"), }, numPartitions) - if p1 == int(numPartitions-1) { - assert.Equal(t, 0, p2) + assert.Equal(t, p1, p2) + + p3 := router(&ProducerMessage{ + Payload: []byte("message 3"), + }, numPartitions) + if p2 == int(numPartitions-1) { + assert.Equal(t, 0, p3) } else { - assert.Equal(t, p1+1, p2) + assert.Equal(t, p2+1, p3) } }