Skip to content

Commit

Permalink
refactor(redis): switch to redis-py
Browse files Browse the repository at this point in the history
We have to vendor redis-py typing because types-redis has an incomplete
redis.asyncio module typing.

I proposed my typing here: python/typeshed#7820

The change also makes mypy discovers some un-catched typing issue.
  • Loading branch information
sileht committed May 10, 2022
1 parent e8df844 commit 00273cd
Show file tree
Hide file tree
Showing 73 changed files with 4,746 additions and 145 deletions.
4 changes: 2 additions & 2 deletions mergify_engine/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ async def clear_user_permission_cache_for_user(
) -> None:
await redis.hdel(
cls._users_permission_cache_key_for_repo(owner["id"], repo["id"]),
user["id"],
str(user["id"]),
)

@classmethod
Expand Down Expand Up @@ -655,7 +655,7 @@ async def get_user_permission(
key = self._users_permission_cache_key
cached_permission = typing.cast(
typing.Optional[github_types.GitHubRepositoryPermission],
await self.installation.redis.cache.hget(key, user["id"]),
await self.installation.redis.cache.hget(key, str(user["id"])),
)
if cached_permission is None:
permission = typing.cast(
Expand Down
4 changes: 2 additions & 2 deletions mergify_engine/count_seats.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _add_user(user: github_types.GitHubAccount) -> None:
transaction = await redis_cache.pipeline()
for user_id, user_login in users.items():
user_key = f"{user_id}~{user_login}"
await transaction.zadd(repo_key, **{user_key: time.time()})
await transaction.zadd(repo_key, {user_key: time.time()})

await transaction.execute()

Expand Down Expand Up @@ -425,7 +425,7 @@ async def count_and_send(redis_cache: redis_utils.RedisCache) -> None:


async def report(args: argparse.Namespace) -> None:
redis_links = redis_utils.RedisLinks()
redis_links = redis_utils.RedisLinks(name="report")
if args.daemon:
service.setup("count-seats")
await count_and_send(redis_links.cache)
Expand Down
21 changes: 14 additions & 7 deletions mergify_engine/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,13 @@ async def report_dashboard_synchro(
)


async def report_worker_status(owner: github_types.GitHubLogin) -> None:
async def report_worker_status(
redis_links: redis_utils.RedisLinks, owner: github_types.GitHubLogin
) -> None:
stream_name = f"stream~{owner}".encode()
redis_links = redis_utils.RedisLinks()
streams = await redis_links.stream.zrangebyscore(
streams: typing.List[
typing.Tuple[bytes, float]
] = await redis_links.stream.zrangebyscore(
"streams", min=0, max="+inf", withscores=True
)

Expand All @@ -91,9 +94,13 @@ async def report_worker_status(owner: github_types.GitHubLogin) -> None:
print("* WORKER: Installation not queued to process")
return

planned = datetime.datetime.utcfromtimestamp(streams[pos]).isoformat()
planned = datetime.datetime.utcfromtimestamp(streams[pos][1]).isoformat()

attempts = await redis_links.stream.hget("attempts", stream_name) or 0
attempts_raw = await redis_links.stream.hget("attempts", stream_name)
if attempts_raw is None:
attempts = 0
else:
attempts = int(attempts)
print(
"* WORKER: Installation queued, "
f" pos: {pos}/{len(streams)},"
Expand Down Expand Up @@ -170,7 +177,7 @@ def _url_parser(
async def report(
url: str,
) -> typing.Union[context.Context, github.AsyncGithubInstallationClient, None]:
redis_links = redis_utils.RedisLinks(max_idle_time=0)
redis_links = redis_utils.RedisLinks(name="debug")

try:
owner_login, repo, pull_number = _url_parser(url)
Expand Down Expand Up @@ -234,7 +241,7 @@ async def report(
installation.installation["id"], db_sub, db_tokens, "DASHBOARD", slug
)

await report_worker_status(owner_login)
await report_worker_status(redis_links, owner_login)

if repo is not None:
repository = await installation.get_repository_by_name(repo)
Expand Down
8 changes: 5 additions & 3 deletions mergify_engine/delayed_refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def _set_current_refresh_datetime(
) -> None:
await repository.installation.redis.cache.zadd(
DELAYED_REFRESH_KEY,
**{_redis_key(repository, pull_number): at.timestamp()},
{_redis_key(repository, pull_number): at.timestamp()},
)


Expand Down Expand Up @@ -129,16 +129,18 @@ async def send(
for subkey in keys:
(
owner_id_str,
owner_login,
owner_login_str,
repository_id_str,
repository_name,
repository_name_str,
pull_request_number_str,
) = subkey.split("~")
owner_id = github_types.GitHubAccountIdType(int(owner_id_str))
repository_id = github_types.GitHubRepositoryIdType(int(repository_id_str))
pull_request_number = github_types.GitHubPullRequestNumber(
int(pull_request_number_str)
)
repository_name = github_types.GitHubRepositoryName(repository_name_str)
owner_login = github_types.GitHubLogin(owner_login_str)

LOG.info(
"sending delayed pull request refresh",
Expand Down
6 changes: 3 additions & 3 deletions mergify_engine/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import datetime
import typing

import yaaredis
from redis import exceptions as redis_exceptions

from mergify_engine.clients import http

Expand Down Expand Up @@ -118,11 +118,11 @@ def need_retry(
elif exception.response.status_code == 403:
return datetime.timedelta(minutes=3)

elif isinstance(exception, yaaredis.exceptions.ResponseError):
elif isinstance(exception, redis_exceptions.ResponseError):
# Redis script bug or OOM
return datetime.timedelta(minutes=1)

elif isinstance(exception, yaaredis.exceptions.ConnectionError):
elif isinstance(exception, redis_exceptions.ConnectionError):
# Redis down
return datetime.timedelta(minutes=1)

Expand Down
11 changes: 5 additions & 6 deletions mergify_engine/migrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.


import typing

import pkg_resources
Expand All @@ -33,14 +34,12 @@ async def run(redis_cache: redis_utils.RedisCache) -> None:
await _run_scripts("cache", redis_cache)


async def _run_scripts(
dirname: str, redis: typing.Union[redis_utils.RedisCache, redis_utils.RedisStream]
) -> None:
current_version = await redis.get(MIGRATION_STAMPS_KEY)
if current_version is None:
async def _run_scripts(dirname: str, redis: redis_utils.RedisCache) -> None:
current_version_raw: typing.Optional[str] = await redis.get(MIGRATION_STAMPS_KEY)
if current_version_raw is None:
current_version = 0
else:
current_version = int(current_version)
current_version = int(current_version_raw)

files = pkg_resources.resource_listdir(__name__, dirname)
for script in sorted(files):
Expand Down
2 changes: 1 addition & 1 deletion mergify_engine/queue/merge_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ async def iter_trains(
await train.load(train_raw)
yield train

async def load(self, train_raw: typing.Optional[bytes] = None) -> None:
async def load(self, train_raw: typing.Optional[str] = None) -> None:
if train_raw is None:
train_raw = await self.repository.installation.redis.cache.hget(
self._get_redis_key(), self._get_redis_hash_key()
Expand Down
131 changes: 75 additions & 56 deletions mergify_engine/redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,32 @@
# License for the specific language governing permissions and limitations
# under the License.

import asyncio
import dataclasses
import functools
import hashlib
import ssl
import typing
import uuid

import daiquiri
import yaaredis
import redis.asyncio as redispy

from mergify_engine import config
from mergify_engine import service


LOG = daiquiri.getLogger(__name__)


# NOTE(sileht): I wonder why mypy thinks `yaaredis.StrictRedis` is `typing.Any`...
RedisCache = typing.NewType("RedisCache", yaaredis.StrictRedis) # type: ignore
RedisStream = typing.NewType("RedisStream", yaaredis.StrictRedis) # type: ignore
RedisQueue = typing.NewType("RedisQueue", yaaredis.StrictRedis) # type: ignore
RedisCache = typing.NewType("RedisCache", "redispy.Redis[str]")
RedisStream = typing.NewType("RedisStream", "redispy.Redis[bytes]")
RedisQueue = typing.NewType("RedisQueue", "redispy.Redis[bytes]")

ScriptIdT = typing.NewType("ScriptIdT", uuid.UUID)

SCRIPTS: typing.Dict[ScriptIdT, typing.Tuple[str, str]] = {}


# TODO(sileht): Redis script management can be moved back to Redis.register_script() mechanism
def register_script(script: str) -> ScriptIdT:
global SCRIPTS
# NOTE(sileht): We don't use sha, in case of something server side change the script sha
Expand All @@ -54,10 +53,16 @@ def register_script(script: str) -> ScriptIdT:
return script_id


async def load_script(redis: RedisStream, script_id: ScriptIdT) -> None:
# FIXME(sileht): We store Cache and Stream script into the same global object
# it works but if a script is loaded into two redis, this won't works as expected
# as the app will think it's already loaded while it's not...
async def load_script(
redis: typing.Union[RedisCache, RedisStream], script_id: ScriptIdT
) -> None:
global SCRIPTS
sha, script = SCRIPTS[script_id]
newsha = await redis.script_load(script)
# FIXME(sileht): weird, this method is typed on redis-py
newsha = await redis.script_load(script) # type: ignore[no-untyped-call]
if newsha != sha:
LOG.error(
"wrong redis script sha cached",
Expand All @@ -68,22 +73,24 @@ async def load_script(redis: RedisStream, script_id: ScriptIdT) -> None:
SCRIPTS[script_id] = (newsha, script)


async def load_scripts(redis: RedisStream) -> None:
async def load_scripts(redis: typing.Union[RedisCache, RedisStream]) -> None:
# TODO(sileht): cleanup unused script, this is tricky, because during
# deployment we have running in parallel due to the rolling upgrade:
# * an old version of the asgi server
# * a new version of the asgi server
# * a new version of the backend
global SCRIPTS
ids = list(SCRIPTS.keys())
exists = await redis.script_exists(*ids)
scripts = list(SCRIPTS.items()) # order matter for zip bellow
shas = [sha for _, (sha, _) in scripts]
ids = [_id for _id, _ in scripts]
exists = await redis.script_exists(*shas) # type: ignore[no-untyped-call]
for script_id, exist in zip(ids, exists):
if not exist:
await load_script(redis, script_id)


async def run_script(
redis: RedisStream,
redis: typing.Union[RedisCache, RedisStream],
script_id: ScriptIdT,
keys: typing.Tuple[str, ...],
args: typing.Optional[typing.Tuple[typing.Union[str], ...]] = None,
Expand All @@ -94,79 +101,91 @@ async def run_script(
args = keys
else:
args = keys + args
return await redis.evalsha(sha, len(keys), *args)


async def stop_pending_yaaredis_tasks() -> None:
tasks = [
task
for task in asyncio.all_tasks()
if (
getattr(task.get_coro(), "__qualname__", None)
== "ConnectionPool.disconnect_on_idle_time_exceeded"
)
]

if tasks:
for task in tasks:
task.cancel()
await asyncio.wait(tasks)
return await redis.evalsha(sha, len(keys), *args) # type: ignore[no-untyped-call]


@dataclasses.dataclass
class RedisLinks:
max_idle_time: int = 60
name: str
cache_max_connections: typing.Optional[int] = None
stream_max_connections: typing.Optional[int] = None
queue_max_connections: typing.Optional[int] = None

@functools.cached_property
def queue(self) -> RedisQueue:
client = self.redis_from_url(
"queue",
config.QUEUE_URL,
max_idle_time=self.max_idle_time,
decode_responses=False,
max_connections=self.queue_max_connections,
)
return RedisQueue(client)

@functools.cached_property
def stream(self) -> RedisStream:
client = self.redis_from_url(
"stream",
config.STREAM_URL,
max_idle_time=self.max_idle_time,
decode_responses=False,
max_connections=self.stream_max_connections,
)
return RedisStream(client)

@functools.cached_property
def cache(self) -> RedisCache:
client = self.redis_from_url(
"cache",
config.STORAGE_URL,
decode_responses=True,
max_idle_time=self.max_idle_time,
max_connections=self.cache_max_connections,
)
return RedisCache(client)

@staticmethod
def redis_from_url(url: str, **options: typing.Any) -> yaaredis.StrictRedis:
ssl_scheme = "rediss://"
if config.REDIS_SSL_VERIFY_MODE_CERT_NONE and url.startswith(ssl_scheme):
final_url = f"redis://{url[len(ssl_scheme):]}"
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = (
ssl.CERT_NONE # nosemgrep contrib.dlint.dlint-equivalent.insecure-ssl-use
)
options["ssl_context"] = ctx
else:
final_url = url
return yaaredis.StrictRedis.from_url(final_url, **options)

def shutdown_all(self) -> None:
self.cache.connection_pool.max_idle_time = 0
self.cache.connection_pool.disconnect()
self.stream.connection_pool.max_idle_time = 0
self.stream.connection_pool.disconnect()
self.queue.connection_pool.max_idle_time = 0
self.queue.connection_pool.disconnect()
@typing.overload
def redis_from_url(
self, # FIXME(sileht): mypy is lost if the method is static...
name: str,
url: str,
decode_responses: typing.Literal[True],
max_connections: typing.Optional[int] = None,
) -> "redispy.Redis[str]":
...

@typing.overload
def redis_from_url(
self, # FIXME(sileht): mypy is lost if the method is static...
name: str,
url: str,
decode_responses: typing.Literal[False],
max_connections: typing.Optional[int] = None,
) -> "redispy.Redis[bytes]":
...

def redis_from_url(
self, # FIXME(sileht): mypy is lost if the method is static...
name: str,
url: str,
decode_responses: bool,
max_connections: typing.Optional[int] = None,
) -> typing.Union["redispy.Redis[bytes]", "redispy.Redis[str]"]:

options: typing.Dict[str, typing.Any] = {}
if config.REDIS_SSL_VERIFY_MODE_CERT_NONE and url.startswith("rediss://"):
options["ssl_check_hostname"] = False
options["ssl_cert_reqs"] = None

return redispy.Redis.from_url(
url,
max_connections=max_connections,
decode_responses=decode_responses,
client_name=f"{service.SERVICE_NAME}/{self.name}/{name}",
**options,
)

async def shutdown_all(self) -> None:
if "cache" in self.__dict__:
await self.cache.close(close_connection_pool=True)
if "stream" in self.__dict__:
await self.stream.close(close_connection_pool=True)
if "queue" in self.__dict__:
await self.queue.close(close_connection_pool=True)
Loading

0 comments on commit 00273cd

Please sign in to comment.