Skip to content

Commit

Permalink
chore: update datetime objects to be datetime aware (#983)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon authored Feb 2, 2024
1 parent 25a740a commit 7d34c59
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 20 deletions.
12 changes: 9 additions & 3 deletions google/cloud/sql/connector/refresh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,15 @@ async def _get_ephemeral(

# decode cert to read expiration
x509 = load_pem_x509_certificate(ephemeral_cert.encode("UTF-8"), default_backend())
expiration = x509.not_valid_after
expiration = x509.not_valid_after_utc
# for IAM authentication OAuth2 token is embedded in cert so it
# must still be valid for successful connection
if enable_iam_auth:
token_expiration: datetime.datetime = login_creds.expiry
# google.auth library strips timezone info for backwards compatibality
# reasons with Python 2. Add it back to allow timezone aware datetimes.
# Ref: https://github.com/googleapis/google-auth-library-python/blob/49a5ff7411a2ae4d32a7d11700f9f961c55406a9/google/auth/_helpers.py#L93-L99
token_expiration = token_expiration.replace(tzinfo=datetime.timezone.utc)
if expiration > token_expiration:
expiration = token_expiration
return ephemeral_cert, expiration
Expand All @@ -213,7 +217,9 @@ def _seconds_until_refresh(
:returns: Time in seconds to wait before performing next refresh.
"""

duration = int((expiration - datetime.datetime.utcnow()).total_seconds())
duration = int(
(expiration - datetime.datetime.now(datetime.timezone.utc)).total_seconds()
)

# if certificate duration is less than 1 hour
if duration < 3600:
Expand All @@ -230,7 +236,7 @@ async def _is_valid(task: asyncio.Task) -> bool:
try:
metadata = await task
# only valid if now is before the cert expires
if datetime.datetime.utcnow() < metadata.expiration:
if datetime.datetime.now(datetime.timezone.utc) < metadata.expiration:
return True
except Exception:
# supress any errors from task
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
release_status = "Development Status :: 5 - Production/Stable"
dependencies = [
"aiohttp",
"cryptography>=38.0.3",
"cryptography>=42.0.0",
"Requests",
"google-auth",
]
Expand Down
28 changes: 19 additions & 9 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def __class__(self) -> Credentials:
def refresh(self, request: Callable) -> None:
"""Refreshes the access token."""
self.token = "12345"
self.expiry = datetime.datetime.utcnow() + datetime.timedelta(minutes=60)
self.expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
minutes=60
)

@property
def expired(self) -> bool:
Expand All @@ -62,7 +64,9 @@ def expired(self) -> bool:
"""
if self.expiry is None:
return False
return False if self.expiry > datetime.datetime.utcnow() else True
if self.expiry > datetime.datetime.now(datetime.timezone.utc):
return False
return True

@property
def valid(self) -> bool:
Expand Down Expand Up @@ -108,11 +112,15 @@ def __init__(


async def instance_metadata_success(*args: Any, **kwargs: Any) -> MockMetadata:
return MockMetadata(datetime.datetime.utcnow() + datetime.timedelta(minutes=10))
return MockMetadata(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=10)
)


async def instance_metadata_expired(*args: Any, **kwargs: Any) -> MockMetadata:
return MockMetadata(datetime.datetime.utcnow() - datetime.timedelta(minutes=10))
return MockMetadata(
datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(minutes=10)
)


async def instance_metadata_error(*args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -145,10 +153,10 @@ def generate_cert(
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(
# cert valid for 10 mins
datetime.datetime.utcnow()
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(minutes=60)
)
)
Expand Down Expand Up @@ -189,7 +197,7 @@ def client_key_signed_cert(
.issuer_name(issuer)
.public_key(client_key)
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(cert._not_valid_after) # type: ignore
)
return (
Expand Down Expand Up @@ -258,7 +266,8 @@ def connect_settings(self, ip_addrs: Optional[Dict] = None) -> str:
"cert": server_ca_cert,
"instance": self.name,
"expirationTime": str(
datetime.datetime.utcnow() + datetime.timedelta(minutes=10)
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(minutes=10)
),
},
"dnsName": "abcde.12345.us-central1.sql.goog",
Expand All @@ -284,7 +293,8 @@ def generate_ephemeral(self, client_bytes: str) -> str:
"kind": "sql#sslCert",
"cert": ephemeral_cert,
"expirationTime": str(
datetime.datetime.utcnow() + datetime.timedelta(minutes=10)
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(minutes=10)
),
}
}
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ async def test_perform_refresh(
assert isinstance(instance_metadata, ConnectionInfo)
# verify instance metadata expiration
assert (
mock_instance.cert._not_valid_after.replace(microsecond=0) # type: ignore
mock_instance.cert._not_valid_after.replace(
tzinfo=datetime.timezone.utc, microsecond=0 # type: ignore
)
== instance_metadata.expiration
)

Expand All @@ -273,7 +275,9 @@ async def test_perform_refresh_expiration(
credentials expiration should be used.
"""
# set credentials expiration to 1 minute from now
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=1)
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
minutes=1
)
credentials = mocks.FakeCredentials(token="my-token", expiry=expiration)
setattr(instance, "_enable_iam_auth", True)
# set downscoped credential to mock
Expand Down
22 changes: 17 additions & 5 deletions tests/unit/test_refresh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
limitations under the License.
"""
import asyncio
from datetime import datetime
from datetime import timedelta
import datetime
from typing import Any, no_type_check

import aiohttp
Expand Down Expand Up @@ -226,7 +225,11 @@ def test_seconds_until_refresh_over_1_hour() -> None:
# using pytest.approx since sometimes can be off by a second
assert (
pytest.approx(
_seconds_until_refresh(datetime.utcnow() + timedelta(minutes=62)), 1
_seconds_until_refresh(
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(minutes=62)
),
1,
)
== 31 * 60
)
Expand All @@ -242,7 +245,11 @@ def test_seconds_until_refresh_under_1_hour_over_4_mins() -> None:
# using pytest.approx since sometimes can be off by a second
assert (
pytest.approx(
_seconds_until_refresh(datetime.utcnow() + timedelta(minutes=5)), 1
_seconds_until_refresh(
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(minutes=5)
),
1,
)
== 60
)
Expand All @@ -254,4 +261,9 @@ def test_seconds_until_refresh_under_4_mins() -> None:
If expiration is under 4 minutes, should return 0.
"""
assert _seconds_until_refresh(datetime.utcnow() + timedelta(minutes=3)) == 0
assert (
_seconds_until_refresh(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=3)
)
== 0
)

0 comments on commit 7d34c59

Please sign in to comment.