Skip to content

Commit

Permalink
bulk-cdk-extract*: remove StateQuerier (#52051)
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Posta authored Jan 30, 2025
1 parent 7c28d4e commit 14c6c9d
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.time.ZoneOffset
/**
* [FeedBootstrap] is the input to a [PartitionsCreatorFactory].
*
* This object conveniently packages the [StateQuerier] singleton with the [feed] for which the
* This object conveniently packages the [StateManager] singleton with the [feed] for which the
* [PartitionsCreatorFactory] is to operate on, eventually causing the emission of Airbyte RECORD
* messages for the [Stream]s in the [feed]. For this purpose, [FeedBootstrap] provides
* [StreamRecordConsumer] instances which essentially provide a layer of caching over
Expand All @@ -34,15 +34,30 @@ sealed class FeedBootstrap<T : Feed>(
* The [MetaFieldDecorator] instance which [StreamRecordConsumer] will use to decorate records.
*/
val metaFieldDecorator: MetaFieldDecorator,
/** [StateQuerier] singleton for use by [PartitionsCreatorFactory]. */
val stateQuerier: StateQuerier,
/** [StateManager] singleton which is encapsulated by this [FeedBootstrap]. */
private val stateManager: StateManager,
/** [Feed] to emit records for. */
val feed: T
) {

/** Convenience getter for the current state value for the [feed]. */
/** Delegates to [StateManager.feeds]. */
val feeds: List<Feed>
get() = stateManager.feeds

/** Deletages to [StateManager] to return the current state value for any [Feed]. */
fun currentState(feed: Feed): OpaqueStateValue? = stateManager.scoped(feed).current()

/** Convenience getter for the current state value for this [feed]. */
val currentState: OpaqueStateValue?
get() = stateQuerier.current(feed)
get() = currentState(feed)

/** Resets the state value of this feed and the streams in it to zero. */
fun resetAll() {
stateManager.scoped(feed).reset()
for (stream in feed.streams) {
stateManager.scoped(stream).reset()
}
}

/** A map of all [StreamRecordConsumer] for this [feed]. */
fun streamRecordConsumers(): Map<StreamIdentifier, StreamRecordConsumer> =
Expand Down Expand Up @@ -98,7 +113,7 @@ sealed class FeedBootstrap<T : Feed>(
}

private val precedingGlobalFeed: Global? =
stateQuerier.feeds
stateManager.feeds
.filterIsInstance<Global>()
.filter { it.streams.contains(stream) }
.firstOrNull()
Expand All @@ -109,7 +124,7 @@ sealed class FeedBootstrap<T : Feed>(
if (feed is Stream && precedingGlobalFeed != null) {
metaFieldDecorator.decorateRecordData(
timestamp = outputConsumer.recordEmittedAt.atOffset(ZoneOffset.UTC),
globalStateValue = stateQuerier.current(precedingGlobalFeed),
globalStateValue = stateManager.scoped(precedingGlobalFeed).current(),
stream,
recordData,
)
Expand Down Expand Up @@ -192,14 +207,14 @@ sealed class FeedBootstrap<T : Feed>(
fun create(
outputConsumer: OutputConsumer,
metaFieldDecorator: MetaFieldDecorator,
stateQuerier: StateQuerier,
stateManager: StateManager,
feed: Feed,
): FeedBootstrap<*> =
when (feed) {
is Global ->
GlobalFeedBootstrap(outputConsumer, metaFieldDecorator, stateQuerier, feed)
GlobalFeedBootstrap(outputConsumer, metaFieldDecorator, stateManager, feed)
is Stream ->
StreamFeedBootstrap(outputConsumer, metaFieldDecorator, stateQuerier, feed)
StreamFeedBootstrap(outputConsumer, metaFieldDecorator, stateManager, feed)
}
}
}
Expand Down Expand Up @@ -241,17 +256,17 @@ enum class FieldValueChange {
class GlobalFeedBootstrap(
outputConsumer: OutputConsumer,
metaFieldDecorator: MetaFieldDecorator,
stateQuerier: StateQuerier,
stateManager: StateManager,
global: Global,
) : FeedBootstrap<Global>(outputConsumer, metaFieldDecorator, stateQuerier, global)
) : FeedBootstrap<Global>(outputConsumer, metaFieldDecorator, stateManager, global)

/** [FeedBootstrap] implementation for [Stream] feeds. */
class StreamFeedBootstrap(
outputConsumer: OutputConsumer,
metaFieldDecorator: MetaFieldDecorator,
stateQuerier: StateQuerier,
stateManager: StateManager,
stream: Stream,
) : FeedBootstrap<Stream>(outputConsumer, metaFieldDecorator, stateQuerier, stream) {
) : FeedBootstrap<Stream>(outputConsumer, metaFieldDecorator, stateManager, stream) {

/** A [StreamRecordConsumer] instance for this [Stream]. */
fun streamRecordConsumer(): StreamRecordConsumer = streamRecordConsumers()[feed.id]!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import io.airbyte.cdk.read.PartitionsCreator.TryAcquireResourcesStatus
interface PartitionsCreatorFactory {
/**
* Returns a [PartitionsCreator] which will cause the READ to advance for the [Feed] for which
* the [FeedBootstrap] argument is associated to. The latter exposes a [StateQuerier] to obtain
* the current [OpaqueStateValue] for this [feed] but may also be used to peek at the state of
* other [Feed]s. This may be useful for synchronizing the READ for this [feed] by waiting for
* other [Feed]s to reach a desired state before proceeding; the waiting may be triggered by
* the [FeedBootstrap] argument is associated to. The latter exposes methods to obtain the
* current [OpaqueStateValue] for this [feed] but also to peek at the state of other [Feed]s.
* This may be useful for synchronizing the READ for this [feed] by waiting for other [Feed]s to
* reach a desired state before proceeding; the waiting may be triggered by
* [PartitionsCreator.tryAcquireResources] or [PartitionReader.tryAcquireResources].
*
* Returns null when the factory is unable to generate a [PartitionsCreator]. This causes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,12 @@ import io.airbyte.protocol.models.v0.AirbyteStateMessage
import io.airbyte.protocol.models.v0.AirbyteStateStats
import io.airbyte.protocol.models.v0.AirbyteStreamState

/** A [StateQuerier] is like a read-only [StateManager]. */
interface StateQuerier {
/** [feeds] is all the [Feed]s in the configured catalog passed via the CLI. */
val feeds: List<Feed>

/** Returns the current state value for the given [feed]. */
fun current(feed: Feed): OpaqueStateValue?

/** Rolls back each feed state. This is required when resyncing CDC from scratch */
fun resetFeedStates()
}

/** Singleton object which tracks the state of an ongoing READ operation. */
class StateManager(
global: Global? = null,
initialGlobalState: OpaqueStateValue? = null,
initialStreamStates: Map<Stream, OpaqueStateValue?> = mapOf(),
) : StateQuerier {
) {
private val global: GlobalStateManager?
private val nonGlobal: Map<StreamIdentifier, NonGlobalStreamStateManager>

Expand All @@ -52,16 +40,14 @@ class StateManager(
}
}

override val feeds: List<Feed> =
/** [feeds] is all the [Feed]s in the configured catalog passed via the CLI. */
val feeds: List<Feed> =
listOfNotNull(this.global?.feed) +
(this.global?.streamStateManagers?.values?.map { it.feed } ?: listOf()) +
nonGlobal.values.map { it.feed }

override fun current(feed: Feed): OpaqueStateValue? = scoped(feed).current()

override fun resetFeedStates() {
feeds.forEach { f -> scoped(f).set(Jsons.objectNode(), 0) }
}
/** Returns the current state value for the given [feed]. */
fun current(feed: Feed): OpaqueStateValue? = scoped(feed).current()

/** Returns a [StateManagerScopedToFeed] instance scoped to this [feed]. */
fun scoped(feed: Feed): StateManagerScopedToFeed =
Expand All @@ -86,6 +72,9 @@ class StateManager(
state: OpaqueStateValue,
numRecords: Long,
)

/** Resets the current state value in the [StateManager] for this [feed] to zero. */
fun reset()
}

/**
Expand Down Expand Up @@ -119,6 +108,13 @@ class StateManager(
pendingNumRecords += numRecords
}

@Synchronized
override fun reset() {
currentStateValue = null
pendingStateValue = null
pendingNumRecords = 0L
}

/**
* Called by [StateManager.checkpoint] to generate the Airbyte STATE messages for the
* checkpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,13 @@ class FeedBootstrapTest {

val global = Global(listOf(stream))

fun stateQuerier(
fun stateManager(
globalStateValue: OpaqueStateValue? = null,
streamStateValue: OpaqueStateValue? = null
): StateQuerier =
object : StateQuerier {
override val feeds: List<Feed> = listOf(global, stream)
): StateManager = StateManager(global, globalStateValue, mapOf(stream to streamStateValue))

override fun current(feed: Feed): OpaqueStateValue? =
when (feed) {
is Global -> globalStateValue
is Stream -> streamStateValue
}

override fun resetFeedStates() {
// no-op
}
}

fun Feed.bootstrap(stateQuerier: StateQuerier): FeedBootstrap<*> =
FeedBootstrap.create(outputConsumer, metaFieldDecorator, stateQuerier, this)
fun Feed.bootstrap(stateManager: StateManager): FeedBootstrap<*> =
FeedBootstrap.create(outputConsumer, metaFieldDecorator, stateManager, this)

fun expected(vararg data: String): List<String> {
val ts = outputConsumer.recordEmittedAt.toEpochMilli()
Expand All @@ -76,7 +63,7 @@ class FeedBootstrapTest {

@Test
fun testGlobalColdStart() {
val globalBootstrap: FeedBootstrap<*> = global.bootstrap(stateQuerier())
val globalBootstrap: FeedBootstrap<*> = global.bootstrap(stateManager())
Assertions.assertNull(globalBootstrap.currentState)
Assertions.assertEquals(1, globalBootstrap.streamRecordConsumers().size)
val (actualStreamID, consumer) = globalBootstrap.streamRecordConsumers().toList().first()
Expand All @@ -91,7 +78,7 @@ class FeedBootstrapTest {
@Test
fun testGlobalWarmStart() {
val globalBootstrap: FeedBootstrap<*> =
global.bootstrap(stateQuerier(globalStateValue = Jsons.objectNode()))
global.bootstrap(stateManager(globalStateValue = Jsons.objectNode()))
Assertions.assertEquals(Jsons.objectNode(), globalBootstrap.currentState)
Assertions.assertEquals(1, globalBootstrap.streamRecordConsumers().size)
val (actualStreamID, consumer) = globalBootstrap.streamRecordConsumers().toList().first()
Expand All @@ -103,10 +90,36 @@ class FeedBootstrapTest {
)
}

@Test
fun testGlobalReset() {
val stateManager: StateManager =
stateManager(
streamStateValue = Jsons.objectNode(),
globalStateValue = Jsons.objectNode()
)
val globalBootstrap: FeedBootstrap<*> = global.bootstrap(stateManager)
Assertions.assertEquals(Jsons.objectNode(), globalBootstrap.currentState)
Assertions.assertEquals(Jsons.objectNode(), globalBootstrap.currentState(stream))
// Reset.
globalBootstrap.resetAll()
Assertions.assertNull(globalBootstrap.currentState)
Assertions.assertNull(globalBootstrap.currentState(stream))
// Set new global state and checkpoint
stateManager.scoped(global).set(Jsons.arrayNode(), 0L)
stateManager.checkpoint().forEach { outputConsumer.accept(it) }
// Check that everything is as it should be.
Assertions.assertEquals(Jsons.arrayNode(), globalBootstrap.currentState)
Assertions.assertNull(globalBootstrap.currentState(stream))
Assertions.assertEquals(
listOf(RESET_STATE),
outputConsumer.states().map(Jsons::writeValueAsString)
)
}

@Test
fun testStreamColdStart() {
val streamBootstrap: FeedBootstrap<*> =
stream.bootstrap(stateQuerier(globalStateValue = Jsons.objectNode()))
stream.bootstrap(stateManager(globalStateValue = Jsons.objectNode()))
Assertions.assertNull(streamBootstrap.currentState)
Assertions.assertEquals(1, streamBootstrap.streamRecordConsumers().size)
val (actualStreamID, consumer) = streamBootstrap.streamRecordConsumers().toList().first()
Expand All @@ -122,7 +135,7 @@ class FeedBootstrapTest {
fun testStreamWarmStart() {
val streamBootstrap: FeedBootstrap<*> =
stream.bootstrap(
stateQuerier(
stateManager(
globalStateValue = Jsons.objectNode(),
streamStateValue = Jsons.arrayNode(),
)
Expand All @@ -140,15 +153,8 @@ class FeedBootstrapTest {

@Test
fun testChanges() {
val stateQuerier =
object : StateQuerier {
override val feeds: List<Feed> = listOf(stream)
override fun current(feed: Feed): OpaqueStateValue? = null
override fun resetFeedStates() {
// no-op
}
}
val streamBootstrap = stream.bootstrap(stateQuerier) as StreamFeedBootstrap
val stateManager = StateManager(initialStreamStates = mapOf(stream to null))
val streamBootstrap = stream.bootstrap(stateManager) as StreamFeedBootstrap
val consumer: StreamRecordConsumer = streamBootstrap.streamRecordConsumer()
val changes =
mapOf(
Expand Down Expand Up @@ -184,5 +190,7 @@ class FeedBootstrapTest {
const val STREAM_RECORD_INPUT_DATA = """{"k":2,"v":"bar"}"""
const val STREAM_RECORD_OUTPUT_DATA =
"""{"k":2,"v":"bar","_ab_cdc_lsn":{},"_ab_cdc_updated_at":"2069-04-20T00:00:00.000000Z","_ab_cdc_deleted_at":null}"""
const val RESET_STATE =
"""{"type":"GLOBAL","global":{"shared_state":[],"stream_states":[{"stream_descriptor":{"name":"tbl","namespace":"ns"},"stream_state":{}}]},"sourceStats":{"recordCount":0.0}}"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CdcPartitionsCreator<T : Comparable<T>>(
)
}
val activeStreams: List<Stream> by lazy {
feedBootstrap.feed.streams.filter { feedBootstrap.stateQuerier.current(it) != null }
feedBootstrap.feed.streams.filter { feedBootstrap.currentState(it) != null }
}
val syntheticOffset: DebeziumOffset by lazy { creatorOps.generateColdStartOffset() }
// Ensure that the WAL position upper bound has been computed for this sync.
Expand Down Expand Up @@ -96,7 +96,7 @@ class CdcPartitionsCreator<T : Comparable<T>>(
// TransientErrorException. The next sync will then snapshot the tables.
resetReason.set(warmStartState.reason)
log.info { "Resetting invalid incumbent CDC state with synthetic state." }
feedBootstrap.stateQuerier.resetFeedStates()
feedBootstrap.resetAll()
debeziumProperties = creatorOps.generateColdStartProperties()
startingOffset = syntheticOffset
startingSchemaHistory = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import io.airbyte.cdk.read.ConfiguredSyncMode
import io.airbyte.cdk.read.Global
import io.airbyte.cdk.read.GlobalFeedBootstrap
import io.airbyte.cdk.read.PartitionReader
import io.airbyte.cdk.read.StateQuerier
import io.airbyte.cdk.read.Stream
import io.airbyte.cdk.util.Jsons
import io.airbyte.protocol.models.v0.StreamDescriptor
Expand All @@ -41,8 +40,6 @@ class CdcPartitionsCreatorTest {

@MockK lateinit var readerOps: CdcPartitionReaderDebeziumOperations<CreatorPosition>

@MockK lateinit var stateQuerier: StateQuerier

@MockK lateinit var globalFeedBootstrap: GlobalFeedBootstrap

val stream =
Expand Down Expand Up @@ -78,9 +75,8 @@ class CdcPartitionsCreatorTest {
@BeforeEach
fun setup() {
every { globalFeedBootstrap.feed } returns global
every { globalFeedBootstrap.stateQuerier } returns stateQuerier
every { globalFeedBootstrap.feeds } returns listOf(global, stream)
every { globalFeedBootstrap.streamRecordConsumers() } returns emptyMap()
every { stateQuerier.feeds } returns listOf(global, stream)
every { creatorOps.position(syntheticOffset) } returns 123L
every { creatorOps.position(incumbentOffset) } returns 123L
every { creatorOps.generateColdStartOffset() } returns syntheticOffset
Expand All @@ -91,7 +87,7 @@ class CdcPartitionsCreatorTest {
@Test
fun testCreateWithSyntheticOffset() {
every { globalFeedBootstrap.currentState } returns null
every { stateQuerier.current(stream) } returns null
every { globalFeedBootstrap.currentState(stream) } returns null
val syntheticOffset = DebeziumOffset(mapOf(Jsons.nullNode() to Jsons.nullNode()))
every { creatorOps.position(syntheticOffset) } returns 123L
every { creatorOps.generateColdStartOffset() } returns syntheticOffset
Expand All @@ -106,7 +102,7 @@ class CdcPartitionsCreatorTest {
@Test
fun testCreateWithDeserializedOffset() {
every { globalFeedBootstrap.currentState } returns Jsons.objectNode()
every { stateQuerier.current(stream) } returns Jsons.objectNode()
every { globalFeedBootstrap.currentState(stream) } returns Jsons.objectNode()
val deserializedState =
ValidDebeziumWarmStartState(offset = incumbentOffset, schemaHistory = null)
every { creatorOps.deserializeState(Jsons.objectNode()) } returns deserializedState
Expand All @@ -121,7 +117,7 @@ class CdcPartitionsCreatorTest {
@Test
fun testCreateNothing() {
every { globalFeedBootstrap.currentState } returns Jsons.objectNode()
every { stateQuerier.current(stream) } returns Jsons.objectNode()
every { globalFeedBootstrap.currentState(stream) } returns Jsons.objectNode()
val deserializedState =
ValidDebeziumWarmStartState(offset = incumbentOffset, schemaHistory = null)
every { creatorOps.deserializeState(Jsons.objectNode()) } returns deserializedState
Expand All @@ -133,7 +129,7 @@ class CdcPartitionsCreatorTest {
@Test
fun testCreateWithFailedValidation() {
every { globalFeedBootstrap.currentState } returns Jsons.objectNode()
every { stateQuerier.current(stream) } returns Jsons.objectNode()
every { globalFeedBootstrap.currentState(stream) } returns Jsons.objectNode()
every { creatorOps.deserializeState(Jsons.objectNode()) } returns
AbortDebeziumWarmStartState("boom")
assertThrows(ConfigErrorException::class.java) { runBlocking { creator.run() } }
Expand Down
Loading

0 comments on commit 14c6c9d

Please sign in to comment.