Skip to content

Commit

Permalink
Merge pull request redis-rs#278 from barshaul/transaction-py
Browse files Browse the repository at this point in the history
python - transaction
  • Loading branch information
shachlanAmazon authored Jun 18, 2023
2 parents 184de57 + 16fcf2b commit 4c90e6a
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 4 deletions.
63 changes: 63 additions & 0 deletions python/python/pybushka/async_commands/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from datetime import datetime, timedelta
from enum import Enum
from typing import List, Type, Union, get_args
Expand Down Expand Up @@ -100,6 +101,60 @@ def set_direct(self, key, value):
return self.connection.set(key, value)


class Transaction:
def __init__(self) -> None:
self.commands = []
self.lock = threading.Lock()

def execute_command(self, request_type: RequestType, args: List[str]):
self.lock.acquire()
try:
self.commands.append([request_type, args])
finally:
self.lock.release()

def get(self, key: str):
self.execute_command(RequestType.GetString, [key])

def set(
self,
key: str,
value: str,
conditional_set: Union[ConditionalSet, None] = None,
expiry: Union[ExpirySet, None] = None,
return_old_value: bool = False,
):
args = [key, value]
if conditional_set:
if conditional_set == ConditionalSet.ONLY_IF_EXISTS:
args.append("XX")
if conditional_set == ConditionalSet.ONLY_IF_DOES_NOT_EXIST:
args.append("NX")
if return_old_value:
args.append("GET")
if expiry is not None:
args.extend(expiry.get_cmd_args())
self.execute_command(RequestType.SetString, args)

def custom_command(self, command_args: List[str]):
"""Executes a single command, without checking inputs.
@example - Return a list of all pub/sub clients:
connection.customCommand(["CLIENT", "LIST","TYPE", "PUBSUB"])
Args:
command_args (List[str]): List of strings of the command's arguements.
Every part of the command, including the command name and subcommands, should be added as a separate value in args.
Returns:
TResult: The returning value depends on the executed command
"""
self.execute_command(RequestType.CustomCommand, command_args)

def dispose(self):
with self.lock:
self.commands.clear()


class CoreCommands:
async def set(
self,
Expand Down Expand Up @@ -155,6 +210,14 @@ async def get(self, key: str) -> Union[str, None]:
"""
return await self.execute_command(RequestType.GetString, [key])

async def multi(self) -> Transaction:
return Transaction()

async def exec(self, transaction: Transaction) -> List[Union[str, None, TResult]]:
commands = transaction.commands[:]
transaction.dispose()
return await self.execute_command_Transaction(commands)

async def custom_command(self, command_args: List[str]) -> TResult:
"""Executes a single command, without checking inputs.
@example - Return a list of all pub/sub clients:
Expand Down
23 changes: 21 additions & 2 deletions python/python/pybushka/async_socket_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import threading
from typing import Awaitable, List, Optional, Type
from typing import Awaitable, List, Optional, Type, Union

import async_timeout
from google.protobuf.internal.decoder import _DecodeVarint32
Expand All @@ -18,7 +18,7 @@
)
from pybushka.Logger import Level as LogLevel
from pybushka.Logger import Logger
from pybushka.protobuf.redis_request_pb2 import RedisRequest
from pybushka.protobuf.redis_request_pb2 import Command, RedisRequest
from pybushka.protobuf.response_pb2 import Response
from typing_extensions import Self

Expand Down Expand Up @@ -156,6 +156,25 @@ async def execute_command(
await response_future
return response_future.result()

async def execute_command_Transaction(
self, commands: List[Union[TRequestType, List[str]]]
) -> TResult:
request = RedisRequest()
request.callback_idx = self._get_callback_index()
transaction_commands = []
for requst_type, args in commands:
command = Command()
command.request_type = requst_type
command.args_array.args[:] = args
transaction_commands.append(command)
request.transaction.commands.extend(transaction_commands)
# Create a response future for this request and add it to the available
# futures map
response_future = self._get_future(request.callback_idx)
await self._write_or_buffer_request(request)
await response_future
return response_future.result()

def _get_callback_index(self) -> int:
if not self._available_callbackIndexes:
# Set is empty
Expand Down
96 changes: 94 additions & 2 deletions python/python/tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def parse_info_response(res: str) -> Dict[str, str]:


def get_random_string(length):
letters = string.ascii_letters + string.digits + string.punctuation
result_str = "".join(random.choice(letters) for i in range(length))
result_str = "".join(random.choice(string.ascii_letters) for i in range(length))
return result_str


Expand Down Expand Up @@ -197,6 +196,99 @@ async def test_request_error_raises_exception(
assert "WRONGTYPE" in str(e)


@pytest.mark.asyncio
class TestTransaction:
@pytest.mark.parametrize("cluster_mode", [True, False])
async def test_transaction_set_get(
self, async_socket_client: RedisAsyncSocketClient
):
key = get_random_string(10)
value = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
transaction = await async_socket_client.multi()
transaction.set(key, value)
transaction.get(key)
result = await async_socket_client.exec(transaction)
assert result == [OK, value]

@pytest.mark.parametrize("cluster_mode", [True, False])
async def test_transaction_multiple_commands(
self, async_socket_client: RedisAsyncSocketClient
):
key = get_random_string(10)
key2 = "{{{}}}:{}".format(key, get_random_string(3)) # to get the same slot
value = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
transaction = await async_socket_client.multi()
transaction.set(key, value)
transaction.get(key)
transaction.set(key2, value)
transaction.get(key2)
result = await async_socket_client.exec(transaction)
assert result == [OK, value, OK, value]

@pytest.mark.parametrize("cluster_mode", [True])
async def test_transaction_with_different_slots(
self, async_socket_client: RedisAsyncSocketClient
):
transaction = await async_socket_client.multi()
transaction.set("key1", "value1")
transaction.set("key2", "value2")
with pytest.raises(Exception) as e:
await async_socket_client.exec(transaction)
assert "Moved" in str(e)

@pytest.mark.parametrize("cluster_mode", [True, False])
async def test_transaction_custom_command(
self, async_socket_client: RedisAsyncSocketClient
):
key = get_random_string(10)
transaction = await async_socket_client.multi()
transaction.custom_command(["HSET", key, "foo", "bar"])
transaction.custom_command(["HGET", key, "foo"])
result = await async_socket_client.exec(transaction)
assert result == [1, "bar"]

@pytest.mark.parametrize("cluster_mode", [True, False])
async def test_transaction_custom_unsupported_command(
self, async_socket_client: RedisAsyncSocketClient
):
key = get_random_string(10)
transaction = await async_socket_client.multi()
transaction.custom_command(["WATCH", key])
with pytest.raises(Exception) as e:
await async_socket_client.exec(transaction)
assert "WATCH inside MULTI is not allowed" in str(
e
) # TODO : add an assert on EXEC ABORT

@pytest.mark.parametrize("cluster_mode", [True, False])
async def test_transaction_discard_command(
self, async_socket_client: RedisAsyncSocketClient
):
key = get_random_string(10)
await async_socket_client.set(key, "1")
transaction = await async_socket_client.multi()
transaction.custom_command(["INCR", key])
transaction.custom_command(["DISCARD"])
with pytest.raises(Exception) as e:
await async_socket_client.exec(transaction)
assert "EXEC without MULTI" in str(e) # TODO : add an assert on EXEC ABORT
value = await async_socket_client.get(key)
assert value == "1"

@pytest.mark.parametrize("cluster_mode", [True, False])
async def test_transaction_exec_abort(
self, async_socket_client: RedisAsyncSocketClient
):
key = get_random_string(10)
transaction = await async_socket_client.multi()
transaction.custom_command(["INCR", key, key, key])
with pytest.raises(Exception) as e:
await async_socket_client.exec(transaction)
assert "wrong number of arguments" in str(
e
) # TODO : add an assert on EXEC ABORT


class CommandsUnitTests:
def test_expiry_cmd_args(self):
exp_sec = ExpirySet(ExpiryType.SEC, 5)
Expand Down

0 comments on commit 4c90e6a

Please sign in to comment.