Skip to content

Commit

Permalink
Destination S3V2: Skip full metadata search when sync mode is append (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Jan 28, 2025
1 parent 7c12755 commit bfba1b4
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 566 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ data class DestinationStream(
}
}
}

fun shouldBeTruncatedAtEndOfSync(): Boolean {
return importType is Overwrite || minimumGenerationId == generationId
}
}

@Singleton
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ class DefaultDestinationStateManager<T : DestinationState>(
}

override suspend fun persistState(stream: DestinationStream) {
val state =
states[stream.descriptor]
?: throw IllegalStateException("State not found for stream $stream")
persister.persist(stream, state)
states[stream.descriptor]?.let { persister.persist(stream, it) }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ class DefaultInputConsumerTask(
is DestinationFileStreamComplete -> {
reserved.release() // safe because multiple calls conflate
manager.markEndOfStream(true)
fileTransferQueue.close()
val envelope =
BatchEnvelope(
SimpleBatch(Batch.State.COMPLETE),
Expand Down Expand Up @@ -196,6 +195,7 @@ class DefaultInputConsumerTask(
} finally {
log.info { "Closing record queues" }
catalog.streams.forEach { recordQueueSupplier.get(it.descriptor).close() }
fileTransferQueue.close()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,242 +4,122 @@

package io.airbyte.cdk.load.state.object_storage

import com.fasterxml.jackson.annotation.JsonIgnore
import com.fasterxml.jackson.annotation.JsonProperty
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.file.object_storage.ObjectStorageClient
import io.airbyte.cdk.load.file.object_storage.PathFactory
import io.airbyte.cdk.load.file.object_storage.RemoteObject
import io.airbyte.cdk.load.state.DestinationState
import io.airbyte.cdk.load.state.DestinationStatePersister
import io.airbyte.cdk.load.state.object_storage.ObjectStorageDestinationState.Companion.OPTIONAL_ORDINAL_SUFFIX_PATTERN
import io.airbyte.cdk.load.util.readIntoClass
import io.airbyte.cdk.load.util.serializeToJsonBytes
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.nio.file.Paths
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.fold
import kotlinx.coroutines.flow.mapNotNull
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock

@SuppressFBWarnings("NP_NONNULL_PARAM_VIOLATION", justification = "Kotlin async continuation")
class ObjectStorageDestinationState(
// (State -> (GenerationId -> (Key -> PartNumber)))
@JsonProperty("generations_by_state")
var generationMap:
ConcurrentHashMap<State, ConcurrentHashMap<Long, ConcurrentHashMap<String, Long>>> =
ConcurrentHashMap(),
@JsonProperty("count_by_key") var countByKey: MutableMap<String, Long> = mutableMapOf()
private val stream: DestinationStream,
private val client: ObjectStorageClient<*>,
private val pathFactory: PathFactory,
) : DestinationState {
enum class State {
STAGED,
FINALIZED
}
private val log = KotlinLogging.logger {}

@JsonIgnore private val countByKeyLock = Mutex()
private val countByKey: ConcurrentHashMap<String, AtomicLong> = ConcurrentHashMap()
private val fileNumbersByPath: ConcurrentHashMap<String, AtomicLong> = ConcurrentHashMap()
private val matcher =
pathFactory.getPathMatcher(stream, suffixPattern = OPTIONAL_ORDINAL_SUFFIX_PATTERN)

companion object {
const val METADATA_GENERATION_ID_KEY = "ab-generation-id"
const val STREAM_NAMESPACE_KEY = "ab-stream-namespace"
const val STREAM_NAME_KEY = "ab-stream-name"
const val OPTIONAL_ORDINAL_SUFFIX_PATTERN = "(-[0-9]+)?"

fun metadataFor(stream: DestinationStream): Map<String, String> =
mapOf(METADATA_GENERATION_ID_KEY to stream.generationId.toString())
}

suspend fun addObject(
generationId: Long,
key: String,
partNumber: Long?,
isStaging: Boolean = false
) {
val state = if (isStaging) State.STAGED else State.FINALIZED
generationMap
.getOrPut(state) { ConcurrentHashMap() }
.getOrPut(generationId) { ConcurrentHashMap() }[key] = partNumber ?: 0L
}

suspend fun removeObject(generationId: Long, key: String, isStaging: Boolean = false) {
val state = if (isStaging) State.STAGED else State.FINALIZED
generationMap[state]?.get(generationId)?.remove(key)
}

suspend fun dropGenerationsBefore(minimumGenerationId: Long) {
State.entries.forEach { state ->
(0 until minimumGenerationId).forEach { generationMap[state]?.remove(it) }
/**
* Returns (generationId, object) for all objects that should be cleaned up.
*
* "should be cleaned up" means
* * stream.shouldBeTruncatedAtEndOfSync() is true
* * object's generation id exists and is less than stream.minimumGenerationId
*/
suspend fun getObjectsToDelete(): List<Pair<Long, RemoteObject<*>>> {
if (!stream.shouldBeTruncatedAtEndOfSync()) {
return emptyList()
}
}

data class Generation(
val isStaging: Boolean,
val generationId: Long,
val objects: List<ObjectAndPart>
)

data class ObjectAndPart(
val key: String,
val partNumber: Long,
)

suspend fun getGenerations(): Sequence<Generation> =
generationMap.entries
.asSequence()
.map { (state, gens) ->
val isStaging = state == State.STAGED
gens.map { (generationId, objects) ->
Generation(
isStaging,
generationId,
objects.map { (key, partNumber) -> ObjectAndPart(key, partNumber) }
)
return client
.list(pathFactory.getLongestStreamConstantPrefix(stream, isStaging = false))
.filter { matcher.match(it.key) != null }
.mapNotNull { obj ->
val generationId =
client.getMetadata(obj.key)[METADATA_GENERATION_ID_KEY]?.toLongOrNull() ?: 0L
if (generationId < stream.minimumGenerationId) {
Pair(generationId, obj)
} else {
null
}
}
.flatten()

suspend fun getNextPartNumber(): Long =
getGenerations().flatMap { it.objects }.map { it.partNumber }.maxOrNull()?.plus(1) ?: 0L

/** Returns generationId -> objectAndPart for all staged objects that should be kept. */
suspend fun getStagedObjectsToFinalize(
minimumGenerationId: Long
): Sequence<Pair<Long, ObjectAndPart>> =
getGenerations()
.filter { it.isStaging && it.generationId >= minimumGenerationId }
.flatMap { it.objects.map { obj -> it.generationId to obj } }
.toList()
}

/**
* Returns generationId -> objectAndPart for all objects (staged and unstaged) that should be
* cleaned up.
* Ensures the key is unique by appending `-${max_suffix + 1}` if there is a conflict. If the
* key is unique, it is returned as-is.
*/
suspend fun getObjectsToDelete(minimumGenerationId: Long): Sequence<Pair<Long, ObjectAndPart>> {
val (toKeep, toDrop) = getGenerations().partition { it.generationId >= minimumGenerationId }
val keepKeys = toKeep.flatMap { it.objects.map { obj -> obj.key } }.toSet()
return toDrop.asSequence().flatMap {
it.objects.filter { obj -> obj.key !in keepKeys }.map { obj -> it.generationId to obj }
}
}

/** Used to guarantee the uniqueness of a key */
suspend fun ensureUnique(key: String): String {
val ordinal =
countByKeyLock.withLock {
countByKey.merge(key, 0L) { old, new -> maxOf(old + 1, new) }
}
?: 0L
return if (ordinal > 0L) {
"$key-$ordinal"
} else {
val count =
countByKey
.getOrPut(key) {
client
.list(key)
.mapNotNull { matcher.match(it.key) }
.fold(-1L) { acc, match ->
maxOf(match.customSuffix?.removePrefix("-")?.toLongOrNull() ?: 0L, acc)
}
.let { AtomicLong(it) }
}
.incrementAndGet()

return if (count == 0L) {
key
} else {
"$key-$count"
}
}
}

@SuppressFBWarnings("NP_NONNULL_PARAM_VIOLATION", justification = "Kotlin async continuation")
class ObjectStorageStagingPersister(
private val client: ObjectStorageClient<*>,
private val pathFactory: PathFactory
) : DestinationStatePersister<ObjectStorageDestinationState> {
private val log = KotlinLogging.logger {}
private val fallbackPersister = ObjectStorageFallbackPersister(client, pathFactory)

companion object {
const val STATE_FILENAME = "__airbyte_state.json"
}

private fun keyFor(stream: DestinationStream): String =
Paths.get(pathFactory.getStagingDirectory(stream), STATE_FILENAME).toString()

override suspend fun load(stream: DestinationStream): ObjectStorageDestinationState {
val key = keyFor(stream)
try {
log.info { "Loading destination state from $key" }
return client.get(key) { inputStream ->
inputStream.readIntoClass(ObjectStorageDestinationState::class.java)
}
} catch (e: Exception) {
log.info { "No destination state found at $key: $e; falling back to metadata search" }
return fallbackPersister.load(stream)
/** Returns a shared atomic long referencing the max {part_number} for any given path. */
suspend fun getPartIdCounter(path: String): AtomicLong {
return fileNumbersByPath.getOrPut(path) {
client
.list(path)
.mapNotNull { matcher.match(it.key) }
.fold(-1L) { acc, match -> maxOf(match.partNumber ?: 0L, acc) }
.let { AtomicLong(it) }
}
}

override suspend fun persist(stream: DestinationStream, state: ObjectStorageDestinationState) {
client.put(keyFor(stream), state.serializeToJsonBytes())
}
}

/**
* Note: there's no persisting yet. This will require either a client-provided path to store data or
* a guaranteed sortable set of file names so that we can send the high watermark to the platform.
*/
@SuppressFBWarnings("NP_NONNULL_PARAM_VIOLATION", justification = "Kotlin async continuation")
@Singleton
class ObjectStorageFallbackPersister(
private val client: ObjectStorageClient<*>,
private val pathFactory: PathFactory
) : DestinationStatePersister<ObjectStorageDestinationState> {
private val log = KotlinLogging.logger {}
override suspend fun load(stream: DestinationStream): ObjectStorageDestinationState {
// Add a suffix matching an OPTIONAL -[0-9]+ ordinal
val matcher =
pathFactory.getPathMatcher(stream, suffixPattern = OPTIONAL_ORDINAL_SUFFIX_PATTERN)
val longestUnambiguous =
pathFactory.getLongestStreamConstantPrefix(stream, isStaging = false)
log.info {
"Searching path $longestUnambiguous (matching ${matcher.regex}) for destination state metadata"
}
val matches = client.list(longestUnambiguous).mapNotNull { matcher.match(it.key) }.toList()

/* Initialize the unique key counts. */
val countByKey = mutableMapOf<String, Long>()
matches.forEach {
val key = it.path.replace(Regex("-[0-9]+$"), "")
val ordinal = it.customSuffix?.substring(1)?.toLongOrNull() ?: 0
countByKey.merge(key, ordinal) { a, b -> maxOf(a, b) }
}

/* Build (generationId -> (key -> fileNumber)). */
val generationIdToKeyAndFileNumber =
ConcurrentHashMap(
matches
.groupBy {
client
.getMetadata(it.path)[
ObjectStorageDestinationState.METADATA_GENERATION_ID_KEY]
?.toLong()
?: 0L
}
.mapValues { (_, matches) ->
ConcurrentHashMap(matches.associate { it.path to (it.partNumber ?: 0L) })
}
)

return ObjectStorageDestinationState(
ConcurrentHashMap(
mapOf(
ObjectStorageDestinationState.State.FINALIZED to generationIdToKeyAndFileNumber
)
),
countByKey
)
return ObjectStorageDestinationState(stream, client, pathFactory)
}

override suspend fun persist(stream: DestinationStream, state: ObjectStorageDestinationState) {
// No-op; state is persisted when the generation id is set on the object metadata
}
}

@Factory
class ObjectStorageDestinationStatePersisterFactory<T : RemoteObject<*>>(
private val client: ObjectStorageClient<T>,
private val pathFactory: PathFactory
) {
@Singleton
@Secondary
fun create(): DestinationStatePersister<ObjectStorageDestinationState> =
if (pathFactory.supportsStaging) {
ObjectStorageStagingPersister(client, pathFactory)
} else {
ObjectStorageFallbackPersister(client, pathFactory)
}
}
Loading

0 comments on commit bfba1b4

Please sign in to comment.