Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt committed Feb 19, 2025
1 parent 7500386 commit 3570a29
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ data class StreamProcessingFailed(val streamException: Exception) : StreamResult

data object StreamProcessingSucceeded : StreamResult

data class CheckpointId(val id: Long)
@JvmInline value class CheckpointId(val id: Int)

/** Manages the state of a single stream. */
interface StreamManager {
/**
* Count incoming record and return the record's *index*. If [markEndOfStream] has been called,
* this should throw an exception.
*/
fun countRecordIn(): Long
fun recordCount(): Long
fun incrementReadCount(): Long
fun readCount(): Long

/**
* Mark the end-of-stream, set the end of stream variant (complete or incomplete) and return the
Expand Down Expand Up @@ -103,20 +103,20 @@ interface StreamManager {
fun getNextCheckpointId(): CheckpointId

/** Update the counts of persisted for a given checkpoint. */
fun countPersisted(checkpointId: CheckpointId, count: Long)
fun incrementPersistedCount(checkpointId: CheckpointId, count: Long)

/** Update the counts of completed for a given checkpoint. */
fun countCompleted(checkpointId: CheckpointId, count: Long)
fun incrementCompletedCount(checkpointId: CheckpointId, count: Long)

/**
* True if persisted counts for each checkpoint up to and including checkpoint id match the
* True if persisted counts for each checkpoint up to and including [checkpointId] match the
* number of records read for that id.
*/
fun areRecordsPersistedUntilCheckpoint(checkpointId: CheckpointId): Boolean

/**
* True if completed counts for all checkpoints match the number of records read AND all records
* have been read.
* True if all records in the stream have been marked as completed AND the stream has been marked as
* complete.
*/
fun isBatchProcessingCompleteForCheckpoints(): Boolean
}
Expand All @@ -139,26 +139,27 @@ class DefaultStreamManager(

private val rangesState: ConcurrentHashMap<Batch.State, RangeSet<Long>> = ConcurrentHashMap()

private val lastCheckpointIndex = AtomicLong(0)
private val countsPerCheckpoint: ConcurrentLinkedQueue<Long> = ConcurrentLinkedQueue()
private val persistedForCheckpoint: ConcurrentHashMap<CheckpointId, AtomicLong> =
ConcurrentHashMap()
private val completedForCheckpoint: ConcurrentHashMap<CheckpointId, AtomicLong> =
ConcurrentHashMap()
data class CheckpointCounts(
val recordsRead: Long = 0L,
val recordsPersisted: AtomicLong = AtomicLong(0L),
val recordsCompleted: AtomicLong = AtomicLong(0L),
)
private val lastCheckpointRecordIndex = AtomicLong(0L)
private val checkpointCounts: ConcurrentLinkedQueue<CheckpointCounts> = ConcurrentLinkedQueue()

init {
Batch.State.entries.forEach { rangesState[it] = TreeRangeSet.create() }
}

override fun countRecordIn(): Long {
override fun incrementReadCount(): Long {
if (markedEndOfStream.get()) {
throw IllegalStateException("Stream is closed for reading")
}

return recordCount.getAndIncrement()
}

override fun recordCount(): Long {
override fun readCount(): Long {
return recordCount.get()
}

Expand All @@ -180,10 +181,10 @@ class DefaultStreamManager(
}

override fun markCheckpoint(): Pair<Long, Long> {
val index = recordCount.get()
val count = index - lastCheckpointIndex.getAndSet(index)
countsPerCheckpoint.add(count)
return Pair(index, count)
val recordIndex = recordCount.get()
val count = recordIndex - lastCheckpointRecordIndex.getAndSet(recordIndex)
checkpointCounts.add(CheckpointCounts(count))
return Pair(recordIndex, count)
}

override fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>) {
Expand Down Expand Up @@ -332,43 +333,48 @@ class DefaultStreamManager(
}

override fun getNextCheckpointId(): CheckpointId {
return CheckpointId(countsPerCheckpoint.size.toLong())
return CheckpointId(checkpointCounts.size)
}

override fun countPersisted(checkpointId: CheckpointId, count: Long) {
val result =
persistedForCheckpoint.getOrPut(checkpointId) { AtomicLong(0L) }.addAndGet(count)
val original =
countsPerCheckpoint.elementAtOrNull(checkpointId.id.toInt())
?: throw IllegalStateException("No checkpoint found for $checkpointId")
if (result > original) {
throw IllegalStateException(
"Persisted count $result for $checkpointId exceeds read count $original"
)
override fun incrementPersistedCount(checkpointId: CheckpointId, count: Long) {
checkpointCounts.elementAtOrNull(checkpointId.id)?.let {
val result = it.recordsPersisted.addAndGet(count)
if (result > it.recordsRead) {
throw IllegalStateException(
"Persisted count $result for $checkpointId exceeds read count ${it.recordsRead}"
)
}
}
?: throw IllegalStateException("No checkpoint found for $checkpointId")
}

override fun countCompleted(checkpointId: CheckpointId, count: Long) {
val result =
completedForCheckpoint.getOrPut(checkpointId) { AtomicLong(0L) }.addAndGet(count)
val original =
countsPerCheckpoint.elementAtOrNull(checkpointId.id.toInt())
?: throw IllegalStateException("No checkpoint found for $checkpointId")
if (result > original) {
throw IllegalStateException(
"Completed count $result for $checkpointId exceeds read count $original"
)
override fun incrementCompletedCount(checkpointId: CheckpointId, count: Long) {
checkpointCounts.elementAtOrNull(checkpointId.id)?.let {
val result = it.recordsCompleted.addAndGet(count)
if (result > it.recordsRead) {
throw IllegalStateException(
"Completed count $result for $checkpointId exceeds read count ${it.recordsRead}"
)
}
}
?: throw IllegalStateException("No checkpoint found for $checkpointId")
}

override fun areRecordsPersistedUntilCheckpoint(checkpointId: CheckpointId): Boolean {
val readCount = countsPerCheckpoint.take(checkpointId.id.toInt() + 1).sum()
val persistedCount = persistedForCheckpoint.sumUpTo(checkpointId)
val (readCount, persistedCount, completedCount) =
checkpointCounts.take(checkpointId.id + 1).fold(Triple(0L, 0L, 0L)) {
acc,
checkpointCount ->
Triple(
acc.first + checkpointCount.recordsRead,
acc.second + checkpointCount.recordsPersisted.get(),
acc.third + checkpointCount.recordsCompleted.get(),
)
}
if (persistedCount == readCount) {
return true
}
// Completed implies persisted.
val completedCount = completedForCheckpoint.sumUpTo(checkpointId)
return completedCount == readCount
}

Expand All @@ -382,13 +388,9 @@ class DefaultStreamManager(
return true
}

val completedCount = completedForCheckpoint.values.sumOf { it.get() }
val completedCount = checkpointCounts.sumOf { it.recordsCompleted.get() }
return completedCount == readCount
}

private fun Map<CheckpointId, AtomicLong>.sumUpTo(checkpointId: CheckpointId): Long {
return filter { it.key.id <= checkpointId.id }.values.sumOf { it.get() }
}
}

interface StreamManagerFactory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DefaultInputConsumerTask(
is DestinationRecord -> {
val wrapped =
StreamRecordEvent(
index = manager.countRecordIn(),
index = manager.incrementReadCount(),
sizeBytes = sizeBytes,
payload = message.asRecordSerialized()
)
Expand All @@ -104,7 +104,7 @@ class DefaultInputConsumerTask(
recordQueue.close()
}
is DestinationFile -> {
val index = manager.countRecordIn()
val index = manager.incrementReadCount()
// destinationTaskLauncher.handleFile(stream, message, index)
fileTransferQueue.publish(FileTransferQueueMessage(stream, message, index))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ class CheckpointManagerTest {
* the index of the message.
*/
val streamManager = syncManager.getStreamManager(it.stream.descriptor)
val recordCount = streamManager.recordCount()
val recordCount = streamManager.readCount()
(recordCount until it.index).forEach { _ ->
syncManager.getStreamManager(it.stream.descriptor).countRecordIn()
syncManager.getStreamManager(it.stream.descriptor).incrementReadCount()
}
checkpointManager.addStreamCheckpoint(
it.stream.descriptor,
Expand Down
Loading

0 comments on commit 3570a29

Please sign in to comment.