diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java index 8436181d6a9239..5474f884233832 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java @@ -13,31 +13,39 @@ // limitations under the License. package com.google.devtools.build.lib.remote; -import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture; -import static com.google.devtools.build.lib.remote.util.Utils.waitForBulkTransfer; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static com.google.devtools.build.lib.remote.util.RxFutures.toCompletable; +import static com.google.devtools.build.lib.remote.util.RxFutures.toSingle; +import static com.google.devtools.build.lib.remote.util.RxUtils.mergeBulkTransfer; +import static com.google.devtools.build.lib.remote.util.RxUtils.toTransferResult; import static java.lang.String.format; import build.bazel.remote.execution.v2.Digest; import build.bazel.remote.execution.v2.Directory; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; import com.google.devtools.build.lib.remote.common.RemoteCacheClient; import com.google.devtools.build.lib.remote.merkletree.MerkleTree; import com.google.devtools.build.lib.remote.merkletree.MerkleTree.PathOrBytes; import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.build.lib.remote.util.DigestUtil; -import com.google.devtools.build.lib.remote.util.RxFutures; +import com.google.devtools.build.lib.remote.util.RxUtils.TransferResult; import com.google.protobuf.Message; import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.subjects.AsyncSubject; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.HashSet; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.annotation.concurrent.GuardedBy; /** A {@link RemoteCache} with additional functionality needed for remote execution. */ public class RemoteExecutionCache extends RemoteCache { @@ -73,62 +81,58 @@ public void ensureInputsPresent( .addAll(additionalInputs.keySet()) .build(); - // Collect digests that are not being or already uploaded - ConcurrentHashMap> missingDigestSubjects = - new ConcurrentHashMap<>(); - - List> uploadFutures = new ArrayList<>(); - for (Digest digest : allDigests) { - Completable upload = - casUploadCache.execute( - digest, - Completable.defer( - () -> { - // The digest hasn't been processed, add it to the collection which will be used - // later for findMissingDigests call - AsyncSubject missingDigestSubject = AsyncSubject.create(); - missingDigestSubjects.put(digest, missingDigestSubject); - - return missingDigestSubject.flatMapCompletable( - missing -> { - if (!missing) { - return Completable.complete(); - } - return RxFutures.toCompletable( - () -> uploadBlob(context, digest, merkleTree, additionalInputs), - MoreExecutors.directExecutor()); - }); - }), - force); - uploadFutures.add(RxFutures.toListenableFuture(upload)); + if (allDigests.isEmpty()) { + return; } - ImmutableSet missingDigests; - try { - missingDigests = getFromFuture(findMissingDigests(context, missingDigestSubjects.keySet())); - } catch (IOException | InterruptedException e) { - for (Map.Entry> entry : missingDigestSubjects.entrySet()) { - entry.getValue().onError(e); - } + MissingDigestFinder missingDigestFinder = new MissingDigestFinder(context, allDigests.size()); + Flowable uploads = + Flowable.fromIterable(allDigests) + .flatMapSingle( + digest -> + uploadBlobIfMissing( + context, merkleTree, additionalInputs, force, missingDigestFinder, digest)); - if (e instanceof InterruptedException) { - Thread.currentThread().interrupt(); + try { + mergeBulkTransfer(uploads).blockingAwait(); + } catch (RuntimeException e) { + Throwable cause = e.getCause(); + if (cause != null) { + Throwables.throwIfInstanceOf(cause, InterruptedException.class); + Throwables.throwIfInstanceOf(cause, IOException.class); } throw e; } + } - for (Map.Entry> entry : missingDigestSubjects.entrySet()) { - AsyncSubject missingSubject = entry.getValue(); - if (missingDigests.contains(entry.getKey())) { - missingSubject.onNext(true); - } else { - // The digest is already existed in the remote cache, skip the upload. - missingSubject.onNext(false); - } - missingSubject.onComplete(); - } - - waitForBulkTransfer(uploadFutures, /* cancelRemainingOnInterrupt=*/ false); + private Single uploadBlobIfMissing( + RemoteActionExecutionContext context, + MerkleTree merkleTree, + Map additionalInputs, + boolean force, + MissingDigestFinder missingDigestFinder, + Digest digest) { + Completable upload = + casUploadCache.execute( + digest, + Completable.defer( + () -> + // Only reach here if the digest is missing and is not being uploaded. + missingDigestFinder + .registerAndCount(digest) + .flatMapCompletable( + missingDigests -> { + if (missingDigests.contains(digest)) { + return toCompletable( + () -> uploadBlob(context, digest, merkleTree, additionalInputs), + directExecutor()); + } else { + return Completable.complete(); + } + })), + /* onIgnored= */ missingDigestFinder::count, + force); + return toTransferResult(upload); } private ListenableFuture uploadBlob( @@ -160,4 +164,93 @@ private ListenableFuture uploadBlob( "findMissingDigests returned a missing digest that has not been requested: %s", digest))); } + + /** + * A missing digest finder that initiates the request when the internal counter reaches an + * expected count. + */ + class MissingDigestFinder { + private final int expectedCount; + + private final AsyncSubject> digestsSubject; + private final Single> resultSingle; + + @GuardedBy("this") + private final Set digests; + + @GuardedBy("this") + private int currentCount = 0; + + MissingDigestFinder(RemoteActionExecutionContext context, int expectedCount) { + checkArgument(expectedCount > 0, "expectedCount should be greater than 0"); + this.expectedCount = expectedCount; + this.digestsSubject = AsyncSubject.create(); + this.digests = new HashSet<>(); + + AtomicBoolean findMissingDigestsCalled = new AtomicBoolean(false); + this.resultSingle = + Single.fromObservable( + digestsSubject + .flatMapSingle( + digests -> { + boolean wasCalled = findMissingDigestsCalled.getAndSet(true); + // Make sure we don't have re-subscription caused by refCount() below. + checkState(!wasCalled, "FindMissingDigests is called more than once"); + return toSingle( + () -> findMissingDigests(context, digests), directExecutor()); + }) + // Use replay here because we could have a race condition that downstream hasn't + // been added to the subscription list (to receive the upstream result) while + // upstream is completed. + .replay(1) + .refCount()); + } + + /** + * Register the {@code digest} and increase the counter. + * + *

Returned Single cannot be subscribed more than once. + * + * @return Single that emits the result of the {@code FindMissingDigest} request. + */ + Single> registerAndCount(Digest digest) { + AtomicBoolean subscribed = new AtomicBoolean(false); + // count() will potentially trigger the findMissingDigests call. Adding and counting before + // returning the Single could introduce a race that the result of findMissingDigests is + // available but the consumer doesn't get it because it hasn't subscribed the returned + // Single. In this case, it subscribes after upstream is completed resulting a re-run of + // findMissingDigests (due to refCount()). + // + // Calling count() inside doOnSubscribe to ensure the consumer already subscribed to the + // returned Single to avoid a re-execution of findMissingDigests. + return resultSingle.doOnSubscribe( + d -> { + boolean wasSubscribed = subscribed.getAndSet(true); + checkState(!wasSubscribed, "Single is subscribed more than once"); + synchronized (this) { + digests.add(digest); + } + count(); + }); + } + + /** Increase the counter. */ + void count() { + ImmutableSet digestsResult = null; + + synchronized (this) { + if (currentCount < expectedCount) { + currentCount++; + if (currentCount == expectedCount) { + digestsResult = ImmutableSet.copyOf(digests); + } + } + } + + if (digestsResult != null) { + digestsSubject.onNext(digestsResult); + digestsSubject.onComplete(); + } + } + } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java index 8fb6f4ce20d49f..31369ef4ee1eab 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java @@ -24,6 +24,7 @@ import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.SingleObserver; import io.reactivex.rxjava3.disposables.Disposable; +import io.reactivex.rxjava3.functions.Action; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -256,14 +257,25 @@ public boolean isDisposed() { /** * Executes a task. * + * @see #execute(Object, Single, Action, boolean). + */ + public Single execute(KeyT key, Single task, boolean force) { + return execute(key, task, () -> {}, force); + } + + /** + * Executes a task. If the task has already finished, this execution of the task is ignored unless + * `force` is true. If the task is in progress this execution of the task is always ignored. + * *

If the cache is already shutdown, a {@link CancellationException} will be emitted. * * @param key identifies the task. + * @param onIgnored callback called when provided task is ignored. * @param force re-execute a finished task if set to {@code true}. * @return a {@link Single} which turns to completed once the task is finished or propagates the * error if any. */ - public Single execute(KeyT key, Single task, boolean force) { + public Single execute(KeyT key, Single task, Action onIgnored, boolean force) { return Single.create( emitter -> { synchronized (lock) { @@ -273,14 +285,20 @@ public Single execute(KeyT key, Single task, boolean force) { } if (!force && finished.containsKey(key)) { + onIgnored.run(); emitter.onSuccess(finished.get(key)); return; } finished.remove(key); - Execution execution = - inProgress.computeIfAbsent(key, ignoredKey -> new Execution(key, task)); + Execution execution = inProgress.get(key); + if (execution != null) { + onIgnored.run(); + } else { + execution = new Execution(key, task); + inProgress.put(key, execution); + } // We must subscribe the execution within the scope of lock to avoid race condition // that: @@ -425,10 +443,15 @@ public Completable executeIfNot(KeyT key, Completable task) { cache.executeIfNot(key, task.toSingleDefault(Optional.empty()))); } - /** Same as {@link AsyncTaskCache#executeIfNot} but operates on {@link Completable}. */ + /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */ public Completable execute(KeyT key, Completable task, boolean force) { + return execute(key, task, () -> {}, force); + } + + /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */ + public Completable execute(KeyT key, Completable task, Action onIgnored, boolean force) { return Completable.fromSingle( - cache.execute(key, task.toSingleDefault(Optional.empty()), force)); + cache.execute(key, task.toSingleDefault(Optional.empty()), onIgnored, force)); } /** Returns a set of keys for tasks which is finished. */ diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java index 26615963cd5756..f57a5a6360c998 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java @@ -51,6 +51,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.eventbus.EventBus; import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.SettableFuture; import com.google.devtools.build.lib.actions.ActionInput; import com.google.devtools.build.lib.actions.ActionInputHelper; import com.google.devtools.build.lib.actions.ActionUploadFinishedEvent; @@ -110,6 +111,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collection; +import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Semaphore; @@ -1435,19 +1437,18 @@ public void uploadInputsIfNotPresent_deduplicateFindMissingBlobCalls() throws Ex ActionInput input = ActionInputHelper.fromPath("inputs/foo"); Digest inputDigest = fakeFileCache.createScratchInput(input, "input-foo"); RemoteExecutionService service = newRemoteExecutionService(); + Spawn spawn = + newSpawn( + ImmutableMap.of(), + ImmutableSet.of(), + NestedSetBuilder.create(Order.STABLE_ORDER, input)); + FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); + RemoteAction action = service.buildRemoteAction(spawn, context); for (int i = 0; i < taskCount; ++i) { executorService.execute( () -> { try { - Spawn spawn = - newSpawn( - ImmutableMap.of(), - ImmutableSet.of(), - NestedSetBuilder.create(Order.STABLE_ORDER, input)); - FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); - RemoteAction action = service.buildRemoteAction(spawn, context); - service.uploadInputsIfNotPresent(action, /*force=*/ false); } catch (Throwable e) { if (e instanceof InterruptedException) { @@ -1468,6 +1469,72 @@ public void uploadInputsIfNotPresent_deduplicateFindMissingBlobCalls() throws Ex } } + @Test + public void uploadInputsIfNotPresent_sameInputs_interruptOne_keepOthers() throws Exception { + int taskCount = 100; + ExecutorService executorService = Executors.newFixedThreadPool(taskCount); + AtomicReference error = new AtomicReference<>(null); + Semaphore semaphore = new Semaphore(0); + ActionInput input = ActionInputHelper.fromPath("inputs/foo"); + fakeFileCache.createScratchInput(input, "input-foo"); + RemoteExecutionService service = newRemoteExecutionService(); + Spawn spawn = + newSpawn( + ImmutableMap.of(), + ImmutableSet.of(), + NestedSetBuilder.create(Order.STABLE_ORDER, input)); + FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); + RemoteAction action = service.buildRemoteAction(spawn, context); + Random random = new Random(); + + for (int i = 0; i < taskCount; ++i) { + boolean shouldInterrupt = random.nextBoolean(); + executorService.execute( + () -> { + try { + if (shouldInterrupt) { + Thread.currentThread().interrupt(); + } + service.uploadInputsIfNotPresent(action, /*force=*/ false); + } catch (Throwable e) { + if (!(shouldInterrupt && e instanceof InterruptedException)) { + error.set(e); + } + } finally { + semaphore.release(); + } + }); + } + semaphore.acquire(taskCount); + + assertThat(error.get()).isNull(); + } + + @Test + public void uploadInputsIfNotPresent_interrupted_requestCancelled() throws Exception { + SettableFuture> future = SettableFuture.create(); + doReturn(future).when(cache).findMissingDigests(any(), any()); + ActionInput input = ActionInputHelper.fromPath("inputs/foo"); + fakeFileCache.createScratchInput(input, "input-foo"); + RemoteExecutionService service = newRemoteExecutionService(); + Spawn spawn = + newSpawn( + ImmutableMap.of(), + ImmutableSet.of(), + NestedSetBuilder.create(Order.STABLE_ORDER, input)); + FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); + RemoteAction action = service.buildRemoteAction(spawn, context); + + try { + Thread.currentThread().interrupt(); + service.uploadInputsIfNotPresent(action, /*force=*/ false); + } catch (InterruptedException ignored) { + // Intentionally left empty + } + + assertThat(future.isCancelled()).isTrue(); + } + @Test public void buildMerkleTree_withMemoization_works() throws Exception { // Test that Merkle tree building can be memoized. diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java index c26629f2b3cf14..8925640c11ccbc 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java +++ b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java @@ -19,6 +19,8 @@ import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.remote.common.CacheNotFoundException; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; import com.google.devtools.build.lib.remote.common.RemoteCacheClient; @@ -31,12 +33,15 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; /** A {@link RemoteCacheClient} that stores its contents in memory. */ public final class InMemoryCacheClient implements RemoteCacheClient { + private final ListeningExecutorService executorService = + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(100)); private final ConcurrentMap downloadFailures = new ConcurrentHashMap<>(); private final ConcurrentMap ac = new ConcurrentHashMap<>(); private final ConcurrentMap cas; @@ -142,16 +147,19 @@ public ListenableFuture uploadBlob( @Override public ListenableFuture> findMissingDigests( RemoteActionExecutionContext context, Iterable digests) { - ImmutableSet.Builder missingBuilder = ImmutableSet.builder(); - for (Digest digest : digests) { - numFindMissingDigests - .computeIfAbsent(digest, (key) -> new AtomicInteger(0)) - .incrementAndGet(); - if (!cas.containsKey(digest)) { - missingBuilder.add(digest); - } - } - return Futures.immediateFuture(missingBuilder.build()); + return executorService.submit( + () -> { + ImmutableSet.Builder missingBuilder = ImmutableSet.builder(); + for (Digest digest : digests) { + numFindMissingDigests + .computeIfAbsent(digest, (key) -> new AtomicInteger(0)) + .incrementAndGet(); + if (!cas.containsKey(digest)) { + missingBuilder.add(digest); + } + } + return missingBuilder.build(); + }); } @Override