Skip to content

Commit

Permalink
Pass --remote_header's headers to grpc endpoints
Browse files Browse the repository at this point in the history
Extends bazelbuild#8245 to gRPC.

Closes bazelbuild#10015.

PiperOrigin-RevId: 286175898
  • Loading branch information
Alessandro Patti committed Jan 13, 2020
1 parent a8dcf15 commit ce44f66
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,31 @@ private int computeMaxMissingBlobsDigestsPerMessage() {
private ContentAddressableStorageFutureStub casFutureStub() {
return ContentAddressableStorageGrpc.newFutureStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}

private ByteStreamStub bsAsyncStub() {
return ByteStreamGrpc.newStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}

private ActionCacheBlockingStub acBlockingStub() {
return ActionCacheGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}

private ActionCacheFutureStub acFutureStub() {
return ActionCacheGrpc.newFutureStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import build.bazel.remote.execution.v2.WaitExecutionRequest;
import com.google.common.base.Preconditions;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
import com.google.longrunning.Operation;
import com.google.rpc.Status;
Expand All @@ -42,19 +43,23 @@ class GrpcRemoteExecutor {
private final RemoteRetrier retrier;

private final AtomicBoolean closed = new AtomicBoolean();
private final RemoteOptions options;

public GrpcRemoteExecutor(
ReferenceCountedChannel channel,
@Nullable CallCredentials callCredentials,
RemoteRetrier retrier) {
RemoteRetrier retrier,
RemoteOptions options) {
this.channel = channel;
this.callCredentials = callCredentials;
this.retrier = retrier;
this.options = options;
}

private ExecutionBlockingStub execBlockingStub() {
return ExecutionGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.newExecHeadersInterceptor(options))
.withCallCredentials(callCredentials);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
new GrpcRemoteExecutor(
execChannel.retain(),
GoogleAuthUtils.newCallCredentials(authAndTlsOptions),
execRetrier);
execRetrier,
remoteOptions);
execChannel.release();
RemoteExecutionCache remoteCache =
new RemoteExecutionCache(cacheClient, remoteOptions, digestUtil);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,40 @@ public final class RemoteOptions extends OptionsBase {
documentationCategory = OptionDocumentationCategory.REMOTE,
effectTags = {OptionEffectTag.UNKNOWN},
help =
"Specify a HTTP header that will be included in requests: --remote_header=Name=Value. "
"Specify a header that will be included in requests: --remote_header=Name=Value. "
+ "Multiple headers can be passed by specifying the flag multiple times. Multiple "
+ "values for the same name will be converted to a comma-separated list. This flag"
+ "is currently only implemented for the HTTP protocol.",
+ "values for the same name will be converted to a comma-separated list.",
allowMultiple = true)
public List<Entry<String, String>> remoteHeaders;

@Option(
name = "remote_cache_header",
converter = Converters.AssignmentConverter.class,
defaultValue = "",
documentationCategory = OptionDocumentationCategory.REMOTE,
effectTags = {OptionEffectTag.UNKNOWN},
help =
"Specify a header that will be included in cache requests: "
+ "--remote_cache_header=Name=Value. "
+ "Multiple headers can be passed by specifying the flag multiple times. Multiple "
+ "values for the same name will be converted to a comma-separated list.",
allowMultiple = true)
public List<Entry<String, String>> remoteCacheHeaders;

@Option(
name = "remote_exec_header",
converter = Converters.AssignmentConverter.class,
defaultValue = "",
documentationCategory = OptionDocumentationCategory.REMOTE,
effectTags = {OptionEffectTag.UNKNOWN},
help =
"Specify a header that will be included in execution requests: "
+ "--remote_exec_header=Name=Value. "
+ "Multiple headers can be passed by specifying the flag multiple times. Multiple "
+ "values for the same name will be converted to a comma-separated list.",
allowMultiple = true)
public List<Entry<String, String>> remoteExecHeaders;

@Option(
name = "remote_timeout",
defaultValue = "60",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.common.base.Preconditions;
import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
import com.google.devtools.build.lib.remote.common.RemoteCacheClient.ActionKey;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
import io.grpc.ClientInterceptor;
import io.grpc.Context;
import io.grpc.Contexts;
Expand All @@ -29,6 +30,8 @@
import io.grpc.ServerInterceptor;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.MetadataUtils;
import java.util.List;
import java.util.Map.Entry;
import javax.annotation.Nullable;

/** Utility functions to handle Metadata for remote Grpc calls. */
Expand Down Expand Up @@ -118,6 +121,28 @@ public static ClientInterceptor attachMetadataFromContextInterceptor() {
return MetadataUtils.newAttachHeadersInterceptor(headersFromCurrentContext());
}

private static Metadata newMetadataForHeaders(List<Entry<String, String>> headers) {
Metadata metadata = new Metadata();
headers.forEach(
header ->
metadata.put(
Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER),
header.getValue()));
return metadata;
}

public static ClientInterceptor newCacheHeadersInterceptor(RemoteOptions options) {
Metadata metadata = newMetadataForHeaders(options.remoteHeaders);
metadata.merge(newMetadataForHeaders(options.remoteCacheHeaders));
return MetadataUtils.newAttachHeadersInterceptor(metadata);
}

public static ClientInterceptor newExecHeadersInterceptor(RemoteOptions options) {
Metadata metadata = newMetadataForHeaders(options.remoteHeaders);
metadata.merge(newMetadataForHeaders(options.remoteExecHeaders));
return MetadataUtils.newAttachHeadersInterceptor(metadata);
}

/** GRPC interceptor to add logging metadata to the GRPC context. */
public static class ServerHeadersInterceptor implements ServerInterceptor {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.devtools.build.lib.actions.ActionInputHelper;
Expand All @@ -75,14 +76,20 @@
import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
import com.google.devtools.common.options.Options;
import com.google.protobuf.ByteString;
import io.grpc.BindableService;
import io.grpc.CallCredentials;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
Expand Down Expand Up @@ -685,6 +692,74 @@ private ActionResult uploadDirectory(RemoteCache remoteCache, List<Path> outputs
return remoteCache.upload(actionKey, action, cmd, execRoot, outputs, outErr);
}

@Test
public void extraHeaders() throws Exception {
RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class);
remoteOptions.remoteHeaders =
ImmutableList.of(
Maps.immutableEntry("CommonKey1", "CommonValue1"),
Maps.immutableEntry("CommonKey2", "CommonValue2"));
remoteOptions.remoteExecHeaders =
ImmutableList.of(
Maps.immutableEntry("ExecKey1", "ExecValue1"),
Maps.immutableEntry("ExecKey2", "ExecValue2"));
remoteOptions.remoteCacheHeaders =
ImmutableList.of(
Maps.immutableEntry("CacheKey1", "CacheValue1"),
Maps.immutableEntry("CacheKey2", "CacheValue2"));

ServerInterceptor interceptor =
new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata metadata,
ServerCallHandler<ReqT, RespT> next) {
assertThat(
metadata.get(Metadata.Key.of("CommonKey1", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("CommonValue1");
assertThat(
metadata.get(Metadata.Key.of("CommonKey2", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("CommonValue2");
assertThat(metadata.get(Metadata.Key.of("CacheKey1", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("CacheValue1");
assertThat(metadata.get(Metadata.Key.of("CacheKey2", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("CacheValue2");
assertThat(metadata.get(Metadata.Key.of("ExecKey1", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo(null);
assertThat(metadata.get(Metadata.Key.of("ExecKey2", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo(null);
return next.startCall(call, metadata);
}
};

BindableService cas =
new ContentAddressableStorageImplBase() {
@Override
public void findMissingBlobs(
FindMissingBlobsRequest request,
StreamObserver<FindMissingBlobsResponse> responseObserver) {
responseObserver.onNext(FindMissingBlobsResponse.getDefaultInstance());
responseObserver.onCompleted();
}
};
serviceRegistry.addService(cas);
BindableService actionCache =
new ActionCacheImplBase() {
@Override
public void getActionResult(
GetActionResultRequest request, StreamObserver<ActionResult> responseObserver) {
responseObserver.onNext(ActionResult.getDefaultInstance());
responseObserver.onCompleted();
}
};
serviceRegistry.addService(ServerInterceptors.intercept(actionCache, interceptor));

GrpcCacheClient client = newClient(remoteOptions);
RemoteCache remoteCache = new RemoteCache(client, remoteOptions, DIGEST_UTIL);
remoteCache.downloadActionResult(DIGEST_UTIL.asActionKey(DIGEST_UTIL.computeAsUtf8("key")));
}

@Test
public void testUpload() throws Exception {
RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.devtools.build.lib.actions.ActionInput;
Expand Down Expand Up @@ -208,6 +209,20 @@ public PathFragment getExecPath() {
FileSystemUtils.createDirectoryAndParents(stderr.getParentDirectory());
outErr = new FileOutErr(stdout, stderr);
RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class);

remoteOptions.remoteHeaders =
ImmutableList.of(
Maps.immutableEntry("CommonKey1", "CommonValue1"),
Maps.immutableEntry("CommonKey2", "CommonValue2"));
remoteOptions.remoteExecHeaders =
ImmutableList.of(
Maps.immutableEntry("ExecKey1", "ExecValue1"),
Maps.immutableEntry("ExecKey2", "ExecValue2"));
remoteOptions.remoteCacheHeaders =
ImmutableList.of(
Maps.immutableEntry("CacheKey1", "CacheValue1"),
Maps.immutableEntry("CacheKey2", "CacheValue2"));

RemoteRetrier retrier =
TestUtils.newRemoteRetrier(
() -> new ExponentialBackoff(remoteOptions),
Expand All @@ -217,7 +232,7 @@ public PathFragment getExecPath() {
new ReferenceCountedChannel(
InProcessChannelBuilder.forName(fakeServerName).directExecutor().build());
GrpcRemoteExecutor executor =
new GrpcRemoteExecutor(channel.retain(), null, retrier);
new GrpcRemoteExecutor(channel.retain(), null, retrier, remoteOptions);
CallCredentials creds =
GoogleAuthUtils.newCallCredentials(Options.getDefaults(AuthAndTLSOptions.class));
ByteStreamUploader uploader =
Expand Down Expand Up @@ -510,6 +525,69 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
}
}

@Test
public void extraHeaders() throws Exception {
BindableService actionCache =
new ActionCacheImplBase() {
@Override
public void getActionResult(
GetActionResultRequest request, StreamObserver<ActionResult> responseObserver) {
responseObserver.onError(Status.NOT_FOUND.asRuntimeException());
}
};
serviceRegistry.addService(actionCache);

BindableService cas =
new ContentAddressableStorageImplBase() {
@Override
public void findMissingBlobs(
FindMissingBlobsRequest request,
StreamObserver<FindMissingBlobsResponse> responseObserver) {
responseObserver.onNext(FindMissingBlobsResponse.getDefaultInstance());
responseObserver.onCompleted();
}
};
serviceRegistry.addService(cas);

BindableService execService =
new ExecutionImplBase() {
@Override
public void execute(ExecuteRequest request, StreamObserver<Operation> responseObserver) {
responseObserver.onNext(Operation.getDefaultInstance());
responseObserver.onCompleted();
}
};
ServerInterceptor interceptor =
new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata metadata,
ServerCallHandler<ReqT, RespT> next) {
assertThat(
metadata.get(Metadata.Key.of("CommonKey1", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("CommonValue1");
assertThat(
metadata.get(Metadata.Key.of("CommonKey2", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("CommonValue2");
assertThat(metadata.get(Metadata.Key.of("ExecKey1", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("ExecValue1");
assertThat(metadata.get(Metadata.Key.of("ExecKey2", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("ExecValue2");
assertThat(metadata.get(Metadata.Key.of("CacheKey1", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo(null);
assertThat(metadata.get(Metadata.Key.of("CacheKey2", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo(null);
return next.startCall(call, metadata);
}
};
serviceRegistry.addService(ServerInterceptors.intercept(execService, interceptor));

FakeSpawnExecutionContext policy =
new FakeSpawnExecutionContext(simpleSpawn, fakeFileCache, execRoot, outErr);
client.exec(simpleSpawn, policy);
}

@Test
public void remotelyExecute() throws Exception {
BindableService actionCache =
Expand Down

0 comments on commit ce44f66

Please sign in to comment.