From 80a2d7cc5f8a22816934dcd2ca9bdf87050f3d9f Mon Sep 17 00:00:00 2001 From: John Millikin Date: Mon, 9 Mar 2020 12:42:08 -0700 Subject: [PATCH] Implementation (but not plumbing) of the gRPC remote downloader Extracted from https://github.com/bazelbuild/bazel/pull/10622 Per discussion on that PR, there's still some unanswered questions about how exactly we plumb the new `Downloader` type into `RemoteModule`. And per https://github.com/bazelbuild/bazel/issues/10742#issuecomment-595633454, it is unlikely that even heroic effort from me will get the full end-to-end functionality into v3.0. Given this, to simplify the review, I'm taking some of the bits the reviewer is happy with and moving them to a separate PR. After merger, `GrpcRemoteDownloader` and its tests will exist in the source tree, but will not yet be available as CLI options. R: @michajlo CC: @adunham-stripe @dslomov @EricBurnett @philwo @sstriker Closes #10914. PiperOrigin-RevId: 299908615 --- .../downloader/DownloadManager.java | 2 +- .../repository/downloader/Downloader.java | 1 + .../downloader/HashOutputStream.java | 89 +++++ .../repository/downloader/HttpDownloader.java | 1 + .../UnrecoverableHttpException.java | 3 +- .../google/devtools/build/lib/remote/BUILD | 56 ++- .../lib/remote/ReferenceCountedChannel.java | 11 +- .../build/lib/remote/RemoteRetrier.java | 5 +- .../build/lib/remote/downloader/BUILD | 32 ++ .../downloader/GrpcRemoteDownloader.java | 201 +++++++++++ .../lib/remote/options/RemoteOptions.java | 16 + .../lib/remote/util/TracingMetadataUtils.java | 6 + .../google/devtools/build/lib/remote/BUILD | 1 + .../build/lib/remote/downloader/BUILD | 49 +++ .../downloader/GrpcRemoteDownloaderTest.java | 328 ++++++++++++++++++ .../downloader/RemoteDownloaderTestSuite.java | 26 ++ 16 files changed, 817 insertions(+), 10 deletions(-) create mode 100644 src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HashOutputStream.java create mode 100644 src/main/java/com/google/devtools/build/lib/remote/downloader/BUILD create mode 100644 src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java create mode 100644 src/test/java/com/google/devtools/build/lib/remote/downloader/BUILD create mode 100644 src/test/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloaderTest.java create mode 100644 src/test/java/com/google/devtools/build/lib/remote/downloader/RemoteDownloaderTestSuite.java diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java index 5bb107a3bbe9ec..2c289d8df35b1d 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java @@ -187,7 +187,7 @@ public Path download( try { downloader.download( - urls, authHeaders, checksum, destination, eventHandler, clientEnv); + urls, authHeaders, checksum, canonicalId, destination, eventHandler, clientEnv); } catch (InterruptedIOException e) { throw new InterruptedException(e.getMessage()); } diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Downloader.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Downloader.java index 887d9b68e08a80..202ece226b9187 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Downloader.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Downloader.java @@ -43,6 +43,7 @@ void download( List urls, Map> authHeaders, Optional checksum, + String canonicalId, Path output, ExtendedEventHandler eventHandler, Map clientEnv) diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HashOutputStream.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HashOutputStream.java new file mode 100644 index 00000000000000..9235fc73af48c8 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HashOutputStream.java @@ -0,0 +1,89 @@ +// Copyright 2020 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.devtools.build.lib.bazel.repository.downloader; + +import com.google.common.hash.HashCode; +import com.google.common.hash.Hasher; +import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadCompatible; +import java.io.IOException; +import java.io.OutputStream; +import javax.annotation.Nullable; +import javax.annotation.WillCloseWhenClosed; + +/** + * Output stream that guarantees its contents matches a hash code. + * + *

The actual checksum is computed gradually as the output is written. If it doesn't match, then + * an {@link IOException} will be thrown when {@link #close()} is called. This error will be thrown + * multiple times if these methods are called again for some reason. + * + *

Note that as the checksum can only be computed once the stream is closed, data will be written + * to the underlying stream regardless of whether it matches the expected checksum. + * + *

This class is not thread safe, but it is safe to message pass this object between threads. + */ +@ThreadCompatible +public final class HashOutputStream extends OutputStream { + + private final OutputStream delegate; + private final Hasher hasher; + private final HashCode code; + @Nullable private volatile HashCode actual; + + public HashOutputStream(@WillCloseWhenClosed OutputStream delegate, Checksum checksum) { + this.delegate = delegate; + this.hasher = checksum.getKeyType().newHasher(); + this.code = checksum.getHashCode(); + } + + @Override + public void write(int buffer) throws IOException { + hasher.putByte((byte) buffer); + delegate.write(buffer); + } + + @Override + public void write(byte[] buffer) throws IOException { + hasher.putBytes(buffer); + delegate.write(buffer); + } + + @Override + public void write(byte[] buffer, int offset, int length) throws IOException { + hasher.putBytes(buffer, offset, length); + delegate.write(buffer, offset, length); + } + + @Override + public void flush() throws IOException { + delegate.flush(); + } + + @Override + public void close() throws IOException { + delegate.close(); + check(); + } + + private void check() throws IOException { + if (actual == null) { + actual = hasher.hash(); + } + if (!code.equals(actual)) { + throw new UnrecoverableHttpException( + String.format("Checksum was %s but wanted %s", actual, code)); + } + } +} diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java index 98dd4be5e6d63c..5691fd474d1896 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java @@ -62,6 +62,7 @@ public void download( List urls, Map> authHeaders, Optional checksum, + String canonicalId, Path destination, ExtendedEventHandler eventHandler, Map clientEnv) diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/UnrecoverableHttpException.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/UnrecoverableHttpException.java index 3ccd2f4a2c96e3..0b05e4cd977e6d 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/UnrecoverableHttpException.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/UnrecoverableHttpException.java @@ -16,7 +16,8 @@ import java.io.IOException; -final class UnrecoverableHttpException extends IOException { +/** Indicates an HTTP error that cannot be recovered from. */ +public final class UnrecoverableHttpException extends IOException { UnrecoverableHttpException(String message) { super(message); } diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD index 3e19f88dc70404..69bc4de19eaff8 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD @@ -6,6 +6,7 @@ filegroup( name = "srcs", srcs = glob(["**"]) + [ "//src/main/java/com/google/devtools/build/lib/remote/common:srcs", + "//src/main/java/com/google/devtools/build/lib/remote/downloader:srcs", "//src/main/java/com/google/devtools/build/lib/remote/disk:srcs", "//src/main/java/com/google/devtools/build/lib/remote/http:srcs", "//src/main/java/com/google/devtools/build/lib/remote/logging:srcs", @@ -18,13 +19,30 @@ filegroup( java_library( name = "remote", - srcs = glob(["*.java"]), + srcs = glob( + ["*.java"], + exclude = [ + "ExecutionStatusException.java", + "ReferenceCountedChannel.java", + "RemoteRetrier.java", + "RemoteRetrierUtils.java", + "Retrier.java", + ], + ), tags = ["bazel"], + exports = [ + ":ExecutionStatusException", + ":ReferenceCountedChannel", + ":Retrier", + ], runtime_deps = [ # This is required for client TLS. "//third_party:netty_tcnative", ], deps = [ + ":ExecutionStatusException", + ":ReferenceCountedChannel", + ":Retrier", "//src/main/java/com/google/devtools/build/lib:build-base", "//src/main/java/com/google/devtools/build/lib:events", "//src/main/java/com/google/devtools/build/lib:packages-internal", @@ -65,3 +83,39 @@ java_library( "@remoteapis//:build_bazel_semver_semver_java_proto", ], ) + +java_library( + name = "ExecutionStatusException", + srcs = ["ExecutionStatusException.java"], + deps = [ + "//third_party:jsr305", + "//third_party/grpc:grpc-jar", + "@googleapis//:google_rpc_status_java_proto", + "@remoteapis//:build_bazel_remote_execution_v2_remote_execution_java_proto", + ], +) + +java_library( + name = "ReferenceCountedChannel", + srcs = ["ReferenceCountedChannel.java"], + deps = [ + "//third_party:netty", + "//third_party/grpc:grpc-jar", + ], +) + +java_library( + name = "Retrier", + srcs = [ + "RemoteRetrier.java", + "RemoteRetrierUtils.java", + "Retrier.java", + ], + deps = [ + ":ExecutionStatusException", + "//src/main/java/com/google/devtools/build/lib/remote/options", + "//third_party:guava", + "//third_party:jsr305", + "//third_party/grpc:grpc-jar", + ], +) diff --git a/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java b/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java index eff9621da14948..1d948cc5e28037 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java +++ b/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java @@ -21,13 +21,14 @@ import io.netty.util.ReferenceCounted; import java.util.concurrent.TimeUnit; -/** A wrapper around a {@link io.grpc.ManagedChannel} exposing a reference count. - * When instantiated the reference count is 1. {@link ManagedChannel#shutdown()} will be called - * on the wrapped channel when the reference count reaches 0. +/** + * A wrapper around a {@link io.grpc.ManagedChannel} exposing a reference count. When instantiated + * the reference count is 1. {@link ManagedChannel#shutdown()} will be called on the wrapped channel + * when the reference count reaches 0. * - * See {@link ReferenceCounted} for more information about reference counting. + *

See {@link ReferenceCounted} for more information about reference counting. */ -class ReferenceCountedChannel extends ManagedChannel implements ReferenceCounted { +public class ReferenceCountedChannel extends ManagedChannel implements ReferenceCounted { private final ManagedChannel channel; private final AbstractReferenceCounted referenceCounted = new AbstractReferenceCounted() { diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteRetrier.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteRetrier.java index fc8064f4657ef9..cb4b7bf8cd170f 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteRetrier.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteRetrier.java @@ -120,7 +120,8 @@ public T execute(Callable call) throws IOException, InterruptedException } } - static class ExponentialBackoff implements Backoff { + /** Backoff strategy that backs off exponentially. */ + public static class ExponentialBackoff implements Backoff { private final long maxMillis; private long nextDelayMillis; @@ -152,7 +153,7 @@ static class ExponentialBackoff implements Backoff { this.maxAttempts = maxAttempts; } - ExponentialBackoff(RemoteOptions options) { + public ExponentialBackoff(RemoteOptions options) { this( /* initial = */ Duration.ofMillis(100), /* max = */ Duration.ofSeconds(5), diff --git a/src/main/java/com/google/devtools/build/lib/remote/downloader/BUILD b/src/main/java/com/google/devtools/build/lib/remote/downloader/BUILD new file mode 100644 index 00000000000000..1035c7d75db487 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/downloader/BUILD @@ -0,0 +1,32 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_visibility = ["//src:__subpackages__"], +) + +filegroup( + name = "srcs", + srcs = glob(["*"]), +) + +java_library( + name = "downloader", + srcs = glob(["*.java"]), + deps = [ + "//src/main/java/com/google/devtools/build/lib:events", + "//src/main/java/com/google/devtools/build/lib/bazel/repository/downloader", + "//src/main/java/com/google/devtools/build/lib/remote", + "//src/main/java/com/google/devtools/build/lib/remote:ReferenceCountedChannel", + "//src/main/java/com/google/devtools/build/lib/remote:Retrier", + "//src/main/java/com/google/devtools/build/lib/remote/common", + "//src/main/java/com/google/devtools/build/lib/remote/options", + "//src/main/java/com/google/devtools/build/lib/remote/util", + "//src/main/java/com/google/devtools/build/lib/vfs", + "//third_party:gson", + "//third_party:guava", + "//third_party/grpc:grpc-jar", + "@remoteapis//:build_bazel_remote_asset_v1_remote_asset_java_grpc", + "@remoteapis//:build_bazel_remote_asset_v1_remote_asset_java_proto", + "@remoteapis//:build_bazel_remote_execution_v2_remote_execution_java_proto", + ], +) diff --git a/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java b/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java new file mode 100644 index 00000000000000..81ccf51ca806d5 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java @@ -0,0 +1,201 @@ +// Copyright 2020 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.devtools.build.lib.remote.downloader; + +import build.bazel.remote.asset.v1.FetchBlobRequest; +import build.bazel.remote.asset.v1.FetchBlobResponse; +import build.bazel.remote.asset.v1.FetchGrpc; +import build.bazel.remote.asset.v1.FetchGrpc.FetchBlockingStub; +import build.bazel.remote.asset.v1.Qualifier; +import build.bazel.remote.execution.v2.Digest; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; +import com.google.devtools.build.lib.bazel.repository.downloader.Checksum; +import com.google.devtools.build.lib.bazel.repository.downloader.Downloader; +import com.google.devtools.build.lib.bazel.repository.downloader.HashOutputStream; +import com.google.devtools.build.lib.events.ExtendedEventHandler; +import com.google.devtools.build.lib.remote.ReferenceCountedChannel; +import com.google.devtools.build.lib.remote.RemoteRetrier; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient; +import com.google.devtools.build.lib.remote.options.RemoteOptions; +import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; +import com.google.devtools.build.lib.remote.util.Utils; +import com.google.devtools.build.lib.vfs.Path; +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import io.grpc.CallCredentials; +import io.grpc.Context; +import io.grpc.StatusRuntimeException; +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; +import java.net.URL; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A Downloader implementation that uses Bazel's Remote Execution APIs to delegate downloads of + * external files to a remote service. + * + *

See https://github.com/bazelbuild/remote-apis for more details on the exact capabilities and + * semantics of the Remote Execution API. + */ +public class GrpcRemoteDownloader implements AutoCloseable, Downloader { + + private final ReferenceCountedChannel channel; + private final Optional credentials; + private final RemoteRetrier retrier; + private final Context requestCtx; + private final RemoteCacheClient cacheClient; + private final RemoteOptions options; + + private final AtomicBoolean closed = new AtomicBoolean(); + + // The `Qualifier::name` field uses well-known string keys to attach arbitrary + // key-value metadata to download requests. These are the qualifier names + // supported by Bazel. + private static final String QUALIFIER_CHECKSUM_SRI = "checksum.sri"; + private static final String QUALIFIER_CANONICAL_ID = "bazel.canonical_id"; + private static final String QUALIFIER_AUTH_HEADERS = "bazel.auth_headers"; + + public GrpcRemoteDownloader( + ReferenceCountedChannel channel, + Optional credentials, + RemoteRetrier retrier, + Context requestCtx, + RemoteCacheClient cacheClient, + RemoteOptions options) { + this.channel = channel; + this.credentials = credentials; + this.retrier = retrier; + this.cacheClient = cacheClient; + this.requestCtx = requestCtx; + this.options = options; + } + + @Override + public void close() { + if (closed.getAndSet(true)) { + return; + } + cacheClient.close(); + channel.release(); + } + + @Override + public void download( + List urls, + Map> authHeaders, + com.google.common.base.Optional checksum, + String canonicalId, + Path destination, + ExtendedEventHandler eventHandler, + Map clientEnv) + throws IOException, InterruptedException { + final FetchBlobRequest request = + newFetchBlobRequest(options.remoteInstanceName, urls, authHeaders, checksum, canonicalId); + try { + FetchBlobResponse response = + retrier.execute(() -> requestCtx.call(() -> fetchBlockingStub().fetchBlob(request))); + final Digest blobDigest = response.getBlobDigest(); + + retrier.execute( + () -> + requestCtx.call( + () -> { + try (OutputStream out = newOutputStream(destination, checksum)) { + Utils.getFromFuture(cacheClient.downloadBlob(blobDigest, out)); + } + return null; + })); + } catch (StatusRuntimeException e) { + throw new IOException(e); + } + } + + @VisibleForTesting + static FetchBlobRequest newFetchBlobRequest( + String instanceName, + List urls, + Map> authHeaders, + com.google.common.base.Optional checksum, + String canonicalId) { + FetchBlobRequest.Builder requestBuilder = + FetchBlobRequest.newBuilder().setInstanceName(instanceName); + for (URL url : urls) { + requestBuilder.addUris(url.toString()); + } + if (checksum.isPresent()) { + requestBuilder.addQualifiers( + Qualifier.newBuilder() + .setName(QUALIFIER_CHECKSUM_SRI) + .setValue(checksum.get().toSubresourceIntegrity()) + .build()); + } + if (!Strings.isNullOrEmpty(canonicalId)) { + requestBuilder.addQualifiers( + Qualifier.newBuilder().setName(QUALIFIER_CANONICAL_ID).setValue(canonicalId).build()); + } + if (!authHeaders.isEmpty()) { + requestBuilder.addQualifiers( + Qualifier.newBuilder() + .setName(QUALIFIER_AUTH_HEADERS) + .setValue(authHeadersJson(authHeaders)) + .build()); + } + + return requestBuilder.build(); + } + + private FetchBlockingStub fetchBlockingStub() { + return FetchGrpc.newBlockingStub(channel) + .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) + .withInterceptors(TracingMetadataUtils.newDownloaderHeadersInterceptor(options)) + .withCallCredentials(credentials.orElse(null)) + .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS); + } + + private OutputStream newOutputStream( + Path destination, com.google.common.base.Optional checksum) throws IOException { + OutputStream out = destination.getOutputStream(); + if (checksum.isPresent()) { + out = new HashOutputStream(out, checksum.get()); + } + return out; + } + + private static String authHeadersJson(Map> authHeaders) { + Map subObjects = new TreeMap<>(); + for (Map.Entry> entry : authHeaders.entrySet()) { + JsonObject subObject = new JsonObject(); + Map orderedHeaders = new TreeMap<>(entry.getValue()); + for (Map.Entry subEntry : orderedHeaders.entrySet()) { + subObject.addProperty(subEntry.getKey(), subEntry.getValue()); + } + subObjects.put(entry.getKey().toString(), subObject); + } + + JsonObject authHeadersJson = new JsonObject(); + for (Map.Entry entry : subObjects.entrySet()) { + authHeadersJson.add(entry.getKey(), entry.getValue()); + } + + return (new Gson()).toJson(authHeadersJson); + } +} diff --git a/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java b/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java index 290934ee03c009..4e4e27ee4512fc 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java +++ b/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java @@ -85,6 +85,8 @@ public final class RemoteOptions extends OptionsBase { + " https://docs.bazel.build/versions/master/remote-caching.html") public String remoteCache; + public final String remoteDownloader = ""; + @Option( name = "remote_header", converter = Converters.AssignmentConverter.class, @@ -126,6 +128,20 @@ public final class RemoteOptions extends OptionsBase { allowMultiple = true) public List> remoteExecHeaders; + @Option( + name = "remote_downloader_header", + converter = Converters.AssignmentConverter.class, + defaultValue = "", + documentationCategory = OptionDocumentationCategory.REMOTE, + effectTags = {OptionEffectTag.UNKNOWN}, + help = + "Specify a header that will be included in remote downloader requests: " + + "--remote_downloader_header=Name=Value. " + + "Multiple headers can be passed by specifying the flag multiple times. Multiple " + + "values for the same name will be converted to a comma-separated list.", + allowMultiple = true) + public List> remoteDownloaderHeaders; + @Option( name = "remote_timeout", defaultValue = "60", diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java b/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java index df134a294216ca..a4b5511688cf19 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java @@ -137,6 +137,12 @@ public static ClientInterceptor newCacheHeadersInterceptor(RemoteOptions options return MetadataUtils.newAttachHeadersInterceptor(metadata); } + public static ClientInterceptor newDownloaderHeadersInterceptor(RemoteOptions options) { + Metadata metadata = newMetadataForHeaders(options.remoteHeaders); + metadata.merge(newMetadataForHeaders(options.remoteDownloaderHeaders)); + return MetadataUtils.newAttachHeadersInterceptor(metadata); + } + public static ClientInterceptor newExecHeadersInterceptor(RemoteOptions options) { Metadata metadata = newMetadataForHeaders(options.remoteHeaders); metadata.merge(newMetadataForHeaders(options.remoteExecHeaders)); diff --git a/src/test/java/com/google/devtools/build/lib/remote/BUILD b/src/test/java/com/google/devtools/build/lib/remote/BUILD index a506f825cba916..9910b5141e1d20 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/BUILD +++ b/src/test/java/com/google/devtools/build/lib/remote/BUILD @@ -9,6 +9,7 @@ filegroup( name = "srcs", testonly = 0, srcs = glob(["**"]) + [ + "//src/test/java/com/google/devtools/build/lib/remote/downloader:srcs", "//src/test/java/com/google/devtools/build/lib/remote/http:srcs", "//src/test/java/com/google/devtools/build/lib/remote/logging:srcs", "//src/test/java/com/google/devtools/build/lib/remote/merkletree:srcs", diff --git a/src/test/java/com/google/devtools/build/lib/remote/downloader/BUILD b/src/test/java/com/google/devtools/build/lib/remote/downloader/BUILD new file mode 100644 index 00000000000000..5fb26195a19770 --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/downloader/BUILD @@ -0,0 +1,49 @@ +load("@rules_java//java:defs.bzl", "java_test") + +package( + default_testonly = 1, + default_visibility = ["//src:__subpackages__"], +) + +filegroup( + name = "srcs", + testonly = 0, + srcs = glob(["**"]), + visibility = ["//src/test/java/com/google/devtools/build/lib/remote:__pkg__"], +) + +java_test( + name = "RemoteDownloaderTestSuite", + srcs = glob(["*.java"]), + tags = [ + "requires-network", + "rules", + ], + deps = [ + "//src/main/java/com/google/devtools/build/lib:events", + "//src/main/java/com/google/devtools/build/lib:util", + "//src/main/java/com/google/devtools/build/lib/bazel/repository/cache", + "//src/main/java/com/google/devtools/build/lib/bazel/repository/downloader", + "//src/main/java/com/google/devtools/build/lib/remote:ReferenceCountedChannel", + "//src/main/java/com/google/devtools/build/lib/remote:Retrier", + "//src/main/java/com/google/devtools/build/lib/remote/common", + "//src/main/java/com/google/devtools/build/lib/remote/downloader", + "//src/main/java/com/google/devtools/build/lib/remote/options", + "//src/main/java/com/google/devtools/build/lib/remote/util", + "//src/main/java/com/google/devtools/build/lib/vfs", + "//src/main/java/com/google/devtools/common/options", + "//src/test/java/com/google/devtools/build/lib:foundations_testutil", + "//src/test/java/com/google/devtools/build/lib:test_runner", + "//src/test/java/com/google/devtools/build/lib:testutil", + "//src/test/java/com/google/devtools/build/lib/remote/util", + "//third_party:guava", + "//third_party:junit4", + "//third_party:mockito", + "//third_party:truth", + "//third_party/grpc:grpc-jar", + "//third_party/protobuf:protobuf_java", + "@remoteapis//:build_bazel_remote_asset_v1_remote_asset_java_grpc", + "@remoteapis//:build_bazel_remote_asset_v1_remote_asset_java_proto", + "@remoteapis//:build_bazel_remote_execution_v2_remote_execution_java_proto", + ], +) diff --git a/src/test/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloaderTest.java new file mode 100644 index 00000000000000..e4e61345a95a93 --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloaderTest.java @@ -0,0 +1,328 @@ +// Copyright 2019 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.devtools.build.lib.remote.downloader; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture; +import static com.google.devtools.build.lib.testutil.MoreAsserts.assertThrows; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.mockito.Mockito.mock; + +import build.bazel.remote.asset.v1.FetchBlobRequest; +import build.bazel.remote.asset.v1.FetchBlobResponse; +import build.bazel.remote.asset.v1.FetchGrpc.FetchImplBase; +import build.bazel.remote.asset.v1.Qualifier; +import build.bazel.remote.execution.v2.Digest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.ByteStreams; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType; +import com.google.devtools.build.lib.bazel.repository.downloader.Checksum; +import com.google.devtools.build.lib.bazel.repository.downloader.UnrecoverableHttpException; +import com.google.devtools.build.lib.events.ExtendedEventHandler; +import com.google.devtools.build.lib.remote.ReferenceCountedChannel; +import com.google.devtools.build.lib.remote.RemoteRetrier; +import com.google.devtools.build.lib.remote.RemoteRetrier.ExponentialBackoff; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient; +import com.google.devtools.build.lib.remote.options.RemoteOptions; +import com.google.devtools.build.lib.remote.util.DigestUtil; +import com.google.devtools.build.lib.remote.util.InMemoryCacheClient; +import com.google.devtools.build.lib.remote.util.TestUtils; +import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; +import com.google.devtools.build.lib.testutil.Scratch; +import com.google.devtools.build.lib.vfs.DigestHashFunction; +import com.google.devtools.build.lib.vfs.Path; +import com.google.devtools.common.options.Options; +import com.google.protobuf.ByteString; +import io.grpc.CallCredentials; +import io.grpc.Context; +import io.grpc.Server; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import io.grpc.util.MutableHandlerRegistry; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URL; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Executors; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link GrpcRemoteDownloader}. */ +@RunWith(JUnit4.class) +public class GrpcRemoteDownloaderTest { + + private static final DigestUtil DIGEST_UTIL = new DigestUtil(DigestHashFunction.SHA256); + + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + private final String fakeServerName = "fake server for " + getClass(); + private Server fakeServer; + private Context withEmptyMetadata; + private Context prevContext; + private static ListeningScheduledExecutorService retryService; + + @BeforeClass + public static void beforeEverything() { + retryService = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(1)); + } + + @Before + public final void setUp() throws Exception { + // Use a mutable service registry for later registering the service impl for each test case. + fakeServer = + InProcessServerBuilder.forName(fakeServerName) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + withEmptyMetadata = + TracingMetadataUtils.contextWithMetadata( + "none", "none", DIGEST_UTIL.asActionKey(Digest.getDefaultInstance())); + prevContext = withEmptyMetadata.attach(); + } + + @After + public void tearDown() throws Exception { + withEmptyMetadata.detach(prevContext); + fakeServer.shutdownNow(); + fakeServer.awaitTermination(); + } + + @AfterClass + public static void afterEverything() { + retryService.shutdownNow(); + } + + private GrpcRemoteDownloader newDownloader(RemoteCacheClient cacheClient) throws IOException { + final RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class); + final RemoteRetrier retrier = + TestUtils.newRemoteRetrier( + () -> new ExponentialBackoff(remoteOptions), + RemoteRetrier.RETRIABLE_GRPC_ERRORS, + retryService); + final ReferenceCountedChannel channel = + new ReferenceCountedChannel( + InProcessChannelBuilder.forName(fakeServerName).directExecutor().build()); + return new GrpcRemoteDownloader( + channel.retain(), + Optional.empty(), + retrier, + withEmptyMetadata, + cacheClient, + remoteOptions); + } + + private static byte[] downloadBlob( + GrpcRemoteDownloader downloader, URL url, Optional checksum) + throws IOException, InterruptedException { + final List urls = ImmutableList.of(url); + com.google.common.base.Optional guavaChecksum = + com.google.common.base.Optional.absent(); + if (checksum.isPresent()) { + guavaChecksum = com.google.common.base.Optional.of(checksum.get()); + } + + final Map> authHeaders = ImmutableMap.of(); + final String canonicalId = ""; + final ExtendedEventHandler eventHandler = mock(ExtendedEventHandler.class); + final Map clientEnv = ImmutableMap.of(); + + Scratch scratch = new Scratch(); + final Path destination = scratch.resolve("output file path"); + downloader.download( + urls, authHeaders, guavaChecksum, canonicalId, destination, eventHandler, clientEnv); + + try (InputStream in = destination.getInputStream()) { + return ByteStreams.toByteArray(in); + } + } + + @Test + public void testDownload() throws Exception { + final byte[] content = "example content".getBytes(UTF_8); + final Digest contentDigest = DIGEST_UTIL.compute(content); + + serviceRegistry.addService( + new FetchImplBase() { + @Override + public void fetchBlob( + FetchBlobRequest request, StreamObserver responseObserver) { + assertThat(request) + .isEqualTo( + FetchBlobRequest.newBuilder() + .addUris("http://example.com/content.txt") + .build()); + responseObserver.onNext( + FetchBlobResponse.newBuilder().setBlobDigest(contentDigest).build()); + responseObserver.onCompleted(); + } + }); + + final RemoteCacheClient cacheClient = new InMemoryCacheClient(); + final GrpcRemoteDownloader downloader = newDownloader(cacheClient); + + getFromFuture(cacheClient.uploadBlob(contentDigest, ByteString.copyFrom(content))); + final byte[] downloaded = + downloadBlob( + downloader, new URL("http://example.com/content.txt"), Optional.empty()); + + assertThat(downloaded).isEqualTo(content); + } + + @Test + public void testPropagateChecksum() throws Exception { + final byte[] content = "example content".getBytes(UTF_8); + final Digest contentDigest = DIGEST_UTIL.compute(content); + + serviceRegistry.addService( + new FetchImplBase() { + @Override + public void fetchBlob( + FetchBlobRequest request, StreamObserver responseObserver) { + assertThat(request) + .isEqualTo( + FetchBlobRequest.newBuilder() + .addUris("http://example.com/content.txt") + .addQualifiers( + Qualifier.newBuilder() + .setName("checksum.sri") + .setValue("sha256-ot7ke6YmiSXal3UKt0K69n8C4vtUziPUmftmpbAiKQM=")) + .build()); + responseObserver.onNext( + FetchBlobResponse.newBuilder().setBlobDigest(contentDigest).build()); + responseObserver.onCompleted(); + } + }); + + final RemoteCacheClient cacheClient = new InMemoryCacheClient(); + final GrpcRemoteDownloader downloader = newDownloader(cacheClient); + + getFromFuture(cacheClient.uploadBlob(contentDigest, ByteString.copyFrom(content))); + final byte[] downloaded = + downloadBlob( + downloader, + new URL("http://example.com/content.txt"), + Optional.of(Checksum.fromString(KeyType.SHA256, contentDigest.getHash()))); + + assertThat(downloaded).isEqualTo(content); + } + + @Test + public void testRejectChecksumMismatch() throws Exception { + final byte[] content = "example content".getBytes(UTF_8); + final Digest contentDigest = DIGEST_UTIL.compute(content); + + serviceRegistry.addService( + new FetchImplBase() { + @Override + public void fetchBlob( + FetchBlobRequest request, StreamObserver responseObserver) { + assertThat(request) + .isEqualTo( + FetchBlobRequest.newBuilder() + .addUris("http://example.com/content.txt") + .addQualifiers( + Qualifier.newBuilder() + .setName("checksum.sri") + .setValue("sha256-ot7ke6YmiSXal3UKt0K69n8C4vtUziPUmftmpbAiKQM=")) + .build()); + responseObserver.onNext( + FetchBlobResponse.newBuilder().setBlobDigest(contentDigest).build()); + responseObserver.onCompleted(); + } + }); + + final RemoteCacheClient cacheClient = new InMemoryCacheClient(); + final GrpcRemoteDownloader downloader = newDownloader(cacheClient); + + getFromFuture(cacheClient.uploadBlob(contentDigest, ByteString.copyFromUtf8("wrong content"))); + + IOException e = + assertThrows( + UnrecoverableHttpException.class, + () -> + downloadBlob( + downloader, + new URL("http://example.com/content.txt"), + Optional.of( + Checksum.fromString(KeyType.SHA256, contentDigest.getHash())))); + + assertThat(e).hasMessageThat().contains(contentDigest.getHash()); + assertThat(e).hasMessageThat().contains(DIGEST_UTIL.computeAsUtf8("wrong content").getHash()); + } + + @Test + public void testFetchBlobRequest() throws Exception { + FetchBlobRequest request = + GrpcRemoteDownloader.newFetchBlobRequest( + "instance name", + ImmutableList.of( + new URL("http://example.com/a"), + new URL("http://example.com/b"), + new URL("file:/not/limited/to/http")), + ImmutableMap.of( + new URI("http://example.com"), + ImmutableMap.of( + "Some-Header", "some header content", + "Another-Header", "another header content"), + new URI("http://example.org"), + ImmutableMap.of("Org-Header", "org header content")), + com.google.common.base.Optional.of( + Checksum.fromSubresourceIntegrity( + "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=")), + "canonical ID"); + + final String expectedAuthHeadersJson = + "{" + + "\"http://example.com\":{" + + "\"Another-Header\":\"another header content\"," + + "\"Some-Header\":\"some header content\"" + + "}," + + "\"http://example.org\":{" + + "\"Org-Header\":\"org header content\"" + + "}" + + "}"; + + assertThat(request) + .isEqualTo( + FetchBlobRequest.newBuilder() + .setInstanceName("instance name") + .addUris("http://example.com/a") + .addUris("http://example.com/b") + .addUris("file:/not/limited/to/http") + .addQualifiers( + Qualifier.newBuilder() + .setName("checksum.sri") + .setValue("sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=")) + .addQualifiers( + Qualifier.newBuilder().setName("bazel.canonical_id").setValue("canonical ID")) + .addQualifiers( + Qualifier.newBuilder() + .setName("bazel.auth_headers") + .setValue(expectedAuthHeadersJson)) + .build()); + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/downloader/RemoteDownloaderTestSuite.java b/src/test/java/com/google/devtools/build/lib/remote/downloader/RemoteDownloaderTestSuite.java new file mode 100644 index 00000000000000..14bce367a2672f --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/downloader/RemoteDownloaderTestSuite.java @@ -0,0 +1,26 @@ +// Copyright 2020 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.devtools.build.lib.remote.downloader; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test suite for remote/downloader package. */ +@RunWith(Suite.class) +@SuiteClasses({ + GrpcRemoteDownloaderTest.class, +}) +public class RemoteDownloaderTestSuite {}