diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java index 4d694f06cfe128..1c07c87c3f2502 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java @@ -122,6 +122,7 @@ private int computeMaxMissingBlobsDigestsPerMessage() { private ContentAddressableStorageFutureStub casFutureStub() { return ContentAddressableStorageGrpc.newFutureStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) + .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options)) .withCallCredentials(credentials) .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS); } @@ -129,6 +130,7 @@ private ContentAddressableStorageFutureStub casFutureStub() { private ByteStreamStub bsAsyncStub() { return ByteStreamGrpc.newStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) + .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options)) .withCallCredentials(credentials) .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS); } @@ -136,6 +138,7 @@ private ByteStreamStub bsAsyncStub() { private ActionCacheBlockingStub acBlockingStub() { return ActionCacheGrpc.newBlockingStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) + .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options)) .withCallCredentials(credentials) .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS); } @@ -143,6 +146,7 @@ private ActionCacheBlockingStub acBlockingStub() { private ActionCacheFutureStub acFutureStub() { return ActionCacheGrpc.newFutureStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) + .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options)) .withCallCredentials(credentials) .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS); } diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java index 544b17b8e0068a..306643b0ac6d46 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java @@ -21,6 +21,7 @@ import build.bazel.remote.execution.v2.WaitExecutionRequest; import com.google.common.base.Preconditions; import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe; +import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; import com.google.longrunning.Operation; import com.google.rpc.Status; @@ -42,19 +43,23 @@ class GrpcRemoteExecutor { private final RemoteRetrier retrier; private final AtomicBoolean closed = new AtomicBoolean(); + private final RemoteOptions options; public GrpcRemoteExecutor( ReferenceCountedChannel channel, @Nullable CallCredentials callCredentials, - RemoteRetrier retrier) { + RemoteRetrier retrier, + RemoteOptions options) { this.channel = channel; this.callCredentials = callCredentials; this.retrier = retrier; + this.options = options; } private ExecutionBlockingStub execBlockingStub() { return ExecutionGrpc.newBlockingStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) + .withInterceptors(TracingMetadataUtils.newExecHeadersInterceptor(options)) .withCallCredentials(callCredentials); } diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java index 9f4a11ba6ddcda..24e295315e60b0 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java @@ -272,7 +272,8 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException { new GrpcRemoteExecutor( execChannel.retain(), GoogleAuthUtils.newCallCredentials(authAndTlsOptions), - execRetrier); + execRetrier, + remoteOptions); execChannel.release(); RemoteExecutionCache remoteCache = new RemoteExecutionCache(cacheClient, remoteOptions, digestUtil); diff --git a/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java b/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java index 88571221363f3a..460e3a9e8fbeec 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java +++ b/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java @@ -92,13 +92,42 @@ public final class RemoteOptions extends OptionsBase { documentationCategory = OptionDocumentationCategory.REMOTE, effectTags = {OptionEffectTag.UNKNOWN}, help = - "Specify a HTTP header that will be included in requests: --remote_header=Name=Value. " + "Specify a header that will be included in requests: --remote_header=Name=Value. " + "Multiple headers can be passed by specifying the flag multiple times. Multiple " - + "values for the same name will be converted to a comma-separated list. This flag" - + "is currently only implemented for the HTTP protocol.", + + "values for the same name will be converted to a comma-separated list.", allowMultiple = true) public List> remoteHeaders; + + @Option( + name = "remote_cache_header", + converter = Converters.AssignmentConverter.class, + defaultValue = "", + documentationCategory = OptionDocumentationCategory.REMOTE, + effectTags = {OptionEffectTag.UNKNOWN}, + help = + "Specify a header that will be included in cache requests: " + + "--remote_cache_header=Name=Value. " + + "Multiple headers can be passed by specifying the flag multiple times. Multiple " + + "values for the same name will be converted to a comma-separated list.", + allowMultiple = true) + public List> remoteCacheHeaders; + + + @Option( + name = "remote_exec_header", + converter = Converters.AssignmentConverter.class, + defaultValue = "", + documentationCategory = OptionDocumentationCategory.REMOTE, + effectTags = {OptionEffectTag.UNKNOWN}, + help = + "Specify a header that will be included in execution requests: " + + "--remote_exec_header=Name=Value. " + + "Multiple headers can be passed by specifying the flag multiple times. Multiple " + + "values for the same name will be converted to a comma-separated list.", + allowMultiple = true) + public List> remoteExecHeaders; + @Option( name = "remote_timeout", defaultValue = "60", diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java b/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java index e70726060d4519..100da2b3515118 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java @@ -19,6 +19,7 @@ import com.google.common.base.Preconditions; import com.google.devtools.build.lib.analysis.BlazeVersionInfo; import com.google.devtools.build.lib.remote.common.RemoteCacheClient.ActionKey; +import com.google.devtools.build.lib.remote.options.RemoteOptions; import io.grpc.ClientInterceptor; import io.grpc.Context; import io.grpc.Contexts; @@ -29,6 +30,9 @@ import io.grpc.ServerInterceptor; import io.grpc.protobuf.ProtoUtils; import io.grpc.stub.MetadataUtils; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import javax.annotation.Nullable; /** Utility functions to handle Metadata for remote Grpc calls. */ @@ -118,6 +122,27 @@ public static ClientInterceptor attachMetadataFromContextInterceptor() { return MetadataUtils.newAttachHeadersInterceptor(headersFromCurrentContext()); } + private static Metadata newMetadataForHeaders(List> headers){ + Metadata metadata = new Metadata(); + headers.forEach(header -> metadata.put( + Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER), + header.getValue() + )); + return metadata; + } + + public static ClientInterceptor newCacheHeadersInterceptor(RemoteOptions options){ + Metadata metadata = newMetadataForHeaders(options.remoteHeaders); + metadata.merge(newMetadataForHeaders(options.remoteCacheHeaders)); + return MetadataUtils.newAttachHeadersInterceptor(metadata); + } + + public static ClientInterceptor newExecHeadersInterceptor(RemoteOptions options){ + Metadata metadata = newMetadataForHeaders(options.remoteHeaders); + metadata.merge(newMetadataForHeaders(options.remoteExecHeaders)); + return MetadataUtils.newAttachHeadersInterceptor(metadata); + } + /** GRPC interceptor to add logging metadata to the GRPC context. */ public static class ServerHeadersInterceptor implements ServerInterceptor { @Override diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java index 94975c5efbd3e9..cd5d13843335d4 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java @@ -49,6 +49,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSortedMap; +import com.google.common.collect.Maps; import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.actions.ActionInputHelper; @@ -75,14 +76,20 @@ import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem; import com.google.devtools.common.options.Options; import com.google.protobuf.ByteString; +import io.grpc.BindableService; import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.Context; +import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; import io.grpc.Status; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; @@ -685,6 +692,72 @@ private ActionResult uploadDirectory(RemoteCache remoteCache, List outputs return remoteCache.upload(actionKey, action, cmd, execRoot, outputs, outErr); } + @Test + public void extraHeaders() throws Exception { + RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class); + remoteOptions.remoteHeaders = ImmutableList.of( + Maps.immutableEntry("CommonKey1", "CommonValue1"), + Maps.immutableEntry("CommonKey2", "CommonValue2")); + remoteOptions.remoteExecHeaders = ImmutableList.of( + Maps.immutableEntry("ExecKey1", "ExecValue1"), + Maps.immutableEntry("ExecKey2", "ExecValue2")); + remoteOptions.remoteCacheHeaders = ImmutableList.of( + Maps.immutableEntry("CacheKey1", "CacheValue1"), + Maps.immutableEntry("CacheKey2", "CacheValue2")); + + ServerInterceptor interceptor = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata metadata, + ServerCallHandler next) { + assertThat(metadata + .get(Metadata.Key.of("CommonKey1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("CommonValue1"); + assertThat(metadata + .get(Metadata.Key.of("CommonKey2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("CommonValue2"); + assertThat(metadata + .get(Metadata.Key.of("CacheKey1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("CacheValue1"); + assertThat(metadata + .get(Metadata.Key.of("CacheKey2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("CacheValue2"); + assertThat(metadata + .get(Metadata.Key.of("ExecKey1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo(null); + assertThat(metadata + .get(Metadata.Key.of("ExecKey2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo(null); + return next.startCall(call, metadata); + } + }; + + BindableService cas = new ContentAddressableStorageImplBase() { + @Override + public void findMissingBlobs( + FindMissingBlobsRequest request, + StreamObserver responseObserver) { + responseObserver.onNext(FindMissingBlobsResponse.getDefaultInstance()); + responseObserver.onCompleted(); + } + }; + serviceRegistry.addService(cas); + BindableService actionCache = new ActionCacheImplBase() { + @Override + public void getActionResult( + GetActionResultRequest request, StreamObserver responseObserver){ + responseObserver.onNext(ActionResult.getDefaultInstance()); + responseObserver.onCompleted(); + } + }; + serviceRegistry.addService(ServerInterceptors.intercept(actionCache, interceptor)); + + GrpcCacheClient client = newClient(remoteOptions); + RemoteCache remoteCache = new RemoteCache(client, remoteOptions, DIGEST_UTIL); + remoteCache.downloadActionResult(DIGEST_UTIL.asActionKey(DIGEST_UTIL.computeAsUtf8("key"))); + } + @Test public void testUpload() throws Exception { RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class); diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java index 1da3fc6b78b7f8..c3696617368121 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java @@ -46,6 +46,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.actions.ActionInput; @@ -208,6 +209,17 @@ public PathFragment getExecPath() { FileSystemUtils.createDirectoryAndParents(stderr.getParentDirectory()); outErr = new FileOutErr(stdout, stderr); RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class); + + remoteOptions.remoteHeaders = ImmutableList.of( + Maps.immutableEntry("CommonKey1", "CommonValue1"), + Maps.immutableEntry("CommonKey2", "CommonValue2")); + remoteOptions.remoteExecHeaders = ImmutableList.of( + Maps.immutableEntry("ExecKey1", "ExecValue1"), + Maps.immutableEntry("ExecKey2", "ExecValue2")); + remoteOptions.remoteCacheHeaders = ImmutableList.of( + Maps.immutableEntry("CacheKey1", "CacheValue1"), + Maps.immutableEntry("CacheKey2", "CacheValue2")); + RemoteRetrier retrier = TestUtils.newRemoteRetrier( () -> new ExponentialBackoff(remoteOptions), @@ -217,7 +229,7 @@ public PathFragment getExecPath() { new ReferenceCountedChannel( InProcessChannelBuilder.forName(fakeServerName).directExecutor().build()); GrpcRemoteExecutor executor = - new GrpcRemoteExecutor(channel.retain(), null, retrier); + new GrpcRemoteExecutor(channel.retain(), null, retrier, remoteOptions); CallCredentials creds = GoogleAuthUtils.newCallCredentials(Options.getDefaults(AuthAndTLSOptions.class)); ByteStreamUploader uploader = @@ -509,6 +521,69 @@ public ServerCall.Listener interceptCall( } } + @Test + public void extraHeaders() throws Exception { + BindableService actionCache = new ActionCacheImplBase() { + @Override + public void getActionResult( + GetActionResultRequest request, StreamObserver responseObserver) { + responseObserver.onError(Status.NOT_FOUND.asRuntimeException()); + } + }; + serviceRegistry.addService(actionCache); + + BindableService cas = new ContentAddressableStorageImplBase() { + @Override + public void findMissingBlobs( + FindMissingBlobsRequest request, + StreamObserver responseObserver) { + responseObserver.onNext(FindMissingBlobsResponse.getDefaultInstance()); + responseObserver.onCompleted(); + } + }; + serviceRegistry.addService(cas); + + BindableService execService = new ExecutionImplBase() { + @Override + public void execute(ExecuteRequest request, StreamObserver responseObserver) { + responseObserver.onNext(Operation.getDefaultInstance()); + responseObserver.onCompleted(); + } + }; + ServerInterceptor interceptor = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata metadata, + ServerCallHandler next) { + assertThat(metadata + .get(Metadata.Key.of("CommonKey1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("CommonValue1"); + assertThat(metadata + .get(Metadata.Key.of("CommonKey2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("CommonValue2"); + assertThat(metadata + .get(Metadata.Key.of("ExecKey1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("ExecValue1"); + assertThat(metadata + .get(Metadata.Key.of("ExecKey2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("ExecValue2"); + assertThat(metadata + .get(Metadata.Key.of("CacheKey1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo(null); + assertThat(metadata + .get(Metadata.Key.of("CacheKey2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo(null); + return next.startCall(call, metadata); + } + }; + serviceRegistry.addService(ServerInterceptors.intercept(execService, interceptor)); + + FakeSpawnExecutionContext policy = + new FakeSpawnExecutionContext(simpleSpawn, fakeFileCache, execRoot, outErr); + client.exec(simpleSpawn, policy); + } + @Test public void remotelyExecute() throws Exception { BindableService actionCache =