Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load CDK: CheckpointManager support for Index-Based Checkpointing #53663

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import io.airbyte.cdk.load.util.use
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import io.micronaut.context.annotation.Value
import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
Expand All @@ -28,7 +29,7 @@ import kotlinx.coroutines.sync.withLock
* requests to flush all data-sufficient checkpoints.
*/
interface CheckpointManager<K, T> {
suspend fun addStreamCheckpoint(key: K, index: Long, checkpointMessage: T)
suspend fun addStreamCheckpoint(key: K, indexOrId: Long, checkpointMessage: T)
suspend fun addGlobalCheckpoint(keyIndexes: List<Pair<K, Long>>, checkpointMessage: T)
suspend fun flushReadyCheckpointMessages()
suspend fun getLastSuccessfulFlushTimeMs(): Long
Expand Down Expand Up @@ -59,6 +60,14 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream
abstract val outputConsumer: suspend (T) -> Unit
abstract val timeProvider: TimeProvider

/**
* Whether or not we are using the new style checkpoint-by-id or the old style
* checkpoint-by-range.
*
* TODO: Remove this once everything is using the new interface.
*/
abstract val checkpointById: Boolean

data class GlobalCheckpoint<T>(
val streamIndexes: List<Pair<DestinationStream.Descriptor, Long>>,
val checkpointMessage: T
Expand All @@ -74,7 +83,7 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream

override suspend fun addStreamCheckpoint(
key: DestinationStream.Descriptor,
index: Long,
indexOrId: Long,
checkpointMessage: T
) {
flushLock.withLock {
Expand All @@ -89,15 +98,15 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream
if (indexedMessages.isNotEmpty()) {
// Make sure the messages are coming in order
val (latestIndex, _) = indexedMessages.last()!!
if (latestIndex > index) {
if (latestIndex > indexOrId) {
throw IllegalStateException(
"Checkpoint message received out of order ($latestIndex before $index)"
"Checkpoint message received out of order ($latestIndex before $indexOrId)"
)
}
}
indexedMessages.add(index to checkpointMessage)
indexedMessages.add(indexOrId to checkpointMessage)

log.info { "Added checkpoint for stream: $key at index: $index" }
log.info { "Added checkpoint for stream: $key at index: $indexOrId" }
}
}

Expand Down Expand Up @@ -155,7 +164,13 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream
val head = globalCheckpoints.peek()
val allStreamsPersisted =
head.streamIndexes.all { (stream, index) ->
syncManager.getStreamManager(stream).areRecordsPersistedUntil(index)
if (!checkpointById) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: invert this if statement (i.e. if (checkpointById) { ... CheckpointId()... } else {...})

syncManager.getStreamManager(stream).areRecordsPersistedUntil(index)
} else {
syncManager
.getStreamManager(stream)
.areRecordsPersistedUntilCheckpoint(CheckpointId(index.toInt()))
}
}
if (allStreamsPersisted) {
log.info { "Flushing global checkpoint with stream indexes: ${head.streamIndexes}" }
Expand Down Expand Up @@ -183,7 +198,13 @@ abstract class StreamsCheckpointManager<T> : CheckpointManager<DestinationStream
}
while (true) {
val (nextIndex, nextMessage) = streamCheckpoints.peek() ?: break
if (manager.areRecordsPersistedUntil(nextIndex)) {
val persisted =
if (checkpointById) {
manager.areRecordsPersistedUntilCheckpoint(CheckpointId(nextIndex.toInt()))
} else {
manager.areRecordsPersistedUntil(nextIndex)
}
if (persisted) {

log.info {
"Flushing checkpoint for stream: ${stream.descriptor} at index: $nextIndex"
Expand Down Expand Up @@ -268,7 +289,9 @@ class DefaultCheckpointManager(
override val catalog: DestinationCatalog,
override val syncManager: SyncManager,
override val outputConsumer: suspend (Reserved<CheckpointMessage>) -> Unit,
override val timeProvider: TimeProvider
override val timeProvider: TimeProvider,
@Value("\${airbyte.destination.core.checkpoint-by-id:false}")
override val checkpointById: Boolean = false
) : StreamsCheckpointManager<Reserved<CheckpointMessage>>() {
init {
lastFlushTimeMs.set(timeProvider.currentTimeMillis())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ class CheckpointManagerTest {
override val catalog: DestinationCatalog,
override val syncManager: SyncManager,
override val outputConsumer: MockOutputConsumer,
override val timeProvider: TimeProvider
) : StreamsCheckpointManager<MockCheckpoint>()
override val timeProvider: TimeProvider,
) : StreamsCheckpointManager<MockCheckpoint>() {
override val checkpointById: Boolean = false
}

sealed class TestEvent
data class TestStreamMessage(val stream: DestinationStream, val index: Long, val message: Int) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import io.mockk.coVerify
import io.mockk.impl.annotations.MockK
import io.mockk.mockk
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test

class CheckpointManagerUTest {
Expand All @@ -23,33 +24,50 @@ class CheckpointManagerUTest {
private val outputConsumer: suspend (Reserved<CheckpointMessage>) -> Unit =
mockk<suspend (Reserved<CheckpointMessage>) -> Unit>(relaxed = true)
@MockK(relaxed = true) lateinit var timeProvider: TimeProvider
@MockK(relaxed = true) lateinit var streamManager1: StreamManager
@MockK(relaxed = true) lateinit var streamManager2: StreamManager

@Test
fun `test checkpoint manager does not ignore ready checkpoint after empty one`() = runTest {
// Populate the mock catalog with two streams in order
val stream1 =
DestinationStream(
DestinationStream.Descriptor("test", "stream1"),
importType = Append,
schema = ObjectTypeWithEmptySchema,
generationId = 10L,
minimumGenerationId = 10L,
syncId = 101L
)
val stream2 =
DestinationStream(
DestinationStream.Descriptor("test", "stream2"),
importType = Append,
schema = ObjectTypeWithEmptySchema,
generationId = 10L,
minimumGenerationId = 10L,
syncId = 101L
)
private val stream1 =
DestinationStream(
DestinationStream.Descriptor("test", "stream1"),
importType = Append,
schema = ObjectTypeWithEmptySchema,
generationId = 10L,
minimumGenerationId = 10L,
syncId = 101L
)

private val stream2 =
DestinationStream(
DestinationStream.Descriptor("test", "stream2"),
importType = Append,
schema = ObjectTypeWithEmptySchema,
generationId = 10L,
minimumGenerationId = 10L,
syncId = 101L
)

@BeforeEach
fun setup() {
coEvery { catalog.streams } returns listOf(stream1, stream2)
coEvery { outputConsumer.invoke(any()) } returns Unit
coEvery { syncManager.getStreamManager(stream1.descriptor) } returns streamManager1
coEvery { syncManager.getStreamManager(stream2.descriptor) } returns streamManager2
}

val checkpointManager =
DefaultCheckpointManager(catalog, syncManager, outputConsumer, timeProvider)
private fun makeCheckpointManager(checkpointById: Boolean): DefaultCheckpointManager {
return DefaultCheckpointManager(
catalog,
syncManager,
outputConsumer,
timeProvider,
checkpointById = checkpointById
)
}

@Test
fun `test checkpoint manager does not ignore ready checkpoint after empty one`() = runTest {
val checkpointManager = makeCheckpointManager(checkpointById = false)

// Add a checkpoint for only the second stream
val message = mockk<Reserved<CheckpointMessage>>(relaxed = true)
Expand All @@ -65,4 +83,28 @@ class CheckpointManagerUTest {
checkpointManager.flushReadyCheckpointMessages()
coVerify { outputConsumer.invoke(message) }
}

@Test
fun `test checkpoint-by-id`() = runTest {
val checkpointManager = makeCheckpointManager(checkpointById = true)

val message1 = mockk<Reserved<CheckpointMessage>>(relaxed = true)
val message2 = mockk<Reserved<CheckpointMessage>>(relaxed = true)

checkpointManager.addStreamCheckpoint(stream1.descriptor, 10, message1)
checkpointManager.addStreamCheckpoint(stream2.descriptor, 10, message2)

// Make stream1 data sufficient by old method, stream2 data sufficient by new.
coEvery { streamManager1.areRecordsPersistedUntil(10) } returns true
coEvery { streamManager1.areRecordsPersistedUntilCheckpoint(CheckpointId(10)) } returns
false

coEvery { streamManager2.areRecordsPersistedUntil(10) } returns false
coEvery { streamManager2.areRecordsPersistedUntilCheckpoint(CheckpointId(10)) } returns true

// Only stream2 should be flushed.
checkpointManager.flushReadyCheckpointMessages()
coVerify(exactly = 0) { outputConsumer.invoke(message1) }
coVerify(exactly = 1) { outputConsumer.invoke(message2) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class MockCheckpointManager : CheckpointManager<DestinationStream.Descriptor, Ch

override suspend fun addStreamCheckpoint(
key: DestinationStream.Descriptor,
index: Long,
indexOrId: Long,
checkpointMessage: CheckpointMessage
) {
streamStates.getOrPut(key) { mutableListOf() }.add(index to checkpointMessage)
streamStates.getOrPut(key) { mutableListOf() }.add(indexOrId to checkpointMessage)
}

override suspend fun addGlobalCheckpoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ class StreamManagerTest {
val manager = DefaultStreamManager(stream1)

val checkpointId1 = manager.getNextCheckpointId()

repeat(10) { manager.incrementReadCount() }
manager.markCheckpoint()

Expand All @@ -545,6 +546,7 @@ class StreamManagerTest {
Assertions.assertFalse(manager.areRecordsPersistedUntilCheckpoint(checkpointId2))

manager.incrementPersistedCount(checkpointId1, 10)

Assertions.assertTrue(manager.areRecordsPersistedUntilCheckpoint(checkpointId1))
Assertions.assertTrue(manager.areRecordsPersistedUntilCheckpoint(checkpointId2))
}
Expand All @@ -554,6 +556,7 @@ class StreamManagerTest {
val manager = DefaultStreamManager(stream1)

val checkpointId1 = manager.getNextCheckpointId()

repeat(10) { manager.incrementReadCount() }
manager.markCheckpoint()

Expand Down Expand Up @@ -587,6 +590,7 @@ class StreamManagerTest {
cases.forEach { steps ->
val manager = DefaultStreamManager(stream1)
val checkpointId1 = manager.getNextCheckpointId()

repeat(10) { manager.incrementReadCount() }
manager.markCheckpoint()

Expand Down
Loading