diff --git a/src/vocutouts/uws/storage.py b/src/vocutouts/uws/storage.py index 5c5151f..cad96e4 100644 --- a/src/vocutouts/uws/storage.py +++ b/src/vocutouts/uws/storage.py @@ -2,10 +2,10 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine from datetime import datetime, timedelta from functools import wraps -from typing import Any, TypeVar, cast +from typing import ParamSpec, TypeVar from safir.arq import JobMetadata, JobResult from safir.database import datetime_from_db, datetime_to_db @@ -31,8 +31,8 @@ from .schema.job_parameter import JobParameter as SQLJobParameter from .schema.job_result import JobResult as SQLJobResult -F = TypeVar("F", bound=Callable[..., Any]) -G = TypeVar("G", bound=Callable[..., Awaitable[Any]]) +T = TypeVar("T") +P = ParamSpec("P") __all__ = ["JobStore"] @@ -85,7 +85,9 @@ def _convert_job(job: SQLJob) -> UWSJob: ) -def retry_async_transaction(g: G) -> G: +def retry_async_transaction( + f: Callable[P, Coroutine[None, None, T]], +) -> Callable[P, Coroutine[None, None, T]]: """Retry once if a transaction failed. Notes @@ -104,16 +106,16 @@ def retry_async_transaction(g: G) -> G: multiple times. """ - @wraps(g) - async def wrapper(*args: Any, **kwargs: Any) -> Any: + @wraps(f) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: for _ in range(1, 5): try: - return await g(*args, **kwargs) + return await f(*args, **kwargs) except (DBAPIError, OperationalError): continue - return await g(*args, **kwargs) + return await f(*args, **kwargs) - return cast(G, wrapper) + return wrapper class JobStore: