From ac282f2a8feb4b07264e70543d58db8fc52f9d4b Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 20 Dec 2023 12:16:04 +0200 Subject: [PATCH 01/13] cache invalidations --- redis/_parsers/resp3.py | 42 +++++++++++++++++++++-------------- redis/client.py | 49 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 22 deletions(-) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 569e7ee679..ac52e63448 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -12,9 +12,10 @@ class _RESP3Parser(_RESPBase): def __init__(self, socket_read_size): super().__init__(socket_read_size) - self.push_handler_func = self.handle_push_response + self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.inalidations_push_handler_func = None - def handle_push_response(self, response): + def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response @@ -114,13 +115,7 @@ def _read_response(self, disable_decoding=False, push_request=False): ) for _ in range(int(response)) ] - res = self.push_handler_func(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + self.handle_push_response(response, disable_decoding, push_request) else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -128,16 +123,31 @@ def _read_response(self, disable_decoding=False, push_request=False): response = self.encoder.decode(response) return response - def set_push_handler(self, push_handler_func): - self.push_handler_func = push_handler_func + def handle_push_response(self, response, disable_decoding, push_request): + if response[0] == b"invalidate": + res = self.invalidation_push_handler_func(response) + else: + res = self.pubsub_push_handler_func(response) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + + def set_pubsub_push_handler(self, pubsub_push_handler_func): + self.pubsub_push_handler_func = pubsub_push_handler_func + + def set_invalidation_push_handler(self, invalidations_push_handler_func): + self.invalidation_push_handler_func = invalidations_push_handler_func class _AsyncRESP3Parser(_AsyncRESPBase): def __init__(self, socket_read_size): super().__init__(socket_read_size) - self.push_handler_func = self.handle_push_response + self.pubsub_push_handler_func = self.handle_pubsub_push_response - def handle_push_response(self, response): + def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response @@ -246,7 +256,7 @@ async def _read_response( ) for _ in range(int(response)) ] - res = self.push_handler_func(response) + res = self.pubsub_push_handler_func(response) if not push_request: return await self._read_response( disable_decoding=disable_decoding, push_request=push_request @@ -260,5 +270,5 @@ async def _read_response( response = self.encoder.decode(response) return response - def set_push_handler(self, push_handler_func): - self.push_handler_func = push_handler_func + def set_push_handler(self, pubsub_push_handler_func): + self.pubsub_push_handler_func = pubsub_push_handler_func diff --git a/redis/client.py b/redis/client.py index 0af7e050d6..f04144e043 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,6 +1,7 @@ import copy import re import threading +import socket import time import warnings from itertools import chain @@ -325,6 +326,10 @@ def __init__( self.response_callbacks.update(_RedisCallbacksRESP2) self.client_cache = client_cache + self.connection_lock = threading.Lock() + self.invalidations_listener_thread = threading.Thread( + target=self._invalidations_listener + ) if cache_enable: self.client_cache = _LocalChace( cache_max_size, cache_ttl, cache_eviction_policy @@ -332,6 +337,11 @@ def __init__( if self.client_cache is not None: self.cache_blacklist = cache_blacklist self.cache_whitelist = cache_whitelist + self.execute_command("CLIENT", "TRACKING", "ON") + self.connection._parser.set_invalidation_push_handler( + self._cache_invalidation_process + ) + self.invalidations_listener_thread.start() def __repr__(self) -> str: return ( @@ -358,6 +368,30 @@ def set_response_callback(self, command: str, callback: Callable) -> None: """Set a custom Response Callback""" self.response_callbacks[command] = callback + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + if data[1] is not None: + for key in data[1]: + self.client_cache.invalidate(key) + else: + self.client_cache.flush() + + def _invalidations_listener(self) -> None: + connection_lock = threading.Lock() + sock = self.connection._parser._sock + # TODO: socket keepalive + while self.connection is not None: + with connection_lock: + try: + data_peek = sock.recv(65536, socket.MSG_PEEK) + if data_peek: + self.connection.read_response(push_request=True) + except (ConnectionError, ValueError): + self.client_cache.flush() + time.sleep(0.5) + self.client_cache.flush() + def load_external_module(self, funcname, func) -> None: """ This function can be used to add externally defined redis modules, @@ -530,6 +564,8 @@ def close(self): if self.auto_close_connection_pool: self.connection_pool.disconnect() + if self.client_cache: + self.invalidations_listener_thread.join() def _send_command_parse_response(self, conn, command_name, *args, **options): """ @@ -597,12 +633,13 @@ def execute_command(self, *args, **options): conn = self.connection or pool.get_connection(command_name, **options) try: - response = conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) + with self.connection_lock: + response = conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) self._add_to_local_cache(args, response, keys) return response finally: From 5638acb69da143a18c3a465c952499e9ffa7645e Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 20 Dec 2023 12:29:16 +0200 Subject: [PATCH 02/13] isort --- redis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index f04144e043..e720951292 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,7 +1,7 @@ import copy import re -import threading import socket +import threading import time import warnings from itertools import chain From 494d1bc82ff7419b405619d75e69858ad1048baa Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 24 Dec 2023 13:07:27 +0200 Subject: [PATCH 03/13] deamon thread --- redis/__init__.py | 2 -- redis/cache.py | 2 +- redis/client.py | 12 +++++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/redis/__init__.py b/redis/__init__.py index 7bf6839453..495d2d99bb 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -2,7 +2,6 @@ from redis import asyncio # noqa from redis.backoff import default_backoff -from redis.cache import _LocalChace from redis.client import Redis, StrictRedis from redis.cluster import RedisCluster from redis.connection import ( @@ -62,7 +61,6 @@ def int_or_str(value): VERSION = tuple([99, 99, 99]) __all__ = [ - "_LocalChace", "AuthenticationError", "AuthenticationWrongNumberOfArgsError", "BlockingConnectionPool", diff --git a/redis/cache.py b/redis/cache.py index 5a689d0ebd..defc5fc90f 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -159,7 +159,7 @@ class EvictionPolicy(Enum): RANDOM = "random" -class _LocalChace: +class _LocalCache: """ A caching mechanism for storing redis commands and their responses. diff --git a/redis/client.py b/redis/client.py index e720951292..6f8b810215 100755 --- a/redis/client.py +++ b/redis/client.py @@ -18,7 +18,7 @@ DEFAULT_BLACKLIST, DEFAULT_EVICTION_POLICY, DEFAULT_WHITELIST, - _LocalChace, + _LocalCache, ) from redis.commands import ( CoreCommands, @@ -212,7 +212,7 @@ def __init__( credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, cache_enable: bool = False, - client_cache: Optional[_LocalChace] = None, + client_cache: Optional[_LocalCache] = None, cache_max_size: int = 100, cache_ttl: int = 0, cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, @@ -330,8 +330,9 @@ def __init__( self.invalidations_listener_thread = threading.Thread( target=self._invalidations_listener ) + self.invalidations_listener_thread.daemon = True if cache_enable: - self.client_cache = _LocalChace( + self.client_cache = _LocalCache( cache_max_size, cache_ttl, cache_eviction_policy ) if self.client_cache is not None: @@ -382,6 +383,7 @@ def _invalidations_listener(self) -> None: sock = self.connection._parser._sock # TODO: socket keepalive while self.connection is not None: + print("listening for invalidations") with connection_lock: try: data_peek = sock.recv(65536, socket.MSG_PEEK) @@ -564,8 +566,8 @@ def close(self): if self.auto_close_connection_pool: self.connection_pool.disconnect() - if self.client_cache: - self.invalidations_listener_thread.join() + # if self.client_cache: + # self.invalidations_listener_thread.join() def _send_command_parse_response(self, conn, command_name, *args, **options): """ From 3cf4f14c1f60400207358baaf6624d29bb17d90c Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 26 Dec 2023 17:00:19 +0200 Subject: [PATCH 04/13] remove threads --- redis/client.py | 58 +++++++++++++++++++++------------------------ redis/connection.py | 6 +++++ 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/redis/client.py b/redis/client.py index 6f8b810215..0742b8ed64 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,6 +1,5 @@ import copy import re -import socket import threading import time import warnings @@ -326,11 +325,6 @@ def __init__( self.response_callbacks.update(_RedisCallbacksRESP2) self.client_cache = client_cache - self.connection_lock = threading.Lock() - self.invalidations_listener_thread = threading.Thread( - target=self._invalidations_listener - ) - self.invalidations_listener_thread.daemon = True if cache_enable: self.client_cache = _LocalCache( cache_max_size, cache_ttl, cache_eviction_policy @@ -342,7 +336,6 @@ def __init__( self.connection._parser.set_invalidation_push_handler( self._cache_invalidation_process ) - self.invalidations_listener_thread.start() def __repr__(self) -> str: return ( @@ -378,21 +371,23 @@ def _cache_invalidation_process( else: self.client_cache.flush() - def _invalidations_listener(self) -> None: - connection_lock = threading.Lock() - sock = self.connection._parser._sock - # TODO: socket keepalive - while self.connection is not None: - print("listening for invalidations") - with connection_lock: - try: - data_peek = sock.recv(65536, socket.MSG_PEEK) - if data_peek: - self.connection.read_response(push_request=True) - except (ConnectionError, ValueError): - self.client_cache.flush() - time.sleep(0.5) - self.client_cache.flush() + # def _invalidations_listener(self) -> None: + # connection_lock = threading.Lock() + # sock = self.connection._parser._sock + # # TODO: socket keepalive + # while self.connection is not None: + # print("listening for invalidations") + # with connection_lock: + # try: + # sock.setblocking(0) + # data_peek = sock.recv(65536, socket.MSG_PEEK) + # sock.setblocking(1) + # if data_peek: + # self.connection.read_response(push_request=True) + # except (ConnectionError, ValueError): + # self.client_cache.flush() + # time.sleep(0.5) + # self.client_cache.flush() def load_external_module(self, funcname, func) -> None: """ @@ -566,8 +561,8 @@ def close(self): if self.auto_close_connection_pool: self.connection_pool.disconnect() - # if self.client_cache: - # self.invalidations_listener_thread.join() + if self.client_cache: + self.client_cache.flush() def _send_command_parse_response(self, conn, command_name, *args, **options): """ @@ -599,6 +594,8 @@ def _get_from_local_cache(self, command: str): or command[0] not in self.cache_whitelist ): return None + while not self.connection._is_socket_empty(): + self.connection.read_response(push_request=True) return self.client_cache.get(command) def _add_to_local_cache(self, command: str, response: ResponseT, keys: List[KeysT]): @@ -635,13 +632,12 @@ def execute_command(self, *args, **options): conn = self.connection or pool.get_connection(command_name, **options) try: - with self.connection_lock: - response = conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) + response = conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) self._add_to_local_cache(args, response, keys) return response finally: diff --git a/redis/connection.py b/redis/connection.py index c201224e35..35a4ff4a37 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,5 +1,6 @@ import copy import os +import select import socket import ssl import sys @@ -572,6 +573,11 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output + def _is_socket_empty(self): + """Check if the socket is empty""" + r, _, _ = select.select([self._sock], [], [], 0) + return not bool(r) + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" From e5f2bbaccdb0def5aa14d166035d6fdcc340d35c Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 26 Dec 2023 17:02:20 +0200 Subject: [PATCH 05/13] delete comment --- redis/client.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/redis/client.py b/redis/client.py index 0742b8ed64..423a70fd5f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -371,24 +371,6 @@ def _cache_invalidation_process( else: self.client_cache.flush() - # def _invalidations_listener(self) -> None: - # connection_lock = threading.Lock() - # sock = self.connection._parser._sock - # # TODO: socket keepalive - # while self.connection is not None: - # print("listening for invalidations") - # with connection_lock: - # try: - # sock.setblocking(0) - # data_peek = sock.recv(65536, socket.MSG_PEEK) - # sock.setblocking(1) - # if data_peek: - # self.connection.read_response(push_request=True) - # except (ConnectionError, ValueError): - # self.client_cache.flush() - # time.sleep(0.5) - # self.client_cache.flush() - def load_external_module(self, funcname, func) -> None: """ This function can be used to add externally defined redis modules, From 455e89493e002a59bc36e48f30dd5ae5f52f54fe Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 28 Dec 2023 12:40:18 +0200 Subject: [PATCH 06/13] tests --- redis/cache.py | 13 ++++--- tests/test_cache.py | 91 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 6 deletions(-) create mode 100644 tests/test_cache.py diff --git a/redis/cache.py b/redis/cache.py index defc5fc90f..d0337997f8 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -220,6 +220,7 @@ def get(self, command: str) -> ResponseT: if command in self.cache: if self._is_expired(command): self.delete(command) + return self._update_access(command) return self.cache[command]["response"] @@ -266,28 +267,28 @@ def _update_access(self, command: str): Args: command (str): The redis command. """ - if self.eviction_policy == EvictionPolicy.LRU: + if self.eviction_policy == EvictionPolicy.LRU.value: self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.LFU: + elif self.eviction_policy == EvictionPolicy.LFU.value: self.cache[command]["access_count"] = ( self.cache.get(command, {}).get("access_count", 0) + 1 ) self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.RANDOM: + elif self.eviction_policy == EvictionPolicy.RANDOM.value: pass # Random eviction doesn't require updates def _evict(self): """Evict a redis command from the cache based on the eviction policy.""" if self._is_expired(self.commands_ttl_list[0]): self.delete(self.commands_ttl_list[0]) - elif self.eviction_policy == EvictionPolicy.LRU: + elif self.eviction_policy == EvictionPolicy.LRU.value: self.cache.popitem(last=False) - elif self.eviction_policy == EvictionPolicy.LFU: + elif self.eviction_policy == EvictionPolicy.LFU.value: min_access_command = min( self.cache, key=lambda k: self.cache[k].get("access_count", 0) ) self.cache.pop(min_access_command) - elif self.eviction_policy == EvictionPolicy.RANDOM: + elif self.eviction_policy == EvictionPolicy.RANDOM.value: random_command = random.choice(list(self.cache.keys())) self.cache.pop(random_command) diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000000..88a5c89bd8 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,91 @@ +# from redis.cache import _LocalCache, EvictionPolicy +import time + +import redis + + +def test_get_from_cache(): + r = redis.Redis(cache_enable=True, single_connection_client=True, protocol=3) + r2 = redis.Redis(protocol=3) + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == b"barbar" + + +def test_cache_max_size(): + r = redis.Redis( + cache_enable=True, cache_max_size=3, single_connection_client=True, protocol=3 + ) + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) == b"bar2" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + +def test_cache_ttl(): + r = redis.Redis( + cache_enable=True, cache_ttl=1, single_connection_client=True, protocol=3 + ) + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + +def test_cache_lfu_eviction(): + r = redis.Redis( + cache_enable=True, + cache_max_size=3, + cache_eviction_policy="lfu", + single_connection_client=True, + protocol=3, + ) + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # test the eviction policy + assert len(r.client_cache.cache) == 3 + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) is None From 814f78a4c9df5f89cf06ea4a58e0d31fd30846f9 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 28 Dec 2023 16:31:19 +0200 Subject: [PATCH 07/13] skip if hiredis available --- tests/test_cache.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 88a5c89bd8..ba945c9780 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,9 +1,11 @@ -# from redis.cache import _LocalCache, EvictionPolicy import time +import pytest import redis +from redis.utils import HIREDIS_AVAILABLE +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_get_from_cache(): r = redis.Redis(cache_enable=True, single_connection_client=True, protocol=3) r2 = redis.Redis(protocol=3) @@ -23,6 +25,7 @@ def test_get_from_cache(): assert r.get("foo") == b"barbar" +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_cache_max_size(): r = redis.Redis( cache_enable=True, cache_max_size=3, single_connection_client=True, protocol=3 @@ -46,6 +49,7 @@ def test_cache_max_size(): assert r.client_cache.get(("GET", "foo")) is None +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_cache_ttl(): r = redis.Redis( cache_enable=True, cache_ttl=1, single_connection_client=True, protocol=3 @@ -62,6 +66,7 @@ def test_cache_ttl(): assert r.client_cache.get(("GET", "foo")) is None +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_cache_lfu_eviction(): r = redis.Redis( cache_enable=True, From 8dc575e18a78d8c8f793944cea0bd237c889268d Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 31 Dec 2023 04:07:47 +0200 Subject: [PATCH 08/13] async --- redis/_parsers/resp3.py | 28 ++++--- redis/asyncio/client.py | 123 ++++++++++++++++++++++++++----- redis/asyncio/connection.py | 4 + redis/client.py | 2 +- redis/cluster.py | 2 +- tests/test_asyncio/test_cache.py | 104 ++++++++++++++++++++++++++ 6 files changed, 234 insertions(+), 29 deletions(-) create mode 100644 tests/test_asyncio/test_cache.py diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index ac52e63448..b12b84e345 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -13,7 +13,7 @@ class _RESP3Parser(_RESPBase): def __init__(self, socket_read_size): super().__init__(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response - self.inalidations_push_handler_func = None + self.invalidations_push_handler_func = None def handle_pubsub_push_response(self, response): logger = getLogger("push_response") @@ -146,6 +146,7 @@ class _AsyncRESP3Parser(_AsyncRESPBase): def __init__(self, socket_read_size): super().__init__(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.invalidations_push_handler_func = None def handle_pubsub_push_response(self, response): logger = getLogger("push_response") @@ -256,13 +257,7 @@ async def _read_response( ) for _ in range(int(response)) ] - res = self.pubsub_push_handler_func(response) - if not push_request: - return await self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + await self.handle_push_response(response, disable_decoding, push_request) else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -270,5 +265,20 @@ async def _read_response( response = self.encoder.decode(response) return response - def set_push_handler(self, pubsub_push_handler_func): + async def handle_push_response(self, response, disable_decoding, push_request): + if response[0] == b"invalidate": + res = self.invalidation_push_handler_func(response) + else: + res = self.pubsub_push_handler_func(response) + if not push_request: + return await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + + def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func + + def set_invalidation_push_handler(self, invalidations_push_handler_func): + self.invalidation_push_handler_func = invalidations_push_handler_func diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9e0491f810..5e5d02c446 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -37,6 +37,12 @@ ) from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry +from redis.cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalCache, +) from redis.client import ( EMPTY_RESPONSE, NEVER_DECODE, @@ -60,7 +66,7 @@ TimeoutError, WatchError, ) -from redis.typing import ChannelT, EncodableT, KeyT +from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT from redis.utils import ( HIREDIS_AVAILABLE, _set_info_logger, @@ -231,6 +237,13 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + cache_enable: bool = False, + client_cache: Optional[_LocalCache] = None, + cache_max_size: int = 100, + cache_ttl: int = 0, + cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_blacklist: List[str] = DEFAULT_BLACKLIST, + cache_whitelist: List[str] = DEFAULT_WHITELIST, ): """ Initialize a new Redis client. @@ -336,6 +349,16 @@ def __init__( # on a set of redis commands self._single_conn_lock = asyncio.Lock() + self.client_cache = client_cache + if cache_enable: + self.client_cache = _LocalCache( + cache_max_size, cache_ttl, cache_eviction_policy + ) + if self.client_cache is not None: + self.cache_blacklist = cache_blacklist + self.cache_whitelist = cache_whitelist + self.client_cache_initialized = False + def __repr__(self): return ( f"<{self.__class__.__module__}.{self.__class__.__name__}" @@ -350,6 +373,10 @@ async def initialize(self: _RedisT) -> _RedisT: async with self._single_conn_lock: if self.connection is None: self.connection = await self.connection_pool.get_connection("_") + if self.client_cache is not None: + self.connection._parser.set_invalidation_push_handler( + self._cache_invalidation_process + ) return self def set_response_callback(self, command: str, callback: ResponseCallbackT): @@ -568,6 +595,8 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: close_connection_pool is None and self.auto_close_connection_pool ): await self.connection_pool.disconnect() + if self.client_cache: + self.client_cache.flush() @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") async def close(self, close_connection_pool: Optional[bool] = None) -> None: @@ -596,29 +625,87 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): ): raise error + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + if data[1] is not None: + for key in data[1]: + self.client_cache.invalidate(key) + else: + self.client_cache.flush() + + async def _get_from_local_cache(self, command: str): + """ + If the command is in the local cache, return the response + """ + if ( + self.client_cache is None + or command[0] in self.cache_blacklist + or command[0] not in self.cache_whitelist + ): + return None + while not self.connection._is_socket_empty(): + await self.connection.read_response(push_request=True) + return self.client_cache.get(command) + + def _add_to_local_cache(self, command: str, response: ResponseT, keys: List[KeysT]): + """ + Add the command and response to the local cache if the command + is allowed to be cached + """ + if ( + self.client_cache is not None + and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) + and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + ): + self.client_cache.set(command, response, keys) + + def delete_from_local_cache(self, command: str): + """ + Delete the command from the local cache + """ + try: + self.client_cache.delete(command) + except AttributeError: + pass + # COMMAND EXECUTION AND PROTOCOL PARSING async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() - options.pop("keys", None) # the keys are used only for client side caching - pool = self.connection_pool command_name = args[0] - conn = self.connection or await pool.get_connection(command_name, **options) + keys = options.pop("keys", None) # keys are used only for client side caching + response_from_cache = await self._get_from_local_cache(args) + if response_from_cache is not None: + return response_from_cache + else: + pool = self.connection_pool + conn = self.connection or await pool.get_connection(command_name, **options) - if self.single_connection_client: - await self._single_conn_lock.acquire() - try: - return await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - finally: if self.single_connection_client: - self._single_conn_lock.release() - if not self.connection: - await pool.release(conn) + await self._single_conn_lock.acquire() + try: + if self.client_cache is not None and not self.client_cache_initialized: + await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, "CLIENT", *("CLIENT", "TRACKING", "ON") + ), + lambda error: self._disconnect_raise(conn, error), + ) + self.client_cache_initialized = True + response = await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + self._add_to_local_cache(args, response, keys) + return response + finally: + if self.single_connection_client: + self._single_conn_lock.release() + if not self.connection: + await pool.release(conn) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options @@ -866,7 +953,7 @@ async def connect(self): else: await self.connection.connect() if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) async def _disconnect_raise_connect(self, conn, error): """ diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index bbd438fc0b..39f75a5f13 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -645,6 +645,10 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] output.append(SYM_EMPTY.join(pieces)) return output + def _is_socket_empty(self): + """Check if the socket is empty""" + return not self._reader.at_eof() + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" diff --git a/redis/client.py b/redis/client.py index 423a70fd5f..4bd6306d6a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -836,7 +836,7 @@ def execute_command(self, *args): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: diff --git a/redis/cluster.py b/redis/cluster.py index 0405b0547c..8032173e66 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1778,7 +1778,7 @@ def execute_command(self, *args): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) connection = self.connection self._execute(connection, connection.send_command, *args) diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py new file mode 100644 index 0000000000..b029555f0e --- /dev/null +++ b/tests/test_asyncio/test_cache.py @@ -0,0 +1,104 @@ +import time + +import pytest +import redis.asyncio as redis +from redis.utils import HIREDIS_AVAILABLE + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_get_from_cache(): + r = redis.Redis(cache_enable=True, single_connection_client=True, protocol=3) + r2 = redis.Redis(protocol=3) + # add key to redis + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + await r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == b"barbar" + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_max_size(): + r = redis.Redis( + cache_enable=True, cache_max_size=3, single_connection_client=True, protocol=3 + ) + # add 3 keys to redis + await r.set("foo", "bar") + await r.set("foo2", "bar2") + await r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert await r.get("foo") == b"bar" + assert await r.get("foo2") == b"bar2" + assert await r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) == b"bar2" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + await r.set("foo4", "bar4") + assert await r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_ttl(): + r = redis.Redis( + cache_enable=True, cache_ttl=1, single_connection_client=True, protocol=3 + ) + # add key to redis + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_lfu_eviction(): + r = redis.Redis( + cache_enable=True, + cache_max_size=3, + cache_eviction_policy="lfu", + single_connection_client=True, + protocol=3, + ) + # add 3 keys to redis + await r.set("foo", "bar") + await r.set("foo2", "bar2") + await r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert await r.get("foo") == b"bar" + assert await r.get("foo2") == b"bar2" + assert await r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + await r.set("foo4", "bar4") + assert await r.get("foo4") == b"bar4" + # test the eviction policy + assert len(r.client_cache.cache) == 3 + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) is None + + await r.aclose() From da924026cdc90aebdd082aa9e1c29222e1c2a0b8 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 31 Dec 2023 14:28:24 +0200 Subject: [PATCH 09/13] review comments --- redis/_parsers/resp3.py | 6 ++++-- redis/asyncio/client.py | 4 +++- redis/client.py | 8 +++++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index b12b84e345..521bc42fb9 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -6,6 +6,8 @@ from .base import _AsyncRESPBase, _RESPBase from .socket import SERVER_CLOSED_CONNECTION_ERROR +_INVALIDATION_MESSAGE = b"invalidate" + class _RESP3Parser(_RESPBase): """RESP3 protocol implementation""" @@ -124,7 +126,7 @@ def _read_response(self, disable_decoding=False, push_request=False): return response def handle_push_response(self, response, disable_decoding, push_request): - if response[0] == b"invalidate": + if response[0] == _INVALIDATION_MESSAGE: res = self.invalidation_push_handler_func(response) else: res = self.pubsub_push_handler_func(response) @@ -266,7 +268,7 @@ async def _read_response( return response async def handle_push_response(self, response, disable_decoding, push_request): - if response[0] == b"invalidate": + if response[0] == _INVALIDATION_MESSAGE: res = self.invalidation_push_handler_func(response) else: res = self.pubsub_push_handler_func(response) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 5e5d02c446..3ca695651f 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -648,7 +648,9 @@ async def _get_from_local_cache(self, command: str): await self.connection.read_response(push_request=True) return self.client_cache.get(command) - def _add_to_local_cache(self, command: str, response: ResponseT, keys: List[KeysT]): + def _add_to_local_cache( + self, command: Tuple[str], response: ResponseT, keys: List[KeysT] + ): """ Add the command and response to the local cache if the command is allowed to be cached diff --git a/redis/client.py b/redis/client.py index 4bd6306d6a..b030516e73 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,7 +4,7 @@ import time import warnings from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( @@ -332,7 +332,7 @@ def __init__( if self.client_cache is not None: self.cache_blacklist = cache_blacklist self.cache_whitelist = cache_whitelist - self.execute_command("CLIENT", "TRACKING", "ON") + self.client_tracking_on() self.connection._parser.set_invalidation_push_handler( self._cache_invalidation_process ) @@ -580,7 +580,9 @@ def _get_from_local_cache(self, command: str): self.connection.read_response(push_request=True) return self.client_cache.get(command) - def _add_to_local_cache(self, command: str, response: ResponseT, keys: List[KeysT]): + def _add_to_local_cache( + self, command: Tuple[str], response: ResponseT, keys: List[KeysT] + ): """ Add the command and response to the local cache if the command is allowed to be cached From 8fe83ab905ba796166598bb915a4b2d4994ff8b0 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 31 Dec 2023 15:03:45 +0200 Subject: [PATCH 10/13] docstring --- redis/asyncio/client.py | 6 ++++++ redis/client.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 3ca695651f..36c1b35a4c 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -628,6 +628,12 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): def _cache_invalidation_process( self, data: List[Union[str, Optional[List[str]]]] ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ if data[1] is not None: for key in data[1]: self.client_cache.invalidate(key) diff --git a/redis/client.py b/redis/client.py index b030516e73..4f7a4cf015 100755 --- a/redis/client.py +++ b/redis/client.py @@ -365,6 +365,12 @@ def set_response_callback(self, command: str, callback: Callable) -> None: def _cache_invalidation_process( self, data: List[Union[str, Optional[List[str]]]] ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ if data[1] is not None: for key in data[1]: self.client_cache.invalidate(key) From a9965b27f29e94f8e724e727b434f298f96cf060 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 31 Dec 2023 15:35:29 +0200 Subject: [PATCH 11/13] decode test --- tests/test_cache.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_cache.py b/tests/test_cache.py index ba945c9780..45621fe77e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -94,3 +94,26 @@ def test_cache_lfu_eviction(): assert len(r.client_cache.cache) == 3 assert r.client_cache.get(("GET", "foo")) == b"bar" assert r.client_cache.get(("GET", "foo2")) is None + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_decode_response(): + r = redis.Redis( + decode_responses=True, + cache_enable=True, + single_connection_client=True, + protocol=3, + ) + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == "bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == "barbar" From f9089546804f5975e6e45263799724c67a103ba8 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 31 Dec 2023 15:44:01 +0200 Subject: [PATCH 12/13] fix test --- tests/test_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 45621fe77e..926521d099 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -104,13 +104,14 @@ def test_cache_decode_response(): single_connection_client=True, protocol=3, ) + r2 = redis.Redis(protocol=3) r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == "bar" # get key from local cache assert r.client_cache.get(("GET", "foo")) == "bar" # change key in redis (cause invalidation) - r.set("foo", "barbar") + r2.set("foo", "barbar") # send any command to redis (process invalidation in background) r.ping() # the command is not in the local cache anymore From 5fd31c7d518413aed1c9410f7b38525408b63df5 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 1 Jan 2024 01:26:40 +0200 Subject: [PATCH 13/13] fix decode response test --- redis/_parsers/resp3.py | 6 +++--- redis/asyncio/client.py | 2 +- redis/cache.py | 3 ++- redis/client.py | 2 +- tests/test_asyncio/test_cache.py | 25 +++++++++++++++++++++++++ tests/test_cache.py | 3 +-- 6 files changed, 33 insertions(+), 8 deletions(-) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 521bc42fb9..13aa1ffccb 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -6,7 +6,7 @@ from .base import _AsyncRESPBase, _RESPBase from .socket import SERVER_CLOSED_CONNECTION_ERROR -_INVALIDATION_MESSAGE = b"invalidate" +_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] class _RESP3Parser(_RESPBase): @@ -126,7 +126,7 @@ def _read_response(self, disable_decoding=False, push_request=False): return response def handle_push_response(self, response, disable_decoding, push_request): - if response[0] == _INVALIDATION_MESSAGE: + if response[0] in _INVALIDATION_MESSAGE: res = self.invalidation_push_handler_func(response) else: res = self.pubsub_push_handler_func(response) @@ -268,7 +268,7 @@ async def _read_response( return response async def handle_push_response(self, response, disable_decoding, push_request): - if response[0] == _INVALIDATION_MESSAGE: + if response[0] in _INVALIDATION_MESSAGE: res = self.invalidation_push_handler_func(response) else: res = self.pubsub_push_handler_func(response) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 36c1b35a4c..eea9612f4a 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -636,7 +636,7 @@ def _cache_invalidation_process( """ if data[1] is not None: for key in data[1]: - self.client_cache.invalidate(key) + self.client_cache.invalidate(str_if_bytes(key)) else: self.client_cache.flush() diff --git a/redis/cache.py b/redis/cache.py index d0337997f8..d920702339 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -323,5 +323,6 @@ def invalidate(self, key: KeyT): """ if key not in self.key_commands_map: return - for command in self.key_commands_map[key]: + commands = list(self.key_commands_map[key]) + for command in commands: self.delete(command) diff --git a/redis/client.py b/redis/client.py index 4f7a4cf015..7f2c8d290d 100755 --- a/redis/client.py +++ b/redis/client.py @@ -373,7 +373,7 @@ def _cache_invalidation_process( """ if data[1] is not None: for key in data[1]: - self.client_cache.invalidate(key) + self.client_cache.invalidate(str_if_bytes(key)) else: self.client_cache.flush() diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py index b029555f0e..c837acfed1 100644 --- a/tests/test_asyncio/test_cache.py +++ b/tests/test_asyncio/test_cache.py @@ -102,3 +102,28 @@ async def test_cache_lfu_eviction(): assert r.client_cache.get(("GET", "foo2")) is None await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_decode_response(): + r = redis.Redis( + decode_responses=True, + cache_enable=True, + single_connection_client=True, + protocol=3, + ) + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == "bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + await r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == "barbar" + + await r.aclose() diff --git a/tests/test_cache.py b/tests/test_cache.py index 926521d099..45621fe77e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -104,14 +104,13 @@ def test_cache_decode_response(): single_connection_client=True, protocol=3, ) - r2 = redis.Redis(protocol=3) r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == "bar" # get key from local cache assert r.client_cache.get(("GET", "foo")) == "bar" # change key in redis (cause invalidation) - r2.set("foo", "barbar") + r.set("foo", "barbar") # send any command to redis (process invalidation in background) r.ping() # the command is not in the local cache anymore