From a10aa5ad1f5c4f372914ee11f1180ad0c5f3b703 Mon Sep 17 00:00:00 2001 From: Shubha Rajan Date: Wed, 26 May 2021 01:00:01 -0700 Subject: [PATCH] fix: force use of TLSv1.3 when IAM auth enabled (#108) * fix: force use of TLSv1.3 when IAM auth enabled --- .kokoro/tests/run_tests_windows.sh | 1 + .../connector/instance_connection_manager.py | 27 ++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/.kokoro/tests/run_tests_windows.sh b/.kokoro/tests/run_tests_windows.sh index 259190ef..53176fec 100755 --- a/.kokoro/tests/run_tests_windows.sh +++ b/.kokoro/tests/run_tests_windows.sh @@ -27,6 +27,7 @@ fi # Add python and pip to PATH export PATH=/c/python37:/c/python37/scripts:$PATH +python --version # install nox for testing pip install --user -q nox diff --git a/google/cloud/sql/connector/instance_connection_manager.py b/google/cloud/sql/connector/instance_connection_manager.py index c3a3eb69..38667231 100644 --- a/google/cloud/sql/connector/instance_connection_manager.py +++ b/google/cloud/sql/connector/instance_connection_manager.py @@ -80,6 +80,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super(ConnectionSSLContext, self).__init__(*args, **kwargs) +class TLSVersionError(Exception): + """ + Raised when the required TLS protocol version is not supported. + """ + + def __init__(self, *args: Any) -> None: + super(TLSVersionError, self).__init__(self, *args) + + class CloudSQLConnectionError(Exception): """ Raised when the provided connection string is not formatted @@ -111,8 +120,16 @@ def __init__( private_key: bytes, server_ca_cert: str, expiration: datetime.datetime, + enable_iam_auth: bool, ) -> None: self.ip_addrs = ip_addrs + + if enable_iam_auth and not ssl.HAS_TLSv1_3: # type: ignore + raise TLSVersionError( + "Your current version of OpenSSL does not support TLSv1.3, " + "which is required to use IAM Authentication." + ) + self.context = ConnectionSSLContext() self.expiration = expiration @@ -293,11 +310,12 @@ async def _get_instance_data(self) -> InstanceMetadata: expiration = datetime.datetime.strptime( x509.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ" ) - if self._credentials is not None: - token_expiration: datetime.datetime = self._credentials.expiry - if expiration > token_expiration: - expiration = token_expiration + if self._enable_iam_auth: + if self._credentials is not None: + token_expiration: datetime.datetime = self._credentials.expiry + if expiration > token_expiration: + expiration = token_expiration return InstanceMetadata( ephemeral_cert, @@ -305,6 +323,7 @@ async def _get_instance_data(self) -> InstanceMetadata: priv_key, metadata["server_ca_cert"], expiration, + self._enable_iam_auth, ) def _auth_init(self) -> None: