Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SSL transport #149

Merged
merged 10 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions .github/workflows/run_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ jobs:
- name: Test with pytest
run: |
docker-compose up -d
sleep 60
pytest -s -v
sleep 45
pytest -s -v -k "not TestSSLConnection and not TestSSLConnectionSelfSigned"
docker-compose down -v
working-directory: tests
- name: Test SSL connection with pytest
run: |
enable_ssl=true docker-compose up -d
sleep 45
pytest -s -v test_ssl_connection.py::TestSSLConnection
working-directory: tests
- name: Test self-signed SSL connection with pytest
run: |
enable_ssl=true docker-compose up -d
sleep 45
pytest -s -v test_ssl_connection.py::TestSSLConnectionSelfSigned
working-directory: tests
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ nebula/graph.thrift

# ide
.idea/
.vscode/

# CI data
tests/data
tests/logs
2 changes: 2 additions & 0 deletions example/GraphClientSimpleExample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import json

from nebula2.gclient.net import ConnectionPool

from nebula2.Config import Config
from nebula2.common import *
from FormatResp import print_resp

if __name__ == '__main__':
Expand Down
49 changes: 49 additions & 0 deletions nebula2/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# This source code is licensed under Apache 2.0 License,
# attached with Common Clause Condition 1.0, found in the LICENSES directory.

import ssl

class Config(object):
# the min connection always in pool
Expand All @@ -22,3 +23,51 @@ class Config(object):

# the interval to check idle time connection, unit second, -1 means no check
interval_check = -1


class SSL_config(object):
"""configs used to Initialize a TSSLSocket.
@ ssl_version(int) protocol version. see ssl module. If none is
specified, we will default to the most
reasonably secure and compatible configuration
if possible.
For Python versions >= 2.7.9, we will default
to at least TLS 1.1.
For Python versions < 2.7.9, we can only
default to TLS 1.0, which is the best that
Python guarantees to offers at this version.
If you specify ssl.PROTOCOL_SSLv23, and
the OpenSSL linked with Python is new enough,
it is possible for a TLS 1.2 connection be
established; however, there is no way in
< Python 2.7.9 to explicitly disable SSLv2
and SSLv3. For that reason, we default to
TLS 1.0.

@ cert_reqs(int) whether to verify peer certificate. see ssl
module.

@ ca_certs(str) filename containing trusted root certs.

@ verify_name if False, no peer name validation is performed
if True, verify subject name of peer vs 'host'
if a str, verify subject name of peer vs given
str

@ keyfile filename containing the client's private key

@ certfile filename containing the client's cert and
optionally the private key

@ allow_weak_ssl_versions(bool) By default, we try to disable older
protocol versions. Only set this
if you know what you are doing.
"""
unix_socket = None
ssl_version = None
cert_reqs = ssl.CERT_NONE
ca_certs = None
verify_name = False
keyfile = None
certfile = None
allow_weak_ssl_versions = False
50 changes: 38 additions & 12 deletions nebula2/gclient/net/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import time

from nebula2.fbthrift.transport import TSocket, TTransport
from nebula2.fbthrift.transport import TSocket, TTransport, TSSLSocket
from nebula2.fbthrift.transport.TTransport import TTransportException
from nebula2.fbthrift.protocol import TBinaryProtocol

Expand Down Expand Up @@ -43,20 +43,46 @@ def open(self, ip, port, timeout):
:param timeout: the timeout for connect and execute
:return: void
"""
self.open_SSL(ip, port, timeout, None)

def open_SSL(self, ip, port, timeout, ssl_config=None):
"""open the SSL connection

:param ip: the server ip
:param port: the server port
:param timeout: the timeout for connect and execute
:ssl_config: configs for SSL
:return: void
"""
self._ip = ip
self._port = port
self._timeout = timeout
s = TSocket.TSocket(self._ip, self._port)
if timeout > 0:
s.setTimeout(timeout)
transport = TTransport.TBufferedTransport(s)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
transport.open()
self._connection = GraphService.Client(protocol)
resp = self._connection.verifyClientVersion(VerifyClientVersionReq())
if resp.error_code != ErrorCode.SUCCEEDED:
self._connection._iprot.trans.close()
raise ClientServerIncompatibleException(resp.error_msg)
try:
if ssl_config is not None:
s = TSSLSocket.TSSLSocket(self._ip, self._port,
ssl_config.unix_socket,
ssl_config.ssl_version,
ssl_config.cert_reqs,
ssl_config.ca_certs,
ssl_config.verify_name,
ssl_config.keyfile,
ssl_config.certfile,
ssl_config.allow_weak_ssl_versions)
else:
s = TSocket.TSocket(self._ip, self._port)
if timeout > 0:
s.setTimeout(timeout)
transport = TTransport.TBufferedTransport(s)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
transport.open()
self._connection = GraphService.Client(protocol)
resp = self._connection.verifyClientVersion(
VerifyClientVersionReq())
if resp.error_code != ErrorCode.SUCCEEDED:
self._connection._iprot.trans.close()
raise ClientServerIncompatibleException(resp.error_msg)
except Exception:
raise

def _reopen(self):
"""reopen the connection
Expand Down
43 changes: 32 additions & 11 deletions nebula2/gclient/net/ConnectionPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@ def __init__(self):
# all connections
self._connections = dict()
self._configs = None
self._ssl_configs = None
self._lock = RLock()
self._pos = -1
self._close = False

def __del__(self):
self.close()

def init(self, addresses, configs):
def init(self, addresses, configs, ssl_conf=None):
"""init the connection pool

:param addresses: the graphd servers' addresses
:param configs: the config of the pool
:param ssl_conf: the config of SSL socket
:return: if all addresses are ok, return True else return False.
"""
if self._close:
Expand All @@ -72,14 +74,25 @@ def init(self, addresses, configs):
# init min connections
ok_num = self.get_ok_servers_num()
if ok_num < len(self._addresses):
raise RuntimeError('The services status exception: {}'.format(self._get_services_status()))

conns_per_address = int(self._configs.min_connection_pool_size / ok_num)
for addr in self._addresses:
for i in range(0, conns_per_address):
connection = Connection()
connection.open(addr[0], addr[1], self._configs.timeout)
self._connections[addr].append(connection)
raise RuntimeError('The services status exception: {}'.format(
self._get_services_status()))

conns_per_address = int(
self._configs.min_connection_pool_size / ok_num)

if ssl_conf is None:
for addr in self._addresses:
for i in range(0, conns_per_address):
connection = Connection()
connection.open(addr[0], addr[1], self._configs.timeout)
self._connections[addr].append(connection)
else:
for addr in self._addresses:
for i in range(0, conns_per_address):
connection = Connection()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  if ssl
      openSSL
   else 
      open

connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs)
self._connections[addr].append(connection)
return True

def get_session(self, user_name, password, retry_connect=True):
Expand Down Expand Up @@ -152,7 +165,12 @@ def get_connection(self):

if len(self._connections[addr]) < max_con_per_address:
connection = Connection()
connection.open(addr[0], addr[1], self._configs.timeout)
if self._ssl_configs is None:
connection.open(
addr[0], addr[1], self._configs.timeout)
else:
connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs)
connection.is_used = True
self._connections[addr].append(connection)
logging.info('Get connection to {}'.format(addr))
Expand All @@ -175,7 +193,10 @@ def ping(self, address):
"""
try:
conn = Connection()
conn.open(address[0], address[1], 1000)
if self._ssl_configs is None:
conn.open(address[0], address[1], 1000)
else:
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.close()
return True
except Exception as ex:
Expand Down
5 changes: 5 additions & 0 deletions tests/.env
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
TZ=UTC

ca_path=/secrets/test.ca.pem
cert_path=/secrets/test.client.crt
key_path=/secrets/test.client.key
enable_ssl=false
Loading