Skip to content

Commit

Permalink
fix: Pin redis-py to 4.5.5 and enhance logging for Redis reconnection…
Browse files Browse the repository at this point in the history
… and failover (#1620)

This commit only includes redis-py 4.5.5 pinning, connection pool
mis-creation in `AgentRegistry.handle_kernel_log()`, and limitation of
`max_connections` for the default connection pool.

Added an explicit warning if the native sentinel client is used, because the rework
on the sentinel connection pool (#1586) targets 23.09 only.

Backported-from: main
Backported-to: 23.03
  • Loading branch information
achimnol committed Oct 16, 2023
1 parent a69fccf commit 502a6c4
Show file tree
Hide file tree
Showing 14 changed files with 218 additions and 175 deletions.
1 change: 1 addition & 0 deletions changes/1620.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve logging when retrying redis connections during failover and use explicit names for all redis connection pools
247 changes: 123 additions & 124 deletions python.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ dataclasses-json~=0.5.7
etcetra==0.1.17
faker~=13.12.0
graphene~=2.1.9
hiredis~=2.2.3
humanize>=3.1.0
ifaddr~=0.2
inquirer~=2.9.2
Expand All @@ -55,7 +54,8 @@ pyzmq~=24.0.1
PyJWT~=2.0
PyYAML~=6.0
packaging>=21.3
redis[hiredis]~=4.6.0
hiredis>=2.2.3
redis[hiredis]==4.5.5
rich~=12.2
SQLAlchemy[postgresql_asyncpg]~=1.4.40
setproctitle~=1.3.2
Expand All @@ -69,7 +69,7 @@ tqdm>=4.61
trafaret~=2.1
typeguard~=2.10
typing_extensions~=4.3
uvloop>=0.17; sys_platform != "Windows"
uvloop~=0.17.0; sys_platform != "Windows" # 0.18 breaks the API and adds Python 3.12 support
yarl~=1.8.2 # FIXME: revert to >=1.7 after aio-libs/yarl#862 is resolved
zipstream-new~=1.1.8

Expand Down
14 changes: 11 additions & 3 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.config import model_definition_iv
from ai.backend.common.defs import REDIS_STREAM_DB
from ai.backend.common.defs import REDIS_STAT_DB, REDIS_STREAM_DB
from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef
from ai.backend.common.events import (
AbstractEvent,
Expand Down Expand Up @@ -616,8 +616,16 @@ async def __ainit__(self) -> None:
node_id=self.local_config["agent"]["id"],
consumer_group=EVENT_DISPATCHER_CONSUMER_GROUP,
)
self.redis_stream_pool = redis_helper.get_redis_object(self.local_config["redis"], db=4)
self.redis_stat_pool = redis_helper.get_redis_object(self.local_config["redis"], db=0)
self.redis_stream_pool = redis_helper.get_redis_object(
self.local_config["redis"],
name="stream",
db=REDIS_STREAM_DB,
)
self.redis_stat_pool = redis_helper.get_redis_object(
self.local_config["redis"],
name="stat",
db=REDIS_STAT_DB,
)

alloc_map_mod.log_alloc_map = self.local_config["debug"]["log-alloc-map"]
computers, self.slots = await self.detect_resources()
Expand Down
36 changes: 18 additions & 18 deletions src/ai/backend/common/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import socket
import uuid
from collections import defaultdict
from types import TracebackType
from typing import (
Any,
Awaitable,
Callable,
ClassVar,
Coroutine,
Expand All @@ -32,8 +30,8 @@
from aiotools.context import aclosing
from aiotools.server import process_index
from aiotools.taskgroup import PersistentTaskGroup
from aiotools.taskgroup.types import AsyncExceptionHandler
from redis.asyncio import ConnectionPool
from typing_extensions import TypeAlias

from . import msgpack, redis_helper
from .logging import BraceStyleAdapter
Expand All @@ -57,10 +55,6 @@

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]

PTGExceptionHandler: TypeAlias = Callable[
[Type[Exception], Exception, TracebackType], Awaitable[None]
]


class AbstractEvent(metaclass=abc.ABCMeta):
# derivatives should define the fields.
Expand Down Expand Up @@ -688,16 +682,18 @@ def __init__(
log_events: bool = False,
*,
consumer_group: str,
service_name: str = None,
service_name: str | None = None,
stream_key: str = "events",
node_id: str = None,
consumer_exception_handler: PTGExceptionHandler = None,
subscriber_exception_handler: PTGExceptionHandler = None,
node_id: str | None = None,
consumer_exception_handler: AsyncExceptionHandler | None = None,
subscriber_exception_handler: AsyncExceptionHandler | None = None,
) -> None:
_redis_config = redis_config.copy()
if service_name:
_redis_config["service_name"] = service_name
self.redis_client = redis_helper.get_redis_object(_redis_config, db=db)
self.redis_client = redis_helper.get_redis_object(
_redis_config, name="event_dispatcher.stream", db=db
)
self._log_events = log_events
self._closed = False
self.consumers = defaultdict(set)
Expand Down Expand Up @@ -743,7 +739,7 @@ def consume(
callback: EventCallback[TContext, TEvent],
coalescing_opts: CoalescingOptions = None,
*,
name: str = None,
name: str | None = None,
) -> EventHandler[TContext, TEvent]:
if name is None:
name = f"evh-{secrets.token_urlsafe(16)}"
Expand All @@ -766,9 +762,9 @@ def subscribe(
event_cls: Type[TEvent],
context: TContext,
callback: EventCallback[TContext, TEvent],
coalescing_opts: CoalescingOptions = None,
coalescing_opts: CoalescingOptions | None = None,
*,
name: str = None,
name: str | None = None,
) -> EventHandler[TContext, TEvent]:
if name is None:
name = f"evh-{secrets.token_urlsafe(16)}"
Expand Down Expand Up @@ -892,15 +888,19 @@ def __init__(
redis_config: EtcdRedisConfig,
db: int = 0,
*,
service_name: str = None,
service_name: str | None = None,
stream_key: str = "events",
log_events: bool = False,
) -> None:
_redis_config = redis_config.copy()
if service_name:
_redis_config["service_name"] = service_name
self._closed = False
self.redis_client = redis_helper.get_redis_object(_redis_config, db=db)
self.redis_client = redis_helper.get_redis_object(
_redis_config,
name="event_producer.stream",
db=db,
)
self._log_events = log_events
self._stream_key = stream_key

Expand Down Expand Up @@ -930,7 +930,7 @@ async def produce_event(
)


def _generate_consumer_id(node_id: str = None) -> str:
def _generate_consumer_id(node_id: str | None = None) -> str:
h = hashlib.sha1()
h.update(str(node_id or socket.getfqdn()).encode("utf8"))
hostname_hash = h.hexdigest()
Expand Down
32 changes: 27 additions & 5 deletions src/ai/backend/common/redis_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import redis.exceptions
import yarl
from redis.asyncio import Redis
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.client import Pipeline, PubSub
from redis.asyncio.sentinel import (
MasterNotFoundError,
Expand Down Expand Up @@ -59,8 +59,6 @@


_default_conn_opts: Mapping[str, Any] = {
"socket_timeout": 5.0,
"socket_connect_timeout": 2.0,
"socket_keepalive": True,
"socket_keepalive_options": _keepalive_options,
"retry": Retry(ExponentialBackoff(), 10),
Expand All @@ -69,7 +67,10 @@
redis.exceptions.TimeoutError,
],
}

_default_conn_pool_opts: Mapping[str, Any] = {
"max_connections": 16,
# "timeout": 20.0, # for redis-py 5.0+
}

_scripts: Dict[str, str] = {}

Expand Down Expand Up @@ -471,10 +472,24 @@ async def read_stream_by_group(

def get_redis_object(
redis_config: EtcdRedisConfig,
name: str, # placeholder for backported codes
db: int = 0,
**kwargs,
) -> RedisConnectionInfo:
conn_opts = {
**_default_conn_opts,
**kwargs,
# "lib_name": None, # disable implicit "CLIENT SETINFO" (for redis-py 5.0+)
# "lib_version": None, # disable implicit "CLIENT SETINFO" (for redis-py 5.0+)
}
conn_pool_opts = {
**_default_conn_pool_opts,
}
if _sentinel_addresses := redis_config.get("sentinel"):
log.warning(
"Native sentinel client in 23.03 has imperfect implementation. "
"It is not recommended to use it."
)
sentinel_addresses: Any = None
if isinstance(_sentinel_addresses, str):
sentinel_addresses = DelimiterSeperatedList(HostPortPair).check_and_return(
Expand Down Expand Up @@ -503,7 +518,14 @@ def get_redis_object(
redis_url[1]
).with_password(redis_config.get("password")) / str(db)
return RedisConnectionInfo(
client=Redis.from_url(str(url), **kwargs),
client=Redis(
connection_pool=ConnectionPool.from_url(
str(url),
**conn_pool_opts,
),
**conn_opts,
auto_close_connection_pool=True,
),
service_name=None,
)

Expand Down
12 changes: 3 additions & 9 deletions src/ai/backend/manager/api/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
import sqlalchemy as sa
import trafaret as t
from aiohttp import web
from dateutil.relativedelta import relativedelta

from ai.backend.common import redis_helper
from ai.backend.common import validators as tx
from ai.backend.common.defs import REDIS_LIVE_DB
from ai.backend.common.distributed import GlobalTimer
from ai.backend.common.events import AbstractEvent, EmptyEventArgs, EventHandler
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import AgentId, LogSeverity, RedisConnectionInfo
from ai.backend.common.types import AgentId, LogSeverity

from ..defs import LockID
from ..models import UserRole, error_logs, groups
Expand Down Expand Up @@ -219,6 +218,7 @@ async def log_cleanup_task(app: web.Application, src: AgentId, event: DoLogClean
raw_lifetime = await etcd.get("config/logs/error/retention")
if raw_lifetime is None:
raw_lifetime = "90d"
lifetime: dt.timedelta | relativedelta
try:
lifetime = tx.TimeDuration().check(raw_lifetime)
except ValueError:
Expand All @@ -239,7 +239,6 @@ async def log_cleanup_task(app: web.Application, src: AgentId, event: DoLogClean
@attrs.define(slots=True, auto_attribs=True, init=False)
class PrivateContext:
log_cleanup_timer: GlobalTimer
log_cleanup_timer_redis: RedisConnectionInfo
log_cleanup_timer_evh: EventHandler[web.Application, DoLogCleanupEvent]


Expand All @@ -251,10 +250,6 @@ async def init(app: web.Application) -> None:
app,
log_cleanup_task,
)
app_ctx.log_cleanup_timer_redis = redis_helper.get_redis_object(
root_ctx.shared_config.data["redis"],
db=REDIS_LIVE_DB,
)
app_ctx.log_cleanup_timer = GlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
Expand All @@ -270,7 +265,6 @@ async def shutdown(app: web.Application) -> None:
app_ctx: PrivateContext = app["logs.context"]
await app_ctx.log_cleanup_timer.leave()
root_ctx.event_dispatcher.unconsume(app_ctx.log_cleanup_timer_evh)
await app_ctx.log_cleanup_timer_redis.close()


def create_app(
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/api/ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def init(app: web.Application) -> None:
root_ctx: RootContext = app["_root.context"]
app_ctx: PrivateContext = app["ratelimit.context"]
app_ctx.redis_rlim = redis_helper.get_redis_object(
root_ctx.shared_config.data["redis"], db=REDIS_RLIM_DB
root_ctx.shared_config.data["redis"], name="ratelimit", db=REDIS_RLIM_DB
)
app_ctx.redis_rlim_script = await redis_helper.execute(
app_ctx.redis_rlim, lambda r: r.script_load(_rlim_script)
Expand Down
14 changes: 12 additions & 2 deletions src/ai/backend/manager/cli/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,24 @@ async def redis_ctx(cli_ctx: CLIContext) -> AsyncIterator[RedisConnectionSet]:
await shared_config.reload()
raw_redis_config = await shared_config.etcd.get_prefix("config/redis")
local_config["redis"] = redis_config_iv.check(raw_redis_config)
redis_live = redis_helper.get_redis_object(shared_config.data["redis"], db=REDIS_LIVE_DB)
redis_stat = redis_helper.get_redis_object(shared_config.data["redis"], db=REDIS_STAT_DB)
redis_live = redis_helper.get_redis_object(
shared_config.data["redis"],
name="mgr_cli.live",
db=REDIS_LIVE_DB,
)
redis_stat = redis_helper.get_redis_object(
shared_config.data["redis"],
name="mgr_cli.stat",
db=REDIS_STAT_DB,
)
redis_image = redis_helper.get_redis_object(
shared_config.data["redis"],
name="mgr_cli.image",
db=REDIS_IMAGE_DB,
)
redis_stream = redis_helper.get_redis_object(
shared_config.data["redis"],
name="mgr_cli.stream",
db=REDIS_STREAM_DB,
)
yield RedisConnectionSet(
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/idle.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,12 @@ def __init__(
self._lock_factory = lock_factory
self._redis_live = redis_helper.get_redis_object(
self._shared_config.data["redis"],
name="idle.live",
db=REDIS_LIVE_DB,
)
self._redis_stat = redis_helper.get_redis_object(
self._shared_config.data["redis"],
name="idle.stat",
db=REDIS_STAT_DB,
)
self._grace_period_checker: NewUserGracePeriodChecker = NewUserGracePeriodChecker(
Expand Down
Loading

0 comments on commit 502a6c4

Please sign in to comment.