Skip to content

Commit

Permalink
Intercept capabilities and uploader requests and add custom grpc headers
Browse files Browse the repository at this point in the history
Following #10015. Some requests do not use the custom headers.

Closes #10634.

PiperOrigin-RevId: 298574179
  • Loading branch information
Alessandro Patti authored and copybara-github committed Mar 3, 2020
1 parent 247ca0c commit 52c8773
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public static ManagedChannel newChannel(
String target,
String proxy,
AuthAndTLSOptions options,
@Nullable ClientInterceptor interceptor)
@Nullable List<ClientInterceptor> interceptors)
throws IOException {
Preconditions.checkNotNull(target);
Preconditions.checkNotNull(options);
Expand All @@ -69,8 +69,8 @@ public static ManagedChannel newChannel(
newNettyChannelBuilder(targetUrl, proxy)
.negotiationType(
isTlsEnabled(target) ? NegotiationType.TLS : NegotiationType.PLAINTEXT);
if (interceptor != null) {
builder.intercept(interceptor);
if (interceptors != null) {
builder.intercept(interceptors);
}
if (sslContext != null) {
builder.sslContext(sslContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,31 +122,27 @@ private int computeMaxMissingBlobsDigestsPerMessage() {
private ContentAddressableStorageFutureStub casFutureStub() {
return ContentAddressableStorageGrpc.newFutureStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}

private ByteStreamStub bsAsyncStub() {
return ByteStreamGrpc.newStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}

private ActionCacheBlockingStub acBlockingStub() {
return ActionCacheGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}

private ActionCacheFutureStub acFutureStub() {
return ActionCacheGrpc.newFutureStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ public GrpcRemoteExecutor(
private ExecutionBlockingStub execBlockingStub() {
return ExecutionGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newExecHeadersInterceptor(options))
.withCallCredentials(callCredentials);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.netty.channel.unix.DomainSocketAddress;
import java.io.IOException;
import java.net.URI;
import java.util.List;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -59,10 +60,10 @@ public static ReferenceCountedChannel createGrpcChannel(
String target,
String proxyUri,
AuthAndTLSOptions authOptions,
@Nullable ClientInterceptor interceptor)
@Nullable List<ClientInterceptor> interceptors)
throws IOException {
return new ReferenceCountedChannel(
GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptor));
GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptors));
}

public static RemoteCacheClient create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,37 @@ public static boolean shouldEnableRemoteExecution(RemoteOptions options) {
return !Strings.isNullOrEmpty(options.remoteExecutor);
}

private void verifyServerCapabilities(
RemoteOptions remoteOptions,
ReferenceCountedChannel channel,
CallCredentials credentials,
RemoteRetrier retrier,
CommandEnvironment env,
DigestUtil digestUtil)
throws AbruptExitException {
RemoteServerCapabilities rsc =
new RemoteServerCapabilities(
remoteOptions.remoteInstanceName,
channel,
credentials,
remoteOptions.remoteTimeout,
retrier);
ServerCapabilities capabilities = null;
try {
capabilities = rsc.get(env.getCommandId().toString(), env.getBuildRequestId());
} catch (IOException e) {
throw new AbruptExitException(
"Failed to query remote execution capabilities: " + Utils.grpcAwareErrorMessage(e),
ExitCode.REMOTE_ERROR,
e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return;
}
checkClientServerCompatibility(
capabilities, remoteOptions, digestUtil.getDigestFunction(), env.getReporter());
}

@Override
public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
Preconditions.checkState(actionContextProvider == null, "actionContextProvider must be null");
Expand Down Expand Up @@ -178,12 +209,17 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
ReferenceCountedChannel execChannel = null;
ReferenceCountedChannel cacheChannel = null;
if (enableRemoteExecution) {
ImmutableList.Builder<ClientInterceptor> interceptors = ImmutableList.builder();
interceptors.add(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions));
if (loggingInterceptor != null) {
interceptors.add(loggingInterceptor);
}
execChannel =
RemoteCacheClientFactory.createGrpcChannel(
remoteOptions.remoteExecutor,
remoteOptions.remoteProxy,
authAndTlsOptions,
loggingInterceptor);
interceptors.build());
// Create a separate channel if --remote_executor and --remote_cache point to different
// endpoints.
if (Strings.isNullOrEmpty(remoteOptions.remoteCache)
Expand All @@ -193,12 +229,17 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
}

if (cacheChannel == null) {
ImmutableList.Builder<ClientInterceptor> interceptors = ImmutableList.builder();
interceptors.add(TracingMetadataUtils.newCacheHeadersInterceptor(remoteOptions));
if (loggingInterceptor != null) {
interceptors.add(loggingInterceptor);
}
cacheChannel =
RemoteCacheClientFactory.createGrpcChannel(
remoteOptions.remoteCache,
remoteOptions.remoteProxy,
authAndTlsOptions,
loggingInterceptor);
interceptors.build());
}

CallCredentials credentials = GoogleAuthUtils.newCallCredentials(authAndTlsOptions);
Expand All @@ -212,27 +253,13 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
// We always query the execution server for capabilities, if it is defined. A remote
// execution/cache system should have all its servers to return the capabilities pertaining
// to the system as a whole.
RemoteServerCapabilities rsc =
new RemoteServerCapabilities(
remoteOptions.remoteInstanceName,
(execChannel != null ? execChannel : cacheChannel),
credentials,
remoteOptions.remoteTimeout,
retrier);
ServerCapabilities capabilities = null;
try {
capabilities = rsc.get(buildRequestId, invocationId);
} catch (IOException e) {
throw new AbruptExitException(
"Failed to query remote execution capabilities: " + Utils.grpcAwareErrorMessage(e),
ExitCode.REMOTE_ERROR,
e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return;
if (execChannel != null) {
verifyServerCapabilities(remoteOptions, execChannel, credentials, retrier, env, digestUtil);
}
if (cacheChannel != execChannel) {
verifyServerCapabilities(
remoteOptions, cacheChannel, credentials, retrier, env, digestUtil);
}
checkClientServerCompatibility(
capabilities, remoteOptions, digestUtil.getDigestFunction(), env.getReporter());

ByteStreamUploader uploader =
new ByteStreamUploader(
Expand All @@ -241,6 +268,7 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
credentials,
remoteOptions.remoteTimeout,
retrier);

cacheChannel.release();
RemoteCacheClient cacheClient =
new GrpcCacheClient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import io.grpc.util.MutableHandlerRegistry;
import java.io.ByteArrayInputStream;
Expand Down Expand Up @@ -689,6 +691,74 @@ public void queryWriteStatus(
withEmptyMetadata.detach(prevContext);
}

@Test
public void customHeadersAreAttachedToRequest() throws Exception {
RemoteRetrier retrier =
TestUtils.newRemoteRetrier(() -> new FixedBackoff(1, 0), (e) -> true, retryService);

Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER), "Value1");
metadata.put(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER), "Value2");

ByteStreamUploader uploader =
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(
InProcessChannelBuilder.forName("Server for " + this.getClass())
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata))
.build()),
null, /* timeout seconds */
60,
retrier);

byte[] blob = new byte[CHUNK_SIZE];
Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build();
HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());

serviceRegistry.addService(
ServerInterceptors.intercept(
new ByteStreamImplBase() {
@Override
public StreamObserver<WriteRequest> write(
StreamObserver<WriteResponse> streamObserver) {
return new StreamObserver<WriteRequest>() {
@Override
public void onNext(WriteRequest writeRequest) {}

@Override
public void onError(Throwable throwable) {
fail("onError should never be called.");
}

@Override
public void onCompleted() {
WriteResponse response =
WriteResponse.newBuilder().setCommittedSize(blob.length).build();
streamObserver.onNext(response);
streamObserver.onCompleted();
}
};
}
},
new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata metadata,
ServerCallHandler<ReqT, RespT> next) {
assertThat(metadata.get(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("Value1");
assertThat(metadata.get(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("Value2");
assertThat(metadata.get(Metadata.Key.of("Key3", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo(null);
return next.startCall(call, metadata);
}
}));

uploader.uploadBlob(hash, chunker, true);
}

@Test
public void sameBlobShouldNotBeUploadedTwice() throws Exception {
// Test that uploading the same file concurrently triggers only one file upload.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ private GrpcCacheClient newClient(RemoteOptions remoteOptions, Supplier<Backoff>
InProcessChannelBuilder.forName(fakeServerName)
.directExecutor()
.intercept(new CallCredentialsInterceptor(creds))
.intercept(TracingMetadataUtils.newCacheHeadersInterceptor(remoteOptions))
.build());
ByteStreamUploader uploader =
new ByteStreamUploader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ public PathFragment getExecPath() {
retryService);
ReferenceCountedChannel channel =
new ReferenceCountedChannel(
InProcessChannelBuilder.forName(fakeServerName).directExecutor().build());
InProcessChannelBuilder.forName(fakeServerName)
.intercept(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions))
.directExecutor()
.build());
GrpcRemoteExecutor executor =
new GrpcRemoteExecutor(channel.retain(), null, retrier, remoteOptions);
CallCredentials creds =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import build.bazel.remote.execution.v2.PriorityCapabilities.PriorityRange;
import build.bazel.remote.execution.v2.RequestMetadata;
import build.bazel.remote.execution.v2.ServerCapabilities;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
Expand Down Expand Up @@ -107,6 +109,60 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
}
}

private static class RequestCustomHeadersValidator implements ServerInterceptor {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
assertThat(headers.get(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("Value1");
assertThat(headers.get(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("Value2");
return next.startCall(call, headers);
}
}

@Test
public void testCustomHeadersAreAttached() throws Exception {
ServerCapabilities caps =
ServerCapabilities.newBuilder()
.setExecutionCapabilities(
ExecutionCapabilities.newBuilder().setExecEnabled(true).build())
.build();
serviceRegistry.addService(
ServerInterceptors.intercept(
new CapabilitiesImplBase() {
@Override
public void getCapabilities(
GetCapabilitiesRequest request,
StreamObserver<ServerCapabilities> responseObserver) {
responseObserver.onNext(caps);
responseObserver.onCompleted();
}
},
new RequestCustomHeadersValidator()));

RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class);
remoteOptions.remoteHeaders =
ImmutableList.of(
Maps.immutableEntry("Key1", "Value1"), Maps.immutableEntry("Key2", "Value2"));

RemoteRetrier retrier =
TestUtils.newRemoteRetrier(
() -> new ExponentialBackoff(remoteOptions),
RemoteRetrier.RETRIABLE_GRPC_ERRORS,
retryService);
ReferenceCountedChannel channel =
new ReferenceCountedChannel(
InProcessChannelBuilder.forName(fakeServerName)
.intercept(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions))
.directExecutor()
.build());
RemoteServerCapabilities client =
new RemoteServerCapabilities("instance", channel.retain(), null, 3, retrier);

assertThat(client.get("build-req-id", "command-id")).isEqualTo(caps);
}

@Test
public void testGetCapabilitiesWithRetries() throws Exception {
ServerCapabilities caps =
Expand Down

0 comments on commit 52c8773

Please sign in to comment.