Skip to content

Commit

Permalink
Load CDK: CheckpointManager support for Index-Based Checkpointing (#5…
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Feb 19, 2025
1 parent 9424806 commit 5330830
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import io.airbyte.cdk.load.util.use
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import io.micronaut.context.annotation.Value
import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
Expand All @@ -28,7 +29,7 @@ import kotlinx.coroutines.sync.withLock
* requests to flush all data-sufficient checkpoints.
*/
interface CheckpointManager<K, T> {
suspend fun addStreamCheckpoint(key: K, index: Long, checkpointMessage: T)
suspend fun addStreamCheckpoint(key: K, indexOrId: Long, checkpointMessage: T)
suspend fun addGlobalCheckpoint(keyIndexes: List<Pair<K, Long>>, checkpointMessage: T)
suspend fun flushReadyCheckpointMessages()
suspend fun getLastSuccessfulFlushTimeMs(): Long
Expand Down Expand Up @@ -59,6 +60,14 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream
abstract val outputConsumer: suspend (T) -> Unit
abstract val timeProvider: TimeProvider

/**
* Whether or not we are using the new style checkpoint-by-id or the old style
* checkpoint-by-range.
*
* TODO: Remove this once everything is using the new interface.
*/
abstract val checkpointById: Boolean

data class GlobalCheckpoint<T>(
val streamIndexes: List<Pair<DestinationStream.Descriptor, Long>>,
val checkpointMessage: T
Expand All @@ -74,7 +83,7 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream

override suspend fun addStreamCheckpoint(
key: DestinationStream.Descriptor,
index: Long,
indexOrId: Long,
checkpointMessage: T
) {
flushLock.withLock {
Expand All @@ -89,15 +98,15 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream
if (indexedMessages.isNotEmpty()) {
// Make sure the messages are coming in order
val (latestIndex, _) = indexedMessages.last()!!
if (latestIndex > index) {
if (latestIndex > indexOrId) {
throw IllegalStateException(
"Checkpoint message received out of order ($latestIndex before $index)"
"Checkpoint message received out of order ($latestIndex before $indexOrId)"
)
}
}
indexedMessages.add(index to checkpointMessage)
indexedMessages.add(indexOrId to checkpointMessage)

log.info { "Added checkpoint for stream: $key at index: $index" }
log.info { "Added checkpoint for stream: $key at index: $indexOrId" }
}
}

Expand Down Expand Up @@ -155,7 +164,13 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream
val head = globalCheckpoints.peek()
val allStreamsPersisted =
head.streamIndexes.all { (stream, index) ->
syncManager.getStreamManager(stream).areRecordsPersistedUntil(index)
if (!checkpointById) {
syncManager.getStreamManager(stream).areRecordsPersistedUntil(index)
} else {
syncManager
.getStreamManager(stream)
.areRecordsPersistedUntilCheckpoint(CheckpointId(index.toInt()))
}
}
if (allStreamsPersisted) {
log.info { "Flushing global checkpoint with stream indexes: ${head.streamIndexes}" }
Expand Down Expand Up @@ -183,7 +198,13 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream
}
while (true) {
val (nextIndex, nextMessage) = streamCheckpoints.peek() ?: break
if (manager.areRecordsPersistedUntil(nextIndex)) {
val persisted =
if (checkpointById) {
manager.areRecordsPersistedUntilCheckpoint(CheckpointId(nextIndex.toInt()))
} else {
manager.areRecordsPersistedUntil(nextIndex)
}
if (persisted) {

log.info {
"Flushing checkpoint for stream: ${stream.descriptor} at index: $nextIndex"
Expand Down Expand Up @@ -268,7 +289,9 @@ class DefaultCheckpointManager(
override val catalog: DestinationCatalog,
override val syncManager: SyncManager,
override val outputConsumer: suspend (Reserved<CheckpointMessage>) -> Unit,
override val timeProvider: TimeProvider
override val timeProvider: TimeProvider,
@Value("\${airbyte.destination.core.checkpoint-by-id:false}")
override val checkpointById: Boolean = false
) : StreamsCheckpointManager<Reserved<CheckpointMessage>>() {
init {
lastFlushTimeMs.set(timeProvider.currentTimeMillis())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ class CheckpointManagerTest {
override val catalog: DestinationCatalog,
override val syncManager: SyncManager,
override val outputConsumer: MockOutputConsumer,
override val timeProvider: TimeProvider
) : StreamsCheckpointManager<MockCheckpoint>()
override val timeProvider: TimeProvider,
) : StreamsCheckpointManager<MockCheckpoint>() {
override val checkpointById: Boolean = false
}

sealed class TestEvent
data class TestStreamMessage(val stream: DestinationStream, val index: Long, val message: Int) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import io.mockk.coVerify
import io.mockk.impl.annotations.MockK
import io.mockk.mockk
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test

class CheckpointManagerUTest {
Expand All @@ -23,33 +24,50 @@ class CheckpointManagerUTest {
private val outputConsumer: suspend (Reserved<CheckpointMessage>) -> Unit =
mockk<suspend (Reserved<CheckpointMessage>) -> Unit>(relaxed = true)
@MockK(relaxed = true) lateinit var timeProvider: TimeProvider
@MockK(relaxed = true) lateinit var streamManager1: StreamManager
@MockK(relaxed = true) lateinit var streamManager2: StreamManager

@Test
fun `test checkpoint manager does not ignore ready checkpoint after empty one`() = runTest {
// Populate the mock catalog with two streams in order
val stream1 =
DestinationStream(
DestinationStream.Descriptor("test", "stream1"),
importType = Append,
schema = ObjectTypeWithEmptySchema,
generationId = 10L,
minimumGenerationId = 10L,
syncId = 101L
)
val stream2 =
DestinationStream(
DestinationStream.Descriptor("test", "stream2"),
importType = Append,
schema = ObjectTypeWithEmptySchema,
generationId = 10L,
minimumGenerationId = 10L,
syncId = 101L
)
private val stream1 =
DestinationStream(
DestinationStream.Descriptor("test", "stream1"),
importType = Append,
schema = ObjectTypeWithEmptySchema,
generationId = 10L,
minimumGenerationId = 10L,
syncId = 101L
)

private val stream2 =
DestinationStream(
DestinationStream.Descriptor("test", "stream2"),
importType = Append,
schema = ObjectTypeWithEmptySchema,
generationId = 10L,
minimumGenerationId = 10L,
syncId = 101L
)

@BeforeEach
fun setup() {
coEvery { catalog.streams } returns listOf(stream1, stream2)
coEvery { outputConsumer.invoke(any()) } returns Unit
coEvery { syncManager.getStreamManager(stream1.descriptor) } returns streamManager1
coEvery { syncManager.getStreamManager(stream2.descriptor) } returns streamManager2
}

val checkpointManager =
DefaultCheckpointManager(catalog, syncManager, outputConsumer, timeProvider)
private fun makeCheckpointManager(checkpointById: Boolean): DefaultCheckpointManager {
return DefaultCheckpointManager(
catalog,
syncManager,
outputConsumer,
timeProvider,
checkpointById = checkpointById
)
}

@Test
fun `test checkpoint manager does not ignore ready checkpoint after empty one`() = runTest {
val checkpointManager = makeCheckpointManager(checkpointById = false)

// Add a checkpoint for only the second stream
val message = mockk<Reserved<CheckpointMessage>>(relaxed = true)
Expand All @@ -65,4 +83,28 @@ class CheckpointManagerUTest {
checkpointManager.flushReadyCheckpointMessages()
coVerify { outputConsumer.invoke(message) }
}

@Test
fun `test checkpoint-by-id`() = runTest {
val checkpointManager = makeCheckpointManager(checkpointById = true)

val message1 = mockk<Reserved<CheckpointMessage>>(relaxed = true)
val message2 = mockk<Reserved<CheckpointMessage>>(relaxed = true)

checkpointManager.addStreamCheckpoint(stream1.descriptor, 10, message1)
checkpointManager.addStreamCheckpoint(stream2.descriptor, 10, message2)

// Make stream1 data sufficient by old method, stream2 data sufficient by new.
coEvery { streamManager1.areRecordsPersistedUntil(10) } returns true
coEvery { streamManager1.areRecordsPersistedUntilCheckpoint(CheckpointId(10)) } returns
false

coEvery { streamManager2.areRecordsPersistedUntil(10) } returns false
coEvery { streamManager2.areRecordsPersistedUntilCheckpoint(CheckpointId(10)) } returns true

// Only stream2 should be flushed.
checkpointManager.flushReadyCheckpointMessages()
coVerify(exactly = 0) { outputConsumer.invoke(message1) }
coVerify(exactly = 1) { outputConsumer.invoke(message2) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class MockCheckpointManager : CheckpointManager<DestinationStream.Descriptor, Ch

override suspend fun addStreamCheckpoint(
key: DestinationStream.Descriptor,
index: Long,
indexOrId: Long,
checkpointMessage: CheckpointMessage
) {
streamStates.getOrPut(key) { mutableListOf() }.add(index to checkpointMessage)
streamStates.getOrPut(key) { mutableListOf() }.add(indexOrId to checkpointMessage)
}

override suspend fun addGlobalCheckpoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ class StreamManagerTest {
val manager = DefaultStreamManager(stream1)

val checkpointId1 = manager.getNextCheckpointId()

repeat(10) { manager.incrementReadCount() }
manager.markCheckpoint()

Expand All @@ -545,6 +546,7 @@ class StreamManagerTest {
Assertions.assertFalse(manager.areRecordsPersistedUntilCheckpoint(checkpointId2))

manager.incrementPersistedCount(checkpointId1, 10)

Assertions.assertTrue(manager.areRecordsPersistedUntilCheckpoint(checkpointId1))
Assertions.assertTrue(manager.areRecordsPersistedUntilCheckpoint(checkpointId2))
}
Expand All @@ -554,6 +556,7 @@ class StreamManagerTest {
val manager = DefaultStreamManager(stream1)

val checkpointId1 = manager.getNextCheckpointId()

repeat(10) { manager.incrementReadCount() }
manager.markCheckpoint()

Expand Down Expand Up @@ -587,6 +590,7 @@ class StreamManagerTest {
cases.forEach { steps ->
val manager = DefaultStreamManager(stream1)
val checkpointId1 = manager.getNextCheckpointId()

repeat(10) { manager.incrementReadCount() }
manager.markCheckpoint()

Expand Down

0 comments on commit 5330830

Please sign in to comment.