Skip to content

Commit 3994c3d

Browse files
authored
Merge pull request #231 from Krukov/fix-for-few-issues
fix: check setup for disable not configured cache, feat: get_or_set, …
2 parents 65c6ce8 + 4379945 commit 3994c3d

21 files changed

+209
-41
lines changed

.pre-commit-config.yaml

-6
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@ repos:
1313
- id: end-of-file-fixer
1414
- id: trailing-whitespace
1515
- id: debug-statements
16-
#
17-
# - repo: https://github.com/pre-commit/mirrors-prettier
18-
# rev: v4.0.0-alpha.8
19-
# hooks:
20-
# - id: prettier
21-
# stages: [commit]
2216

2317
- repo: https://github.com/astral-sh/ruff-pre-commit
2418
rev: v0.3.2

Readme.md

+9-6
Original file line numberDiff line numberDiff line change
@@ -196,16 +196,19 @@ from cashews import cache
196196

197197
cache.setup("mem://") # configure as in-memory cache
198198

199-
await cache.set(key="key", value=90, expire=60, exist=None) # -> bool
199+
await cache.set(key="key", value=90, expire="2h", exist=None) # -> bool
200200
await cache.set_raw(key="key", value="str") # -> bool
201+
await cache.set_many({"key1": value, "key2": value}) # -> None
201202

202203
await cache.get("key", default=None) # -> Any
203-
await cache.get_raw("key")
204-
await cache.get_many("key1", "key2", default=None)
204+
await cache.get_or_set("key", default=awaitable_or_callable, expire="1h") # -> Any
205+
await cache.get_raw("key") # -> Any
206+
await cache.get_many("key1", "key2", default=None) # -> tuple[Any]
205207
async for key, value in cache.get_match("pattern:*", batch_size=100):
206208
...
207209

208210
await cache.incr("key") # -> int
211+
await cache.exists("key") # -> bool
209212

210213
await cache.delete("key")
211214
await cache.delete_many("key1", "key2")
@@ -928,8 +931,8 @@ E.g. A simple middleware to use it in a web app:
928931
async def add_from_cache_headers(request: Request, call_next):
929932
with cache.detect as detector:
930933
response = await call_next(request)
931-
if detector.keys:
932-
key = list(detector.keys.keys())[0]
934+
if detector.calls:
935+
key = list(detector.calls.keys())[0]
933936
response.headers["X-From-Cache"] = key
934937
expire = await cache.get_expire(key)
935938
response.headers["X-From-Cache-Expire-In-Seconds"] = str(expire)
@@ -1004,7 +1007,7 @@ Here we want to have some way to protect our code from race conditions and do op
10041007

10051008
Cashews support transaction operations:
10061009

1007-
> :warning: \*\*Warning: transaction operations are `set`, `set_many`, `delete`, `delete_many`, `delete_match` and `incr`
1010+
> :warning: \*\*Warning: transaction operations are `set`, `set_many`, `delete`, `delete_many`, `delete_match` and `incr`
10081011
10091012
```python
10101013
from cashews import cache

cashews/_typing.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ def __call__(
3030

3131
CacheCondition = Union[CallableCacheCondition, str, None]
3232

33-
AsyncCallableResult_T = TypeVar("AsyncCallableResult_T")
34-
AsyncCallable_T = Callable[..., Awaitable[AsyncCallableResult_T]]
33+
Result_T = TypeVar("Result_T")
34+
AsyncCallable_T = Callable[..., Awaitable[Result_T]]
35+
Callable_T = Callable[..., Result_T]
3536

3637
DecoratedFunc = TypeVar("DecoratedFunc", bound=AsyncCallable_T)
3738

@@ -44,7 +45,7 @@ def __call__(
4445
backend: Backend,
4546
*args,
4647
**kwargs,
47-
) -> Awaitable[AsyncCallableResult_T | None]: # pragma: no cover
48+
) -> Awaitable[Result_T | None]: # pragma: no cover
4849
...
4950

5051

cashews/commands.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class Command(Enum):
1919
EXPIRE = "expire"
2020
GET_EXPIRE = "get_expire"
2121
CLEAR = "clear"
22+
2223
SET_LOCK = "set_lock"
2324
UNLOCK = "unlock"
2425
IS_LOCKED = "is_locked"

cashews/contrib/fastapi.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,12 @@ def __init__(
5555
cache_instance: Cache = cache,
5656
methods: Sequence[str] = ("get",),
5757
private=True,
58+
prefix_to_disable: str = "",
5859
):
5960
self._private = private
6061
self._cache = cache_instance
6162
self._methods = methods
63+
self._prefix_to_disable = prefix_to_disable
6264
super().__init__(app)
6365

6466
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
@@ -68,7 +70,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
6870
return await call_next(request)
6971
to_disable = _to_disable(cache_control_value)
7072
if to_disable:
71-
context = self._cache.disabling(*to_disable)
73+
context = self._cache.disabling(*to_disable, prefix=self._prefix_to_disable)
7274
with context, max_age(cache_control_value), self._cache.detect as detector:
7375
response = await call_next(request)
7476
calls = detector.calls_list

cashews/formatter.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,11 @@ def _upper(value: TemplateValue) -> TemplateValue:
158158

159159

160160
def default_format(template: KeyTemplate, **values) -> KeyOrTemplate:
161-
_template_context = key_context.get()
162-
_template_context.update(values)
161+
_template_context, rewrite = key_context.get()
162+
if rewrite:
163+
_template_context = {**values, **_template_context}
164+
else:
165+
_template_context = {**_template_context, **values}
163166
return default_formatter.format(template, **_template_context)
164167

165168

cashews/helpers.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from typing import Optional
22

3-
from ._typing import AsyncCallable_T, AsyncCallableResult_T, Middleware
3+
from ._typing import AsyncCallable_T, Middleware, Result_T
44
from .backends.interface import Backend
55
from .commands import PATTERN_CMDS, Command
66
from .key import get_call_values
77
from .utils import get_obj_size
88

99

1010
def add_prefix(prefix: str) -> Middleware:
11-
async def _middleware(
12-
call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs
13-
) -> AsyncCallableResult_T:
11+
async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T:
1412
if cmd == Command.GET_MANY:
1513
return await call(*[prefix + key for key in args])
1614
call_values = get_call_values(call, args, kwargs)
@@ -29,9 +27,7 @@ async def _middleware(
2927

3028

3129
def all_keys_lower() -> Middleware:
32-
async def _middleware(
33-
call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs
34-
) -> AsyncCallableResult_T:
30+
async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T:
3531
if cmd == Command.GET_MANY:
3632
return await call(*[key.lower() for key in args])
3733
call_values = get_call_values(call, args, kwargs)
@@ -54,7 +50,7 @@ async def _middleware(
5450
def memory_limit(min_bytes=0, max_bytes=None) -> Middleware:
5551
async def _middleware(
5652
call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs
57-
) -> Optional[AsyncCallableResult_T]:
53+
) -> Optional[Result_T]:
5854
if cmd != Command.SET:
5955
return await call(*args, **kwargs)
6056
call_values = get_call_values(call, args, kwargs)

cashews/key.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _default(name):
119119
return "*"
120120

121121
check = _ReplaceFormatter(default=_default)
122-
check.format(key, **{**get_key_context(), **func_params})
122+
check.format(key, **{**get_key_context()[0], **func_params})
123123
if errors:
124124
raise WrongKeyError(f"Wrong parameter placeholder '{errors}' in the key ")
125125

cashews/key_context.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,24 @@
44
from contextvars import ContextVar
55
from typing import Any, Iterator
66

7-
_template_context: ContextVar[dict[str, Any]] = ContextVar("template_context", default={})
7+
_REWRITE = "__rewrite"
8+
_template_context: ContextVar[dict[str, Any]] = ContextVar("template_context", default={_REWRITE: False})
89

910

1011
@contextmanager
11-
def context(**values) -> Iterator[None]:
12+
def context(rewrite=False, **values) -> Iterator[None]:
1213
new_context = {**_template_context.get(), **values}
14+
new_context[_REWRITE] = rewrite
1315
token = _template_context.set(new_context)
1416
try:
1517
yield
1618
finally:
1719
_template_context.reset(token)
1820

1921

20-
def get():
21-
return {**_template_context.get()}
22+
def get() -> tuple[dict[str, Any], bool]:
23+
_context = {**_template_context.get()}
24+
return _context, _context.pop(_REWRITE)
2225

2326

2427
def register(*names: str) -> None:

cashews/validation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .commands import RETRIEVE_CMDS, Command
99
from .formatter import default_format
1010
from .key import get_call_values
11+
from .key_context import context as template_context
1112

1213

1314
def invalidate(
@@ -29,7 +30,8 @@ async def _wrap(*args, **kwargs):
2930
if dest in _args:
3031
_args[source] = _args.pop(dest)
3132
key = default_format(key_template, **_args)
32-
await backend.delete_match(key)
33+
with template_context(**_args, rewrite=True):
34+
await backend.delete_match(key)
3335
return result
3436

3537
return _wrap

cashews/wrapper/auto_init.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import asyncio
22

3-
from cashews._typing import AsyncCallable_T, AsyncCallableResult_T, Middleware
3+
from cashews._typing import AsyncCallable_T, Middleware, Result_T
44
from cashews.backends.interface import Backend
55
from cashews.commands import Command
66

77

88
def create_auto_init() -> Middleware:
99
lock = asyncio.Lock()
1010

11-
async def _auto_init(
12-
call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs
13-
) -> AsyncCallableResult_T:
11+
async def _auto_init(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T:
1412
if backend.is_init:
1513
return await call(*args, **kwargs)
1614
async with lock:

cashews/wrapper/commands.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import inspect
34
from functools import partial
45
from typing import TYPE_CHECKING, AsyncIterator, Iterable, Mapping, overload
56

@@ -10,7 +11,9 @@
1011
from .wrapper import Wrapper
1112

1213
if TYPE_CHECKING: # pragma: no cover
13-
from cashews._typing import TTL, Default, Key, Value
14+
from cashews._typing import TTL, AsyncCallable_T, Callable_T, Default, Key, Result_T, Value
15+
16+
_empty = object()
1417

1518

1619
class CommandWrapper(Wrapper):
@@ -40,6 +43,22 @@ async def get(self, key: Key, default: None = None) -> Value | None: ...
4043
async def get(self, key: Key, default: Default | None = None) -> Value | Default | None:
4144
return await self._with_middlewares(Command.GET, key)(key=key, default=default)
4245

46+
async def get_or_set(
47+
self, key: Key, default: Default | AsyncCallable_T | Callable_T, expire: TTL = None
48+
) -> Value | Default | Result_T:
49+
value = await self.get(key, default=_empty)
50+
if value is not _empty:
51+
return value
52+
if callable(default):
53+
if inspect.iscoroutinefunction(default):
54+
_default = await default()
55+
else:
56+
_default = default()
57+
else:
58+
_default = default
59+
await self.set(key, _default, expire=expire)
60+
return default
61+
4362
async def get_raw(self, key: Key) -> Value:
4463
return await self._with_middlewares(Command.GET_RAW, key)(key=key)
4564

cashews/wrapper/disable_control.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3-
from contextlib import contextmanager
3+
from contextlib import contextmanager, suppress
44
from typing import TYPE_CHECKING, Iterator
55

66
from cashews.commands import Command
7+
from cashews.exceptions import NotConfiguredError
78

89
from .wrapper import Wrapper
910

@@ -26,7 +27,8 @@ def __init__(self, name: str = ""):
2627
self.add_middleware(_is_disable_middleware)
2728

2829
def disable(self, *cmds: Command, prefix: str = "") -> None:
29-
return self._get_backend(prefix).disable(*cmds)
30+
with suppress(NotConfiguredError):
31+
return self._get_backend(prefix).disable(*cmds)
3032

3133
def enable(self, *cmds: Command, prefix: str = "") -> None:
3234
return self._get_backend(prefix).enable(*cmds)
@@ -37,7 +39,8 @@ def disabling(self, *cmds: Command, prefix: str = "") -> Iterator[None]:
3739
try:
3840
yield
3941
finally:
40-
self.enable(*cmds, prefix=prefix)
42+
with suppress(NotConfiguredError):
43+
self.enable(*cmds, prefix=prefix)
4144

4245
def is_disable(self, *cmds: Command, prefix: str = "") -> bool:
4346
return self._get_backend(prefix).is_disable(*cmds)

examples/bug.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import asyncio
2+
from collections.abc import Mapping
3+
from typing import Any
4+
5+
from cashews import cache, default_formatter
6+
7+
cache.setup("mem://?size=1000000&check_interval=5")
8+
9+
10+
@default_formatter.register("get_item", preformat=False)
11+
def _getitem_func(mapping: Mapping[str, Any], key: str) -> str:
12+
try:
13+
return str(mapping[key])
14+
except Exception as e:
15+
# when key/tag matching, this may be called with the rendered value
16+
raise RuntimeError(f"{mapping=}, {key=}") from e
17+
18+
19+
@cache(
20+
ttl="1h",
21+
key="prefix:keys:{mapping:get_item(bar)}",
22+
tags=["prefix:tags:{mapping:get_item(bar)}"],
23+
)
24+
async def foo(mapping: str) -> None:
25+
print("Foo", mapping)
26+
27+
28+
@cache.invalidate("prefix:keys:{mapping:get_item(bar)}")
29+
async def bar(mapping: str) -> None:
30+
print("Bar", mapping)
31+
32+
33+
async def main() -> None:
34+
await foo({"bar": "baz"})
35+
await bar({"bar": "baz"})
36+
37+
38+
if __name__ == "__main__":
39+
asyncio.run(main())
40+
41+
# prints Foo {'bar': 'baz'}
42+
# prints Bar {'bar': 'baz'}

examples/keys.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def _call(function, *args, **kwargs):
5656
await function(*args, **kwargs)
5757
with cache.detect as detector:
5858
await function(*args, **kwargs)
59-
key = list(detector.keys.keys())[-1]
59+
key = list(detector.calls.keys())[-1]
6060

6161
print(
6262
f"""

examples/simple.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ async def basic():
1818
await cache.set("key", 1)
1919
assert await cache.get("key") == 1
2020
await cache.set("key1", value={"any": True}, expire="1m")
21+
print(await cache.get_or_set("key200", default=lambda: "test"))
22+
print(await cache.get_or_set("key10", default="test"))
2123

2224
await cache.set_many({"key2": "test", "key3": Decimal("10.1")}, expire="1m")
2325
print("Get: ", await cache.get("key1")) # -> Any

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ line-length = 119
44

55
[tool.ruff.lint]
66
select = ["E", "F", "B", "I", "SIM", "UP", "C4"]
7+
ignore = ["SIM108"]
78

89
[tool.ruff.lint.per-file-ignores]
910
"tests/**/*.py" = [

0 commit comments

Comments
 (0)