diff --git a/CHANGELOG.md b/CHANGELOG.md index 774fc07..f78b95f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 4.5.0 - 2025-01-31 +### Updated +- `cached` and `cachedmethod` improved: + we used `threading.Lock` for sync functions, and `asyncio.Lock` for async functions to avoid [`cache stampede`](https://en.wikipedia.org/wiki/Cache_stampede). This changes fix [#15](https://github.com/awolverp/cachebox/issues/15) and [#20](https://github.com/awolverp/cachebox/issues/20) issues. Special thanks to [@AlePiccin](https://github.com/AlePiccin). + ## 4.4.2 - 2024-12-19 ### Updated - Update `pyo3` to v0.23.4 diff --git a/Cargo.lock b/Cargo.lock index 6b378f9..5f2ecd8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22,7 +22,7 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cachebox" -version = "4.4.2" +version = "4.5.0" dependencies = [ "cfg-if", "fastrand", diff --git a/Cargo.toml b/Cargo.toml index 609c383..666ed24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,13 @@ [package] name = "cachebox" -version = "4.4.2" +version = "4.5.0" edition = "2021" description = "The fastest memoizing and caching Python library written in Rust" readme = "README.md" license = "MIT" homepage = "https://github.com/awolverp/cachebox" repository = "https://github.com/awolverp/cachebox.git" +authors = ["awolverp"] [lib] name = "cachebox" diff --git a/Makefile b/Makefile index e031b02..5d2e21b 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ build-prod: .PHONY: test-py test-py: - maturin develop + maturin develop RUST_BACKTRACE=1 pytest -vv rm -rf .pytest_cache ruff check . diff --git a/cachebox/utils.py b/cachebox/utils.py index 60cb7a8..072ff8b 100644 --- a/cachebox/utils.py +++ b/cachebox/utils.py @@ -1,7 +1,9 @@ from ._cachebox import BaseCacheImpl, FIFOCache -from collections import namedtuple +from collections import namedtuple, defaultdict import functools import warnings +import asyncio +import _thread import inspect import typing @@ -146,6 +148,34 @@ def items(self) -> typing.Iterable[typing.Tuple[KT, VT]]: return self.__cache.items() +class _LockWithCounter: + """ + A threading/asyncio lock which count the waiters + """ + + __slots__ = ("lock", "waiters") + + def __init__(self, is_async: bool = False): + self.lock = _thread.allocate_lock() if not is_async else asyncio.Lock() + self.waiters = 0 + + async def __aenter__(self) -> None: + self.waiters += 1 + await self.lock.acquire() + + async def __aexit__(self, *args, **kwds) -> None: + self.waiters -= 1 + self.lock.release() + + def __enter__(self) -> None: + self.waiters += 1 + self.lock.acquire() + + def __exit__(self, *args, **kwds) -> None: + self.waiters -= 1 + self.lock.release() + + def _copy_if_need(obj, tocopy=(dict, list, set), level: int = 1): from copy import copy @@ -203,14 +233,18 @@ def _cached_wrapper( hits = 0 misses = 0 + locks = defaultdict(_LockWithCounter) + exceptions = {} def _wrapped(*args, **kwds): - nonlocal hits, misses + nonlocal hits, misses, locks, exceptions if kwds.pop("cachebox__ignore", False): return func(*args, **kwds) key = _key_maker(args, kwds) + + # try to get result from cache try: result = cache[key] hits += 1 @@ -220,13 +254,31 @@ def _wrapped(*args, **kwds): return _copy_if_need(result, level=copy_level) except KeyError: - misses += 1 - - result = func(*args, **kwds) - cache[key] = result + pass + + with locks[key]: + if exceptions.get(key, None) is not None: + e = exceptions[key] if locks[key].waiters > 1 else exceptions.pop(key) + raise e + + try: + result = cache[key] + hits += 1 + event = EVENT_HIT + except KeyError: + try: + result = func(*args, **kwds) + except Exception as e: + exceptions[key] = e + raise e + + else: + cache[key] = result + misses += 1 + event = EVENT_MISS if callback is not None: - callback(EVENT_MISS, key, result) + callback(event, key, result) return _copy_if_need(result, level=copy_level) @@ -237,10 +289,11 @@ def _wrapped(*args, **kwds): ) def cache_clear(): - nonlocal misses, hits + nonlocal misses, hits, locks cache.clear(reuse=clear_reuse) misses = 0 hits = 0 + locks.clear() _wrapped.cache_clear = cache_clear @@ -260,14 +313,18 @@ def _async_cached_wrapper( hits = 0 misses = 0 + locks = defaultdict(lambda: _LockWithCounter(True)) + exceptions = {} async def _wrapped(*args, **kwds): - nonlocal hits, misses + nonlocal hits, misses, locks, exceptions if kwds.pop("cachebox__ignore", False): return await func(*args, **kwds) key = _key_maker(args, kwds) + + # try to get result from cache try: result = cache[key] hits += 1 @@ -279,13 +336,31 @@ async def _wrapped(*args, **kwds): return _copy_if_need(result, level=copy_level) except KeyError: - misses += 1 - - result = await func(*args, **kwds) - cache[key] = result + pass + + async with locks[key]: + if exceptions.get(key, None) is not None: + e = exceptions[key] if locks[key].waiters > 1 else exceptions.pop(key) + raise e + + try: + result = cache[key] + hits += 1 + event = EVENT_HIT + except KeyError: + try: + result = await func(*args, **kwds) + except Exception as e: + exceptions[key] = e + raise e + + else: + cache[key] = result + misses += 1 + event = EVENT_MISS if callback is not None: - awaitable = callback(EVENT_MISS, key, result) + awaitable = callback(event, key, result) if inspect.isawaitable(awaitable): await awaitable @@ -298,10 +373,11 @@ async def _wrapped(*args, **kwds): ) def cache_clear(): - nonlocal misses, hits + nonlocal misses, hits, locks cache.clear(reuse=clear_reuse) misses = 0 hits = 0 + locks.clear() _wrapped.cache_clear = cache_clear diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..2935ee1 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,108 @@ +from cachebox import cached, LRUCache +from concurrent import futures +import asyncio +import pytest +import time + + +def test_threading_return(): + calls = 0 + + @cached(LRUCache(0)) + def func(): + nonlocal calls + time.sleep(1) + calls += 1 + return "Hello" + + with futures.ThreadPoolExecutor(max_workers=10) as executor: + future_list = [executor.submit(func) for _ in range(10)] + for future in futures.as_completed(future_list): + assert future.result() == "Hello" + + assert calls == 1 + + +def test_threading_exc(): + calls = 0 + + @cached(LRUCache(0)) + def func(): + nonlocal calls + time.sleep(1) + calls += 1 + raise RuntimeError + + with futures.ThreadPoolExecutor(max_workers=5) as executor: + future_list = [executor.submit(func) for _ in range(5)] + for future in futures.as_completed(future_list): + assert isinstance(future.exception(), RuntimeError) + + assert calls == 1 + + with futures.ThreadPoolExecutor(max_workers=5) as executor: + future_list = [executor.submit(func) for _ in range(5)] + for future in futures.as_completed(future_list): + assert isinstance(future.exception(), RuntimeError) + + assert calls == 2 + + +@pytest.mark.asyncio +async def test_asyncio_return(): + calls = 0 + + @cached(LRUCache(0)) + async def func(): + nonlocal calls + await asyncio.sleep(1) + calls += 1 + return "Hello" + + await asyncio.gather( + func(), + func(), + func(), + func(), + func(), + ) + + assert calls == 1 + + +@pytest.mark.asyncio +async def test_asyncio_exc(): + calls = 0 + + @cached(LRUCache(0)) + async def func(): + nonlocal calls + await asyncio.sleep(1) + calls += 1 + raise RuntimeError + + tasks = await asyncio.gather( + func(), + func(), + func(), + func(), + func(), + return_exceptions=True, + ) + for future in tasks: + assert isinstance(future, RuntimeError) + + assert calls == 1 + + tasks = await asyncio.gather( + func(), + func(), + func(), + func(), + func(), + return_exceptions=True, + ) + for future in tasks: + assert isinstance(future, RuntimeError) + + assert calls == 2 diff --git a/tests/test_utils.py b/tests/test_utils.py index 63bc1fd..a6ba7aa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -100,7 +100,8 @@ def func(a, b, c): assert len(func.cache) == 3 -async def _test_async_cached(): +@pytest.mark.asyncio +async def test_async_cached(): obj = LRUCache(3) # type: LRUCache[int, int] @cached(obj) @@ -142,15 +143,6 @@ async def factorial(n: int, _: str): assert len(factorial.cache) == 0 -def test_async_cached(): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - - loop.run_until_complete(_test_async_cached()) - - def test_cachedmethod(): class TestCachedMethod: def __init__(self, num) -> None: @@ -165,7 +157,8 @@ def method(self, char: str): assert cls.method("a") == ("a" * 10) -async def _test_async_cachedmethod(): +@pytest.mark.asyncio +async def test_async_cachedmethod(): class TestCachedMethod: def __init__(self, num) -> None: self.num = num @@ -179,15 +172,6 @@ async def method(self, char: str): assert (await cls.method("a")) == ("a" * 10) -def test_async_cachedmethod(): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - - loop.run_until_complete(_test_async_cachedmethod()) - - def test_callback(): obj = LRUCache(3)