diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ShareCompletedFetch.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ShareCompletedFetch.java index c9c2eac369504..74760beec6d73 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ShareCompletedFetch.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ShareCompletedFetch.java @@ -180,8 +180,10 @@ ShareInFlightBatch fetchRecords(final Deserializers deseriali try { int recordsInBatch = 0; - while (recordsInBatch < maxRecords) { - lastRecord = nextFetchedRecord(checkCrcs); + boolean currentBatchHasMoreRecords = false; + + while (recordsInBatch < maxRecords || currentBatchHasMoreRecords) { + currentBatchHasMoreRecords = nextFetchedRecord(checkCrcs); if (lastRecord == null) { // Any remaining acquired records are gaps while (nextAcquired != null) { @@ -323,14 +325,22 @@ private static RecordDeserializationException newRecordDeserializationException( + ". The record has been released.", e); } - private Record nextFetchedRecord(final boolean checkCrcs) { + /** + * Scans for the next record in the available batches, skipping control records + * + * @param checkCrcs Whether to check the CRC of fetched records + * + * @return true if the current batch has more records, else false + */ + private boolean nextFetchedRecord(final boolean checkCrcs) { while (true) { if (records == null || !records.hasNext()) { maybeCloseRecordStream(); if (!batches.hasNext()) { drain(); - return null; + lastRecord = null; + break; } currentBatch = batches.next(); @@ -343,10 +353,13 @@ private Record nextFetchedRecord(final boolean checkCrcs) { // control records are not returned to the user if (!currentBatch.isControlBatch()) { - return record; + lastRecord = record; + break; } } } + + return records != null && records.hasNext(); } private Optional maybeLeaderEpoch(final int leaderEpoch) { diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareCompletedFetchTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareCompletedFetchTest.java index 01bab1d1f9a29..b117af177b17d 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareCompletedFetchTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ShareCompletedFetchTest.java @@ -69,11 +69,11 @@ public class ShareCompletedFetchTest { @Test public void testSimple() { - long firstMessageId = 5; long startingOffset = 10L; - int numRecords = 11; // Records for 10-20 + int numRecordsPerBatch = 10; + int numRecords = 20; // Records for 10-29, in 2 equal batches ShareFetchResponseData.PartitionData partitionData = new ShareFetchResponseData.PartitionData() - .setRecords(newRecords(startingOffset, numRecords, firstMessageId)) + .setRecords(newRecords(startingOffset, numRecordsPerBatch, 2)) .setAcquiredRecords(acquiredRecords(startingOffset, numRecords)); Deserializers deserializers = newStringDeserializers(); @@ -91,7 +91,7 @@ public void testSimple() { batch = completedFetch.fetchRecords(deserializers, 10, true); records = batch.getInFlightRecords(); - assertEquals(1, records.size()); + assertEquals(10, records.size()); record = records.get(0); assertEquals(20L, record.offset()); assertEquals(Optional.of((short) 1), record.deliveryCount()); @@ -105,13 +105,40 @@ record = records.get(0); assertEquals(0, acknowledgements.size()); } + @Test + public void testSoftMaxPollRecordLimit() { + long startingOffset = 10L; + int numRecords = 11; // Records for 10-20, in a single batch + ShareFetchResponseData.PartitionData partitionData = new ShareFetchResponseData.PartitionData() + .setRecords(newRecords(startingOffset, numRecords)) + .setAcquiredRecords(acquiredRecords(startingOffset, numRecords)); + + Deserializers deserializers = newStringDeserializers(); + + ShareCompletedFetch completedFetch = newShareCompletedFetch(partitionData); + + ShareInFlightBatch batch = completedFetch.fetchRecords(deserializers, 10, true); + List> records = batch.getInFlightRecords(); + assertEquals(11, records.size()); + ConsumerRecord record = records.get(0); + assertEquals(10L, record.offset()); + assertEquals(Optional.of((short) 1), record.deliveryCount()); + Acknowledgements acknowledgements = batch.getAcknowledgements(); + assertEquals(0, acknowledgements.size()); + + batch = completedFetch.fetchRecords(deserializers, 10, true); + records = batch.getInFlightRecords(); + assertEquals(0, records.size()); + acknowledgements = batch.getAcknowledgements(); + assertEquals(0, acknowledgements.size()); + } + @Test public void testUnaligned() { - long firstMessageId = 5; long startingOffset = 10L; int numRecords = 10; ShareFetchResponseData.PartitionData partitionData = new ShareFetchResponseData.PartitionData() - .setRecords(newRecords(startingOffset, numRecords + 500, firstMessageId)) + .setRecords(newRecords(startingOffset, numRecords + 500)) .setAcquiredRecords(acquiredRecords(startingOffset + 500, numRecords)); Deserializers deserializers = newStringDeserializers(); @@ -153,11 +180,10 @@ public void testCommittedTransactionRecordsIncluded() { @Test public void testNegativeFetchCount() { - long firstMessageId = 0; int startingOffset = 0; int numRecords = 10; ShareFetchResponseData.PartitionData partitionData = new ShareFetchResponseData.PartitionData() - .setRecords(newRecords(startingOffset, numRecords, firstMessageId)) + .setRecords(newRecords(startingOffset, numRecords)) .setAcquiredRecords(acquiredRecords(0L, 10)); try (final Deserializers deserializers = newStringDeserializers()) { @@ -267,7 +293,6 @@ public void testCorruptedMessage() { @Test public void testAcquiredRecords() { - long firstMessageId = 5; int startingOffset = 0; int numRecords = 10; // Records for 0-9 @@ -275,7 +300,7 @@ public void testAcquiredRecords() { List acquiredRecords = new ArrayList<>(acquiredRecords(0L, 3)); acquiredRecords.addAll(acquiredRecords(6L, 3)); ShareFetchResponseData.PartitionData partitionData = new ShareFetchResponseData.PartitionData() - .setRecords(newRecords(startingOffset, numRecords, firstMessageId)) + .setRecords(newRecords(startingOffset, numRecords)) .setAcquiredRecords(acquiredRecords); Deserializers deserializers = newStringDeserializers(); @@ -299,7 +324,6 @@ record = records.get(3); @Test public void testAcquireOddRecords() { - long firstMessageId = 5; int startingOffset = 0; int numRecords = 10; // Records for 0-9 @@ -310,7 +334,7 @@ public void testAcquireOddRecords() { } ShareFetchResponseData.PartitionData partitionData = new ShareFetchResponseData.PartitionData() - .setRecords(newRecords(startingOffset, numRecords, firstMessageId)) + .setRecords(newRecords(startingOffset, numRecords)) .setAcquiredRecords(acquiredRecords); Deserializers deserializers = newStringDeserializers(); @@ -357,17 +381,45 @@ private static Deserializers newStringDeserializers() { return new Deserializers<>(new StringDeserializer(), new StringDeserializer()); } - private Records newRecords(long baseOffset, int count, long firstMessageId) { + private Records newRecords(long baseOffset, int count) { try (final MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), Compression.NONE, TimestampType.CREATE_TIME, baseOffset)) { for (int i = 0; i < count; i++) - builder.append(0L, "key".getBytes(), ("value-" + (firstMessageId + i)).getBytes()); + builder.append(0L, "key".getBytes(), "value-".getBytes()); return builder.build(); } } + private Records newRecords(long baseOffset, int numRecordsPerBatch, int batchCount) { + Time time = new MockTime(); + ByteBuffer buffer = ByteBuffer.allocate(1024); + + for (long b = 0; b < batchCount; b++) { + try (MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, + RecordBatch.CURRENT_MAGIC_VALUE, + Compression.NONE, + TimestampType.CREATE_TIME, + baseOffset + b * numRecordsPerBatch, + time.milliseconds(), + PRODUCER_ID, + PRODUCER_EPOCH, + 0, + true, + RecordBatch.NO_PARTITION_LEADER_EPOCH)) { + for (int i = 0; i < numRecordsPerBatch; i++) + builder.append(new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes())); + + builder.build(); + } + } + + buffer.flip(); + + return MemoryRecords.readableRecords(buffer); + } + public static List acquiredRecords(long firstOffset, int count) { ShareFetchResponseData.AcquiredRecords acquiredRecords = new ShareFetchResponseData.AcquiredRecords() .setFirstOffset(firstOffset)