diff --git a/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java b/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java index d709cfe0fef7c8..52c81cd5683ca0 100644 --- a/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java +++ b/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java @@ -54,7 +54,7 @@ public static ManagedChannel newChannel( String target, String proxy, AuthAndTLSOptions options, - @Nullable ClientInterceptor interceptor) + @Nullable List interceptors) throws IOException { Preconditions.checkNotNull(target); Preconditions.checkNotNull(options); @@ -69,8 +69,8 @@ public static ManagedChannel newChannel( newNettyChannelBuilder(targetUrl, proxy) .negotiationType( isTlsEnabled(target) ? NegotiationType.TLS : NegotiationType.PLAINTEXT); - if (interceptor != null) { - builder.intercept(interceptor); + if (interceptors != null) { + builder.intercept(interceptors); } if (sslContext != null) { builder.sslContext(sslContext); diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java index 1750526490239c..c53186a6c0c7e3 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java @@ -122,7 +122,6 @@ private int computeMaxMissingBlobsDigestsPerMessage() { private ContentAddressableStorageFutureStub casFutureStub() { return ContentAddressableStorageGrpc.newFutureStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) - .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options)) .withCallCredentials(credentials) .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS); } @@ -130,7 +129,6 @@ private ContentAddressableStorageFutureStub casFutureStub() { private ByteStreamStub bsAsyncStub() { return ByteStreamGrpc.newStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) - .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options)) .withCallCredentials(credentials) .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS); } @@ -138,7 +136,6 @@ private ByteStreamStub bsAsyncStub() { private ActionCacheBlockingStub acBlockingStub() { return ActionCacheGrpc.newBlockingStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) - .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options)) .withCallCredentials(credentials) .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS); } @@ -146,7 +143,6 @@ private ActionCacheBlockingStub acBlockingStub() { private ActionCacheFutureStub acFutureStub() { return ActionCacheGrpc.newFutureStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) - .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options)) .withCallCredentials(credentials) .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS); } diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java index 306643b0ac6d46..52fc07f0135e25 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java @@ -59,7 +59,6 @@ public GrpcRemoteExecutor( private ExecutionBlockingStub execBlockingStub() { return ExecutionGrpc.newBlockingStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor()) - .withInterceptors(TracingMetadataUtils.newExecHeadersInterceptor(options)) .withCallCredentials(callCredentials); } diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java index ff00eb6fac8c58..34b0bfe201c631 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java @@ -32,6 +32,7 @@ import io.netty.channel.unix.DomainSocketAddress; import java.io.IOException; import java.net.URI; +import java.util.List; import javax.annotation.Nullable; /** @@ -59,10 +60,10 @@ public static ReferenceCountedChannel createGrpcChannel( String target, String proxyUri, AuthAndTLSOptions authOptions, - @Nullable ClientInterceptor interceptor) + @Nullable List interceptors) throws IOException { return new ReferenceCountedChannel( - GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptor)); + GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptors)); } public static RemoteCacheClient create( diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java index 26e54f3fb641e6..7a342f1b1d98ca 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java @@ -107,6 +107,37 @@ public static boolean shouldEnableRemoteExecution(RemoteOptions options) { return !Strings.isNullOrEmpty(options.remoteExecutor); } + private void verifyServerCapabilities( + RemoteOptions remoteOptions, + ReferenceCountedChannel channel, + CallCredentials credentials, + RemoteRetrier retrier, + CommandEnvironment env, + DigestUtil digestUtil) + throws AbruptExitException { + RemoteServerCapabilities rsc = + new RemoteServerCapabilities( + remoteOptions.remoteInstanceName, + channel, + credentials, + remoteOptions.remoteTimeout, + retrier); + ServerCapabilities capabilities = null; + try { + capabilities = rsc.get(env.getCommandId().toString(), env.getBuildRequestId()); + } catch (IOException e) { + throw new AbruptExitException( + "Failed to query remote execution capabilities: " + Utils.grpcAwareErrorMessage(e), + ExitCode.REMOTE_ERROR, + e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + checkClientServerCompatibility( + capabilities, remoteOptions, digestUtil.getDigestFunction(), env.getReporter()); + } + @Override public void beforeCommand(CommandEnvironment env) throws AbruptExitException { Preconditions.checkState(actionContextProvider == null, "actionContextProvider must be null"); @@ -178,12 +209,17 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException { ReferenceCountedChannel execChannel = null; ReferenceCountedChannel cacheChannel = null; if (enableRemoteExecution) { + ImmutableList.Builder interceptors = ImmutableList.builder(); + interceptors.add(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions)); + if (loggingInterceptor != null) { + interceptors.add(loggingInterceptor); + } execChannel = RemoteCacheClientFactory.createGrpcChannel( remoteOptions.remoteExecutor, remoteOptions.remoteProxy, authAndTlsOptions, - loggingInterceptor); + interceptors.build()); // Create a separate channel if --remote_executor and --remote_cache point to different // endpoints. if (Strings.isNullOrEmpty(remoteOptions.remoteCache) @@ -193,12 +229,17 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException { } if (cacheChannel == null) { + ImmutableList.Builder interceptors = ImmutableList.builder(); + interceptors.add(TracingMetadataUtils.newCacheHeadersInterceptor(remoteOptions)); + if (loggingInterceptor != null) { + interceptors.add(loggingInterceptor); + } cacheChannel = RemoteCacheClientFactory.createGrpcChannel( remoteOptions.remoteCache, remoteOptions.remoteProxy, authAndTlsOptions, - loggingInterceptor); + interceptors.build()); } CallCredentials credentials = GoogleAuthUtils.newCallCredentials(authAndTlsOptions); @@ -212,27 +253,13 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException { // We always query the execution server for capabilities, if it is defined. A remote // execution/cache system should have all its servers to return the capabilities pertaining // to the system as a whole. - RemoteServerCapabilities rsc = - new RemoteServerCapabilities( - remoteOptions.remoteInstanceName, - (execChannel != null ? execChannel : cacheChannel), - credentials, - remoteOptions.remoteTimeout, - retrier); - ServerCapabilities capabilities = null; - try { - capabilities = rsc.get(buildRequestId, invocationId); - } catch (IOException e) { - throw new AbruptExitException( - "Failed to query remote execution capabilities: " + Utils.grpcAwareErrorMessage(e), - ExitCode.REMOTE_ERROR, - e); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return; + if (execChannel != null) { + verifyServerCapabilities(remoteOptions, execChannel, credentials, retrier, env, digestUtil); + } + if (cacheChannel != execChannel) { + verifyServerCapabilities( + remoteOptions, cacheChannel, credentials, retrier, env, digestUtil); } - checkClientServerCompatibility( - capabilities, remoteOptions, digestUtil.getDigestFunction(), env.getReporter()); ByteStreamUploader uploader = new ByteStreamUploader( @@ -241,6 +268,7 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException { credentials, remoteOptions.remoteTimeout, retrier); + cacheChannel.release(); RemoteCacheClient cacheClient = new GrpcCacheClient( diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java index 2b47d949f697d7..8592412b0cf78b 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java @@ -46,6 +46,7 @@ import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptors; import io.grpc.ServerServiceDefinition; import io.grpc.Status; @@ -53,6 +54,7 @@ import io.grpc.StatusRuntimeException; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.MetadataUtils; import io.grpc.stub.StreamObserver; import io.grpc.util.MutableHandlerRegistry; import java.io.ByteArrayInputStream; @@ -689,6 +691,74 @@ public void queryWriteStatus( withEmptyMetadata.detach(prevContext); } + @Test + public void customHeadersAreAttachedToRequest() throws Exception { + RemoteRetrier retrier = + TestUtils.newRemoteRetrier(() -> new FixedBackoff(1, 0), (e) -> true, retryService); + + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER), "Value1"); + metadata.put(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER), "Value2"); + + ByteStreamUploader uploader = + new ByteStreamUploader( + INSTANCE_NAME, + new ReferenceCountedChannel( + InProcessChannelBuilder.forName("Server for " + this.getClass()) + .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)) + .build()), + null, /* timeout seconds */ + 60, + retrier); + + byte[] blob = new byte[CHUNK_SIZE]; + Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build(); + HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash()); + + serviceRegistry.addService( + ServerInterceptors.intercept( + new ByteStreamImplBase() { + @Override + public StreamObserver write( + StreamObserver streamObserver) { + return new StreamObserver() { + @Override + public void onNext(WriteRequest writeRequest) {} + + @Override + public void onError(Throwable throwable) { + fail("onError should never be called."); + } + + @Override + public void onCompleted() { + WriteResponse response = + WriteResponse.newBuilder().setCommittedSize(blob.length).build(); + streamObserver.onNext(response); + streamObserver.onCompleted(); + } + }; + } + }, + new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata metadata, + ServerCallHandler next) { + assertThat(metadata.get(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("Value1"); + assertThat(metadata.get(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("Value2"); + assertThat(metadata.get(Metadata.Key.of("Key3", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo(null); + return next.startCall(call, metadata); + } + })); + + uploader.uploadBlob(hash, chunker, true); + } + @Test public void sameBlobShouldNotBeUploadedTwice() throws Exception { // Test that uploading the same file concurrently triggers only one file upload. diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java index 9c551d8b3a40e7..240947dec29659 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java @@ -226,6 +226,7 @@ private GrpcCacheClient newClient(RemoteOptions remoteOptions, Supplier InProcessChannelBuilder.forName(fakeServerName) .directExecutor() .intercept(new CallCredentialsInterceptor(creds)) + .intercept(TracingMetadataUtils.newCacheHeadersInterceptor(remoteOptions)) .build()); ByteStreamUploader uploader = new ByteStreamUploader( diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java index 43fec08970ed5a..fab371cc17b434 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java @@ -233,7 +233,10 @@ public PathFragment getExecPath() { retryService); ReferenceCountedChannel channel = new ReferenceCountedChannel( - InProcessChannelBuilder.forName(fakeServerName).directExecutor().build()); + InProcessChannelBuilder.forName(fakeServerName) + .intercept(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions)) + .directExecutor() + .build()); GrpcRemoteExecutor executor = new GrpcRemoteExecutor(channel.retain(), null, retrier, remoteOptions); CallCredentials creds = diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java index 2c47b443ce86f9..357446c373fa05 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java @@ -25,6 +25,8 @@ import build.bazel.remote.execution.v2.PriorityCapabilities.PriorityRange; import build.bazel.remote.execution.v2.RequestMetadata; import build.bazel.remote.execution.v2.ServerCapabilities; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.analysis.BlazeVersionInfo; @@ -107,6 +109,60 @@ public ServerCall.Listener interceptCall( } } + private static class RequestCustomHeadersValidator implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + assertThat(headers.get(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("Value1"); + assertThat(headers.get(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("Value2"); + return next.startCall(call, headers); + } + } + + @Test + public void testCustomHeadersAreAttached() throws Exception { + ServerCapabilities caps = + ServerCapabilities.newBuilder() + .setExecutionCapabilities( + ExecutionCapabilities.newBuilder().setExecEnabled(true).build()) + .build(); + serviceRegistry.addService( + ServerInterceptors.intercept( + new CapabilitiesImplBase() { + @Override + public void getCapabilities( + GetCapabilitiesRequest request, + StreamObserver responseObserver) { + responseObserver.onNext(caps); + responseObserver.onCompleted(); + } + }, + new RequestCustomHeadersValidator())); + + RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class); + remoteOptions.remoteHeaders = + ImmutableList.of( + Maps.immutableEntry("Key1", "Value1"), Maps.immutableEntry("Key2", "Value2")); + + RemoteRetrier retrier = + TestUtils.newRemoteRetrier( + () -> new ExponentialBackoff(remoteOptions), + RemoteRetrier.RETRIABLE_GRPC_ERRORS, + retryService); + ReferenceCountedChannel channel = + new ReferenceCountedChannel( + InProcessChannelBuilder.forName(fakeServerName) + .intercept(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions)) + .directExecutor() + .build()); + RemoteServerCapabilities client = + new RemoteServerCapabilities("instance", channel.retain(), null, 3, retrier); + + assertThat(client.get("build-req-id", "command-id")).isEqualTo(caps); + } + @Test public void testGetCapabilitiesWithRetries() throws Exception { ServerCapabilities caps =