From c0ee67213c09375032c992b768d2e8842318efa7 Mon Sep 17 00:00:00 2001 From: Xiang Fu Date: Wed, 16 Feb 2022 17:53:21 -0800 Subject: [PATCH] Adding secure grpc query server support (#8207) --- .../common/utils/grpc/GrpcQueryClient.java | 85 +++++++++++++++++-- .../core/transport/grpc/GrpcQueryServer.java | 41 ++++++++- .../OfflineGRPCServerIntegrationTest.java | 23 +++-- ...fflineSecureGRPCServerIntegrationTest.java | 62 ++++++++++++++ .../apache/pinot/server/conf/ServerConf.java | 4 + .../pinot/server/starter/ServerInstance.java | 12 +-- .../pinot/spi/utils/CommonConstants.java | 3 + 7 files changed, 205 insertions(+), 25 deletions(-) create mode 100644 pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineSecureGRPCServerIntegrationTest.java diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java index e41c20143a7b..bbb70d85e98b 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java @@ -18,13 +18,27 @@ */ package org.apache.pinot.common.utils.grpc; +import com.google.common.collect.ImmutableMap; +import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; import java.util.Iterator; +import java.util.Map; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; +import org.apache.pinot.common.config.TlsConfig; import org.apache.pinot.common.proto.PinotQueryServerGrpc; import org.apache.pinot.common.proto.Server; +import org.apache.pinot.common.utils.TlsUtils; +import org.apache.pinot.spi.env.PinotConfiguration; public class GrpcQueryClient { + private final ManagedChannel _managedChannel; private final PinotQueryServerGrpc.PinotQueryServerBlockingStub _blockingStub; public GrpcQueryClient(String host, int port) { @@ -32,32 +46,79 @@ public GrpcQueryClient(String host, int port) { } public GrpcQueryClient(String host, int port, Config config) { - ManagedChannelBuilder managedChannelBuilder = ManagedChannelBuilder - .forAddress(host, port) - .maxInboundMessageSize(config.getMaxInboundMessageSizeBytes()); if (config.isUsePlainText()) { - managedChannelBuilder.usePlaintext(); + _managedChannel = + ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes()) + .usePlaintext().build(); + } else { + try { + SslContextBuilder sslContextBuilder = SslContextBuilder.forClient(); + if (config.getTlsConfig().getKeyStorePath() != null) { + KeyManagerFactory keyManagerFactory = TlsUtils.createKeyManagerFactory(config.getTlsConfig()); + sslContextBuilder.keyManager(keyManagerFactory); + } + if (config.getTlsConfig().getTrustStorePath() != null) { + TrustManagerFactory trustManagerFactory = TlsUtils.createTrustManagerFactory(config.getTlsConfig()); + sslContextBuilder.trustManager(trustManagerFactory); + } + if (config.getTlsConfig().getSslProvider() != null) { + sslContextBuilder = + GrpcSslContexts.configure(sslContextBuilder, SslProvider.valueOf(config.getTlsConfig().getSslProvider())); + } else { + sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder); + } + _managedChannel = + NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes()) + .sslContext(sslContextBuilder.build()).build(); + } catch (SSLException e) { + throw new RuntimeException("Failed to create Netty gRPC channel with SSL Context", e); + } } - _blockingStub = PinotQueryServerGrpc.newBlockingStub(managedChannelBuilder.build()); + _blockingStub = PinotQueryServerGrpc.newBlockingStub(_managedChannel); } public Iterator submit(Server.ServerRequest request) { return _blockingStub.submit(request); } + public void close() { + if (!_managedChannel.isShutdown()) { + _managedChannel.shutdownNow(); + } + } + public static class Config { + public static final String GRPC_TLS_PREFIX = "tls"; + public static final String CONFIG_USE_PLAIN_TEXT = "usePlainText"; + public static final String CONFIG_MAX_INBOUND_MESSAGE_BYTES_SIZE = "maxInboundMessageSizeBytes"; // Default max message size to 128MB - private static final int DEFAULT_MAX_INBOUND_MESSAGE_BYTES_SIZE = 128 * 1024 * 1024; + public static final int DEFAULT_MAX_INBOUND_MESSAGE_BYTES_SIZE = 128 * 1024 * 1024; + private final int _maxInboundMessageSizeBytes; private final boolean _usePlainText; + private final TlsConfig _tlsConfig; + private final PinotConfiguration _pinotConfig; public Config() { this(DEFAULT_MAX_INBOUND_MESSAGE_BYTES_SIZE, true); } public Config(int maxInboundMessageSizeBytes, boolean usePlainText) { - _maxInboundMessageSizeBytes = maxInboundMessageSizeBytes; - _usePlainText = usePlainText; + this(ImmutableMap.of(CONFIG_MAX_INBOUND_MESSAGE_BYTES_SIZE, maxInboundMessageSizeBytes, CONFIG_USE_PLAIN_TEXT, + usePlainText)); + } + + public Config(Map configMap) { + _pinotConfig = new PinotConfiguration(configMap); + _maxInboundMessageSizeBytes = + _pinotConfig.getProperty(CONFIG_MAX_INBOUND_MESSAGE_BYTES_SIZE, DEFAULT_MAX_INBOUND_MESSAGE_BYTES_SIZE); + _usePlainText = Boolean.valueOf(configMap.get(CONFIG_USE_PLAIN_TEXT).toString()); + _tlsConfig = TlsUtils.extractTlsConfig(_pinotConfig, GRPC_TLS_PREFIX); + } + + // Allow get customized configs. + public Object get(String key) { + return _pinotConfig.getProperty(key); } public int getMaxInboundMessageSizeBytes() { @@ -67,5 +128,13 @@ public int getMaxInboundMessageSizeBytes() { public boolean isUsePlainText() { return _usePlainText; } + + public TlsConfig getTlsConfig() { + return _tlsConfig; + } + + public PinotConfiguration getPinotConfig() { + return _pinotConfig; + } } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java b/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java index 1711ec78d2d1..611524464156 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java @@ -21,16 +21,24 @@ import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.Status; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; import io.grpc.stub.StreamObserver; import java.io.IOException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import org.apache.pinot.common.config.TlsConfig; import org.apache.pinot.common.metrics.ServerMeter; import org.apache.pinot.common.metrics.ServerMetrics; import org.apache.pinot.common.proto.PinotQueryServerGrpc; import org.apache.pinot.common.proto.Server.ServerRequest; import org.apache.pinot.common.proto.Server.ServerResponse; import org.apache.pinot.common.utils.DataTable; +import org.apache.pinot.common.utils.TlsUtils; import org.apache.pinot.core.operator.streaming.StreamingResponseUtils; import org.apache.pinot.core.query.executor.QueryExecutor; import org.apache.pinot.core.query.request.ServerQueryRequest; @@ -52,16 +60,41 @@ public class GrpcQueryServer extends PinotQueryServerGrpc.PinotQueryServerImplBa Executors.newFixedThreadPool(ResourceManager.DEFAULT_QUERY_WORKER_THREADS); private final AccessControl _accessControl; - public GrpcQueryServer(int port, QueryExecutor queryExecutor, ServerMetrics serverMetrics, + public GrpcQueryServer(int port, TlsConfig tlsConfig, QueryExecutor queryExecutor, ServerMetrics serverMetrics, AccessControl accessControl) { _queryExecutor = queryExecutor; _serverMetrics = serverMetrics; - _server = ServerBuilder.forPort(port).addService(this).build(); + if (tlsConfig != null) { + try { + _server = NettyServerBuilder.forPort(port).sslContext(buildGRpcSslContext(tlsConfig)).addService(this).build(); + } catch (Exception e) { + throw new RuntimeException("Failed to start secure grpcQueryServer", e); + } + } else { + _server = ServerBuilder.forPort(port).addService(this).build(); + } _accessControl = accessControl; LOGGER.info("Initialized GrpcQueryServer on port: {} with numWorkerThreads: {}", port, ResourceManager.DEFAULT_QUERY_WORKER_THREADS); } + private SslContext buildGRpcSslContext(TlsConfig tlsConfig) + throws Exception { + LOGGER.info("Building gRPC SSL context"); + if (tlsConfig.getKeyStorePath() == null) { + throw new IllegalArgumentException("Must provide key store path for secured gRpc server"); + } + SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(TlsUtils.createKeyManagerFactory(tlsConfig)) + .sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider())); + if (tlsConfig.getTrustStorePath() != null) { + sslContextBuilder.trustManager(TlsUtils.createTrustManagerFactory(tlsConfig)); + } + if (tlsConfig.isClientAuthEnabled()) { + sslContextBuilder.clientAuth(ClientAuth.REQUIRE); + } + return GrpcSslContexts.configure(sslContextBuilder).build(); + } + public void start() { LOGGER.info("Starting GrpcQueryServer"); try { @@ -98,8 +131,8 @@ public void submit(ServerRequest request, StreamObserver respons if (!_accessControl.hasDataAccess(requestIdentity, queryRequest.getTableNameWithType())) { Exception unsupportedOperationException = new UnsupportedOperationException( String.format("No access to table %s while processing request %d: %s from broker: %s", - queryRequest.getTableNameWithType(), queryRequest.getRequestId(), - queryRequest.getQueryContext(), queryRequest.getBrokerId())); + queryRequest.getTableNameWithType(), queryRequest.getRequestId(), queryRequest.getQueryContext(), + queryRequest.getBrokerId())); final String exceptionMsg = String.format("Table not found: %s", queryRequest.getTableNameWithType()); LOGGER.error(exceptionMsg, unsupportedOperationException); _serverMetrics.addMeteredGlobalValue(ServerMeter.NO_TABLE_ACCESS, 1); diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineGRPCServerIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineGRPCServerIntegrationTest.java index f3df4c9cd3fe..ee8ce590a2cf 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineGRPCServerIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineGRPCServerIntegrationTest.java @@ -112,10 +112,14 @@ public void setUp() waitForAllDocsLoaded(600_000L); } + void setExtraServerConfigs(PinotConfiguration serverConfig) { + } + protected void startServers() { // Enable gRPC server PinotConfiguration serverConfig = getDefaultServerConfiguration(); serverConfig.setProperty(CommonConstants.Server.CONFIG_OF_ENABLE_GRPC_SERVER, true); + setExtraServerConfigs(serverConfig); startServer(serverConfig); } @@ -146,10 +150,14 @@ private void registerCallbackHandlers() { } } + public GrpcQueryClient getGrpcQueryClient() { + return new GrpcQueryClient("localhost", CommonConstants.Server.DEFAULT_GRPC_PORT); + } + @Test public void testGrpcQueryServer() throws Exception { - GrpcQueryClient queryClient = new GrpcQueryClient("localhost", CommonConstants.Server.DEFAULT_GRPC_PORT); + GrpcQueryClient queryClient = getGrpcQueryClient(); String sql = "SELECT * FROM mytable_OFFLINE LIMIT 1000000"; BrokerRequest brokerRequest = new Pql2Compiler().compileToBrokerRequest(sql); List segments = _helixResourceManager.getSegmentsFor("mytable_OFFLINE", false); @@ -161,12 +169,13 @@ public void testGrpcQueryServer() requestBuilder.setEnableStreaming(true); testStreamingRequest(queryClient.submit(requestBuilder.setSql(sql).build())); testStreamingRequest(queryClient.submit(requestBuilder.setBrokerRequest(brokerRequest).build())); + queryClient.close(); } @Test(dataProvider = "provideSqlTestCases") public void testQueryingGrpcServer(String sql) throws Exception { - GrpcQueryClient queryClient = new GrpcQueryClient("localhost", CommonConstants.Server.DEFAULT_GRPC_PORT); + GrpcQueryClient queryClient = getGrpcQueryClient(); List segments = _helixResourceManager.getSegmentsFor("mytable_OFFLINE", false); GrpcRequestBuilder requestBuilder = new GrpcRequestBuilder().setSegments(segments); @@ -174,6 +183,7 @@ public void testQueryingGrpcServer(String sql) requestBuilder.setEnableStreaming(true); collectAndCompareResult(queryClient.submit(requestBuilder.setSql(sql).build()), dataTable); + queryClient.close(); } @DataProvider(name = "provideSqlTestCases") @@ -184,7 +194,8 @@ public Object[][] provideSqlAndResultRowsAndNumDocScanTestCases() { entries.add(new Object[]{"SELECT * FROM mytable_OFFLINE LIMIT 10000000"}); entries.add(new Object[]{"SELECT * FROM mytable_OFFLINE WHERE DaysSinceEpoch > 16312 LIMIT 10000000"}); entries.add(new Object[]{ - "SELECT timeConvert(DaysSinceEpoch,'DAYS','SECONDS') FROM mytable_OFFLINE LIMIT 10000000"}); + "SELECT timeConvert(DaysSinceEpoch,'DAYS','SECONDS') FROM mytable_OFFLINE LIMIT 10000000" + }); // aggregate entries.add(new Object[]{"SELECT count(*) FROM mytable_OFFLINE"}); @@ -194,8 +205,10 @@ public Object[][] provideSqlAndResultRowsAndNumDocScanTestCases() { entries.add(new Object[]{"SELECT DISTINCTCOUNT(AirlineID) FROM mytable_OFFLINE GROUP BY Carrier"}); // order by - entries.add(new Object[]{"SELECT DaysSinceEpoch, timeConvert(DaysSinceEpoch,'DAYS','SECONDS') " - + "FROM mytable_OFFLINE ORDER BY DaysSinceEpoch limit 10000"}); + entries.add(new Object[]{ + "SELECT DaysSinceEpoch, timeConvert(DaysSinceEpoch,'DAYS','SECONDS') " + + "FROM mytable_OFFLINE ORDER BY DaysSinceEpoch limit 10000" + }); return entries.toArray(new Object[entries.size()][]); } diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineSecureGRPCServerIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineSecureGRPCServerIntegrationTest.java new file mode 100644 index 000000000000..1183b792028a --- /dev/null +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineSecureGRPCServerIntegrationTest.java @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.integration.tests; + +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import org.apache.pinot.common.utils.grpc.GrpcQueryClient; +import org.apache.pinot.spi.env.PinotConfiguration; +import org.apache.pinot.spi.utils.CommonConstants; + + +public class OfflineSecureGRPCServerIntegrationTest extends OfflineGRPCServerIntegrationTest { + private static final String JKS = "JKS"; + private static final String JDK = "JDK"; + private static final String PASSWORD = "changeit"; + private final URL _tlsStoreJKS = OfflineSecureGRPCServerIntegrationTest.class.getResource("/tlstest.jks"); + + @Override + protected void setExtraServerConfigs(PinotConfiguration serverConfig) { + serverConfig.setProperty(CommonConstants.Server.CONFIG_OF_GRPCTLS_SERVER_ENABLED, true); + serverConfig.setProperty("pinot.server.grpctls.client.auth.enabled", true); + serverConfig.setProperty("pinot.server.grpctls.keystore.type", JKS); + serverConfig.setProperty("pinot.server.grpctls.keystore.path", _tlsStoreJKS); + serverConfig.setProperty("pinot.server.grpctls.keystore.password", PASSWORD); + serverConfig.setProperty("pinot.server.grpctls.truststore.type", JKS); + serverConfig.setProperty("pinot.server.grpctls.truststore.path", _tlsStoreJKS); + serverConfig.setProperty("pinot.server.grpctls.truststore.password", PASSWORD); + serverConfig.setProperty("pinot.server.grpctls.ssl.provider", JDK); + } + + @Override + public GrpcQueryClient getGrpcQueryClient() { + Map configMap = new HashMap<>(); + configMap.put("usePlainText", "false"); + configMap.put("tls.keystore.path", _tlsStoreJKS.getFile()); + configMap.put("tls.keystore.password", PASSWORD); + configMap.put("tls.keystore.type", JKS); + configMap.put("tls.truststore.path", _tlsStoreJKS.getFile()); + configMap.put("tls.truststore.password", PASSWORD); + configMap.put("tls.truststore.type", JKS); + configMap.put("tls.ssl.provider", JDK); + GrpcQueryClient.Config config = new GrpcQueryClient.Config(configMap); + return new GrpcQueryClient("localhost", CommonConstants.Server.DEFAULT_GRPC_PORT, config); + } +} diff --git a/pinot-server/src/main/java/org/apache/pinot/server/conf/ServerConf.java b/pinot-server/src/main/java/org/apache/pinot/server/conf/ServerConf.java index 0e700cd87f0e..d6300134194b 100644 --- a/pinot-server/src/main/java/org/apache/pinot/server/conf/ServerConf.java +++ b/pinot-server/src/main/java/org/apache/pinot/server/conf/ServerConf.java @@ -92,6 +92,10 @@ public boolean isEnableGrpcServer() { return _serverConf.getProperty(Server.CONFIG_OF_ENABLE_GRPC_SERVER, Server.DEFAULT_ENABLE_GRPC_SERVER); } + public boolean isGrpcTlsServerEnabled() { + return _serverConf.getProperty(Server.CONFIG_OF_GRPCTLS_SERVER_ENABLED, Server.DEFAULT_GRPCTLS_SERVER_ENABLED); + } + public boolean isEnableSwagger() { return _serverConf.getProperty(CONFIG_OF_SWAGGER_SERVER_ENABLED, DEFAULT_SWAGGER_SERVER_ENABLED); } diff --git a/pinot-server/src/main/java/org/apache/pinot/server/starter/ServerInstance.java b/pinot-server/src/main/java/org/apache/pinot/server/starter/ServerInstance.java index a9b5f55030e2..78e00d525ba2 100644 --- a/pinot-server/src/main/java/org/apache/pinot/server/starter/ServerInstance.java +++ b/pinot-server/src/main/java/org/apache/pinot/server/starter/ServerInstance.java @@ -111,20 +111,16 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo if (serverConf.isNettyTlsServerEnabled()) { int nettySecPort = serverConf.getNettyTlsPort(); LOGGER.info("Initializing TLS-secured Netty query server on port: {}", nettySecPort); - _nettyTlsQueryServer = new QueryServer(nettySecPort, _queryScheduler, _serverMetrics, tlsConfig, - _accessControl); + _nettyTlsQueryServer = new QueryServer(nettySecPort, _queryScheduler, _serverMetrics, tlsConfig, _accessControl); } else { _nettyTlsQueryServer = null; } - if (serverConf.isEnableGrpcServer()) { - if (tlsConfig.isCustomized()) { - LOGGER.warn("gRPC query server does not support TLS yet"); - } - int grpcPort = serverConf.getGrpcPort(); LOGGER.info("Initializing gRPC query server on port: {}", grpcPort); - _grpcQueryServer = new GrpcQueryServer(grpcPort, _queryExecutor, _serverMetrics, _accessControl); + _grpcQueryServer = new GrpcQueryServer(grpcPort, + serverConf.isGrpcTlsServerEnabled() ? TlsUtils.extractTlsConfig(serverConf.getPinotConfig(), + CommonConstants.Server.SERVER_GRPCTLS_PREFIX) : null, _queryExecutor, _serverMetrics, _accessControl); } else { _grpcQueryServer = null; } diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java index 4a8c77424e27..2dca2c0ff409 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java @@ -271,6 +271,8 @@ public static class Server { public static final boolean DEFAULT_ENABLE_GRPC_SERVER = false; public static final String CONFIG_OF_GRPC_PORT = "pinot.server.grpc.port"; public static final int DEFAULT_GRPC_PORT = 8090; + public static final String CONFIG_OF_GRPCTLS_SERVER_ENABLED = "pinot.server.grpctls.enabled"; + public static final boolean DEFAULT_GRPCTLS_SERVER_ENABLED = false; public static final String CONFIG_OF_NETTYTLS_SERVER_ENABLED = "pinot.server.nettytls.enabled"; public static final boolean DEFAULT_NETTYTLS_SERVER_ENABLED = false; public static final String CONFIG_OF_SWAGGER_SERVER_ENABLED = "pinot.server.swagger.enabled"; @@ -362,6 +364,7 @@ public static class Server { public static final String SERVER_TLS_PREFIX = "pinot.server.tls"; public static final String SERVER_NETTYTLS_PREFIX = "pinot.server.nettytls"; + public static final String SERVER_GRPCTLS_PREFIX = "pinot.server.grpctls"; // The complete config key is pinot.server.instance.segment.store.uri public static final String CONFIG_OF_SEGMENT_STORE_URI = "segment.store.uri";