Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancements in PostgreSQL and Redis Management #13

Merged
merged 16 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
pull_request:
branches: ["main"]

env:
UV_FROZEN: 1

jobs:
lint:
name: Lint
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ jobs:
publish-docs:
runs-on: ubuntu-latest
needs: [bump-version]
env:
UV_FROZEN: 1
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -91,6 +93,8 @@ jobs:
publish-pypi:
needs: [bump-version]
runs-on: ubuntu-latest
env:
UV_FROZEN: 1
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
23 changes: 19 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,29 @@ repos:
types: [python]
require_serial: true

- id: pytest
name: pytest
- id: pytest-unit
name: pytest-unit
description: "Run 'pytest' for unit testing"
entry: uv run pytest --cov-fail-under=90
entry: uv run pytest -m "not integration"
language: system
pass_filenames: false

- id: pytest-integration
name: pytest-integration
description: "Run 'pytest' for integration testing"
entry: uv run pytest -m "integration" --cov-append
language: system
pass_filenames: false

- id: coverage-report
name: coverage-report
description: "Generate coverage report"
entry: uv run coverage report --fail-under=100
language: system
pass_filenames: false


ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
skip: [uv-lock, mypy, pytest]
skip: [uv-lock, mypy, pytest-unit, pytest-integration, coverage-report]
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"python.terminal.activateEnvironment": true,
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"python.testing.pytestArgs": ["--no-cov", "--color=yes"],
"python.testing.pytestArgs": ["--color=yes"],
"python.analysis.inlayHints.pytestParameters": true,

// Python editor settings
Expand Down
72 changes: 40 additions & 32 deletions grelmicro/sync/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,52 @@


class _PostgresSettings(BaseSettings):
"""PostgreSQL settings from the environment variables."""

POSTGRES_HOST: str | None = None
POSTGRES_PORT: int = 5432
POSTGRES_DB: str | None = None
POSTGRES_USER: str | None = None
POSTGRES_PASSWORD: str | None = None
POSTGRES_URL: PostgresDsn | None = None

def url(self) -> str:
"""Generate the Postgres URL from the parts."""
if self.POSTGRES_URL:
return self.POSTGRES_URL.unicode_string()

if all(
(
self.POSTGRES_HOST,
self.POSTGRES_DB,
self.POSTGRES_USER,
self.POSTGRES_PASSWORD,
)
):
return MultiHostUrl.build(
scheme="postgresql",
username=self.POSTGRES_USER,
password=self.POSTGRES_PASSWORD,
host=self.POSTGRES_HOST,
port=self.POSTGRES_PORT,
path=self.POSTGRES_DB,
).unicode_string()

msg = (
"Either POSTGRES_URL or all of POSTGRES_HOST, POSTGRES_DB, POSTGRES_USER, and "
"POSTGRES_PASSWORD must be set"
)
raise SyncSettingsValidationError(msg)

def _get_postgres_url() -> str:
"""Get the PostgreSQL URL from the environment variables.

Raises:
SyncSettingsValidationError: If the URL or all of the host, database, user, and password
"""
try:
settings = _PostgresSettings()
except ValidationError as error:
raise SyncSettingsValidationError(error) from None

required_parts = [
settings.POSTGRES_HOST,
settings.POSTGRES_DB,
settings.POSTGRES_USER,
settings.POSTGRES_PASSWORD,
]

if settings.POSTGRES_URL and not any(required_parts):
return settings.POSTGRES_URL.unicode_string()

if all(required_parts) and not settings.POSTGRES_URL:
return MultiHostUrl.build(
scheme="postgresql",
username=settings.POSTGRES_USER,
password=settings.POSTGRES_PASSWORD,
host=settings.POSTGRES_HOST,
port=settings.POSTGRES_PORT,
path=settings.POSTGRES_DB,
).unicode_string()

msg = (
"Either POSTGRES_URL or all of POSTGRES_HOST, POSTGRES_DB, POSTGRES_USER, and "
"POSTGRES_PASSWORD must be set"
)
raise SyncSettingsValidationError(msg)


class PostgresSyncBackend(SyncBackend):
Expand Down Expand Up @@ -120,11 +132,7 @@ def __init__(
msg = f"Table name '{table_name}' is not a valid identifier"
raise ValueError(msg)

try:
self._url = url or _PostgresSettings().url()
except ValidationError as error:
raise SyncSettingsValidationError(error) from None

self._url = url or _get_postgres_url()
self._table_name = table_name
self._acquire_sql = self._SQL_ACQUIRE_OR_EXTEND.format(
table_name=table_name
Expand Down
73 changes: 65 additions & 8 deletions grelmicro/sync/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,52 @@
from types import TracebackType
from typing import Annotated, Self

from pydantic import RedisDsn
from pydantic import RedisDsn, ValidationError
from pydantic_core import Url
from pydantic_settings import BaseSettings
from redis.asyncio.client import Redis
from typing_extensions import Doc

from grelmicro.sync._backends import loaded_backends
from grelmicro.sync.abc import SyncBackend
from grelmicro.sync.errors import SyncSettingsValidationError


class _RedisSettings(BaseSettings):
"""Redis settings from the environment variables."""

REDIS_HOST: str | None = None
REDIS_PORT: int = 6379
REDIS_DB: int = 0
REDIS_PASSWORD: str | None = None
REDIS_URL: RedisDsn | None = None


def _get_redis_url() -> str:
"""Get the Redis URL from the environment variables.

Raises:
SyncSettingsValidationError: If the URL or host is not set.
"""
try:
settings = _RedisSettings()
except ValidationError as error:
raise SyncSettingsValidationError(error) from None

if settings.REDIS_URL and not settings.REDIS_HOST:
return settings.REDIS_URL.unicode_string()

if settings.REDIS_HOST and not settings.REDIS_URL:
return Url.build(
scheme="redis",
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
path=str(settings.REDIS_DB),
password=settings.REDIS_PASSWORD,
).unicode_string()

msg = "Either REDIS_URL or REDIS_HOST must be set"
raise SyncSettingsValidationError(msg)


class RedisSyncBackend(SyncBackend):
Expand Down Expand Up @@ -37,8 +77,24 @@ class RedisSyncBackend(SyncBackend):

def __init__(
self,
url: Annotated[RedisDsn | str, Doc("The Redis database URL.")],
url: Annotated[
RedisDsn | str | None,
Doc("""
The Redis URL.

If not provided, the URL will be taken from the environment variables REDIS_URL
or REDIS_HOST, REDIS_PORT, REDIS_DB, and REDIS_PASSWORD.
"""),
] = None,
*,
prefix: Annotated[
str,
Doc("""
The prefix to add on redis keys to avoid conflicts with other keys.

By default no prefix is added.
"""),
] = "",
auto_register: Annotated[
bool,
Doc(
Expand All @@ -47,8 +103,9 @@ def __init__(
] = True,
) -> None:
"""Initialize the lock backend."""
self._url = url
self._redis: Redis = Redis.from_url(str(url))
self._url = url or _get_redis_url()
self._redis: Redis = Redis.from_url(str(self._url))
self._prefix = prefix
self._lua_release = self._redis.register_script(self._LUA_RELEASE)
self._lua_acquire = self._redis.register_script(
self._LUA_ACQUIRE_OR_EXTEND
Expand All @@ -73,7 +130,7 @@ async def acquire(self, *, name: str, token: str, duration: float) -> bool:
"""Acquire the lock."""
return bool(
await self._lua_acquire(
keys=[name],
keys=[f"{self._prefix}{name}"],
args=[token, int(duration * 1000)],
client=self._redis,
)
Expand All @@ -83,16 +140,16 @@ async def release(self, *, name: str, token: str) -> bool:
"""Release the lock."""
return bool(
await self._lua_release(
keys=[name], args=[token], client=self._redis
keys=[f"{self._prefix}{name}"], args=[token], client=self._redis
)
)

async def locked(self, *, name: str) -> bool:
"""Check if the lock is acquired."""
return bool(await self._redis.get(name))
return bool(await self._redis.get(f"{self._prefix}{name}"))

async def owned(self, *, name: str, token: str) -> bool:
"""Check if the lock is owned."""
return bool(
(await self._redis.get(name)) == token.encode()
(await self._redis.get(f"{self._prefix}{name}")) == token.encode()
) # redis returns bytes
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,9 @@ disallow_untyped_defs = false
[tool.pytest.ini_options]
addopts = """
--cov=grelmicro
--cov-report term:skip-covered
--cov-report xml:cov.xml
--strict-config
--strict-markers
-m "not integration"
"""
markers = """
integration: mark a test as an integration test (disabled by default).
Expand Down
19 changes: 11 additions & 8 deletions tests/sync/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

pytestmark = [pytest.mark.anyio, pytest.mark.timeout(1)]

URL = "postgres://user:password@localhost:5432/db"
URL = "postgresql://test_user:test_password@test_host:1234/test_db"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -51,9 +51,7 @@ async def test_sync_backend_out_of_context_errors() -> None:
@pytest.mark.parametrize(
("environs"),
[
{
"POSTGRES_URL": "postgresql://test_user:test_password@test_host:1234/test_db"
},
{"POSTGRES_URL": URL},
{
"POSTGRES_USER": "test_user",
"POSTGRES_PASSWORD": "test_password",
Expand All @@ -75,10 +73,7 @@ def test_postgres_env_var_settings(
backend = PostgresSyncBackend()

# Assert
assert (
backend._url
== "postgresql://test_user:test_password@test_host:1234/test_db"
)
assert backend._url == URL


@pytest.mark.parametrize(
Expand All @@ -88,6 +83,14 @@ def test_postgres_env_var_settings(
"POSTGRES_URL": "test://test_user:test_password@test_host:1234/test_db"
},
{"POSTGRES_USER": "test_user"},
{
"POSTGRES_URL": URL,
"POSTGRES_USER": "test_user",
"POSTGRES_PASSWORD": "test_password",
"POSTGRES_HOST": "test_host",
"POSTGRES_PORT": "1234",
"POSTGRES_DB": "test_db",
},
],
)
def test_postgres_env_var_settings_validation_error(
Expand Down
Loading
Loading