Skip to content

Commit

Permalink
Improve cached and cachedmthod
Browse files Browse the repository at this point in the history
  • Loading branch information
awolverp committed Jan 31, 2025
1 parent 884f2c9 commit 1c08346
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 38 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## 4.5.0 - 2025-01-31
### Updated
- `cached` and `cachedmethod` improved:
we used `threading.Lock` for sync functions, and `asyncio.Lock` for async functions to avoid [`cache stampede`](https://en.wikipedia.org/wiki/Cache_stampede). This changes fix [#15](https://github.com/awolverp/cachebox/issues/15) and [#20](https://github.com/awolverp/cachebox/issues/20) issues. Special thanks to [@AlePiccin](https://github.com/AlePiccin).

## 4.4.2 - 2024-12-19
### Updated
- Update `pyo3` to v0.23.4
Expand Down
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
[package]
name = "cachebox"
version = "4.4.2"
version = "4.5.0"
edition = "2021"
description = "The fastest memoizing and caching Python library written in Rust"
readme = "README.md"
license = "MIT"
homepage = "https://github.com/awolverp/cachebox"
repository = "https://github.com/awolverp/cachebox.git"
authors = ["awolverp"]

[lib]
name = "cachebox"
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ build-prod:

.PHONY: test-py
test-py:
maturin develop
maturin develop
RUST_BACKTRACE=1 pytest -vv
rm -rf .pytest_cache
ruff check .
Expand Down
106 changes: 91 additions & 15 deletions cachebox/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from ._cachebox import BaseCacheImpl, FIFOCache
from collections import namedtuple
from collections import namedtuple, defaultdict
import functools
import warnings
import asyncio
import _thread
import inspect
import typing

Expand Down Expand Up @@ -146,6 +148,34 @@ def items(self) -> typing.Iterable[typing.Tuple[KT, VT]]:
return self.__cache.items()


class _LockWithCounter:
"""
A threading/asyncio lock which count the waiters
"""

__slots__ = ("lock", "waiters")

def __init__(self, is_async: bool = False):
self.lock = _thread.allocate_lock() if not is_async else asyncio.Lock()
self.waiters = 0

async def __aenter__(self) -> None:
self.waiters += 1
await self.lock.acquire()

async def __aexit__(self, *args, **kwds) -> None:
self.waiters -= 1
self.lock.release()

def __enter__(self) -> None:
self.waiters += 1
self.lock.acquire()

def __exit__(self, *args, **kwds) -> None:
self.waiters -= 1
self.lock.release()


def _copy_if_need(obj, tocopy=(dict, list, set), level: int = 1):
from copy import copy

Expand Down Expand Up @@ -203,14 +233,18 @@ def _cached_wrapper(

hits = 0
misses = 0
locks = defaultdict(_LockWithCounter)
exceptions = {}

def _wrapped(*args, **kwds):
nonlocal hits, misses
nonlocal hits, misses, locks, exceptions

if kwds.pop("cachebox__ignore", False):
return func(*args, **kwds)

key = _key_maker(args, kwds)

# try to get result from cache
try:
result = cache[key]
hits += 1
Expand All @@ -220,13 +254,31 @@ def _wrapped(*args, **kwds):

return _copy_if_need(result, level=copy_level)
except KeyError:
misses += 1

result = func(*args, **kwds)
cache[key] = result
pass

with locks[key]:
if exceptions.get(key, None) is not None:
e = exceptions[key] if locks[key].waiters > 1 else exceptions.pop(key)
raise e

try:
result = cache[key]
hits += 1
event = EVENT_HIT
except KeyError:
try:
result = func(*args, **kwds)
except Exception as e:
exceptions[key] = e
raise e

else:
cache[key] = result
misses += 1
event = EVENT_MISS

if callback is not None:
callback(EVENT_MISS, key, result)
callback(event, key, result)

return _copy_if_need(result, level=copy_level)

Expand All @@ -237,10 +289,11 @@ def _wrapped(*args, **kwds):
)

def cache_clear():
nonlocal misses, hits
nonlocal misses, hits, locks
cache.clear(reuse=clear_reuse)
misses = 0
hits = 0
locks.clear()

_wrapped.cache_clear = cache_clear

Expand All @@ -260,14 +313,18 @@ def _async_cached_wrapper(

hits = 0
misses = 0
locks = defaultdict(lambda: _LockWithCounter(True))
exceptions = {}

async def _wrapped(*args, **kwds):
nonlocal hits, misses
nonlocal hits, misses, locks, exceptions

if kwds.pop("cachebox__ignore", False):
return await func(*args, **kwds)

key = _key_maker(args, kwds)

# try to get result from cache
try:
result = cache[key]
hits += 1
Expand All @@ -279,13 +336,31 @@ async def _wrapped(*args, **kwds):

return _copy_if_need(result, level=copy_level)
except KeyError:
misses += 1

result = await func(*args, **kwds)
cache[key] = result
pass

async with locks[key]:
if exceptions.get(key, None) is not None:
e = exceptions[key] if locks[key].waiters > 1 else exceptions.pop(key)
raise e

try:
result = cache[key]
hits += 1
event = EVENT_HIT
except KeyError:
try:
result = await func(*args, **kwds)
except Exception as e:
exceptions[key] = e
raise e

else:
cache[key] = result
misses += 1
event = EVENT_MISS

if callback is not None:
awaitable = callback(EVENT_MISS, key, result)
awaitable = callback(event, key, result)
if inspect.isawaitable(awaitable):
await awaitable

Expand All @@ -298,10 +373,11 @@ async def _wrapped(*args, **kwds):
)

def cache_clear():
nonlocal misses, hits
nonlocal misses, hits, locks
cache.clear(reuse=clear_reuse)
misses = 0
hits = 0
locks.clear()

_wrapped.cache_clear = cache_clear

Expand Down
108 changes: 108 additions & 0 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from cachebox import cached, LRUCache
from concurrent import futures
import asyncio
import pytest
import time


def test_threading_return():
calls = 0

@cached(LRUCache(0))
def func():
nonlocal calls
time.sleep(1)
calls += 1
return "Hello"

with futures.ThreadPoolExecutor(max_workers=10) as executor:
future_list = [executor.submit(func) for _ in range(10)]
for future in futures.as_completed(future_list):
assert future.result() == "Hello"

assert calls == 1


def test_threading_exc():
calls = 0

@cached(LRUCache(0))
def func():
nonlocal calls
time.sleep(1)
calls += 1
raise RuntimeError

with futures.ThreadPoolExecutor(max_workers=5) as executor:
future_list = [executor.submit(func) for _ in range(5)]
for future in futures.as_completed(future_list):
assert isinstance(future.exception(), RuntimeError)

assert calls == 1

with futures.ThreadPoolExecutor(max_workers=5) as executor:
future_list = [executor.submit(func) for _ in range(5)]
for future in futures.as_completed(future_list):
assert isinstance(future.exception(), RuntimeError)

assert calls == 2


@pytest.mark.asyncio
async def test_asyncio_return():
calls = 0

@cached(LRUCache(0))
async def func():
nonlocal calls
await asyncio.sleep(1)
calls += 1
return "Hello"

await asyncio.gather(
func(),
func(),
func(),
func(),
func(),
)

assert calls == 1


@pytest.mark.asyncio
async def test_asyncio_exc():
calls = 0

@cached(LRUCache(0))
async def func():
nonlocal calls
await asyncio.sleep(1)
calls += 1
raise RuntimeError

tasks = await asyncio.gather(
func(),
func(),
func(),
func(),
func(),
return_exceptions=True,
)
for future in tasks:
assert isinstance(future, RuntimeError)

assert calls == 1

tasks = await asyncio.gather(
func(),
func(),
func(),
func(),
func(),
return_exceptions=True,
)
for future in tasks:
assert isinstance(future, RuntimeError)

assert calls == 2
24 changes: 4 additions & 20 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def func(a, b, c):
assert len(func.cache) == 3


async def _test_async_cached():
@pytest.mark.asyncio
async def test_async_cached():
obj = LRUCache(3) # type: LRUCache[int, int]

@cached(obj)
Expand Down Expand Up @@ -142,15 +143,6 @@ async def factorial(n: int, _: str):
assert len(factorial.cache) == 0


def test_async_cached():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()

loop.run_until_complete(_test_async_cached())


def test_cachedmethod():
class TestCachedMethod:
def __init__(self, num) -> None:
Expand All @@ -165,7 +157,8 @@ def method(self, char: str):
assert cls.method("a") == ("a" * 10)


async def _test_async_cachedmethod():
@pytest.mark.asyncio
async def test_async_cachedmethod():
class TestCachedMethod:
def __init__(self, num) -> None:
self.num = num
Expand All @@ -179,15 +172,6 @@ async def method(self, char: str):
assert (await cls.method("a")) == ("a" * 10)


def test_async_cachedmethod():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()

loop.run_until_complete(_test_async_cachedmethod())


def test_callback():
obj = LRUCache(3)

Expand Down

0 comments on commit 1c08346

Please sign in to comment.