From 5a4a36b16723a42071a5fd75012695c02046061f Mon Sep 17 00:00:00 2001 From: shylock <33566796+Shylock-Hg@users.noreply.github.com> Date: Tue, 28 Dec 2021 20:28:41 +0800 Subject: [PATCH] Fix/fix tck ssl (#3570) * Fix some errors. * add tck ssl * fix ssl client Co-authored-by: HarrisChu <1726587+HarrisChu@users.noreply.github.com> --- src/common/ssl/SSLConfig.cpp | 1 + tests/common/nebula_service.py | 35 ++++++++++++++++++++++++++-------- tests/common/utils.py | 30 ++++++++++++++++++++++++----- tests/conftest.py | 30 ++++++++++++++++++----------- tests/job/test_session.py | 16 ++++++---------- tests/nebula-test-run.py | 25 ++++++++++++++++++++---- tests/tck/conftest.py | 5 +++-- 7 files changed, 102 insertions(+), 40 deletions(-) diff --git a/src/common/ssl/SSLConfig.cpp b/src/common/ssl/SSLConfig.cpp index 0a3592dd834..f08d819e324 100644 --- a/src/common/ssl/SSLConfig.cpp +++ b/src/common/ssl/SSLConfig.cpp @@ -18,6 +18,7 @@ namespace nebula { std::shared_ptr sslContextConfig() { auto sslCfg = std::make_shared(); sslCfg->addCertificate(FLAGS_cert_path, FLAGS_key_path, FLAGS_password_path); + sslCfg->clientVerification = folly::SSLContext::VerifyClientCertificate::DO_NOT_REQUEST; sslCfg->isDefault = true; return sslCfg; } diff --git a/tests/common/nebula_service.py b/tests/common/nebula_service.py index e688dc56f4d..3c4364974b8 100644 --- a/tests/common/nebula_service.py +++ b/tests/common/nebula_service.py @@ -14,10 +14,12 @@ import signal import copy import fcntl +import logging from pathlib import Path from contextlib import closing from tests.common.constants import TMP_DIR +from tests.common.utils import get_ssl_config from nebula2.gclient.net import ConnectionPool from nebula2.Config import Config @@ -133,6 +135,11 @@ def __init__( self.storaged_port = 0 self.graphd_port = 0 self.ca_signed = ca_signed + self.is_graph_ssl = ( + kwargs.get("enable_graph_ssl", "false").upper() == "TRUE" + or kwargs.get("enable_ssl", "false").upper() == "TRUE" + ) + self.debug_log = debug_log self.ports_per_process = 4 self.lock_file = os.path.join(TMP_DIR, "cluster_port.lock") @@ -200,14 +207,14 @@ def _make_params(self, **kwargs): 'expired_time_factor': 60, } if self.ca_signed: - _params['ca_path'] = 'share/resources/test.ca.pem' _params['cert_path'] = 'share/resources/test.derive.crt' _params['key_path'] = 'share/resources/test.derive.key' + _params['ca_path'] = 'share/resources/test.ca.pem' else: - _params['ca_path'] = 'share/resources/test.ca.pem' - _params['cert_path'] = 'share/resources/test.ca.key' - _params['key_path'] = 'share/resources/test.ca.password' + _params['cert_path'] = 'share/resources/test.ca.pem' + _params['key_path'] = 'share/resources/test.ca.key' + _params['password_path'] = 'share/resources/test.ca.password' if self.debug_log: _params['v'] = '4' @@ -218,6 +225,7 @@ def _make_params(self, **kwargs): self.graphd_param['system_memory_high_watermark_ratio'] = '0.95' self.graphd_param['num_rows_to_check_memory'] = '4' self.graphd_param['session_reclaim_interval_secs'] = '2' + self.storaged_param = copy.copy(_params) self.storaged_param['local_config'] = 'false' self.storaged_param['raft_heartbeat_interval_secs'] = '30' @@ -244,7 +252,9 @@ def _copy_nebula_conf(self): os.makedirs(resources_dir) # timezone file - shutil.copy(self.build_dir + '/../resources/date_time_zonespec.csv', resources_dir) + shutil.copy( + self.build_dir + '/../resources/date_time_zonespec.csv', resources_dir + ) shutil.copy(self.build_dir + '/../resources/gflags.json', resources_dir) # cert files shutil.copy(self.src_dir + '/tests/cert/test.ca.key', resources_dir) @@ -365,14 +375,23 @@ def start(self): # init connection pool client_pool = ConnectionPool() # assert client_pool.init([("127.0.0.1", int(self.graphd_port))], config) - assert client_pool.init([("127.0.0.1", self.graphd_processes[0].tcp_port)], config) + ssl_config = get_ssl_config(self.is_graph_ssl, self.ca_signed) + print("begin to add hosts") + assert client_pool.init( + [("127.0.0.1", self.graphd_processes[0].tcp_port)], config, ssl_config + ) - cmd = "ADD HOSTS 127.0.0.1:" + str(self.storaged_processes[0].tcp_port) + " INTO NEW ZONE \"default_zone\"" - print(cmd) + cmd = ( + "ADD HOSTS 127.0.0.1:" + + str(self.storaged_processes[0].tcp_port) + + " INTO NEW ZONE \"default_zone\"" + ) + print("add hosts cmd is {}".format(cmd)) # get session from the pool client = client_pool.get_session('root', 'nebula') resp = client.execute(cmd) + assert resp.is_succeeded(), resp.error_msg() client.release() # wait nebula start diff --git a/tests/common/utils.py b/tests/common/utils.py index 4f434299b9e..6f2f62ed0db 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -6,17 +6,19 @@ import os import re +import json import random import string import time import yaml from typing import Pattern -from nebula2.Config import Config +from nebula2.Config import Config, SSL_config from nebula2.common import ttypes as CommonTtypes from nebula2.gclient.net import Session from nebula2.gclient.net import ConnectionPool +from tests.common.constants import NB_TMP_PATH, NEBULA_HOME from tests.common.csv_import import CSVImporter from tests.common.path_value import PathVal from tests.common.types import SpaceDesc @@ -113,9 +115,12 @@ def compare_value(real, expect): if eedge.type < 0: esrc, edst = edst, esrc # ignore props comparison - return rsrc == esrc and rdst == edst \ - and redge.ranking == eedge.ranking \ + return ( + rsrc == esrc + and rdst == edst + and redge.ranking == eedge.ranking and redge.name == eedge.name + ) return real == expect @@ -433,13 +438,13 @@ def load_csv_data( return space_desc -def get_conn_pool(host: str, port: int): +def get_conn_pool(host: str, port: int, ssl_config: SSL_config): config = Config() config.max_connection_pool_size = 20 config.timeout = 180000 # init connection pool pool = ConnectionPool() - if not pool.init([(host, port)], config): + if not pool.init([(host, port)], config, ssl_config): raise Exception("Fail to init connection pool.") return pool @@ -450,3 +455,18 @@ def parse_service_index(name: str): if m and len(m.groups()) == 2: return int(m.groups()[1]) return None + +def get_ssl_config(is_graph_ssl: bool, ca_signed: bool): + if not is_graph_ssl: + return None + ssl_config = SSL_config() + + if ca_signed: + ssl_config.ca_certs = os.path.join(NEBULA_HOME, 'tests/cert/test.ca.pem') + ssl_config.certfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.crt') + ssl_config.keyfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.key') + else: + ssl_config.ca_certs = os.path.join(NEBULA_HOME, 'tests/cert/test.ca.pem') + ssl_config.certfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.crt') + ssl_config.keyfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.key') + return ssl_config diff --git a/tests/conftest.py b/tests/conftest.py index 661154f0440..7802f31238d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ from tests.common.configs import all_configs from tests.common.types import SpaceDesc -from tests.common.utils import get_conn_pool +from tests.common.utils import get_conn_pool, get_ssl_config from tests.common.constants import NB_TMP_PATH, SPACE_TMP_PATH, BUILD_DIR, NEBULA_HOME from tests.common.nebula_service import NebulaService @@ -110,6 +110,17 @@ def get_ports(): raise Exception(f"Invalid port: {port}") return port +def get_ssl_config_from_tmp(): + with open(NB_TMP_PATH, "r") as f: + data = json.loads(f.readline()) + is_graph_ssl = ( + data.get("enable_ssl", "false").upper() == "TRUE" + or data.get("enable_graph_ssl", "false").upper() == "TRUE" + ) + ca_signed = data.get("ca_signed", "false").upper() == "TRUE" + return get_ssl_config(is_graph_ssl, ca_signed) + + @pytest.fixture(scope="class") def class_fixture_variables(): """save class scope fixture, used for session update. @@ -140,7 +151,8 @@ def conn_pool_to_first_graph_service(pytestconfig): addr = pytestconfig.getoption("address") host_addr = addr.split(":") if addr else ["localhost", get_ports()[0]] assert len(host_addr) == 2 - pool = get_conn_pool(host_addr[0], host_addr[1]) + ssl_config = get_ssl_config_from_tmp() + pool = get_conn_pool(host_addr[0], host_addr[1], ssl_config) yield pool pool.close() @@ -150,7 +162,8 @@ def conn_pool_to_second_graph_service(pytestconfig): addr = pytestconfig.getoption("address") host_addr = ["localhost", get_ports()[1]] assert len(host_addr) == 2 - pool = get_conn_pool(host_addr[0], host_addr[1]) + ssl_config = get_ssl_config_from_tmp() + pool = get_conn_pool(host_addr[0], host_addr[1], ssl_config) yield pool pool.close() @@ -246,11 +259,6 @@ def workarround_for_class( request.cls.drop_data() @pytest.fixture(scope="class") -def establish_a_rare_connection(pytestconfig): - addr = pytestconfig.getoption("address") - host_addr = addr.split(":") if addr else ["localhost", get_ports()[0]] - socket = TSocket.TSocket(host_addr[0], host_addr[1]) - transport = TTransport.TBufferedTransport(socket) - protocol = TBinaryProtocol.TBinaryProtocol(transport) - transport.open() - return GraphService.Client(protocol) +def establish_a_rare_connection(conn_pool, pytestconfig): + conn = conn_pool.get_connection() + return conn._connection diff --git a/tests/job/test_session.py b/tests/job/test_session.py index 5c9736c1dbf..c08605fa919 100644 --- a/tests/job/test_session.py +++ b/tests/job/test_session.py @@ -139,22 +139,19 @@ def test_sessions(self): def test_the_same_id_to_different_graphd(self): def get_connection(ip, port): + ssl_config = self.client_pool._ssl_configs try: - socket = TSocket.TSocket(ip, port) - transport = TTransport.TBufferedTransport(socket) - protocol = TBinaryProtocol.TBinaryProtocol(transport) - transport.open() - connection = GraphService.Client(protocol) + conn = Connection() + conn.open_SSL(ip, port, 0, ssl_config) except Exception as ex: assert False, 'Create connection to {}:{} failed'.format(ip, port) - return connection + return conn conn1 = get_connection(self.addr_host1, self.addr_port1) conn2 = get_connection(self.addr_host2, self.addr_port2) resp = conn1.authenticate('root', 'nebula') - assert resp.error_code == ttypes.ErrorCode.SUCCEEDED - session_id = resp.session_id + session_id = resp.get_session_id() resp = conn1.execute(session_id, 'CREATE SPACE IF NOT EXISTS aSpace(partition_num=1, vid_type=FIXED_STRING(8));USE aSpace;') self.check_resp_succeeded(ResultSet(resp, 0)) @@ -217,8 +214,7 @@ def test_out_of_max_connections(self): def test_signout_and_execute(self): try: - conn = Connection() - conn.open(self.addr_host1, self.addr_port1, 3000) + conn = self.client_pool.get_connection() auth_result = conn.authenticate(self.user, self.password) session_id = auth_result.get_session_id() conn.signout(session_id) diff --git a/tests/nebula-test-run.py b/tests/nebula-test-run.py index c4986336e92..d7a83838c57 100755 --- a/tests/nebula-test-run.py +++ b/tests/nebula-test-run.py @@ -9,8 +9,14 @@ import os import shutil from tests.common.nebula_service import NebulaService -from tests.common.utils import get_conn_pool, load_csv_data -from tests.common.constants import NEBULA_HOME, TMP_DIR, NB_TMP_PATH, SPACE_TMP_PATH, BUILD_DIR +from tests.common.utils import get_conn_pool, load_csv_data, get_ssl_config +from tests.common.constants import ( + NEBULA_HOME, + TMP_DIR, + NB_TMP_PATH, + SPACE_TMP_PATH, + BUILD_DIR, +) CURR_PATH = os.path.dirname(os.path.abspath(__file__)) @@ -100,8 +106,12 @@ def start_nebula(nb, configs): address = "localhost" ports = nb.start() + is_graph_ssl = opt_is(configs.enable_ssl, "true") or opt_is( + configs.enable_graph_ssl, "true" + ) + ca_signed = opt_is(configs.enable_ssl, "true") # Load csv data - pool = get_conn_pool(address, ports[0]) + pool = get_conn_pool(address, ports[0], get_ssl_config(is_graph_ssl, ca_signed)) sess = pool.get_session(configs.user, configs.password) if not os.path.exists(TMP_DIR): @@ -119,7 +129,14 @@ def start_nebula(nb, configs): f.write(json.dumps(spaces)) with open(NB_TMP_PATH, "w") as f: - data = {"ip": "localhost", "port": ports, "work_dir": nb.work_dir} + data = { + "ip": "localhost", + "port": ports, + "work_dir": nb.work_dir, + "enable_ssl": configs.enable_ssl, + "enable_graph_ssl": configs.enable_graph_ssl, + "ca_signed": configs.ca_signed, + } f.write(json.dumps(data)) print('Start nebula successfully') diff --git a/tests/tck/conftest.py b/tests/tck/conftest.py index 7e8644aad89..bfe85af7734 100644 --- a/tests/tck/conftest.py +++ b/tests/tck/conftest.py @@ -336,7 +336,8 @@ def given_nebulacluster_with_param( nebula_svc.start() graph_ip = nebula_svc.graphd_processes[0].host graph_port = nebula_svc.graphd_processes[0].tcp_port - pool = get_conn_pool(graph_ip, graph_port) + # TODO add ssl pool if tests needed + pool = get_conn_pool(graph_ip, graph_port, None) sess = pool.get_session(user, password) class_fixture_variables["current_session"] = sess class_fixture_variables["sessions"].append(sess) @@ -352,7 +353,7 @@ def when_login_graphd(graph, user, password, class_fixture_variables, pytestconf assert index < len(nebula_svc.graphd_processes) graphd_process = nebula_svc.graphd_processes[index] graph_ip, graph_port = graphd_process.host, graphd_process.tcp_port - pool = get_conn_pool(graph_ip, graph_port) + pool = get_conn_pool(graph_ip, graph_port, None) sess = pool.get_session(user, password) # do not release original session, as we may have cases to test multiple sessions. # connection could be released after cluster stopped.