diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 31bd02c36f..7f1c0b71e4 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -678,7 +678,7 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] output.append(SYM_EMPTY.join(pieces)) return output - def _is_socket_empty(self): + def _socket_is_empty(self): """Check if the socket is empty""" return not self._reader.at_eof() @@ -692,10 +692,10 @@ def _cache_invalidation_process( (if the list of keys is None, then all keys are invalidated) """ if data[1] is not None: + self.client_cache.flush() + else: for key in data[1]: self.client_cache.invalidate(str_if_bytes(key)) - else: - self.client_cache.flush() async def _get_from_local_cache(self, command: str): """ @@ -707,7 +707,7 @@ async def _get_from_local_cache(self, command: str): or command[0] not in self.cache_whitelist ): return None - while not self._is_socket_empty(): + while not self._socket_is_empty(): await self.read_response(push_request=True) return self.client_cache.get(command) diff --git a/redis/connection.py b/redis/connection.py index fb308cd92d..a09fb3949c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -235,6 +235,10 @@ def __init__( _cache = None self.client_cache = client_cache if client_cache is not None else _cache if self.client_cache is not None: + if self.protocol not in [3, "3"]: + raise RedisError( + "client caching is only supported with protocol version 3 or higher" + ) self.cache_blacklist = cache_blacklist self.cache_whitelist = cache_whitelist @@ -604,7 +608,7 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def _is_socket_empty(self): + def _socket_is_empty(self): """Check if the socket is empty""" r, _, _ = select.select([self._sock], [], [], 0) return not bool(r) @@ -618,11 +622,11 @@ def _cache_invalidation_process( and the second string is the list of keys to invalidate. (if the list of keys is None, then all keys are invalidated) """ - if data[1] is not None: + if data[1] is None: + self.client_cache.flush() + else: for key in data[1]: self.client_cache.invalidate(str_if_bytes(key)) - else: - self.client_cache.flush() def _get_from_local_cache(self, command: str): """ @@ -634,7 +638,7 @@ def _get_from_local_cache(self, command: str): or command[0] not in self.cache_whitelist ): return None - while not self._is_socket_empty(): + while not self._socket_is_empty(): self.read_response(push_request=True) return self.client_cache.get(command)