Skip to content

Commit

Permalink
Merge pull request #2 from amishra-u/hot_keys_2
Browse files Browse the repository at this point in the history
feat: Hot CAS Entries - Update read counts in Redis
  • Loading branch information
amishra-u authored Nov 8, 2023
2 parents f5ec1e0 + 77003d7 commit 28c89a1
Show file tree
Hide file tree
Showing 13 changed files with 297 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
6.1.2
6.2.0
1 change: 1 addition & 0 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def buildfarm_init(name = "buildfarm"):
"com.fasterxml.jackson.core:jackson-databind:2.15.0",
"com.github.ben-manes.caffeine:caffeine:2.9.0",
"com.github.docker-java:docker-java:3.2.11",
"com.github.fppt:jedis-mock:1.0.10",
"com.github.jnr:jffi:1.2.16",
"com.github.jnr:jffi:jar:native:1.2.16",
"com.github.jnr:jnr-constants:0.9.9",
Expand Down
18 changes: 18 additions & 0 deletions src/main/java/build/buildfarm/backplane/Backplane.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
import net.jcip.annotations.ThreadSafe;

@ThreadSafe
Expand Down Expand Up @@ -278,4 +279,21 @@ boolean pollOperation(QueueEntry queueEntry, ExecutionStage.Value stage, long re
Boolean propertiesEligibleForQueue(List<Platform.Property> provisions);

GetClientStartTimeResult getClientStartTime(GetClientStartTimeRequest request) throws IOException;

/**
* Updates the read count for CAS entries based on the provided stream of digest and count.
*
* @param casReadCountStream A Stream of Digest and its corresponding read count.
* @return A Map containing the updated read counts for the specified CAS entries.
*/
Map<String, Integer> updateCasReadCount(Stream<Map.Entry<Digest, Integer>> casReadCountStream)
throws IOException;

/**
* Removes the CAS read count entries from the storage.
*
* @param digestsToBeRemoved CAS entries for which each read count needs to be removed.
* @return total count of cas read count entries removed.
*/
int removeCasReadCountEntries(Stream<Digest> digestsToBeRemoved) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,26 @@

import build.bazel.remote.execution.v2.Digest;
import build.buildfarm.backplane.Backplane;
import build.buildfarm.common.DigestUtil;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import java.time.Duration;
import java.time.Instant;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Level;
import lombok.extern.java.Log;

/**
Expand All @@ -44,7 +48,6 @@ public final class CASAccessMetricsRecorder {

private boolean running = false;

@SuppressWarnings("unused")
private final Backplane backplane;

final ReadWriteLock lock = new ReentrantReadWriteLock();
Expand Down Expand Up @@ -166,7 +169,13 @@ private void updateReadCount() {
effectiveReadCount.compute(
k, (digest, readCount) -> readCount == null ? -v.get() : readCount - v.get()));

// TODO : Implement logic to update read counts in redis.
try {
Map<String, Integer> updatedReadCount =
backplane.updateCasReadCount(effectiveReadCount.entrySet().stream());
expireUnreadEntries(updatedReadCount, firstIntervalReadCount);
} catch (Exception e) {
log.log(Level.WARNING, "Failed to update the cas read count to backplane", e);
}

long timeToUpdate = stopwatch.stop().elapsed().toMillis();
log.fine(format("Took %d ms to update read count.", timeToUpdate));
Expand All @@ -177,6 +186,33 @@ private void updateReadCount() {
}
}

private void expireUnreadEntries(
Map<String, Integer> casReadCount, Map<Digest, AtomicInteger> firstIntervalReadCount) {
Set<Digest> digestsToExpire = new HashSet<>();
firstIntervalReadCount.forEach(
(digest, readCount) -> {
Integer totalReadCount = casReadCount.get(digest.getHash());
if (totalReadCount != null && totalReadCount == 0) {
digestsToExpire.add(digest);
if (readCount.get() == 0) {
log.info(
format(
"Digest %s was not read once after being written in last %s duration",
DigestUtil.toString(digest), casEntryReadCountWindow.toString()));
}
}
});
if (digestsToExpire.isEmpty()) {
return;
}
try {
int removedCount = backplane.removeCasReadCountEntries(digestsToExpire.stream());
log.fine(format("Number of cas read count entries removed : %d", removedCount));
} catch (Exception e) {
log.log(Level.WARNING, "Failed to remove cas read count entries", e);
}
}

@VisibleForTesting
Deque<Map<Digest, AtomicInteger>> getReadIntervalCountQueue() {
return readIntervalCountQueue;
Expand Down
1 change: 1 addition & 0 deletions src/main/java/build/buildfarm/common/config/Backplane.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public enum BACKPLANE_TYPE {
private String operationChannelPrefix = "OperationChannel";
private String casPrefix = "ContentAddressableStorage";
private int casExpire = 604800; // 1 Week
private String casReadCountSetName = "CasReadCount";

@Getter(AccessLevel.NONE)
private boolean subscribeToBackplane = true; // deprecated
Expand Down
84 changes: 84 additions & 0 deletions src/main/java/build/buildfarm/common/redis/RedisSortedSet.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package build.buildfarm.common.redis;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import redis.clients.jedis.JedisCluster;
import redis.clients.jedis.JedisClusterPipeline;
import redis.clients.jedis.Response;

/**
* A redis sorted set is an implementation of a concurrent skip list set data structure which
* internally redis to store and distribute the data. It's important to know that the lifetime of
* the set persists before and after this data structure is created (since it exists in redis).
* Therefore, two redis sets with the same name, would in fact be the same underlying redis sets.
*/
public class RedisSortedSet {

/**
* The name is used by the redis cluster client to access the set data. If two set had the same
* name, they would be instances of the same underlying redis set.
*/
private final String name;

/**
* Construct a named redis set with an established redis cluster.
*
* @param name The global name of the set.
*/
public RedisSortedSet(String name) {
this.name = name;
}

/**
* Increments scores for members in a sorted set using a JedisCluster client.
*
* <p>This method increments scores for the specified members in the sorted set. If a member does
* not exist, it is added with the given score. The operation is atomic, ensuring consistency in
* the sorted set.
*
* @param jedis JedisCluster client to interact with Redis cluster.
* @param memberAndScore A map where keys are member names and values are the increment scores.
* @return A map containing updated scores for each member after the increment operation.
*/
public Map<String, Integer> incrementMembersScore(
JedisCluster jedis, Stream<Map.Entry<String, Integer>> memberAndScore) {
JedisClusterPipeline pipeline = jedis.pipelined();
Stream<AbstractMap.SimpleEntry<String, Response<Double>>> updatedScoreResponse =
memberAndScore.map(
entry ->
new AbstractMap.SimpleEntry<>(
entry.getKey(), pipeline.zincrby(this.name, entry.getValue(), entry.getKey())));
pipeline.sync();
// keep the last score for a key.
return updatedScoreResponse.collect(
Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().get().intValue()));
}

/**
* Removes the specified members from the sorted set and returns the count of removed members.
* Removal is performed in batches for improved performance.
*
* @param jedis JedisCluster client to interact with Redis cluster.
* @param members A stream of members to be removed from the set.
* @return total count of members removed.
*/
public int removeMembers(JedisCluster jedis, Stream<String> members) {
Iterator<String> iterator = members.iterator();
int batchSize = 128;
int membersRemoved = 0;
while (true) {
List<String> batch = new ArrayList<>(batchSize);
for (int i = 0; i < batchSize && iterator.hasNext(); i++) {
batch.add(iterator.next());
}
if (batch.isEmpty()) break;
membersRemoved += jedis.zrem(this.name, batch.toArray(new String[0]));
}
return membersRemoved;
}
}
10 changes: 10 additions & 0 deletions src/main/java/build/buildfarm/instance/shard/DistributedState.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import build.buildfarm.common.redis.BalancedRedisQueue;
import build.buildfarm.common.redis.RedisHashMap;
import build.buildfarm.common.redis.RedisMap;
import build.buildfarm.common.redis.RedisSortedSet;

/**
* @class DistributedState
Expand Down Expand Up @@ -127,4 +128,13 @@ public class DistributedState {
* and will refuse to run again.
*/
public RedisMap blockedActions;

/**
* @field casReadCount
* @brief Maintains the read count for each CAS entry in a sorted set.
* @details This data structure tracks the number of times CAS entries are accessed, allowing us
* to identify frequently accessed CAS entries. This information can be used to distribute hot
* keys across multiple workers or decide which CAS entries to keep in memory etc.
*/
public RedisSortedSet casReadCount;
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import build.buildfarm.common.redis.RedisHashtags;
import build.buildfarm.common.redis.RedisMap;
import build.buildfarm.common.redis.RedisNodeHashes;
import build.buildfarm.common.redis.RedisSortedSet;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.SetMultimap;
Expand Down Expand Up @@ -58,6 +59,7 @@ public static DistributedState create(RedisClient client) throws IOException {
new RedisHashMap(configs.getBackplane().getWorkersHashName() + "_execute");
state.storageWorkers =
new RedisHashMap(configs.getBackplane().getWorkersHashName() + "_storage");
state.casReadCount = new RedisSortedSet(configs.getBackplane().getCasReadCountSetName());

return state;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,25 @@ public Map<Digest, Set<String>> getBlobDigestsWorkers(Iterable<Digest> blobDiges
return state.casWorkerMap.getMap(client, blobDigests);
}

@Override
public Map<String, Integer> updateCasReadCount(
Stream<Map.Entry<Digest, Integer>> casReadCountStream) throws IOException {
return client.call(
jedis ->
state.casReadCount.incrementMembersScore(
jedis,
casReadCountStream.map(
entry ->
new AbstractMap.SimpleEntry<>(
entry.getKey().getHash(), entry.getValue()))));
}

@Override
public int removeCasReadCountEntries(Stream<Digest> digestsToBeRemoved) throws IOException {
return client.call(
jedis -> state.casReadCount.removeMembers(jedis, digestsToBeRemoved.map(Digest::getHash)));
}

public static WorkerChange parseWorkerChange(String workerChangeJson)
throws InvalidProtocolBufferException {
WorkerChange.Builder workerChange = WorkerChange.newBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.verify;
import static org.mockito.internal.verification.VerificationModeFactory.times;

import build.bazel.remote.execution.v2.Digest;
import build.buildfarm.backplane.Backplane;
Expand Down Expand Up @@ -93,6 +97,10 @@ public void testRecordRead() throws IOException, InterruptedException {
for (Thread thread : threads) {
thread.join();
}

verify(backplane, atLeast(numberOfDelayCycles + (int) (window / delay)))
.updateCasReadCount(any());
verify(backplane, times(0)).removeCasReadCountEntries(any());
}

@Test
Expand Down
2 changes: 2 additions & 0 deletions src/test/java/build/buildfarm/common/redis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ COMMON_DEPS = [
"//src/main/protobuf:build_buildfarm_v1test_buildfarm_java_proto",
"//src/test/java/build/buildfarm:test_runner",
"//third_party/jedis",
"@maven//:com_github_fppt_jedis_mock",
"@maven//:com_google_truth_truth",
"@maven//:io_grpc_grpc_api",
"@maven//:org_mockito_mockito_core",
Expand All @@ -17,6 +18,7 @@ NATIVE_REDIS_TESTS = [
"RedisQueueTest.java",
"RedisPriorityQueueTest.java",
"RedisHashMapTest.java",
"RedisSortedSetTest.java",
]

java_test(
Expand Down
49 changes: 49 additions & 0 deletions src/test/java/build/buildfarm/common/redis/RedisSortedSetTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package build.buildfarm.common.redis;

import static com.google.common.truth.Truth.assertThat;

import build.buildfarm.common.config.BuildfarmConfigs;
import build.buildfarm.instance.shard.JedisClusterFactory;
import java.io.IOException;
import java.util.Map;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import redis.clients.jedis.JedisCluster;

@RunWith(JUnit4.class)
public class RedisSortedSetTest {
private BuildfarmConfigs configs = BuildfarmConfigs.getInstance();
private JedisCluster redis;
private RedisSortedSet redisSortedSet;

@Before
public void setUp() throws Exception {
this.configs.getBackplane().setRedisUri("redis://localhost:6379");
this.redis = JedisClusterFactory.createTest();
this.redisSortedSet = new RedisSortedSet("test_sorted_set");
}

@After
public void tearDown() throws IOException {
this.redis.close();
}

@Test
public void testIncrementMembersScore() {
Map<String, Integer> membersScoreInput =
Map.of("key1", 100, "key2", 50, "key3", 75, "key4", 25, "key5", 85);

Map<String, Integer> membersScoreResponse =
redisSortedSet.incrementMembersScore(redis, membersScoreInput.entrySet().stream());

membersScoreResponse.forEach(
(member, score) -> assertThat(membersScoreInput.get(member)).isEqualTo(score));
membersScoreResponse =
redisSortedSet.incrementMembersScore(redis, membersScoreInput.entrySet().stream());
membersScoreResponse.forEach(
(member, score) -> assertThat(2 * membersScoreInput.get(member)).isEqualTo(score));
}
}
Loading

0 comments on commit 28c89a1

Please sign in to comment.