Skip to content

Commit

Permalink
Adding secure grpc query server support (apache#8207)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 committed Feb 23, 2022
1 parent c0eee8f commit c0ee672
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,46 +18,107 @@
*/
package org.apache.pinot.common.utils.grpc;

import com.google.common.collect.ImmutableMap;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import java.util.Iterator;
import java.util.Map;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.proto.PinotQueryServerGrpc;
import org.apache.pinot.common.proto.Server;
import org.apache.pinot.common.utils.TlsUtils;
import org.apache.pinot.spi.env.PinotConfiguration;


public class GrpcQueryClient {
private final ManagedChannel _managedChannel;
private final PinotQueryServerGrpc.PinotQueryServerBlockingStub _blockingStub;

public GrpcQueryClient(String host, int port) {
this(host, port, new Config());
}

public GrpcQueryClient(String host, int port, Config config) {
ManagedChannelBuilder managedChannelBuilder = ManagedChannelBuilder
.forAddress(host, port)
.maxInboundMessageSize(config.getMaxInboundMessageSizeBytes());
if (config.isUsePlainText()) {
managedChannelBuilder.usePlaintext();
_managedChannel =
ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.usePlaintext().build();
} else {
try {
SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
if (config.getTlsConfig().getKeyStorePath() != null) {
KeyManagerFactory keyManagerFactory = TlsUtils.createKeyManagerFactory(config.getTlsConfig());
sslContextBuilder.keyManager(keyManagerFactory);
}
if (config.getTlsConfig().getTrustStorePath() != null) {
TrustManagerFactory trustManagerFactory = TlsUtils.createTrustManagerFactory(config.getTlsConfig());
sslContextBuilder.trustManager(trustManagerFactory);
}
if (config.getTlsConfig().getSslProvider() != null) {
sslContextBuilder =
GrpcSslContexts.configure(sslContextBuilder, SslProvider.valueOf(config.getTlsConfig().getSslProvider()));
} else {
sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder);
}
_managedChannel =
NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.sslContext(sslContextBuilder.build()).build();
} catch (SSLException e) {
throw new RuntimeException("Failed to create Netty gRPC channel with SSL Context", e);
}
}
_blockingStub = PinotQueryServerGrpc.newBlockingStub(managedChannelBuilder.build());
_blockingStub = PinotQueryServerGrpc.newBlockingStub(_managedChannel);
}

public Iterator<Server.ServerResponse> submit(Server.ServerRequest request) {
return _blockingStub.submit(request);
}

public void close() {
if (!_managedChannel.isShutdown()) {
_managedChannel.shutdownNow();
}
}

public static class Config {
public static final String GRPC_TLS_PREFIX = "tls";
public static final String CONFIG_USE_PLAIN_TEXT = "usePlainText";
public static final String CONFIG_MAX_INBOUND_MESSAGE_BYTES_SIZE = "maxInboundMessageSizeBytes";
// Default max message size to 128MB
private static final int DEFAULT_MAX_INBOUND_MESSAGE_BYTES_SIZE = 128 * 1024 * 1024;
public static final int DEFAULT_MAX_INBOUND_MESSAGE_BYTES_SIZE = 128 * 1024 * 1024;

private final int _maxInboundMessageSizeBytes;
private final boolean _usePlainText;
private final TlsConfig _tlsConfig;
private final PinotConfiguration _pinotConfig;

public Config() {
this(DEFAULT_MAX_INBOUND_MESSAGE_BYTES_SIZE, true);
}

public Config(int maxInboundMessageSizeBytes, boolean usePlainText) {
_maxInboundMessageSizeBytes = maxInboundMessageSizeBytes;
_usePlainText = usePlainText;
this(ImmutableMap.of(CONFIG_MAX_INBOUND_MESSAGE_BYTES_SIZE, maxInboundMessageSizeBytes, CONFIG_USE_PLAIN_TEXT,
usePlainText));
}

public Config(Map<String, Object> configMap) {
_pinotConfig = new PinotConfiguration(configMap);
_maxInboundMessageSizeBytes =
_pinotConfig.getProperty(CONFIG_MAX_INBOUND_MESSAGE_BYTES_SIZE, DEFAULT_MAX_INBOUND_MESSAGE_BYTES_SIZE);
_usePlainText = Boolean.valueOf(configMap.get(CONFIG_USE_PLAIN_TEXT).toString());
_tlsConfig = TlsUtils.extractTlsConfig(_pinotConfig, GRPC_TLS_PREFIX);
}

// Allow get customized configs.
public Object get(String key) {
return _pinotConfig.getProperty(key);
}

public int getMaxInboundMessageSizeBytes() {
Expand All @@ -67,5 +128,13 @@ public int getMaxInboundMessageSizeBytes() {
public boolean isUsePlainText() {
return _usePlainText;
}

public TlsConfig getTlsConfig() {
return _tlsConfig;
}

public PinotConfiguration getPinotConfig() {
return _pinotConfig;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,24 @@
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.metrics.ServerMeter;
import org.apache.pinot.common.metrics.ServerMetrics;
import org.apache.pinot.common.proto.PinotQueryServerGrpc;
import org.apache.pinot.common.proto.Server.ServerRequest;
import org.apache.pinot.common.proto.Server.ServerResponse;
import org.apache.pinot.common.utils.DataTable;
import org.apache.pinot.common.utils.TlsUtils;
import org.apache.pinot.core.operator.streaming.StreamingResponseUtils;
import org.apache.pinot.core.query.executor.QueryExecutor;
import org.apache.pinot.core.query.request.ServerQueryRequest;
Expand All @@ -52,16 +60,41 @@ public class GrpcQueryServer extends PinotQueryServerGrpc.PinotQueryServerImplBa
Executors.newFixedThreadPool(ResourceManager.DEFAULT_QUERY_WORKER_THREADS);
private final AccessControl _accessControl;

public GrpcQueryServer(int port, QueryExecutor queryExecutor, ServerMetrics serverMetrics,
public GrpcQueryServer(int port, TlsConfig tlsConfig, QueryExecutor queryExecutor, ServerMetrics serverMetrics,
AccessControl accessControl) {
_queryExecutor = queryExecutor;
_serverMetrics = serverMetrics;
_server = ServerBuilder.forPort(port).addService(this).build();
if (tlsConfig != null) {
try {
_server = NettyServerBuilder.forPort(port).sslContext(buildGRpcSslContext(tlsConfig)).addService(this).build();
} catch (Exception e) {
throw new RuntimeException("Failed to start secure grpcQueryServer", e);
}
} else {
_server = ServerBuilder.forPort(port).addService(this).build();
}
_accessControl = accessControl;
LOGGER.info("Initialized GrpcQueryServer on port: {} with numWorkerThreads: {}", port,
ResourceManager.DEFAULT_QUERY_WORKER_THREADS);
}

private SslContext buildGRpcSslContext(TlsConfig tlsConfig)
throws Exception {
LOGGER.info("Building gRPC SSL context");
if (tlsConfig.getKeyStorePath() == null) {
throw new IllegalArgumentException("Must provide key store path for secured gRpc server");
}
SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(TlsUtils.createKeyManagerFactory(tlsConfig))
.sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider()));
if (tlsConfig.getTrustStorePath() != null) {
sslContextBuilder.trustManager(TlsUtils.createTrustManagerFactory(tlsConfig));
}
if (tlsConfig.isClientAuthEnabled()) {
sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
}
return GrpcSslContexts.configure(sslContextBuilder).build();
}

public void start() {
LOGGER.info("Starting GrpcQueryServer");
try {
Expand Down Expand Up @@ -98,8 +131,8 @@ public void submit(ServerRequest request, StreamObserver<ServerResponse> respons
if (!_accessControl.hasDataAccess(requestIdentity, queryRequest.getTableNameWithType())) {
Exception unsupportedOperationException = new UnsupportedOperationException(
String.format("No access to table %s while processing request %d: %s from broker: %s",
queryRequest.getTableNameWithType(), queryRequest.getRequestId(),
queryRequest.getQueryContext(), queryRequest.getBrokerId()));
queryRequest.getTableNameWithType(), queryRequest.getRequestId(), queryRequest.getQueryContext(),
queryRequest.getBrokerId()));
final String exceptionMsg = String.format("Table not found: %s", queryRequest.getTableNameWithType());
LOGGER.error(exceptionMsg, unsupportedOperationException);
_serverMetrics.addMeteredGlobalValue(ServerMeter.NO_TABLE_ACCESS, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,14 @@ public void setUp()
waitForAllDocsLoaded(600_000L);
}

void setExtraServerConfigs(PinotConfiguration serverConfig) {
}

protected void startServers() {
// Enable gRPC server
PinotConfiguration serverConfig = getDefaultServerConfiguration();
serverConfig.setProperty(CommonConstants.Server.CONFIG_OF_ENABLE_GRPC_SERVER, true);
setExtraServerConfigs(serverConfig);
startServer(serverConfig);
}

Expand Down Expand Up @@ -146,10 +150,14 @@ private void registerCallbackHandlers() {
}
}

public GrpcQueryClient getGrpcQueryClient() {
return new GrpcQueryClient("localhost", CommonConstants.Server.DEFAULT_GRPC_PORT);
}

@Test
public void testGrpcQueryServer()
throws Exception {
GrpcQueryClient queryClient = new GrpcQueryClient("localhost", CommonConstants.Server.DEFAULT_GRPC_PORT);
GrpcQueryClient queryClient = getGrpcQueryClient();
String sql = "SELECT * FROM mytable_OFFLINE LIMIT 1000000";
BrokerRequest brokerRequest = new Pql2Compiler().compileToBrokerRequest(sql);
List<String> segments = _helixResourceManager.getSegmentsFor("mytable_OFFLINE", false);
Expand All @@ -161,19 +169,21 @@ public void testGrpcQueryServer()
requestBuilder.setEnableStreaming(true);
testStreamingRequest(queryClient.submit(requestBuilder.setSql(sql).build()));
testStreamingRequest(queryClient.submit(requestBuilder.setBrokerRequest(brokerRequest).build()));
queryClient.close();
}

@Test(dataProvider = "provideSqlTestCases")
public void testQueryingGrpcServer(String sql)
throws Exception {
GrpcQueryClient queryClient = new GrpcQueryClient("localhost", CommonConstants.Server.DEFAULT_GRPC_PORT);
GrpcQueryClient queryClient = getGrpcQueryClient();
List<String> segments = _helixResourceManager.getSegmentsFor("mytable_OFFLINE", false);

GrpcRequestBuilder requestBuilder = new GrpcRequestBuilder().setSegments(segments);
DataTable dataTable = collectNonStreamingRequestResult(queryClient.submit(requestBuilder.setSql(sql).build()));

requestBuilder.setEnableStreaming(true);
collectAndCompareResult(queryClient.submit(requestBuilder.setSql(sql).build()), dataTable);
queryClient.close();
}

@DataProvider(name = "provideSqlTestCases")
Expand All @@ -184,7 +194,8 @@ public Object[][] provideSqlAndResultRowsAndNumDocScanTestCases() {
entries.add(new Object[]{"SELECT * FROM mytable_OFFLINE LIMIT 10000000"});
entries.add(new Object[]{"SELECT * FROM mytable_OFFLINE WHERE DaysSinceEpoch > 16312 LIMIT 10000000"});
entries.add(new Object[]{
"SELECT timeConvert(DaysSinceEpoch,'DAYS','SECONDS') FROM mytable_OFFLINE LIMIT 10000000"});
"SELECT timeConvert(DaysSinceEpoch,'DAYS','SECONDS') FROM mytable_OFFLINE LIMIT 10000000"
});

// aggregate
entries.add(new Object[]{"SELECT count(*) FROM mytable_OFFLINE"});
Expand All @@ -194,8 +205,10 @@ public Object[][] provideSqlAndResultRowsAndNumDocScanTestCases() {
entries.add(new Object[]{"SELECT DISTINCTCOUNT(AirlineID) FROM mytable_OFFLINE GROUP BY Carrier"});

// order by
entries.add(new Object[]{"SELECT DaysSinceEpoch, timeConvert(DaysSinceEpoch,'DAYS','SECONDS') "
+ "FROM mytable_OFFLINE ORDER BY DaysSinceEpoch limit 10000"});
entries.add(new Object[]{
"SELECT DaysSinceEpoch, timeConvert(DaysSinceEpoch,'DAYS','SECONDS') "
+ "FROM mytable_OFFLINE ORDER BY DaysSinceEpoch limit 10000"
});

return entries.toArray(new Object[entries.size()][]);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.integration.tests;

import java.net.URL;
import java.util.HashMap;
import java.util.Map;
import org.apache.pinot.common.utils.grpc.GrpcQueryClient;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.apache.pinot.spi.utils.CommonConstants;


public class OfflineSecureGRPCServerIntegrationTest extends OfflineGRPCServerIntegrationTest {
private static final String JKS = "JKS";
private static final String JDK = "JDK";
private static final String PASSWORD = "changeit";
private final URL _tlsStoreJKS = OfflineSecureGRPCServerIntegrationTest.class.getResource("/tlstest.jks");

@Override
protected void setExtraServerConfigs(PinotConfiguration serverConfig) {
serverConfig.setProperty(CommonConstants.Server.CONFIG_OF_GRPCTLS_SERVER_ENABLED, true);
serverConfig.setProperty("pinot.server.grpctls.client.auth.enabled", true);
serverConfig.setProperty("pinot.server.grpctls.keystore.type", JKS);
serverConfig.setProperty("pinot.server.grpctls.keystore.path", _tlsStoreJKS);
serverConfig.setProperty("pinot.server.grpctls.keystore.password", PASSWORD);
serverConfig.setProperty("pinot.server.grpctls.truststore.type", JKS);
serverConfig.setProperty("pinot.server.grpctls.truststore.path", _tlsStoreJKS);
serverConfig.setProperty("pinot.server.grpctls.truststore.password", PASSWORD);
serverConfig.setProperty("pinot.server.grpctls.ssl.provider", JDK);
}

@Override
public GrpcQueryClient getGrpcQueryClient() {
Map<String, Object> configMap = new HashMap<>();
configMap.put("usePlainText", "false");
configMap.put("tls.keystore.path", _tlsStoreJKS.getFile());
configMap.put("tls.keystore.password", PASSWORD);
configMap.put("tls.keystore.type", JKS);
configMap.put("tls.truststore.path", _tlsStoreJKS.getFile());
configMap.put("tls.truststore.password", PASSWORD);
configMap.put("tls.truststore.type", JKS);
configMap.put("tls.ssl.provider", JDK);
GrpcQueryClient.Config config = new GrpcQueryClient.Config(configMap);
return new GrpcQueryClient("localhost", CommonConstants.Server.DEFAULT_GRPC_PORT, config);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ public boolean isEnableGrpcServer() {
return _serverConf.getProperty(Server.CONFIG_OF_ENABLE_GRPC_SERVER, Server.DEFAULT_ENABLE_GRPC_SERVER);
}

public boolean isGrpcTlsServerEnabled() {
return _serverConf.getProperty(Server.CONFIG_OF_GRPCTLS_SERVER_ENABLED, Server.DEFAULT_GRPCTLS_SERVER_ENABLED);
}

public boolean isEnableSwagger() {
return _serverConf.getProperty(CONFIG_OF_SWAGGER_SERVER_ENABLED, DEFAULT_SWAGGER_SERVER_ENABLED);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,16 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo
if (serverConf.isNettyTlsServerEnabled()) {
int nettySecPort = serverConf.getNettyTlsPort();
LOGGER.info("Initializing TLS-secured Netty query server on port: {}", nettySecPort);
_nettyTlsQueryServer = new QueryServer(nettySecPort, _queryScheduler, _serverMetrics, tlsConfig,
_accessControl);
_nettyTlsQueryServer = new QueryServer(nettySecPort, _queryScheduler, _serverMetrics, tlsConfig, _accessControl);
} else {
_nettyTlsQueryServer = null;
}

if (serverConf.isEnableGrpcServer()) {
if (tlsConfig.isCustomized()) {
LOGGER.warn("gRPC query server does not support TLS yet");
}

int grpcPort = serverConf.getGrpcPort();
LOGGER.info("Initializing gRPC query server on port: {}", grpcPort);
_grpcQueryServer = new GrpcQueryServer(grpcPort, _queryExecutor, _serverMetrics, _accessControl);
_grpcQueryServer = new GrpcQueryServer(grpcPort,
serverConf.isGrpcTlsServerEnabled() ? TlsUtils.extractTlsConfig(serverConf.getPinotConfig(),
CommonConstants.Server.SERVER_GRPCTLS_PREFIX) : null, _queryExecutor, _serverMetrics, _accessControl);
} else {
_grpcQueryServer = null;
}
Expand Down
Loading

0 comments on commit c0ee672

Please sign in to comment.