Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Redis Cluster) - Fixes for using redis cluster + pipeline #8442

Merged
merged 23 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions litellm/_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
)

verbose_logger.debug(
"init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"]
"init_redis_cluster: startup nodes are being initialized."
)
from redis.cluster import ClusterNode

Expand Down Expand Up @@ -266,7 +266,9 @@ def get_redis_client(**env_overrides):
return redis.Redis(**redis_kwargs)


def get_redis_async_client(**env_overrides) -> async_redis.Redis:
def get_redis_async_client(
**env_overrides,
) -> async_redis.Redis:
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
Expand Down
1 change: 1 addition & 0 deletions litellm/caching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache
26 changes: 18 additions & 8 deletions litellm/caching/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache

Expand Down Expand Up @@ -158,14 +159,23 @@ def __init__(
None. Cache is set as a litellm param
"""
if type == LiteLLMCacheType.REDIS:
self.cache: BaseCache = RedisCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
if redis_startup_nodes:
self.cache: BaseCache = RedisClusterCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
else:
self.cache = RedisCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
**kwargs,
)
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
self.cache = RedisSemanticCache(
host=host,
Expand Down
49 changes: 35 additions & 14 deletions litellm/caching/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import time
from datetime import timedelta
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union

import litellm
from litellm._logging import print_verbose, verbose_logger
Expand All @@ -26,15 +26,20 @@

if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis
from redis.asyncio import Redis, RedisCluster
from redis.asyncio.client import Pipeline
from redis.asyncio.cluster import ClusterPipeline

pipeline = Pipeline
cluster_pipeline = ClusterPipeline
async_redis_client = Redis
async_redis_cluster_client = RedisCluster
Span = _Span
else:
pipeline = Any
cluster_pipeline = Any
async_redis_client = Any
async_redis_cluster_client = Any
Span = Any


Expand Down Expand Up @@ -122,7 +127,9 @@ def __init__(
else:
super().__init__() # defaults to 60s

def init_async_client(self):
def init_async_client(
self,
) -> Union[async_redis_client, async_redis_cluster_client]:
from .._redis import get_redis_async_client

return get_redis_async_client(
Expand Down Expand Up @@ -345,8 +352,14 @@ async def async_set_cache(self, key, value, **kwargs):
)

async def _pipeline_helper(
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
self,
pipe: Union[pipeline, cluster_pipeline],
cache_list: List[Tuple[Any, Any]],
ttl: Optional[float],
) -> List:
"""
Helper function for executing a pipeline of set operations on Redis
"""
ttl = self.get_ttl(ttl=ttl)
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list:
Expand All @@ -359,7 +372,11 @@ async def _pipeline_helper(
_td: Optional[timedelta] = None
if ttl is not None:
_td = timedelta(seconds=ttl)
pipe.set(cache_key, json_cache_value, ex=_td)
pipe.set( # type: ignore
name=cache_key,
value=json_cache_value,
ex=_td,
)
# Execute the pipeline and return the results.
results = await pipe.execute()
return results
Expand All @@ -373,9 +390,8 @@ async def async_set_cache_pipeline(
# don't waste a network request if there's nothing to set
if len(cache_list) == 0:
return
from redis.asyncio import Redis

_redis_client: Redis = self.init_async_client() # type: ignore
_redis_client = self.init_async_client()
start_time = time.time()

print_verbose(
Expand All @@ -384,7 +400,7 @@ async def async_set_cache_pipeline(
cache_value: Any = None
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
async with redis_client.pipeline(transaction=False) as pipe:
results = await self._pipeline_helper(pipe, cache_list, ttl)

print_verbose(f"pipeline results: {results}")
Expand Down Expand Up @@ -730,7 +746,8 @@ async def async_batch_get_cache(
"""
Use Redis for bulk read operations
"""
_redis_client = await self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
_redis_client: Any = self.init_async_client()
key_value_dict = {}
start_time = time.time()
try:
Expand Down Expand Up @@ -822,7 +839,8 @@ def sync_ping(self) -> bool:
raise e

async def ping(self) -> bool:
_redis_client = self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
_redis_client: Any = self.init_async_client()
start_time = time.time()
async with _redis_client as redis_client:
print_verbose("Pinging Async Redis Cache")
Expand Down Expand Up @@ -858,7 +876,8 @@ async def ping(self) -> bool:
raise e

async def delete_cache_keys(self, keys):
_redis_client = self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
_redis_client: Any = self.init_async_client()
# keys is a list, unpack it so it gets passed as individual elements to delete
async with _redis_client as redis_client:
await redis_client.delete(*keys)
Expand All @@ -881,7 +900,8 @@ async def disconnect(self):
await self.async_redis_conn_pool.disconnect(inuse_connections=True)

async def async_delete_cache(self, key: str):
_redis_client = self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
_redis_client: Any = self.init_async_client()
# keys is str
async with _redis_client as redis_client:
await redis_client.delete(key)
Expand Down Expand Up @@ -936,7 +956,7 @@ async def async_increment_pipeline(

try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
async with redis_client.pipeline(transaction=False) as pipe:
results = await self._pipeline_increment_helper(
pipe, increment_list
)
Expand Down Expand Up @@ -991,7 +1011,8 @@ async def async_get_ttl(self, key: str) -> Optional[int]:
Redis ref: https://redis.io/docs/latest/commands/ttl/
"""
try:
_redis_client = await self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
_redis_client: Any = self.init_async_client()
async with _redis_client as redis_client:
ttl = await redis_client.ttl(key)
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
Expand Down
44 changes: 44 additions & 0 deletions litellm/caching/redis_cluster_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Redis Cluster Cache implementation

Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
"""

from typing import TYPE_CHECKING, Any, Optional

from litellm.caching.redis_cache import RedisCache

if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis, RedisCluster
from redis.asyncio.client import Pipeline

pipeline = Pipeline
async_redis_client = Redis
Span = _Span
else:
pipeline = Any
async_redis_client = Any
Span = Any


class RedisClusterCache(RedisCache):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.redis_cluster_client: Optional[RedisCluster] = None

def init_async_client(self):
from redis.asyncio import RedisCluster

from .._redis import get_redis_async_client

if self.redis_cluster_client:
return self.redis_cluster_client

_redis_client = get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
if isinstance(_redis_client, RedisCluster):
self.redis_cluster_client = _redis_client
return _redis_client
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[mypy]
warn_return_any = False
ignore_missing_imports = True
mypy_path = litellm/stubs

[mypy-google.*]
ignore_missing_imports = True
12 changes: 9 additions & 3 deletions tests/local_testing/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import litellm
from litellm import aembedding, completion, embedding
from litellm.caching.caching import Cache

from redis.asyncio import RedisCluster
from litellm.caching.redis_cluster_cache import RedisClusterCache
from unittest.mock import AsyncMock, patch, MagicMock, call
import datetime
from datetime import timedelta
Expand Down Expand Up @@ -2328,8 +2329,12 @@ async def test_redis_caching_ttl_pipeline():
# Verify that the set method was called on the mock Redis instance
mock_set.assert_has_calls(
[
call.set("test_key1", '"test_value1"', ex=expected_timedelta),
call.set("test_key2", '"test_value2"', ex=expected_timedelta),
call.set(
name="test_key1", value='"test_value1"', ex=expected_timedelta
),
call.set(
name="test_key2", value='"test_value2"', ex=expected_timedelta
),
]
)

Expand Down Expand Up @@ -2388,6 +2393,7 @@ async def test_redis_increment_pipeline():
from litellm.caching.redis_cache import RedisCache

litellm.set_verbose = True
litellm._turn_on_debug()
redis_cache = RedisCache(
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
Expand Down