Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache policies to CachedTool #792

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,65 @@ import kotlin.time.Duration.Companion.days

data class CachedToolKey<K>(val value: K, val seed: String)

data class CachedToolValue<V>(val value: V, val timestamp: Long)
data class CachedToolValue<V>(val value: V, val accessTimestamp: Long, val writeTimestamp: Long) {
fun withAccessTimestamp() = copy(accessTimestamp = timeInMillis())

companion object {
fun <V> withActualResponse(response: V): CachedToolValue<V> =
CachedToolValue(
value = response,
accessTimestamp = timeInMillis(),
writeTimestamp = timeInMillis()
)
}
}

data class CachedToolConfig(
val timeCachePolicy: Duration,
val cacheExpirationPolicy: CacheExpirationPolicy,
val cacheEvictionPolicy: CacheEvictionPolicy
) {

/** Policy to expire the entries in the cache, based on last access or last write time. */
enum class CacheExpirationPolicy {
/** Last access time is used to determine expiration */
ACCESS,
/** Last write time is used to determine expiration */
WRITE
}

/** Policy to evict the expired entries from the cache, based on one or all expired entries. */
enum class CacheEvictionPolicy {
/** Removes the expired entry when found */
SINGLE,
/** Removes all expired entries when one expired entry found */
ALL
}

companion object {
val Default =
CachedToolConfig(
timeCachePolicy = 1.days,
cacheEvictionPolicy = CacheEvictionPolicy.ALL,
cacheExpirationPolicy = CacheExpirationPolicy.WRITE
)
}
}

/**
* Tool that caches the result of the execution of [onCacheMissed] if [shouldUseCache] returns true.
* Otherwise, returns the result of [onCacheMissed]. This output is added to the cache when
* [shouldCacheOutput] returns true.
*
* Cache is stored in a [Map] of [CachedToolKey] to [CachedToolValue].
*
* Supports expiration policies using [CachedToolConfig].
*/
abstract class CachedTool<Input, Output>(
private val cache: Atomic<MutableMap<CachedToolKey<Input>, CachedToolValue<Output>>>,
private val seed: String,
private val timeCachePolicy: Duration = 1.days
private val config: CachedToolConfig = CachedToolConfig.Default
) : Tool<Input, Output> {

/**
* Logic to be executed when the cache is missed.
*
Expand Down Expand Up @@ -49,44 +100,59 @@ abstract class CachedTool<Input, Output>(
else onCacheMissed(input)

/**
* Exposes the cache as a [Map] of [Input] to [Output] filtered by instance [seed] and
* [timeCachePolicy]. Removes expired cache entries.
* Returns a snapshot of the cache as a [Map] of [Input] to [Output] filtered by instance [seed]
* and removing expired cache entries with the given [config] policies. Does not modify the cache.
*
* @return the map of input to output.
*/
suspend fun getCache(): Map<Input, Output> {
val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds
val withoutExpired =
suspend fun getValidCacheSnapshot(): Map<Input, Output> {
val validEntries =
cache.modify { cachedToolInfo ->
// Filter entries belonging to the current seed and have not expired
val validEntries =
cachedToolInfo
.filter { (key, value) ->
if (key.seed == seed) lastTimeInCache <= value.timestamp else true
}
.toMutableMap()
// Remove expired entries for the current seed only
cachedToolInfo.keys.removeAll { key -> key.seed == seed && !validEntries.containsKey(key) }
// Modifies state A, and returns state B
val validEntries = cachedToolInfo.filterExpired().filter { (key, _) -> key.seed == seed }
Pair(cachedToolInfo, validEntries)
}
return withoutExpired.map { it.key.value to it.value.value }.toMap()
return validEntries.map { it.key.value to it.value.value }.toMap()
}

private suspend fun cache(input: CachedToolKey<Input>, block: suspend () -> Output): Output {
val cachedToolInfo = cache.get()[input]
if (cachedToolInfo != null) {
val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds
if (lastTimeInCache > cachedToolInfo.timestamp) {
cache.get().remove(input)
} else {
return cachedToolInfo.value
}
private suspend fun cache(input: CachedToolKey<Input>, block: suspend () -> Output): Output =
cache.modify { cachedToolInfo ->
cachedToolInfo[input]?.let { output ->
if (output.isExpired()) {
val updatedCache =
when (config.cacheEvictionPolicy) {
CachedToolConfig.CacheEvictionPolicy.SINGLE -> cachedToolInfo.apply { remove(input) }
CachedToolConfig.CacheEvictionPolicy.ALL -> cachedToolInfo.filterExpired()
}
Pair(updatedCache, null)
} else {
val updatedOutput = output.withAccessTimestamp()
Pair(cachedToolInfo, updatedOutput.value)
}
} ?: Pair(cachedToolInfo, null)
}
val response = block()
if (shouldCacheOutput(input.value, response)) {
cache.get()[input] = CachedToolValue(response, timeInMillis())
?: run {
val response = block()
if (shouldCacheOutput(input.value, response)) {
cache.update { cachedToolInfo ->
cachedToolInfo[input] = CachedToolValue.withActualResponse(response)
cachedToolInfo
}
}
response
}

private fun MutableMap<CachedToolKey<Input>, CachedToolValue<Output>>.filterExpired() =
this.filter { (_, value) -> !value.isExpired() }.toMutableMap()

private fun CachedToolValue<Output>.isExpired(): Boolean =
when (config.cacheExpirationPolicy) {
CachedToolConfig.CacheExpirationPolicy.ACCESS -> {
val lastTimeInCache = timeInMillis() - accessTimestamp
lastTimeInCache > config.timeCachePolicy.inWholeMilliseconds
}
CachedToolConfig.CacheExpirationPolicy.WRITE -> {
val lastTimeInCache = timeInMillis() - writeTimestamp
lastTimeInCache > config.timeCachePolicy.inWholeMilliseconds
}
}
return response
}
}
Loading