From 16fcf2b99c255fa05937f49a8714b346b9e9c9af Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 11 Jun 2023 07:38:59 +0000 Subject: [PATCH] Transaction - Python change get_random_string --- python/python/pybushka/async_commands/core.py | 63 ++++++++++++ python/python/pybushka/async_socket_client.py | 23 ++++- python/python/tests/test_async_client.py | 96 ++++++++++++++++++- 3 files changed, 178 insertions(+), 4 deletions(-) diff --git a/python/python/pybushka/async_commands/core.py b/python/python/pybushka/async_commands/core.py index c3d12830c..d0fd12718 100644 --- a/python/python/pybushka/async_commands/core.py +++ b/python/python/pybushka/async_commands/core.py @@ -1,3 +1,4 @@ +import threading from datetime import datetime, timedelta from enum import Enum from typing import List, Type, Union, get_args @@ -100,6 +101,60 @@ def set_direct(self, key, value): return self.connection.set(key, value) +class Transaction: + def __init__(self) -> None: + self.commands = [] + self.lock = threading.Lock() + + def execute_command(self, request_type: RequestType, args: List[str]): + self.lock.acquire() + try: + self.commands.append([request_type, args]) + finally: + self.lock.release() + + def get(self, key: str): + self.execute_command(RequestType.GetString, [key]) + + def set( + self, + key: str, + value: str, + conditional_set: Union[ConditionalSet, None] = None, + expiry: Union[ExpirySet, None] = None, + return_old_value: bool = False, + ): + args = [key, value] + if conditional_set: + if conditional_set == ConditionalSet.ONLY_IF_EXISTS: + args.append("XX") + if conditional_set == ConditionalSet.ONLY_IF_DOES_NOT_EXIST: + args.append("NX") + if return_old_value: + args.append("GET") + if expiry is not None: + args.extend(expiry.get_cmd_args()) + self.execute_command(RequestType.SetString, args) + + def custom_command(self, command_args: List[str]): + """Executes a single command, without checking inputs. + @example - Return a list of all pub/sub clients: + + connection.customCommand(["CLIENT", "LIST","TYPE", "PUBSUB"]) + Args: + command_args (List[str]): List of strings of the command's arguements. + Every part of the command, including the command name and subcommands, should be added as a separate value in args. + + Returns: + TResult: The returning value depends on the executed command + """ + self.execute_command(RequestType.CustomCommand, command_args) + + def dispose(self): + with self.lock: + self.commands.clear() + + class CoreCommands: async def set( self, @@ -155,6 +210,14 @@ async def get(self, key: str) -> Union[str, None]: """ return await self.execute_command(RequestType.GetString, [key]) + async def multi(self) -> Transaction: + return Transaction() + + async def exec(self, transaction: Transaction) -> List[Union[str, None, TResult]]: + commands = transaction.commands[:] + transaction.dispose() + return await self.execute_command_Transaction(commands) + async def custom_command(self, command_args: List[str]) -> TResult: """Executes a single command, without checking inputs. @example - Return a list of all pub/sub clients: diff --git a/python/python/pybushka/async_socket_client.py b/python/python/pybushka/async_socket_client.py index a89d6ec4e..ebf1f5b2d 100644 --- a/python/python/pybushka/async_socket_client.py +++ b/python/python/pybushka/async_socket_client.py @@ -1,6 +1,6 @@ import asyncio import threading -from typing import Awaitable, List, Optional, Type +from typing import Awaitable, List, Optional, Type, Union import async_timeout from google.protobuf.internal.decoder import _DecodeVarint32 @@ -18,7 +18,7 @@ ) from pybushka.Logger import Level as LogLevel from pybushka.Logger import Logger -from pybushka.protobuf.redis_request_pb2 import RedisRequest +from pybushka.protobuf.redis_request_pb2 import Command, RedisRequest from pybushka.protobuf.response_pb2 import Response from typing_extensions import Self @@ -156,6 +156,25 @@ async def execute_command( await response_future return response_future.result() + async def execute_command_Transaction( + self, commands: List[Union[TRequestType, List[str]]] + ) -> TResult: + request = RedisRequest() + request.callback_idx = self._get_callback_index() + transaction_commands = [] + for requst_type, args in commands: + command = Command() + command.request_type = requst_type + command.args_array.args[:] = args + transaction_commands.append(command) + request.transaction.commands.extend(transaction_commands) + # Create a response future for this request and add it to the available + # futures map + response_future = self._get_future(request.callback_idx) + await self._write_or_buffer_request(request) + await response_future + return response_future.result() + def _get_callback_index(self) -> int: if not self._available_callbackIndexes: # Set is empty diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index d0b89b263..9eb7428a7 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -49,8 +49,7 @@ def parse_info_response(res: str) -> Dict[str, str]: def get_random_string(length): - letters = string.ascii_letters + string.digits + string.punctuation - result_str = "".join(random.choice(letters) for i in range(length)) + result_str = "".join(random.choice(string.ascii_letters) for i in range(length)) return result_str @@ -197,6 +196,99 @@ async def test_request_error_raises_exception( assert "WRONGTYPE" in str(e) +@pytest.mark.asyncio +class TestTransaction: + @pytest.mark.parametrize("cluster_mode", [True, False]) + async def test_transaction_set_get( + self, async_socket_client: RedisAsyncSocketClient + ): + key = get_random_string(10) + value = datetime.now().strftime("%m/%d/%Y, %H:%M:%S") + transaction = await async_socket_client.multi() + transaction.set(key, value) + transaction.get(key) + result = await async_socket_client.exec(transaction) + assert result == [OK, value] + + @pytest.mark.parametrize("cluster_mode", [True, False]) + async def test_transaction_multiple_commands( + self, async_socket_client: RedisAsyncSocketClient + ): + key = get_random_string(10) + key2 = "{{{}}}:{}".format(key, get_random_string(3)) # to get the same slot + value = datetime.now().strftime("%m/%d/%Y, %H:%M:%S") + transaction = await async_socket_client.multi() + transaction.set(key, value) + transaction.get(key) + transaction.set(key2, value) + transaction.get(key2) + result = await async_socket_client.exec(transaction) + assert result == [OK, value, OK, value] + + @pytest.mark.parametrize("cluster_mode", [True]) + async def test_transaction_with_different_slots( + self, async_socket_client: RedisAsyncSocketClient + ): + transaction = await async_socket_client.multi() + transaction.set("key1", "value1") + transaction.set("key2", "value2") + with pytest.raises(Exception) as e: + await async_socket_client.exec(transaction) + assert "Moved" in str(e) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + async def test_transaction_custom_command( + self, async_socket_client: RedisAsyncSocketClient + ): + key = get_random_string(10) + transaction = await async_socket_client.multi() + transaction.custom_command(["HSET", key, "foo", "bar"]) + transaction.custom_command(["HGET", key, "foo"]) + result = await async_socket_client.exec(transaction) + assert result == [1, "bar"] + + @pytest.mark.parametrize("cluster_mode", [True, False]) + async def test_transaction_custom_unsupported_command( + self, async_socket_client: RedisAsyncSocketClient + ): + key = get_random_string(10) + transaction = await async_socket_client.multi() + transaction.custom_command(["WATCH", key]) + with pytest.raises(Exception) as e: + await async_socket_client.exec(transaction) + assert "WATCH inside MULTI is not allowed" in str( + e + ) # TODO : add an assert on EXEC ABORT + + @pytest.mark.parametrize("cluster_mode", [True, False]) + async def test_transaction_discard_command( + self, async_socket_client: RedisAsyncSocketClient + ): + key = get_random_string(10) + await async_socket_client.set(key, "1") + transaction = await async_socket_client.multi() + transaction.custom_command(["INCR", key]) + transaction.custom_command(["DISCARD"]) + with pytest.raises(Exception) as e: + await async_socket_client.exec(transaction) + assert "EXEC without MULTI" in str(e) # TODO : add an assert on EXEC ABORT + value = await async_socket_client.get(key) + assert value == "1" + + @pytest.mark.parametrize("cluster_mode", [True, False]) + async def test_transaction_exec_abort( + self, async_socket_client: RedisAsyncSocketClient + ): + key = get_random_string(10) + transaction = await async_socket_client.multi() + transaction.custom_command(["INCR", key, key, key]) + with pytest.raises(Exception) as e: + await async_socket_client.exec(transaction) + assert "wrong number of arguments" in str( + e + ) # TODO : add an assert on EXEC ABORT + + class CommandsUnitTests: def test_expiry_cmd_args(self): exp_sec = ExpirySet(ExpiryType.SEC, 5)