diff --git a/CHANGELOG.md b/CHANGELOG.md index 9591eb7..7fcf257 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,15 +2,17 @@ All notable changes to this project will be documented in this file. -## [Unreleased] +## [0.4.0] - 2023-11-17 ### Breaking - Scripts now need explicitly named arguments. ([#3]) - Script errors are now `ValueError` instead of `redis.exceptions.ResponseError`. ([#3]) +- Add name argument to `redipy.plugin.LuaRedisPatch`. ([#7]) ### Added +- Added ZRANGE (partially) and LRANGE. ([#7]) - Inferring backend in Redis constructor. ([#3]) - Allow access to raw runtime and redis connections. ([#3]) - Executing scripts from a pipeline. ([#3]) @@ -29,6 +31,7 @@ All notable changes to this project will be documented in this file. ### Notable Internal Changes -- Remove usage of `nil` in lua scripts (in favor of `cjson.null`). +- Removed some usage of `nil` in lua scripts (in favor of `cjson.null`). ([#3]) [#3]: https://github.com/JosuaKrause/redipy/pull/3 +[#7]: https://github.com/JosuaKrause/redipy/pull/7 diff --git a/README.md b/README.md index e989be4..a00f12d 100644 --- a/README.md +++ b/README.md @@ -198,15 +198,6 @@ class RStack: """ return f"{base}:{name}" - def init(self, base: str) -> None: - """ - Initializes the stack. - - Args: - base (str): The base key. - """ - self._rt.set(self.key(base, "size"), "0") - def push_frame(self, base: str) -> None: """ Pushes a new stack frame. @@ -216,7 +207,7 @@ class RStack: """ self._rt.incrby(self.key(base, "size"), 1) - def pop_frame(self, base: str) -> dict[str, str] | None: + def pop_frame(self, base: str) -> dict[str, str]: """ Pops the current stack frame and returns its values. @@ -233,7 +224,7 @@ class RStack: }, args={}) if res is None: - return None + return {} return cast(dict, res) def set_value(self, base: str, field: str, value: str) -> None: @@ -300,7 +291,7 @@ class RStack: rframe = RedisHash(Strs( ctx.add_key("frame"), ":", - ToIntStr(rsize.get(no_adjust=True)))) + ToIntStr(rsize.get(default=0)))) field = ctx.add_arg("field") value = ctx.add_arg("value") ctx.add(rframe.hset({ @@ -315,7 +306,7 @@ class RStack: rframe = RedisHash(Strs( ctx.add_key("frame"), ":", - ToIntStr(rsize.get(no_adjust=True)))) + ToIntStr(rsize.get(default=0)))) field = ctx.add_arg("field") ctx.set_return_value(rframe.hget(field)) return self._rt.register_script(ctx) @@ -324,10 +315,14 @@ class RStack: ctx = FnContext() rsize = RedisVar(ctx.add_key("size")) rframe = RedisHash( - Strs(ctx.add_key("frame"), ":", ToIntStr(rsize.get()))) + Strs(ctx.add_key("frame"), ":", ToIntStr(rsize.get(default=0)))) lcl = ctx.add_local(rframe.hgetall()) ctx.add(rframe.delete()) - ctx.add(rsize.incrby(-1)) + + b_then, b_else = ctx.if_(ToNum(rsize.get(default=0)).gt_(0)) + b_then.add(rsize.incrby(-1)) + b_else.add(rsize.delete()) + ctx.set_return_value(lcl) return self._rt.register_script(ctx) @@ -336,7 +331,7 @@ class RStack: rsize = RedisVar(ctx.add_key("size")) base = ctx.add_local(ctx.add_key("frame")) field = ctx.add_arg("field") - pos = ctx.add_local(ToNum(rsize.get())) + pos = ctx.add_local(ToNum(rsize.get(default=0))) res = ctx.add_local(None) cur = ctx.add_local(None) rframe = RedisHash(cur) @@ -445,7 +440,8 @@ For a full implementation follow these steps: 1. Add the signature of the function to `redipy.api.RedisAPI`. Adjust as necessary from the redis spec to get a pythonic feel. Also, add the signature - to `redipy.api.PipelineAPI` but with `None` as return value. + to `redipy.api.PipelineAPI` but with `None` as return value. Additionally, + add the redirect to the backend in `redipy.main.Redis`. 2. Implement the function in `redipy.redis.conn.RedisConnection` and `redipy.redis.conn.PipelineConnection`. This should be straightforward as there are not too many changes expected. Don't forget @@ -466,9 +462,10 @@ For a full implementation follow these steps: a class in `redipy.memory.rfun`. 7. Add the approriate class or method in the right `redipy.symbolic.r...py` file. If it is a new class / file add an import to `redipy.script`. -8. Add a new test to verify the new function works inside a script for all - backends. You can run `make pytest FILE=test/...py` to execute the test and - `make coverage-report` to verify that the new code is executed. +8. Add a new test in `test/test_api.py` to verify the new function works inside + a script for all backends. You can run `make pytest FILE=test/test_api.py` + to execute the test and `make coverage-report` to verify that the new code + is executed. 9. Make sure `make lint-all` passes, as well as, all tests (`make pytest`) run without issue. diff --git a/src/redipy/api.py b/src/redipy/api.py index d158b31..ddeb0ac 100644 --- a/src/redipy/api.py +++ b/src/redipy/api.py @@ -38,6 +38,32 @@ def execute(self) -> list: """ raise NotImplementedError() + def exists(self, *keys: str) -> None: + """ + Determines whether specified keys exist. + + See also the redis documentation: https://redis.io/commands/exists/ + + The pipeline value is set to the number of keys that exist. + + Args: + *keys (str): The keys. + """ + raise NotImplementedError() + + def delete(self, *keys: str) -> None: + """ + Deletes keys. + + See also the redis documentation: https://redis.io/commands/del/ + + The pipeline value is set to the number of keys that got removed. + + Args: + *keys (str): The keys. + """ + raise NotImplementedError() + def set( self, key: str, @@ -93,6 +119,27 @@ def get(self, key: str) -> None: """ raise NotImplementedError() + def incrby(self, key: str, inc: float | int) -> None: + """ + Updates the value associated with the given key by a relative amount. + The value is interpreted as number. If the value doesn't exist zero is + used as starting point. + + See also the redis documentation: + https://redis.io/commands/incrby/ + https://redis.io/commands/incrbyfloat/ + + The pipeline value is set to the new value as float. + If the value cannot be interpreted as float while executing the + pipeline a ValueError exception is raised. + + Args: + key (str): The key. + + inc (float | int): The relative change. + """ + raise NotImplementedError() + def lpush(self, key: str, *values: str) -> None: """ Pushes values to the left side of the list associated with the key. @@ -169,6 +216,25 @@ def rpop( """ raise NotImplementedError() + def lrange(self, key: str, start: int, stop: int) -> None: + """ + Returns a number of values from the list specified by the given range. + Negative numbers are interpreted as index from the back of the list. + Out of range indices are ignored, potentially returning an empty list. + + See also the redis documentation: https://redis.io/commands/lrange/ + + The pipeline value is the resulting elements. + + Args: + key (str): The key. + + start (int): The start index. + + stop (int): The stop index (inclusive). + """ + raise NotImplementedError() + def llen(self, key: str) -> None: """ Computes the length of the list associated with the key. @@ -242,64 +308,39 @@ def zpop_min( """ raise NotImplementedError() - def zcard(self, key: str) -> None: + def zrange(self, key: str, start: int, stop: int) -> None: """ - Computes the cardinality of the sorted set associated with the given - key. + Returns a number of values from the sorted set specified by the given + range. As of now the indices are based on the order of the set. + Negative numbers are interpreted as index from the back of the set. + Out of range indices are ignored, potentially returning an empty set. - See also the redis documentation: https://redis.io/commands/zcard/ + See also the redis documentation: https://redis.io/commands/zrange/ - The number of members in the set is set as pipeline value. + NOTE: not all modes are implemented yet. - Args: - key (str): The key. - """ - raise NotImplementedError() - - def incrby(self, key: str, inc: float | int) -> None: - """ - Updates the value associated with the given key by a relative amount. - The value is interpreted as number. If the value doesn't exist zero is - used as starting point. - - See also the redis documentation: - https://redis.io/commands/incrby/ - https://redis.io/commands/incrbyfloat/ - - The pipeline value is set to the new value as float. - If the value cannot be interpreted as float while executing the - pipeline a ValueError exception is raised. + The members names are set as pipeline value. Args: key (str): The key. - inc (float | int): The relative change. - """ - raise NotImplementedError() - - def exists(self, *keys: str) -> None: - """ - Determines whether specified keys exist. - - See also the redis documentation: https://redis.io/commands/exists/ - - The pipeline value is set to the number of keys that exist. + start (int): The start index. - Args: - *keys (str): The keys. + stop (int): The stop index (inclusive). """ raise NotImplementedError() - def delete(self, *keys: str) -> None: + def zcard(self, key: str) -> None: """ - Deletes keys. + Computes the cardinality of the sorted set associated with the given + key. - See also the redis documentation: https://redis.io/commands/del/ + See also the redis documentation: https://redis.io/commands/zcard/ - The pipeline value is set to the number of keys that got removed. + The number of members in the set is set as pipeline value. Args: - *keys (str): The keys. + key (str): The key. """ raise NotImplementedError() @@ -429,6 +470,34 @@ def hgetall(self, key: str) -> None: class RedisAPI: """The redis API.""" + def exists(self, *keys: str) -> int: + """ + Determines whether specified keys exist. + + See also the redis documentation: https://redis.io/commands/exists/ + + Args: + *keys (str): The keys. + + Returns: + int: The number of keys that exist. + """ + raise NotImplementedError() + + def delete(self, *keys: str) -> int: + """ + Deletes keys. + + See also the redis documentation: https://redis.io/commands/del/ + + Args: + *keys (str): The keys. + + Returns: + int: The number of keys that got removed. + """ + raise NotImplementedError() + @overload def set( self, @@ -526,6 +595,29 @@ def get(self, key: str) -> str | None: """ raise NotImplementedError() + def incrby(self, key: str, inc: float | int) -> float: + """ + Updates the value associated with the given key by a relative amount. + The value is interpreted as number. If the value doesn't exist zero is + used as starting point. + + See also the redis documentation: + https://redis.io/commands/incrby/ + https://redis.io/commands/incrbyfloat/ + + Args: + key (str): The key. + + inc (float | int): The relative change. + + Raises: + ValueError: If the value cannot be interpreted as float. + + Returns: + float: The new value as float. + """ + raise NotImplementedError() + def lpush(self, key: str, *values: str) -> int: """ Pushes values to the left side of the list associated with the key. @@ -634,6 +726,26 @@ def rpop( """ raise NotImplementedError() + def lrange(self, key: str, start: int, stop: int) -> list[str]: + """ + Returns a number of values from the list specified by the given range. + Negative numbers are interpreted as index from the back of the list. + Out of range indices are ignored, potentially returning an empty list. + + See also the redis documentation: https://redis.io/commands/lrange/ + + Args: + key (str): The key. + + start (int): The start index. + + stop (int): The stop index (inclusive). + + Returns: + list[str]: The elements. + """ + raise NotImplementedError() + def llen(self, key: str) -> int: """ Computes the length of the list associated with the key. @@ -711,69 +823,41 @@ def zpop_min( """ raise NotImplementedError() - def zcard(self, key: str) -> int: + def zrange(self, key: str, start: int, stop: int) -> list[str]: """ - Computes the cardinality of the sorted set associated with the given - key. + Returns a number of values from the sorted set specified by the given + range. As of now the indices are based on the order of the set. + Negative numbers are interpreted as index from the back of the set. + Out of range indices are ignored, potentially returning an empty set. - See also the redis documentation: https://redis.io/commands/zcard/ + See also the redis documentation: https://redis.io/commands/zrange/ - Args: - key (str): The key. - - Returns: - int: The number of members in the set. - """ - raise NotImplementedError() - - def incrby(self, key: str, inc: float | int) -> float: - """ - Updates the value associated with the given key by a relative amount. - The value is interpreted as number. If the value doesn't exist zero is - used as starting point. - - See also the redis documentation: - https://redis.io/commands/incrby/ - https://redis.io/commands/incrbyfloat/ + NOTE: not all modes are implemented yet. Args: key (str): The key. - inc (float | int): The relative change. - - Raises: - ValueError: If the value cannot be interpreted as float. - - Returns: - float: The new value as float. - """ - raise NotImplementedError() + start (int): The start index. - def exists(self, *keys: str) -> int: - """ - Determines whether specified keys exist. - - See also the redis documentation: https://redis.io/commands/exists/ - - Args: - *keys (str): The keys. + stop (int): The stop index (inclusive). Returns: - int: The number of keys that exist. + list[str]: The members names. """ raise NotImplementedError() - def delete(self, *keys: str) -> int: + def zcard(self, key: str) -> int: """ - Deletes keys. + Computes the cardinality of the sorted set associated with the given + key. - See also the redis documentation: https://redis.io/commands/del/ + See also the redis documentation: https://redis.io/commands/zcard/ Args: - *keys (str): The keys. + key (str): The key. Returns: - int: The number of keys that got removed. + int: The number of members in the set. """ raise NotImplementedError() diff --git a/src/redipy/main.py b/src/redipy/main.py index e16b397..e662138 100644 --- a/src/redipy/main.py +++ b/src/redipy/main.py @@ -212,6 +212,12 @@ def pipeline(self) -> Iterator[PipelineAPI]: with self._rt.pipeline() as pipe: yield pipe + def exists(self, *keys: str) -> int: + return self._rt.exists(*keys) + + def delete(self, *keys: str) -> int: + return self._rt.delete(*keys) + @overload def set( self, @@ -273,6 +279,9 @@ def set( def get(self, key: str) -> str | None: return self._rt.get(key) + def incrby(self, key: str, inc: float | int) -> float: + return self._rt.incrby(key, inc) + def lpush(self, key: str, *values: str) -> int: return self._rt.lpush(key, *values) @@ -319,6 +328,9 @@ def rpop( count: int | None = None) -> str | list[str] | None: return self._rt.rpop(key, count) + def lrange(self, key: str, start: int, stop: int) -> list[str]: + return self._rt.lrange(key, start, stop) + def llen(self, key: str) -> int: return self._rt.llen(key) @@ -339,18 +351,12 @@ def zpop_min( ) -> list[tuple[str, float]]: return self._rt.zpop_min(key, count) + def zrange(self, key: str, start: int, stop: int) -> list[str]: + return self._rt.zrange(key, start, stop) + def zcard(self, key: str) -> int: return self._rt.zcard(key) - def incrby(self, key: str, inc: float | int) -> float: - return self._rt.incrby(key, inc) - - def exists(self, *keys: str) -> int: - return self._rt.exists(*keys) - - def delete(self, *keys: str) -> int: - return self._rt.delete(*keys) - def hset(self, key: str, mapping: dict[str, str]) -> int: return self._rt.hset(key, mapping) diff --git a/src/redipy/memory/rfun.py b/src/redipy/memory/rfun.py index c8f1d5f..f6859b2 100644 --- a/src/redipy/memory/rfun.py +++ b/src/redipy/memory/rfun.py @@ -7,6 +7,40 @@ from redipy.plugin import ArgcSpec, LocalRedisFunction +class RExistsFn(LocalRedisFunction): + """Implements the exists function.""" + @staticmethod + def name() -> str: + return "exists" + + @staticmethod + def argc() -> ArgcSpec: + return { + "count": 0, + } + + @staticmethod + def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: + return sm.exists(key) + + +class RDelFn(LocalRedisFunction): + """Implements the del function.""" + @staticmethod + def name() -> str: + return "del" + + @staticmethod + def argc() -> ArgcSpec: + return { + "count": 0, + } + + @staticmethod + def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: + return sm.delete(key) + + class RSetFn(LocalRedisFunction): """Implements the set function.""" @staticmethod @@ -69,7 +103,7 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: return sm.get(key) -class RIncrBy(LocalRedisFunction): +class RIncrByFn(LocalRedisFunction): """Implements the incrby and incrbyfloat functions.""" @staticmethod def name() -> str: @@ -160,6 +194,23 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: key, None if len(args) < 1 else int(cast(int, args[0]))) +class RLRangeFn(LocalRedisFunction): + """Implements the lrange function.""" + @staticmethod + def name() -> str: + return "lrange" + + @staticmethod + def argc() -> ArgcSpec: + return { + "count": 2, + } + + @staticmethod + def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: + return sm.lrange(key, int(cast(int, args[0])), int(cast(int, args[1]))) + + class RLLenFn(LocalRedisFunction): """Implements the llen function.""" @staticmethod @@ -232,45 +283,28 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: key, 1 if len(args) < 1 else int(cast(int, args[0])))) -class RZCard(LocalRedisFunction): - """Implements the zcard function.""" - @staticmethod - def name() -> str: - return "zcard" - - @staticmethod - def argc() -> ArgcSpec: - return { - "count": 0, - } - - @staticmethod - def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: - return sm.zcard(key) - - -class RExists(LocalRedisFunction): - """Implements the exists function.""" +class RZRangeFn(LocalRedisFunction): + """Implements the zrange function.""" @staticmethod def name() -> str: - return "exists" + return "zrange" @staticmethod def argc() -> ArgcSpec: return { - "count": 0, + "count": 2, } @staticmethod def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: - return sm.exists(key) + return sm.zrange(key, int(cast(int, args[0])), int(cast(int, args[1]))) -class RDel(LocalRedisFunction): - """Implements the del function.""" +class RZCardFn(LocalRedisFunction): + """Implements the zcard function.""" @staticmethod def name() -> str: - return "del" + return "zcard" @staticmethod def argc() -> ArgcSpec: @@ -280,10 +314,10 @@ def argc() -> ArgcSpec: @staticmethod def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: - return sm.delete(key) + return sm.zcard(key) -class RHSet(LocalRedisFunction): +class RHSetFn(LocalRedisFunction): """Implements the hset function.""" @staticmethod def name() -> str: @@ -311,7 +345,7 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: return sm.hset(key, mapping) -class RHDel(LocalRedisFunction): +class RHDelFn(LocalRedisFunction): """Implements the hdel function.""" @staticmethod def name() -> str: @@ -329,7 +363,7 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: return sm.hdel(key, *(f"{arg}" for arg in args)) -class RHGet(LocalRedisFunction): +class RHGetFn(LocalRedisFunction): """Implements the hget function.""" @staticmethod def name() -> str: @@ -346,7 +380,7 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: return sm.hget(key, f"{args[0]}") -class RHMGet(LocalRedisFunction): +class RHMGetFn(LocalRedisFunction): """Implements the hmget function.""" @staticmethod def name() -> str: @@ -364,7 +398,7 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: return sm.hmget(key, *(f"{arg}" for arg in args)) -class RHIncrBy(LocalRedisFunction): +class RHIncrByFn(LocalRedisFunction): """Implements the hincrby and hincrbyfloat functions.""" @staticmethod def name() -> str: @@ -381,7 +415,7 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: return sm.hincrby(key, f"{args[0]}", cast(float, args[1])) -class RHKeys(LocalRedisFunction): +class RHKeysFn(LocalRedisFunction): """Implements the hkeys function.""" @staticmethod def name() -> str: @@ -398,7 +432,7 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: return sm.hkeys(key) -class RHVals(LocalRedisFunction): +class RHValsFn(LocalRedisFunction): """Implements the hvals function.""" @staticmethod def name() -> str: @@ -415,7 +449,7 @@ def call(sm: Machine, key: str, args: list[JSONType]) -> JSONType: return sm.hvals(key) -class RHGetAll(LocalRedisFunction): +class RHGetAllFn(LocalRedisFunction): """Implements the hgetall function.""" @staticmethod def name() -> str: diff --git a/src/redipy/memory/rt.py b/src/redipy/memory/rt.py index e19fef1..5c362ef 100644 --- a/src/redipy/memory/rt.py +++ b/src/redipy/memory/rt.py @@ -204,6 +204,14 @@ def get_constant(self, raw: str) -> JSONType: """ return CONST[raw] + def exists(self, *keys: str) -> int: + with self.lock(): + return self._sm.exists(*keys) + + def delete(self, *keys: str) -> int: + with self.lock(): + return self._sm.delete(*keys) + @overload def set( self, @@ -267,6 +275,10 @@ def get(self, key: str) -> str | None: with self.lock(): return self._sm.get(key) + def incrby(self, key: str, inc: float | int) -> float: + with self.lock(): + return self._sm.incrby(key, inc) + def lpush(self, key: str, *values: str) -> int: with self.lock(): return self._sm.lpush(key, *values) @@ -317,6 +329,10 @@ def rpop( with self.lock(): return self._sm.rpop(key, count) + def lrange(self, key: str, start: int, stop: int) -> list[str]: + with self.lock(): + return self._sm.lrange(key, start, stop) + def llen(self, key: str) -> int: with self.lock(): return self._sm.llen(key) @@ -341,21 +357,13 @@ def zpop_min( with self.lock(): return self._sm.zpop_min(key, count) - def zcard(self, key: str) -> int: - with self.lock(): - return self._sm.zcard(key) - - def incrby(self, key: str, inc: float | int) -> float: - with self.lock(): - return self._sm.incrby(key, inc) - - def exists(self, *keys: str) -> int: + def zrange(self, key: str, start: int, stop: int) -> list[str]: with self.lock(): - return self._sm.exists(*keys) + return self._sm.zrange(key, start, stop) - def delete(self, *keys: str) -> int: + def zcard(self, key: str) -> int: with self.lock(): - return self._sm.delete(*keys) + return self._sm.zcard(key) def hset(self, key: str, mapping: dict[str, str]) -> int: with self.lock(): @@ -469,6 +477,12 @@ def add_cmd(self, cb: Callable[[], Any]) -> None: """ self._cmd_queue.append(cb) + def exists(self, *keys: str) -> None: + self.add_cmd(lambda: self._sm.exists(*keys)) + + def delete(self, *keys: str) -> None: + self.add_cmd(lambda: self._sm.delete(*keys)) + def set( self, key: str, @@ -491,6 +505,9 @@ def set( def get(self, key: str) -> None: self.add_cmd(lambda: self._sm.get(key)) + def incrby(self, key: str, inc: float | int) -> None: + self.add_cmd(lambda: self._sm.incrby(key, inc)) + def lpush(self, key: str, *values: str) -> None: self.add_cmd(lambda: self._sm.lpush(key, *values)) @@ -509,6 +526,9 @@ def rpop( count: int | None = None) -> None: self.add_cmd(lambda: self._sm.rpop(key, count)) + def lrange(self, key: str, start: int, stop: int) -> None: + self.add_cmd(lambda: self._sm.lrange(key, start, stop)) + def llen(self, key: str) -> None: self.add_cmd(lambda: self._sm.llen(key)) @@ -533,18 +553,12 @@ def zpop_min( ) -> None: self.add_cmd(lambda: self._sm.zpop_min(key, count)) + def zrange(self, key: str, start: int, stop: int) -> None: + self.add_cmd(lambda: self._sm.zrange(key, start, stop)) + def zcard(self, key: str) -> None: self.add_cmd(lambda: self._sm.zcard(key)) - def incrby(self, key: str, inc: float | int) -> None: - self.add_cmd(lambda: self._sm.incrby(key, inc)) - - def exists(self, *keys: str) -> None: - self.add_cmd(lambda: self._sm.exists(*keys)) - - def delete(self, *keys: str) -> None: - self.add_cmd(lambda: self._sm.delete(*keys)) - def hset(self, key: str, mapping: dict[str, str]) -> None: self.add_cmd(lambda: self._sm.hset(key, mapping)) diff --git a/src/redipy/memory/state.py b/src/redipy/memory/state.py index ba5fa83..8367088 100644 --- a/src/redipy/memory/state.py +++ b/src/redipy/memory/state.py @@ -1,6 +1,7 @@ """This module handles the internal state of the memory runtime.""" import collections import datetime +import itertools import time from typing import Literal, overload @@ -585,6 +586,18 @@ def get_state(self) -> State: """ return self._state + def exists(self, *keys: str) -> int: + res = 0 + for key in keys: + if self._state.exists(key): + res += 1 + return res + + def delete(self, *keys: str) -> int: + res = self.exists(*keys) + self._state.delete(set(keys)) + return res + @overload def set( self, @@ -665,6 +678,17 @@ def get(self, key: str) -> str | None: value, _ = res return value + def incrby(self, key: str, inc: float | int) -> float: + res = self._state.get_value(key) + if res is None: + val = "0" + expire = None + else: + val, expire = res + num = float(val) + inc + self._state.set_value(key, to_number_str(num), expire) + return num + def lpush(self, key: str, *values: str) -> int: queue = self._state.get_queue(key) queue.extendleft(values) @@ -745,6 +769,24 @@ def rpop( self.delete(key) return res if res else None + def lrange(self, key: str, start: int, stop: int) -> list[str]: + queue = self._state.readonly_queue(key) + if queue is None: + return [] + if start >= len(queue): + return [] + if start < 0: + start = max(0, start + len(queue)) + if stop < 0: + stop += len(queue) + if stop < 0: + return [] + stop += 1 + queue.rotate(-start) + res = list(itertools.islice(queue, 0, stop - start, 1)) + queue.rotate(start) + return res + def llen(self, key: str) -> int: return self._state.queue_len(key) @@ -789,31 +831,19 @@ def zpop_min( remain -= 1 return res - def zcard(self, key: str) -> int: - return self._state.zorder_len(key) - - def incrby(self, key: str, inc: float | int) -> float: - res = self._state.get_value(key) - if res is None: - val = "0" - expire = None + def zrange(self, key: str, start: int, stop: int) -> list[str]: + zorder = self._state.readonly_zorder(key) + if zorder is None: + return [] + astop: int | None = stop + if astop == -1 or astop is None: # NOTE: mypy workaround + astop = None else: - val, expire = res - num = float(val) + inc - self._state.set_value(key, to_number_str(num), expire) - return num + astop += 1 + return zorder[start:astop] - def exists(self, *keys: str) -> int: - res = 0 - for key in keys: - if self._state.exists(key): - res += 1 - return res - - def delete(self, *keys: str) -> int: - res = self.exists(*keys) - self._state.delete(set(keys)) - return res + def zcard(self, key: str) -> int: + return self._state.zorder_len(key) def hset(self, key: str, mapping: dict[str, str]) -> int: obj = self._state.get_hash(key) diff --git a/src/redipy/plugin.py b/src/redipy/plugin.py index 707ed09..562cc44 100644 --- a/src/redipy/plugin.py +++ b/src/redipy/plugin.py @@ -119,6 +119,7 @@ class LuaRedisPatch(LuaPatch): """Patches a lua redis function call.""" def patch( self, + name: str, expr: CallObj, args: list[ExprObj], *, @@ -127,6 +128,8 @@ def patch( Applies the patch on the expression graph for the given redis call. Args: + name (str): The redis function name. + expr (CallObj): The function call. args (list[ExprObj]): The arguments of the function call. diff --git a/src/redipy/redis/conn.py b/src/redipy/redis/conn.py index 98b39ab..20fad31 100644 --- a/src/redipy/redis/conn.py +++ b/src/redipy/redis/conn.py @@ -248,6 +248,16 @@ def execute(self) -> list: for val, fixup in zip(res, fixes) ] + def exists(self, *keys: str) -> None: + self._pipe.exists(*( + self.with_prefix(key) for key in keys)) + self.add_fixup(int) + + def delete(self, *keys: str) -> None: + self._pipe.delete(*( + self.with_prefix(key) for key in keys)) + self.add_fixup(int) + def set( self, key: str, @@ -280,6 +290,10 @@ def get(self, key: str) -> None: self._pipe.get(self.with_prefix(key)) self.add_fixup(to_maybe_str) + def incrby(self, key: str, inc: float | int) -> None: + self._pipe.incrbyfloat(self.with_prefix(key), inc) + self.add_fixup(float) + def lpush(self, key: str, *values: str) -> None: self._pipe.lpush(self.with_prefix(key), *values) self.add_fixup(int) @@ -308,6 +322,10 @@ def rpop( else: self.add_fixup(to_list_str) + def lrange(self, key: str, start: int, stop: int) -> None: + self._pipe.lrange(self.with_prefix(key), start, stop) + self.add_fixup(to_list_str) + def llen(self, key: str) -> None: self._pipe.llen(self.with_prefix(key)) self.add_fixup(int) @@ -332,24 +350,14 @@ def zpop_min( self._pipe.zpopmin(self.with_prefix(key), count) self.add_fixup(normalize_values) + def zrange(self, key: str, start: int, stop: int) -> None: + self._pipe.zrange(self.with_prefix(key), start, stop) + self.add_fixup(to_list_str) + def zcard(self, key: str) -> None: self._pipe.zcard(self.with_prefix(key)) self.add_fixup(int) - def incrby(self, key: str, inc: float | int) -> None: - self._pipe.incrbyfloat(self.with_prefix(key), inc) - self.add_fixup(float) - - def exists(self, *keys: str) -> None: - self._pipe.exists(*( - self.with_prefix(key) for key in keys)) - self.add_fixup(int) - - def delete(self, *keys: str) -> None: - self._pipe.delete(*( - self.with_prefix(key) for key in keys)) - self.add_fixup(int) - def hset(self, key: str, mapping: dict[str, str]) -> None: self._pipe.hset(self.with_prefix(key), mapping=mapping) # type: ignore self.add_fixup(int) @@ -385,7 +393,7 @@ def hgetall(self, key: str) -> None: self._pipe.hgetall(self.with_prefix(key)) self.add_fixup(lambda res: { to_maybe_str(field): to_maybe_str(val) - for field, val in res + for field, val in res.items() }) @@ -680,6 +688,16 @@ def prefix_exists( if count < 1000: count = int(min(1000, count * 1.2)) + def exists(self, *keys: str) -> int: + with self.get_connection() as conn: + return conn.exists(*( + self.with_prefix(key) for key in keys)) + + def delete(self, *keys: str) -> int: + with self.get_connection() as conn: + return conn.delete(*( + self.with_prefix(key) for key in keys)) + @overload def set( self, @@ -754,6 +772,10 @@ def get(self, key: str) -> str | None: with self.get_connection() as conn: return to_maybe_str(conn.get(self.with_prefix(key))) + def incrby(self, key: str, inc: float | int) -> float: + with self.get_connection() as conn: + return conn.incrbyfloat(self.with_prefix(key), inc) + def lpush(self, key: str, *values: str) -> int: with self.get_connection() as conn: return conn.lpush(self.with_prefix(key), *values) @@ -810,6 +832,10 @@ def rpop( return to_maybe_str(res) return to_list_str(res) + def lrange(self, key: str, start: int, stop: int) -> list[str]: + with self.get_connection() as conn: + return to_list_str(conn.lrange(self.with_prefix(key), start, stop)) + def llen(self, key: str) -> int: with self.get_connection() as conn: return conn.llen(self.with_prefix(key)) @@ -843,23 +869,14 @@ def zpop_min( for name, score in res ] - def zcard(self, key: str) -> int: - with self.get_connection() as conn: - return int(conn.zcard(self.with_prefix(key))) - - def incrby(self, key: str, inc: float | int) -> float: - with self.get_connection() as conn: - return conn.incrbyfloat(self.with_prefix(key), inc) - - def exists(self, *keys: str) -> int: + def zrange(self, key: str, start: int, stop: int) -> list[str]: with self.get_connection() as conn: - return conn.exists(*( - self.with_prefix(key) for key in keys)) + res = conn.zrange(self.with_prefix(key), start, stop) + return to_list_str(res) - def delete(self, *keys: str) -> int: + def zcard(self, key: str) -> int: with self.get_connection() as conn: - return conn.delete(*( - self.with_prefix(key) for key in keys)) + return int(conn.zcard(self.with_prefix(key))) def hset(self, key: str, mapping: dict[str, str]) -> int: with self.get_connection() as conn: diff --git a/src/redipy/redis/lua.py b/src/redipy/redis/lua.py index 81c7fe9..80922e2 100644 --- a/src/redipy/redis/lua.py +++ b/src/redipy/redis/lua.py @@ -164,7 +164,7 @@ def adjust_redis_fn( patch_fn = self._redis_patch_fns.get(name) if patch_fn is None: return expr - return patch_fn.patch(expr, args, is_expr_stmt=is_expr_stmt) + return patch_fn.patch(name, expr, args, is_expr_stmt=is_expr_stmt) def indent_str(code: Iterable[str], add_indent: int) -> list[str]: diff --git a/src/redipy/redis/rpatch.py b/src/redipy/redis/rpatch.py index afef59d..6223a07 100644 --- a/src/redipy/redis/rpatch.py +++ b/src/redipy/redis/rpatch.py @@ -1,11 +1,5 @@ """Module for patching lua redis function calls.""" -from redipy.graph.expr import ( - CallObj, - ExprObj, - find_literal, - get_literal, - LiteralValObj, -) +from redipy.graph.expr import CallObj, ExprObj, find_literal, LiteralValObj from redipy.plugin import LuaRedisPatch @@ -17,6 +11,7 @@ def names() -> set[str]: def patch( self, + name: str, expr: CallObj, args: list[ExprObj], *, @@ -47,12 +42,16 @@ def names() -> set[str]: def patch( self, + name: str, expr: CallObj, args: list[ExprObj], *, is_expr_stmt: bool) -> ExprObj: if is_expr_stmt: return expr + # check if 2nd argument (count) exists for lpop or rpop + if len(args) > 1 and name in ["lpop", "rpop"]: + return expr return { "kind": "binary", "op": "or", @@ -73,6 +72,7 @@ def names() -> set[str]: def patch( self, + name: str, expr: CallObj, args: list[ExprObj], *, @@ -96,11 +96,12 @@ def names() -> set[str]: def patch( self, + name: str, expr: CallObj, args: list[ExprObj], *, is_expr_stmt: bool) -> ExprObj: - name = f"{get_literal(expr['args'][0], 'str')}float" + name = f"{name}float" literal: LiteralValObj = { "kind": "val", "type": "str", @@ -125,6 +126,7 @@ def names() -> set[str]: def patch( self, + name: str, expr: CallObj, args: list[ExprObj], *, @@ -147,6 +149,7 @@ def names() -> set[str]: def patch( self, + name: str, expr: CallObj, args: list[ExprObj], *, diff --git a/src/redipy/symbolic/rlist.py b/src/redipy/symbolic/rlist.py index 81ba858..c17ee7b 100644 --- a/src/redipy/symbolic/rlist.py +++ b/src/redipy/symbolic/rlist.py @@ -67,6 +67,22 @@ def rpop( return self.redis_fn("rpop", no_adjust=no_adjust) return self.redis_fn("rpop", count, no_adjust=no_adjust) + def lrange(self, start: MixedType, stop: MixedType) -> Expr: + """ + Returns a number of values from the list specified by the given range. + Negative numbers are interpreted as index from the back of the list. + Out of range indices are ignored, potentially returning an empty list. + + Args: + start (MixedType): The start index. + + stop (MixedType): The stop index (inclusive). + + Returns: + Expr: The expression. + """ + return self.redis_fn("lrange", start, stop) + def llen(self) -> Expr: """ The length of the list. @@ -75,5 +91,3 @@ def llen(self) -> Expr: Expr: The expression. """ return self.redis_fn("llen") - - # FIXME implement lrange diff --git a/src/redipy/symbolic/rvar.py b/src/redipy/symbolic/rvar.py index 87ba943..0ca56ea 100644 --- a/src/redipy/symbolic/rvar.py +++ b/src/redipy/symbolic/rvar.py @@ -50,17 +50,26 @@ def set( args.append("KEEPTTL") return self.redis_fn("set", value, *args) - def get(self, *, no_adjust: bool = False) -> Expr: + def get( + self, + *, + default: MixedType = None, + no_adjust: bool = False) -> Expr: """ Returns the value. Args: + default (MixedType, optional): The default value to return if + the key does not exist. + no_adjust (bool, optional): Whether to prevent patching the function call. This should not be neccessary. Defaults to False. Returns: Expr: The expression. """ + if default is not None: + return self.redis_fn("get", no_adjust=True).or_(default) return self.redis_fn("get", no_adjust=no_adjust) def incrby(self, inc: MixedType) -> Expr: diff --git a/src/redipy/symbolic/rzset.py b/src/redipy/symbolic/rzset.py index c154d0b..c728327 100644 --- a/src/redipy/symbolic/rzset.py +++ b/src/redipy/symbolic/rzset.py @@ -49,6 +49,21 @@ def pop_min(self, count: MixedType = None) -> Expr: return self.redis_fn("zpopmin") return self.redis_fn("zpopmin", count) + def range(self, start: MixedType, stop: MixedType) -> Expr: + """ + Returns a range of member names. + + Args: + start (MixedType): The start index. + + stop (MixedType): The stop index (inclusive). + + Returns: + Expr: The expression. + """ + # FIXME add all arguments + return self.redis_fn("zrange", start, stop) + def card(self) -> Expr: """ Computes the cardinality of the sorted set. diff --git a/test/test_api.py b/test/test_api.py index 6e6e932..8364ec8 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -94,16 +94,16 @@ def lua_patch_id(res: JSONType) -> JSONType: assert setup(key) == output_setup result = normal(key) - assert teardown(key) == output_teardown assert result == output + assert teardown(key) == output_teardown with redis.pipeline() as pipe: assert setup_pipe(pipe, key) is None assert pipeline(pipe, key) is None setup_result, result = pipe.execute() - assert teardown(key) == output_teardown assert setup_result == output_setup assert result == output + assert teardown(key) == output_teardown ctx = FnContext() key_var = ctx.add_key("key") @@ -125,27 +125,27 @@ def lua_patch_id(res: JSONType) -> JSONType: assert setup(key) == output_setup result = lua_patch(fun(keys={"key": key}, args={})) - assert teardown(key) == output_teardown assert result == output + assert teardown(key) == output_teardown assert setup(key) == output_setup result = lua_patch(fun(keys={"key": key}, args={}, client=redis)) - assert teardown(key) == output_teardown assert result == output + assert teardown(key) == output_teardown assert setup(key) == output_setup result = lua_patch( fun(keys={"key": key}, args={}, client=redis.get_runtime())) - assert teardown(key) == output_teardown assert result == output + assert teardown(key) == output_teardown with redis.pipeline() as pipe: assert setup_pipe(pipe, key) is None assert fun(keys={"key": key}, args={}, client=pipe) is None setup_result, result = pipe.execute() - assert teardown(key) == output_teardown assert setup_result == output_setup assert lua_patch(result) == output + assert teardown(key) == output_teardown check( "exists", @@ -161,7 +161,20 @@ def lua_patch_id(res: JSONType) -> JSONType: output_teardown=["a", 1]) check( - "lpop", + "incrby", + setup=lambda key: redis.set(key, "0.25"), + normal=lambda key: redis.incrby(key, 0.5), + setup_pipe=lambda pipe, key: pipe.set(key, "0.25"), + pipeline=lambda pipe, key: pipe.incrby(key, 0.5), + lua=lambda ctx, key: RedisVar(key).incrby(0.5), + code="tonumber(redis.call(\"incrbyfloat\", key_0, 0.5))", + teardown=lambda key: [redis.get(key), redis.delete(key)], + output_setup=True, + output=0.75, + output_teardown=["0.75", 1]) + + check( + "lpop_0", setup=lambda key: redis.lpush(key, "a"), normal=lambda key: redis.lpop(key), setup_pipe=lambda pipe, key: pipe.lpush(key, "a"), @@ -175,20 +188,48 @@ def lua_patch_id(res: JSONType) -> JSONType: output_teardown=0) check( - "rpop", + "lpop_1", + setup=lambda key: redis.lpush(key, "a", "b", "c"), + normal=lambda key: redis.lpop(key, 2), + setup_pipe=lambda pipe, key: pipe.lpush(key, "a", "b", "c"), + pipeline=lambda pipe, key: pipe.lpop(key, 2), + lua=lambda ctx, key: RedisList(key).lpop(2), + code="redis.call(\"lpop\", key_0, 2)", + teardown=lambda key: [ + redis.llen(key), redis.delete(key), redis.exists(key)], + output_setup=3, + output=["c", "b"], + output_teardown=[1, 1, 0]) + + check( + "rpop_0", setup=lambda key: redis.rpush(key, "a"), normal=lambda key: redis.rpop(key), setup_pipe=lambda pipe, key: pipe.rpush(key, "a"), pipeline=lambda pipe, key: pipe.rpop(key), lua=lambda ctx, key: RedisList(key).rpop(), code="(redis.call(\"rpop\", key_0) or nil)", - teardown=lambda key: redis.llen(key), + teardown=lambda key: redis.exists(key), output_setup=1, output="a", output_teardown=0) check( - "zpopmax", + "rpop_1", + setup=lambda key: redis.rpush(key, "a", "b", "c"), + normal=lambda key: redis.rpop(key, 2), + setup_pipe=lambda pipe, key: pipe.rpush(key, "a", "b", "c"), + pipeline=lambda pipe, key: pipe.rpop(key, 2), + lua=lambda ctx, key: RedisList(key).rpop(2), + code="redis.call(\"rpop\", key_0, 2)", + teardown=lambda key: [ + redis.llen(key), redis.delete(key), redis.exists(key)], + output_setup=3, + output=["c", "b"], + output_teardown=[1, 1, 0]) + + check( + "zpopmax_0", setup=lambda key: redis.zadd(key, {"a": 0.25, "b": 0.5, "c": 0.75}), normal=lambda key: redis.zpop_max(key, 2), setup_pipe=lambda pipe, key: pipe.zadd( @@ -204,7 +245,23 @@ def lua_patch_id(res: JSONType) -> JSONType: output_teardown=[1, True, 0]) check( - "zpopmin", + "zpopmax_1", + setup=lambda key: redis.zadd(key, {"a": 0.25, "b": 0.5, "c": 0.75}), + normal=lambda key: redis.zpop_max(key), + setup_pipe=lambda pipe, key: pipe.zadd( + key, {"a": 0.25, "b": 0.5, "c": 0.75}), + pipeline=lambda pipe, key: pipe.zpop_max(key), + lua=lambda ctx, key: RedisSortedSet(key).pop_max(), + code="redipy.pairlist_scores(redis.call(\"zpopmax\", key_0))", + lua_patch=lambda res: [tuple(elem) for elem in cast(list, res)], + teardown=lambda key: [ + redis.zcard(key), redis.delete(key), redis.zcard(key)], + output_setup=3, + output=[("c", 0.75)], + output_teardown=[2, True, 0]) + + check( + "zpopmin_0", setup=lambda key: redis.zadd(key, {"a": 0.25, "b": 0.5, "c": 0.75}), normal=lambda key: redis.zpop_min(key, 2), setup_pipe=lambda pipe, key: pipe.zadd( @@ -219,6 +276,50 @@ def lua_patch_id(res: JSONType) -> JSONType: output=[("a", 0.25), ("b", 0.5)], output_teardown=[1, True, 0]) + check( + "zpopmin_1", + setup=lambda key: redis.zadd(key, {"a": 0.25, "b": 0.5, "c": 0.75}), + normal=lambda key: redis.zpop_min(key), + setup_pipe=lambda pipe, key: pipe.zadd( + key, {"a": 0.25, "b": 0.5, "c": 0.75}), + pipeline=lambda pipe, key: pipe.zpop_min(key), + lua=lambda ctx, key: RedisSortedSet(key).pop_min(), + code="redipy.pairlist_scores(redis.call(\"zpopmin\", key_0))", + lua_patch=lambda res: [tuple(elem) for elem in cast(list, res)], + teardown=lambda key: [ + redis.zcard(key), redis.delete(key), redis.zcard(key)], + output_setup=3, + output=[("a", 0.25)], + output_teardown=[2, True, 0]) + + check( + "zrange_0", + setup=lambda key: redis.zadd(key, {"a": 0.25, "b": 0.5, "c": 0.75}), + normal=lambda key: redis.zrange(key, 1, 2), + setup_pipe=lambda pipe, key: pipe.zadd( + key, {"a": 0.25, "b": 0.5, "c": 0.75}), + pipeline=lambda pipe, key: pipe.zrange(key, 1, 2), + lua=lambda ctx, key: RedisSortedSet(key).range(1, 2), + code="redis.call(\"zrange\", key_0, 1, 2)", + teardown=lambda key: [redis.delete(key), redis.zcard(key)], + output_setup=3, + output=["b", "c"], + output_teardown=[True, 0]) + + check( + "zrange_1", + setup=lambda key: redis.zadd(key, {"a": 0.25, "b": 0.5, "c": 0.75}), + normal=lambda key: redis.zrange(key, 0, -2), + setup_pipe=lambda pipe, key: pipe.zadd( + key, {"a": 0.25, "b": 0.5, "c": 0.75}), + pipeline=lambda pipe, key: pipe.zrange(key, 0, -2), + lua=lambda ctx, key: RedisSortedSet(key).range(0, -2), + code="redis.call(\"zrange\", key_0, 0, -2)", + teardown=lambda key: [redis.delete(key), redis.zcard(key)], + output_setup=3, + output=["a", "b"], + output_teardown=[True, 0]) + check( "hget", setup=lambda key: redis.hset(key, {"a": "0", "b": "1", "c": "2"}), @@ -278,6 +379,20 @@ def lua_patch_id(res: JSONType) -> JSONType: output=6, output_teardown=[{"d": "3", "e": "6", "f": "5"}, 1, 0]) + check( + "hdel", + setup=lambda key: redis.hset(key, {"d": "3", "e": "4", "f": "5"}), + normal=lambda key: redis.hdel(key, "c", "d", "e"), + setup_pipe=lambda pipe, key: pipe.hset( + key, {"d": "3", "e": "4", "f": "5"}), + pipeline=lambda pipe, key: pipe.hdel(key, "c", "d", "e"), + lua=lambda ctx, key: RedisHash(key).hdel("c", "d", "e"), + code="redis.call(\"hdel\", key_0, \"c\", \"d\", \"e\")", + teardown=lambda key: [redis.hgetall(key), redis.delete(key)], + output_setup=3, + output=2, + output_teardown=[{"f": "5"}, 1]) + check( "hkeys", setup=lambda key: redis.hset(key, {"d": "3", "e": "4", "f": "5"}), @@ -305,3 +420,69 @@ def lua_patch_id(res: JSONType) -> JSONType: output_setup=3, output=["3", "4", "5"], output_teardown=1) + + check( + "hgetall", + setup=lambda key: redis.hset(key, {"d": "3", "e": "4", "f": "5"}), + normal=lambda key: redis.hgetall(key), + setup_pipe=lambda pipe, key: pipe.hset( + key, {"d": "3", "e": "4", "f": "5"}), + pipeline=lambda pipe, key: pipe.hgetall(key), + lua=lambda ctx, key: RedisHash(key).hgetall(), + code="redipy.pairlist_dict(redis.call(\"hgetall\", key_0))", + teardown=lambda key: redis.delete(key), + output_setup=3, + output={"d": "3", "e": "4", "f": "5"}, + output_teardown=1) + + check( + "lrange_0", + setup=lambda key: redis.rpush(key, "a", "b", "c"), + normal=lambda key: redis.lrange(key, 0, 0), + setup_pipe=lambda pipe, key: pipe.rpush(key, "a", "b", "c"), + pipeline=lambda pipe, key: pipe.lrange(key, 0, 0), + lua=lambda ctx, key: RedisList(key).lrange(0, 0), + code="redis.call(\"lrange\", key_0, 0, 0)", + teardown=lambda key: redis.delete(key), + output_setup=3, + output=["a"], + output_teardown=1) + + check( + "lrange_1", + setup=lambda key: redis.rpush(key, "a", "b", "c"), + normal=lambda key: redis.lrange(key, -3, 2), + setup_pipe=lambda pipe, key: pipe.rpush(key, "a", "b", "c"), + pipeline=lambda pipe, key: pipe.lrange(key, -3, 2), + lua=lambda ctx, key: RedisList(key).lrange(-3, 2), + code="redis.call(\"lrange\", key_0, -3, 2)", + teardown=lambda key: redis.delete(key), + output_setup=3, + output=["a", "b", "c"], + output_teardown=1) + + check( + "lrange_2", + setup=lambda key: redis.rpush(key, "a", "b", "c", "d", "e"), + normal=lambda key: redis.lrange(key, 1, -2), + setup_pipe=lambda pipe, key: pipe.rpush(key, "a", "b", "c", "d", "e"), + pipeline=lambda pipe, key: pipe.lrange(key, 1, -2), + lua=lambda ctx, key: RedisList(key).lrange(1, -2), + code="redis.call(\"lrange\", key_0, 1, -2)", + teardown=lambda key: redis.delete(key), + output_setup=5, + output=["b", "c", "d"], + output_teardown=1) + + check( + "lrange_3", + setup=lambda key: redis.rpush(key, "a", "b", "c", "d", "e"), + normal=lambda key: redis.lrange(key, -100, 100), + setup_pipe=lambda pipe, key: pipe.rpush(key, "a", "b", "c", "d", "e"), + pipeline=lambda pipe, key: pipe.lrange(key, -100, 100), + lua=lambda ctx, key: RedisList(key).lrange(-100, 100), + code="redis.call(\"lrange\", key_0, -100, 100)", + teardown=lambda key: redis.delete(key), + output_setup=5, + output=["a", "b", "c", "d", "e"], + output_teardown=1) diff --git a/test/test_stack.py b/test/test_stack.py index a255f37..77ffece 100644 --- a/test/test_stack.py +++ b/test/test_stack.py @@ -17,7 +17,7 @@ from redipy.util import code_fmt, lua_fmt -GET_KEY_0 = "redis.call(\"get\", key_0)" +GET_KEY_0 = "(redis.call(\"get\", key_0) or 0)" RET = "return cjson.encode" RC = "redis.call" RP = "redipy.asintstr" @@ -28,7 +28,7 @@ LUA_SET_VALUE = f""" -- HELPERS START -- -local redipy = {{}} +local redipy = {EMPTY_OBJ} function redipy.asintstr (val) return math.floor(val) end @@ -98,16 +98,20 @@ ]] local key_0 = (KEYS[1]) -- size local key_1 = (KEYS[2]) -- frame -local var_0 = {PLD}({RC}("hgetall", {KEY_1_P} .. ({RP}(({GET_KEY_0} or nil))))) -redis.call("del", {KEY_1_P} .. ({RP}(({GET_KEY_0} or nil)))) -redis.call("incrbyfloat", key_0, -1) +local var_0 = {PLD}({RC}("hgetall", {KEY_1_P} .. ({RP}({GET_KEY_0})))) +redis.call("del", {KEY_1_P} .. ({RP}({GET_KEY_0}))) +if (tonumber({GET_KEY_0}) > 0) then + redis.call("incrbyfloat", key_0, -1) +else + redis.call("del", key_0) +end return cjson.encode(var_0) """ -LUA_GET_CASCADING = """ +LUA_GET_CASCADING = f""" -- HELPERS START -- -local redipy = {} +local redipy = {EMPTY_OBJ} function redipy.asintstr (val) return math.floor(val) end @@ -123,7 +127,7 @@ local key_1 = (KEYS[2]) -- frame local var_0 = key_1 local arg_0 = cjson.decode(ARGV[1]) -- field -local var_1 = tonumber((redis.call("get", key_0) or nil)) +local var_1 = tonumber({GET_KEY_0}) local var_2 = nil local var_3 = nil while ((var_2 == nil) and (var_1 >= 0)) do @@ -167,15 +171,6 @@ def key(self, base: str, name: str) -> str: """ return f"{base}:{name}" - def init(self, base: str) -> None: - """ - Initializes the stack. - - Args: - base (str): The base key. - """ - self._rt.set(self.key(base, "size"), "0") - def push_frame(self, base: str) -> None: """ Pushes a new stack frame. @@ -185,7 +180,7 @@ def push_frame(self, base: str) -> None: """ self._rt.incrby(self.key(base, "size"), 1) - def pop_frame(self, base: str) -> dict[str, str] | None: + def pop_frame(self, base: str) -> dict[str, str]: """ Pops the current stack frame and returns its values. @@ -202,7 +197,7 @@ def pop_frame(self, base: str) -> dict[str, str] | None: }, args={}) if res is None: - return None + return {} return cast(dict, res) def set_value(self, base: str, field: str, value: str) -> None: @@ -269,7 +264,7 @@ def _set_value_script(self) -> ExecFunction: rframe = RedisHash(Strs( ctx.add_key("frame"), ":", - ToIntStr(rsize.get(no_adjust=True)))) + ToIntStr(rsize.get(default=0)))) field = ctx.add_arg("field") value = ctx.add_arg("value") ctx.add(rframe.hset({ @@ -284,7 +279,7 @@ def _get_value_script(self) -> ExecFunction: rframe = RedisHash(Strs( ctx.add_key("frame"), ":", - ToIntStr(rsize.get(no_adjust=True)))) + ToIntStr(rsize.get(default=0)))) field = ctx.add_arg("field") ctx.set_return_value(rframe.hget(field)) return self._rt.register_script(ctx) @@ -293,10 +288,14 @@ def _pop_frame_script(self) -> ExecFunction: ctx = FnContext() rsize = RedisVar(ctx.add_key("size")) rframe = RedisHash( - Strs(ctx.add_key("frame"), ":", ToIntStr(rsize.get()))) + Strs(ctx.add_key("frame"), ":", ToIntStr(rsize.get(default=0)))) lcl = ctx.add_local(rframe.hgetall()) ctx.add(rframe.delete()) - ctx.add(rsize.incrby(-1)) + + b_then, b_else = ctx.if_(ToNum(rsize.get(default=0)).gt_(0)) + b_then.add(rsize.incrby(-1)) + b_else.add(rsize.delete()) + ctx.set_return_value(lcl) return self._rt.register_script(ctx) @@ -305,7 +304,7 @@ def _get_cascading_script(self) -> ExecFunction: rsize = RedisVar(ctx.add_key("size")) base = ctx.add_local(ctx.add_key("frame")) field = ctx.add_arg("field") - pos = ctx.add_local(ToNum(rsize.get())) + pos = ctx.add_local(ToNum(rsize.get(default=0))) res = ctx.add_local(None) cur = ctx.add_local(None) rframe = RedisHash(cur) @@ -359,9 +358,6 @@ def code_hook(code: list[str]) -> None: lua_code_hook=code_hook) stack = RStack(redis, set_lua_script) - stack.init("foo") - stack.init("bar") - stack.set_value("foo", "a", "hi") assert stack.get_value("bar", "a") is None assert stack.get_value("foo", "a") == "hi" @@ -495,3 +491,43 @@ def code_hook(code: list[str]) -> None: assert stack.get_cascading("bar", "b") == "bar" assert stack.get_cascading("bar", "c") is None assert stack.get_cascading("bar", "d") is None + + assert redis.exists("foo:size", "foo:frame:0") == 2 + assert redis.exists("foo:frame:1") == 0 + assert redis.exists("bar:size", "bar:frame:0") == 2 + assert redis.exists("bar:frame:1") == 0 + + assert stack.pop_frame("foo") == { + "a": "hi", + "b": "foo", + } + assert stack.pop_frame("bar") == { + "a": "bye", + "b": "bar", + } + + assert stack.get_cascading("foo", "a") is None + assert stack.get_cascading("foo", "b") is None + assert stack.get_cascading("foo", "c") is None + assert stack.get_cascading("foo", "d") is None + + assert stack.get_cascading("bar", "a") is None + assert stack.get_cascading("bar", "b") is None + assert stack.get_cascading("bar", "c") is None + assert stack.get_cascading("bar", "d") is None + + assert stack.pop_frame("foo") == {} + assert stack.pop_frame("bar") == {} + + assert stack.get_cascading("foo", "a") is None + assert stack.get_cascading("foo", "b") is None + assert stack.get_cascading("foo", "c") is None + assert stack.get_cascading("foo", "d") is None + + assert stack.get_cascading("bar", "a") is None + assert stack.get_cascading("bar", "b") is None + assert stack.get_cascading("bar", "c") is None + assert stack.get_cascading("bar", "d") is None + + assert redis.exists("bar:size", "bar:frame:0") == 0 + assert redis.exists("bar:frame:1") == 0