diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 2cd2daddcc..18fdf94174 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -671,13 +671,13 @@ def __init__( if self.encoder is None: self.encoder = self.connection_pool.get_encoder() if self.encoder.decode_responses: - self.health_check_response: Iterable[Union[str, bytes]] = [ - "pong", + self.health_check_response = [ + ["pong", self.HEALTH_CHECK_MESSAGE], self.HEALTH_CHECK_MESSAGE, ] else: self.health_check_response = [ - b"pong", + [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)], self.encoder.encode(self.HEALTH_CHECK_MESSAGE), ] if self.push_handler_func is None: @@ -807,7 +807,7 @@ async def parse_response(self, block: bool = True, timeout: float = 0): conn, conn.read_response, timeout=read_timeout, push_request=True ) - if conn.health_check_interval and response == self.health_check_response: + if conn.health_check_interval and response in self.health_check_response: # ignore the health check message as user might not expect it return None return response diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 525c17b22d..4a606ad38f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -319,6 +319,8 @@ def __init__( kwargs.update({"retry": self.retry}) kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy() + if kwargs.get("protocol") in ["3", 3]: + kwargs["response_callbacks"].update(self.__class__.RESP3_RESPONSE_CALLBACKS) self.connection_kwargs = kwargs if startup_nodes: diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index bc872ff358..b51e4fd8ce 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -333,7 +333,9 @@ def _error_message(self, exception): async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" self._parser.on_connect(self) + parser = self._parser + auth_args = None # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( @@ -341,8 +343,26 @@ async def on_connect(self) -> None: or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() - # avoid checking health here -- PING will fail if we try - # to check the health prior to the AUTH + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol not in [2, "2"]: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = await self.read_response() + if response.get(b"proto") not in [2, "2"] and response.get("proto") not in [ + 2, + "2", + ]: + raise ConnectionError("Invalid RESP version") + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + elif auth_args: await self.send_command("AUTH", *auth_args, check_health=False) try: @@ -359,9 +379,11 @@ async def on_connect(self) -> None: raise AuthenticationError("Invalid Username or Password") # if resp version is specified, switch to it - if self.protocol != 2: + elif self.protocol != 2: if isinstance(self._parser, _AsyncRESP2Parser): self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) await self.send_command("HELLO", self.protocol) response = await self.read_response() diff --git a/redis/client.py b/redis/client.py index ef327b5922..e4e82981e9 100755 --- a/redis/client.py +++ b/redis/client.py @@ -331,9 +331,15 @@ def parse_xinfo_stream(response, **options): data["last-entry"] = (last[0], pairs_to_dict(last[1])) else: data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} - data["groups"] = [ - pairs_to_dict(group, decode_keys=True) for group in data["groups"] - ] + if isinstance(data["groups"][0], list): + data["groups"] = [ + pairs_to_dict(group, decode_keys=True) for group in data["groups"] + ] + else: + data["groups"] = [ + {str_if_bytes(k): v for k, v in group.items()} + for group in data["groups"] + ] return data @@ -581,14 +587,15 @@ def parse_command_resp3(response, **options): cmd_name = str_if_bytes(command[0]) cmd_dict["name"] = cmd_name cmd_dict["arity"] = command[1] - cmd_dict["flags"] = command[2] + cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]} cmd_dict["first_key_pos"] = command[3] cmd_dict["last_key_pos"] = command[4] cmd_dict["step_count"] = command[5] cmd_dict["acl_categories"] = command[6] - cmd_dict["tips"] = command[7] - cmd_dict["key_specifications"] = command[8] - cmd_dict["subcommands"] = command[9] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] commands[cmd_name] = cmd_dict return commands @@ -626,17 +633,20 @@ def parse_acl_getuser(response, **options): if data["channels"] == [""]: data["channels"] = [] if "selectors" in data: - data["selectors"] = [ - list(map(str_if_bytes, selector)) for selector in data["selectors"] - ] + if data["selectors"] != [] and isinstance(data["selectors"][0], list): + data["selectors"] = [ + list(map(str_if_bytes, selector)) for selector in data["selectors"] + ] + elif data["selectors"] != []: + data["selectors"] = [ + {str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()} + for selector in data["selectors"] + ] # split 'commands' into separate 'categories' and 'commands' lists commands, categories = [], [] for command in data["commands"].split(" "): - if "@" in command: - categories.append(command) - else: - commands.append(command) + categories.append(command) if "@" in command else commands.append(command) data["commands"] = commands data["categories"] = categories diff --git a/redis/cluster.py b/redis/cluster.py index d3956e45f5..898db29cdc 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -33,6 +33,7 @@ from redis.parsers import CommandsParser, Encoder from redis.retry import Retry from redis.utils import ( + HIREDIS_AVAILABLE, dict_merge, list_keys_to_dict, merge_result, @@ -1608,7 +1609,15 @@ class ClusterPubSub(PubSub): https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html """ - def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): + def __init__( + self, + redis_cluster, + node=None, + host=None, + port=None, + push_handler_func=None, + **kwargs, + ): """ When a pubsub instance is created without specifying a node, a single node will be transparently chosen for the pubsub connection on the @@ -1633,7 +1642,10 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): self.node_pubsub_mapping = {} self._pubsubs_generator = self._pubsubs_generator() super().__init__( - **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder + connection_pool=connection_pool, + encoder=redis_cluster.encoder, + push_handler_func=push_handler_func, + **kwargs, ) def set_pubsub_node(self, cluster, node=None, host=None, port=None): @@ -1717,6 +1729,8 @@ def execute_command(self, *args): # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection self._execute(connection, connection.send_command, *args) @@ -1724,7 +1738,9 @@ def _get_node_pubsub(self, node): try: return self.node_pubsub_mapping[node.name] except KeyError: - pubsub = node.redis_connection.pubsub() + pubsub = node.redis_connection.pubsub( + push_handler_func=self.push_handler_func + ) self.node_pubsub_mapping[node.name] = pubsub return pubsub diff --git a/redis/connection.py b/redis/connection.py index 19c80e08f5..ee3bece11c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -276,7 +276,9 @@ def _error_message(self, exception): def on_connect(self): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) + parser = self._parser + auth_args = None # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( @@ -284,6 +286,23 @@ def on_connect(self): or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol != 2: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = self.read_response() + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): + raise ConnectionError("Invalid RESP version") + elif auth_args: # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH self.send_command("AUTH", *auth_args, check_health=False) @@ -302,9 +321,11 @@ def on_connect(self): raise AuthenticationError("Invalid Username or Password") # if resp version is specified, switch to it - if self.protocol != 2: + elif self.protocol != 2: if isinstance(self._parser, _RESP2Parser): self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) self.send_command("HELLO", self.protocol) response = self.read_response() diff --git a/redis/parsers/__init__.py b/redis/parsers/__init__.py index 0586016a61..6cc32e3cae 100644 --- a/redis/parsers/__init__.py +++ b/redis/parsers/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseParser +from .base import BaseParser, _AsyncRESPBase from .commands import AsyncCommandsParser, CommandsParser from .encoders import Encoder from .hiredis import _AsyncHiredisParser, _HiredisParser @@ -8,6 +8,7 @@ __all__ = [ "AsyncCommandsParser", "_AsyncHiredisParser", + "_AsyncRESPBase", "_AsyncRESP2Parser", "_AsyncRESP3Parser", "CommandsParser", diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py index a04f054e24..b443e45ae6 100644 --- a/redis/parsers/resp3.py +++ b/redis/parsers/resp3.py @@ -69,9 +69,12 @@ def _read_response(self, disable_decoding=False, push_request=False): # bool value elif byte == b"#": return response == b"t" - # bulk response and verbatim strings - elif byte in (b"$", b"="): + # bulk response + elif byte == b"$": response = self._buffer.read(int(response)) + # verbatim string response + elif byte == b"=": + response = self._buffer.read(int(response))[4:] # array response elif byte == b"*": response = [ @@ -195,9 +198,12 @@ async def _read_response( # bool value elif byte == b"#": return response == b"t" - # bulk response and verbatim strings - elif byte in (b"$", b"="): + # bulk response + elif byte == b"$": response = await self._read(int(response)) + # verbatim string response + elif byte == b"=": + response = (await self._read(int(response)))[4:] # array response elif byte == b"*": response = [ diff --git a/tests/conftest.py b/tests/conftest.py index c471f3d837..6454750353 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -475,8 +475,31 @@ def wait_for_command(client, monitor, command, key=None): def is_resp2_connection(r): - if isinstance(r, redis.Redis): + if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis): protocol = r.connection_pool.connection_kwargs.get("protocol") - elif isinstance(r, redis.RedisCluster): + elif isinstance(r, redis.cluster.AbstractRedisCluster): protocol = r.nodes_manager.connection_kwargs.get("protocol") return protocol in ["2", 2, None] + + +def get_protocol_version(r): + if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis): + return r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.cluster.AbstractRedisCluster): + return r.nodes_manager.connection_kwargs.get("protocol") + + +def assert_resp_response(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response == resp2_expected + else: + assert response == resp3_expected + + +def assert_resp_response_in(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response in resp2_expected + else: + assert response in resp3_expected diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index e8ab6b297f..28a6f0626f 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -236,29 +236,6 @@ async def wait_for_command( return None -def get_protocol_version(r): - if isinstance(r, redis.Redis): - return r.connection_pool.connection_kwargs.get("protocol") - elif isinstance(r, redis.RedisCluster): - return r.nodes_manager.connection_kwargs.get("protocol") - - -def assert_resp_response(r, response, resp2_expected, resp3_expected): - protocol = get_protocol_version(r) - if protocol in [2, "2", None]: - assert response == resp2_expected - else: - assert response == resp3_expected - - -def assert_resp_response_in(r, response, resp2_expected, resp3_expected): - protocol = get_protocol_version(r) - if protocol in [2, "2", None]: - assert response in resp2_expected - else: - assert response in resp3_expected - - # python 3.6 doesn't have the asynccontextmanager decorator. Provide it here. class AsyncContextManager: def __init__(self, async_generator): diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index a80fa30cb9..58c0e0b0c7 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -31,6 +31,7 @@ from redis.parsers import AsyncCommandsParser from redis.utils import str_if_bytes from tests.conftest import ( + assert_resp_response, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, @@ -1613,7 +1614,8 @@ async def test_cluster_zdiff(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert await r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + response = await r.zdiff(["{foo}a", "{foo}b"], withscores=True) + assert_resp_response(r, response, [b"a3", b"3"], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: @@ -1621,7 +1623,8 @@ async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert await r.zrange("{foo}out", 0, -1) == [b"a3"] - assert await r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + response = await r.zrange("{foo}out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") async def test_cluster_zinter(self, r: RedisCluster) -> None: @@ -1635,32 +1638,41 @@ async def test_cluster_zinter(self, r: RedisCluster) -> None: ["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True ) # aggregate with SUM - assert await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] + response = await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) + assert_resp_response( + r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8], [b"a1", 9]] + ) # aggregate with MAX - assert await r.zinter( + response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] + ) + assert_resp_response( + r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]] + ) # aggregate with MIN - assert await r.zinter( + response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] + ) + assert_resp_response( + r, response, [(b"a1", 1), (b"a3", 1)], [[b"a1", 1], [b"a3", 1]] + ) # with weights - assert await r.zinter( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a3", 20), (b"a1", 23)] + res = await r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) + assert_resp_response( + r, res, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]] + ) async def test_cluster_zinterstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8.0], [b"a1", 9.0]], + ) async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1672,10 +1684,12 @@ async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: ) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5.0], [b"a1", 6.0]], + ) async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1687,10 +1701,12 @@ async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: ) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a3", 3), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1.0], [b"a3", 3.0]], + ) async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1699,10 +1715,12 @@ async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: assert ( await r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("4.9.0") async def test_cluster_bzpopmax(self, r: RedisCluster) -> None: @@ -1767,10 +1785,12 @@ async def test_cluster_zrangestore(self, r: RedisCluster) -> None: assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] assert await r.zrangestore("{foo}b", "{foo}a", 1, 2) assert await r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert await r.zrange("{foo}b", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a3", 3), - ] + assert_resp_response( + r, + await r.zrange("{foo}b", 0, -1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) # reversed order assert await r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] @@ -1797,36 +1817,49 @@ async def test_cluster_zunion(self, r: RedisCluster) -> None: b"a3", b"a1", ] - assert await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) # max - assert await r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) # min - assert await r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + assert_resp_response( + r, + await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 1.0], [b"a3", 1.0], [b"a4", 4.0]], + ) # with weight - assert await r.zunion( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + await r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) async def test_cluster_zunionstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1838,12 +1871,12 @@ async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: ) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1855,12 +1888,12 @@ async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: ) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1869,12 +1902,12 @@ async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: assert ( await r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("2.8.9") async def test_cluster_pfcount(self, r: RedisCluster) -> None: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 866929b2e4..78376fd0e9 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -13,13 +13,14 @@ from redis import exceptions from redis.client import EMPTY_RESPONSE, NEVER_DECODE, parse_info from tests.conftest import ( + assert_resp_response, + assert_resp_response_in, + is_resp2_connection, skip_if_server_version_gte, skip_if_server_version_lt, skip_unless_arch_bits, ) -from .conftest import assert_resp_response, assert_resp_response_in - REDIS_6_VERSION = "5.9.0" @@ -73,7 +74,11 @@ class TestResponseCallbacks: """Tests for the response callback system""" async def test_response_callbacks(self, r: redis.Redis): - assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS + resp3_callbacks = redis.Redis.RESPONSE_CALLBACKS.copy() + resp3_callbacks.update(redis.Redis.RESP3_RESPONSE_CALLBACKS) + assert_resp_response( + r, r.response_callbacks, redis.Redis.RESPONSE_CALLBACKS, resp3_callbacks + ) assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) r.set_response_callback("GET", lambda x: "static") await r.set("a", "foo") @@ -123,27 +128,24 @@ async def test_acl_getuser_setuser(self, r_teardown): r = r_teardown(username) # test enabled=False assert await r.acl_setuser(username, enabled=False, reset=True) - assert await r.acl_getuser(username) == { - "categories": ["-@all"], - "commands": [], - "channels": [b"*"], - "enabled": False, - "flags": ["off", "allchannels", "sanitize-payload"], - "keys": [], - "passwords": [], - } + acl = await r.acl_getuser(username) + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "off" in acl["flags"] + assert acl["enabled"] is False # test nopass=True assert await r.acl_setuser(username, enabled=True, reset=True, nopass=True) - assert await r.acl_getuser(username) == { - "categories": ["-@all"], - "commands": [], - "channels": [b"*"], - "enabled": True, - "flags": ["on", "allchannels", "nopass", "sanitize-payload"], - "keys": [], - "passwords": [], - } + acl = await r.acl_getuser(username) + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "on" in acl["flags"] + assert "nopass" in acl["flags"] + assert acl["enabled"] is True # test all args assert await r.acl_setuser( @@ -160,8 +162,8 @@ async def test_acl_getuser_setuser(self, r_teardown): assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True assert acl["channels"] == [b"*"] - assert acl["flags"] == ["on", "allchannels", "sanitize-payload"] - assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert set(acl["flags"]) == {"on", "allchannels", "sanitize-payload"} + assert acl["keys"] == [b"cache:*", b"objects:*"] assert len(acl["passwords"]) == 2 # test reset=False keeps existing ACL and applies new ACL on top @@ -187,7 +189,7 @@ async def test_acl_getuser_setuser(self, r_teardown): assert set(acl["commands"]) == {"+get", "+mget"} assert acl["enabled"] is True assert acl["channels"] == [b"*"] - assert acl["flags"] == ["on", "allchannels", "sanitize-payload"] + assert set(acl["flags"]) == {"on", "allchannels", "sanitize-payload"} assert set(acl["keys"]) == {b"cache:*", b"objects:*"} assert len(acl["passwords"]) == 2 @@ -2912,16 +2914,16 @@ async def test_xreadgroup(self, r: redis.Redis): # xreadgroup with noack does not have any items in the PEL await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") - # res = r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) - # empty_res = r.xreadgroup(group, consumer, streams={stream: "0"}) - # if is_resp2_connection(r): - # assert len(res[0][1]) == 2 - # # now there should be nothing pending - # assert len(empty_res[0][1]) == 0 - # else: - # assert len(res[strem_name][0]) == 2 - # # now there should be nothing pending - # assert len(empty_res[strem_name][0]) == 0 + res = await r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + empty_res = await r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert len(res[0][1]) == 2 + # now there should be nothing pending + assert len(empty_res[0][1]) == 0 + else: + assert len(res[strem_name][0]) == 2 + # now there should be nothing pending + assert len(empty_res[strem_name][0]) == 0 await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 3a8cf8d9c2..c5b21055e0 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -11,7 +11,12 @@ from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError -from redis.parsers import _AsyncHiredisParser, _AsyncRESP2Parser, _AsyncRESP3Parser +from redis.parsers import ( + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, + _AsyncRESPBase, +) from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt @@ -26,11 +31,11 @@ async def test_invalid_response(create_redis): raw = b"x" fake_stream = MockStream(raw + b"\r\n") - parser: _AsyncRESP2Parser = r.connection._parser + parser: _AsyncRESPBase = r.connection._parser with mock.patch.object(parser, "_stream", fake_stream): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - if isinstance(parser, _AsyncRESP2Parser): + if isinstance(parser, _AsyncRESPBase): assert str(cm.value) == f"Protocol Error: {raw!r}" else: assert ( diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 3df57eb90f..b29aa53487 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -21,7 +21,6 @@ async def test_pipeline(self, r): .zadd("z", {"z1": 1}) .zadd("z", {"z2": 4}) .zincrby("z", 1, "z1") - .zrange("z", 0, 5, withscores=True) ) assert await pipe.execute() == [ True, @@ -29,7 +28,6 @@ async def test_pipeline(self, r): True, True, 2.0, - [(b"z1", 2.0), (b"z2", 4)], ] async def test_pipeline_memoryview(self, r): diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 412398f37b..8160b3b0f1 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -17,10 +17,9 @@ from redis.exceptions import ConnectionError from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import skip_if_server_version_lt +from tests.conftest import get_protocol_version, skip_if_server_version_lt from .compat import create_task, mock -from .conftest import get_protocol_version def with_timeout(t): @@ -422,6 +421,7 @@ async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): assert expect in info.exconly() +@pytest.mark.onlynoncluster class TestPubSubRESP3Handler: def my_handler(self, message): self.message = ["my handler", message] diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 4a43eaea21..2ca323eaf5 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -39,7 +39,7 @@ from .conftest import ( _get_client, - is_resp2_connection, + assert_resp_response, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, @@ -1725,10 +1725,13 @@ def test_cluster_zdiff(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - if is_resp2_connection(r): - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] - else: - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [[b"a3", 3.0]] + response = r.zdiff(["{foo}a", "{foo}b"], withscores=True) + assert_resp_response( + r, + response, + [b"a3", b"3"], + [[b"a3", 3.0]], + ) @skip_if_server_version_lt("6.2.0") def test_cluster_zdiffstore(self, r): @@ -1736,10 +1739,8 @@ def test_cluster_zdiffstore(self, r): r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert r.zrange("{foo}out", 0, -1) == [b"a3"] - if is_resp2_connection(r): - assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] - else: - assert r.zrange("{foo}out", 0, -1, withscores=True) == [[b"a3", 3.0]] + response = r.zrange("{foo}out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") def test_cluster_zinter(self, r): @@ -1750,49 +1751,42 @@ def test_cluster_zinter(self, r): # invalid aggregation with pytest.raises(DataError): r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) - if is_resp2_connection(r): - # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] - # aggregate with MAX - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] - # aggregate with MIN - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] - # with weights - assert r.zinter( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a3", 20), (b"a1", 23)] - else: - # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - [b"a3", 8], - [b"a1", 9], - ] - # aggregate with MAX - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [[b"a3", 5], [b"a1", 6]] - # aggregate with MIN - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [[b"a1", 1], [b"a3", 1]] - # with weights - assert r.zinter( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [[b"a3", 2], [b"a1", 2]] + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True, aggregate="MAX"), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True, aggregate="MIN"), + [(b"a1", 1), (b"a3", 1)], + [[b"a1", 1], [b"a3", 1]], + ) + assert_resp_response( + r, + r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a3", 20.0), (b"a1", 23.0)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) def test_cluster_zinterstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8.0], [b"a1", 9.0]], + ) def test_cluster_zinterstore_max(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1802,7 +1796,12 @@ def test_cluster_zinterstore_max(self, r): r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5.0], [b"a1", 6.0]], + ) def test_cluster_zinterstore_min(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1812,14 +1811,24 @@ def test_cluster_zinterstore_min(self, r): r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1.0], [b"a3", 3.0]], + ) def test_cluster_zinterstore_with_weight(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmax(self, r): @@ -1852,7 +1861,12 @@ def test_cluster_zrangestore(self, r): assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("{foo}b", "{foo}a", 1, 2) assert r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("{foo}b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("{foo}b", 0, 1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) # reversed order assert r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] @@ -1874,39 +1888,45 @@ def test_cluster_zunion(self, r): r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) # max - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) # min - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 1.0], [b"a3", 1.0], [b"a4", 4.0]], + ) # with weight - assert r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) def test_cluster_zunionstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) def test_cluster_zunionstore_max(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1916,12 +1936,12 @@ def test_cluster_zunionstore_max(self, r): r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) def test_cluster_zunionstore_min(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1931,24 +1951,24 @@ def test_cluster_zunionstore_min(self, r): r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) def test_cluster_zunionstore_with_weight(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("2.8.9") def test_cluster_pfcount(self, r): @@ -2970,7 +2990,12 @@ def test_pipeline_readonly(self, r): with r.pipeline() as readonly_pipe: readonly_pipe.get("foo71").zrange("foo88", 0, 5, withscores=True) - assert readonly_pipe.execute() == [b"a1", [(b"z1", 1.0), (b"z2", 4)]] + assert_resp_response( + r, + readonly_pipe.execute(), + [b"a1", [(b"z1", 1.0), (b"z2", 4)]], + [b"a1", [[b"z1", 1.0], [b"z2", 4.0]]], + ) def test_moved_redirection_on_slave_with_default(self, r): """ diff --git a/tests/test_commands.py b/tests/test_commands.py index 1af69c83c0..97fbb34925 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -13,6 +13,8 @@ from .conftest import ( _get_client, + assert_resp_response, + assert_resp_response_in, is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_gte, @@ -56,7 +58,10 @@ class TestResponseCallbacks: "Tests for the response callback system" def test_response_callbacks(self, r): - assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS + callbacks = redis.Redis.RESPONSE_CALLBACKS + if not is_resp2_connection(r): + callbacks.update(redis.Redis.RESP3_RESPONSE_CALLBACKS) + assert r.response_callbacks == callbacks assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) r.set_response_callback("GET", lambda x: "static") r["a"] = "foo" @@ -67,6 +72,7 @@ def test_case_insensitive_command_names(self, r): class TestRedisCommands: + @pytest.mark.onlynoncluster @skip_if_redis_enterprise() def test_auth(self, r, request): # sending an AUTH command before setting a user/password on the @@ -101,7 +107,6 @@ def teardown(): # connection field is not set in Redis Cluster, but that's ok # because the problem discussed above does not apply to Redis Cluster pass - r.auth(temp_pass) r.config_set("requirepass", "") r.acl_deluser(username) @@ -317,9 +322,12 @@ def teardown(): assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 assert set(acl["channels"]) == {"&message:*"} - assert acl["selectors"] == [ - ["commands", "-@all +set", "keys", "%W~app*", "channels", ""] - ] + assert_resp_response( + r, + acl["selectors"], + ["commands", "-@all +set", "keys", "%W~app*", "channels", ""], + [{"commands": "-@all +set", "keys": "%W~app*", "channels": ""}], + ) @skip_if_server_version_lt("6.0.0") def test_acl_help(self, r): @@ -381,11 +389,13 @@ def teardown(): assert len(r.acl_log()) == 2 assert len(r.acl_log(count=1)) == 1 assert isinstance(r.acl_log()[0], dict) - if is_resp2_connection(r): - assert "client-info" in r.acl_log(count=1)[0] - else: - assert "client-info" in r.acl_log(count=1)[0].keys() - assert r.acl_log_reset() + expected = r.acl_log(count=1)[0] + assert_resp_response_in( + r, + "client-info", + expected, + expected.keys(), + ) @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() @@ -1124,8 +1134,12 @@ def test_lcs(self, r): r.mset({"foo": "ohmytext", "bar": "mynewtext"}) assert r.lcs("foo", "bar") == b"mytext" assert r.lcs("foo", "bar", len=True) == 6 - result = [b"matches", [[[4, 7], [5, 8]]], b"len", 6] - assert r.lcs("foo", "bar", idx=True, minmatchlen=3) == result + assert_resp_response( + r, + r.lcs("foo", "bar", idx=True, minmatchlen=3), + [b"matches", [[[4, 7], [5, 8]]], b"len", 6], + {b"matches": [[[4, 7], [5, 8]]], b"len": 6}, + ) with pytest.raises(redis.ResponseError): assert r.lcs("foo", "bar", len=True, idx=True) @@ -1539,10 +1553,7 @@ def test_hrandfield(self, r): assert r.hrandfield("key") is not None assert len(r.hrandfield("key", 2)) == 2 # with values - if is_resp2_connection(r): - assert len(r.hrandfield("key", 2, True)) == 4 - else: - assert len(r.hrandfield("key", 2, True)) == 2 + assert_resp_response(r, len(r.hrandfield("key", 2, withvalues=True)), 4, 2) # without duplications assert len(r.hrandfield("key", 10)) == 5 # with duplications @@ -1695,30 +1706,26 @@ def test_stralgo_lcs(self, r): assert r.stralgo("LCS", key1, key2, specific_argument="keys") == res # test other labels assert r.stralgo("LCS", value1, value2, len=True) == len(res) - if is_resp2_connection(r): - assert r.stralgo("LCS", value1, value2, idx=True) == { - "len": len(res), - "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]], - } - assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { - "len": len(res), - "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]], - } - assert r.stralgo( - "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True - ) == {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]} - else: - assert r.stralgo("LCS", value1, value2, idx=True) == { - "len": len(res), - "matches": [[[4, 7], [5, 8]], [[2, 3], [0, 1]]], - } - assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { - "len": len(res), - "matches": [[[4, 7], [5, 8], 4], [[2, 3], [0, 1], 2]], - } - assert r.stralgo( - "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True - ) == {"len": len(res), "matches": [[[4, 7], [5, 8], 4]]} + assert_resp_response( + r, + r.stralgo("LCS", value1, value2, idx=True), + {"len": len(res), "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8]], [[2, 3], [0, 1]]]}, + ) + assert_resp_response( + r, + r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True), + {"len": len(res), "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8], 4], [[2, 3], [0, 1], 2]]}, + ) + assert_resp_response( + r, + r.stralgo( + "LCS", value1, value2, idx=True, withmatchlen=True, minmatchlen=4 + ), + {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8], 4]]}, + ) @skip_if_server_version_lt("6.0.0") @skip_if_server_version_gte("7.0.0") @@ -2167,10 +2174,12 @@ def test_spop_multi_value(self, r): for value in values: assert value in s - if is_resp2_connection(r): - assert r.spop("a", 1) == list(set(s) - set(values)) - else: - assert r.spop("a", 1) == set(s) - set(values) + assert_resp_response( + r, + r.spop("a", 1), + list(set(s) - set(values)), + set(s) - set(values), + ) def test_srandmember(self, r): s = [b"1", b"2", b"3"] @@ -2221,18 +2230,12 @@ def test_script_debug(self, r): def test_zadd(self, r): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} r.zadd("a", mapping) - if is_resp2_connection(r): - assert r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] - else: - assert r.zrange("a", 0, -1, withscores=True) == [ - [b"a1", 1.0], - [b"a2", 2.0], - [b"a3", 3.0], - ] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0]], + ) # error cases with pytest.raises(exceptions.DataError): @@ -2249,32 +2252,32 @@ def test_zadd(self, r): def test_zadd_nx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - if is_resp2_connection(r): - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] - else: - assert r.zrange("a", 0, -1, withscores=True) == [[b"a1", 1.0], [b"a2", 2.0]] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0)], + [[b"a1", 1.0], [b"a2", 2.0]], + ) def test_zadd_xx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - if is_resp2_connection(r): - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] - else: - assert r.zrange("a", 0, -1, withscores=True) == [[b"a1", 99.0]] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 99.0)], + [[b"a1", 99.0]], + ) def test_zadd_ch(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - if is_resp2_connection(r): - assert r.zrange("a", 0, -1, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 99.0), - ] - else: - assert r.zrange("a", 0, -1, withscores=True) == [ - [b"a2", 2.0], - [b"a1", 99.0], - ] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a2", 2.0), (b"a1", 99.0)], + [[b"a2", 2.0], [b"a1", 99.0]], + ) def test_zadd_incr(self, r): assert r.zadd("a", {"a1": 1}) == 1 @@ -2322,10 +2325,12 @@ def test_zdiff(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiff(["a", "b"]) == [b"a3"] - if is_resp2_connection(r): - assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"] - else: - assert r.zdiff(["a", "b"], withscores=True) == [[b"a3", 3.0]] + assert_resp_response( + r, + r.zdiff(["a", "b"], withscores=True), + [b"a3", b"3"], + [[b"a3", 3.0]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.2.0") @@ -2334,10 +2339,12 @@ def test_zdiffstore(self, r): r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiffstore("out", ["a", "b"]) assert r.zrange("out", 0, -1) == [b"a3"] - if is_resp2_connection(r): - assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)] - else: - assert r.zrange("out", 0, -1, withscores=True) == [[b"a3", 3.0]] + assert_resp_response( + r, + r.zrange("out", 0, -1, withscores=True), + [(b"a3", 3.0)], + [[b"a3", 3.0]], + ) def test_zincrby(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -2362,48 +2369,34 @@ def test_zinter(self, r): # invalid aggregation with pytest.raises(exceptions.DataError): r.zinter(["a", "b", "c"], aggregate="foo", withscores=True) - if is_resp2_connection(r): - # aggregate with SUM - assert r.zinter(["a", "b", "c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] - # aggregate with MAX - assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] - # aggregate with MIN - assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a3", 1), - ] - # with weights - assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] - else: - # aggregate with SUM - assert r.zinter(["a", "b", "c"], withscores=True) == [ - [b"a3", 8], - [b"a1", 9], - ] - # aggregate with MAX - assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - [b"a3", 5], - [b"a1", 6], - ] - # aggregate with MIN - assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - [b"a1", 1], - [b"a3", 1], - ] - # with weights - assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - [b"a3", 20], - [b"a1", 23], - ] + # aggregate with SUM + assert_resp_response( + r, + r.zinter(["a", "b", "c"], withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) + # aggregate with MAX + assert_resp_response( + r, + r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) + # aggregate with MIN + assert_resp_response( + r, + r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a3", 1)], + [[b"a1", 1], [b"a3", 1]], + ) + # with weights + assert_resp_response( + r, + r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20], [b"a1", 23]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") @@ -2420,10 +2413,12 @@ def test_zinterstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"]) == 2 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] - else: - assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 8], [b"a1", 9]] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) @pytest.mark.onlynoncluster def test_zinterstore_max(self, r): @@ -2431,10 +2426,12 @@ def test_zinterstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] - else: - assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 5], [b"a1", 6]] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) @pytest.mark.onlynoncluster def test_zinterstore_min(self, r): @@ -2442,10 +2439,12 @@ def test_zinterstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] - else: - assert r.zrange("d", 0, -1, withscores=True) == [[b"a1", 1], [b"a3", 3]] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1], [b"a3", 3]], + ) @pytest.mark.onlynoncluster def test_zinterstore_with_weight(self, r): @@ -2453,34 +2452,36 @@ def test_zinterstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] - else: - assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 20], [b"a1", 23]] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20], [b"a1", 23]], + ) @skip_if_server_version_lt("4.9.0") def test_zpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - if is_resp2_connection(r): - assert r.zpopmax("a") == [(b"a3", 3)] - # with count - assert r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] - else: - assert r.zpopmax("a") == [b"a3", 3.0] - # with count - assert r.zpopmax("a", count=2) == [[b"a2", 2], [b"a1", 1]] + assert_resp_response(r, r.zpopmax("a"), [(b"a3", 3)], [b"a3", 3.0]) + # with count + assert_resp_response( + r, + r.zpopmax("a", count=2), + [(b"a2", 2), (b"a1", 1)], + [[b"a2", 2], [b"a1", 1]], + ) @skip_if_server_version_lt("4.9.0") def test_zpopmin(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - if is_resp2_connection(r): - assert r.zpopmin("a") == [(b"a1", 1)] - # with count - assert r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] - else: - assert r.zpopmin("a") == [b"a1", 1.0] - # with count - assert r.zpopmin("a", count=2) == [[b"a2", 2], [b"a3", 3]] + assert_resp_response(r, r.zpopmin("a"), [(b"a1", 1)], [b"a1", 1.0]) + # with count + assert_resp_response( + r, + r.zpopmin("a", count=2), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2], [b"a3", 3]], + ) @skip_if_server_version_lt("6.2.0") def test_zrandemember(self, r): @@ -2488,10 +2489,12 @@ def test_zrandemember(self, r): assert r.zrandmember("a") is not None assert len(r.zrandmember("a", 2)) == 2 # with scores - if is_resp2_connection(r): - assert len(r.zrandmember("a", 2, True)) == 4 - else: - assert len(r.zrandmember("a", 2, True)) == 2 + assert_resp_response( + r, + len(r.zrandmember("a", 2, withscores=True)), + 4, + 2, + ) # without duplications assert len(r.zrandmember("a", 10)) == 5 # with duplications @@ -2527,24 +2530,41 @@ def test_bzpopmin(self, r): @skip_if_server_version_lt("7.0.0") def test_zmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] - assert r.zmpop("2", ["b", "a"], min=True, count=2) == res + assert_resp_response( + r, + r.zmpop("2", ["b", "a"], min=True, count=2), + [b"a", [[b"a1", b"1"], [b"a2", b"2"]]], + [b"a", [[b"a1", 1.0], [b"a2", 2.0]]], + ) with pytest.raises(redis.DataError): r.zmpop("2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - assert r.zmpop("2", ["b", "a"], max=True) == [b"b", [[b"b1", b"10"]]] + assert_resp_response( + r, + r.zmpop("2", ["b", "a"], max=True), + [b"b", [[b"b1", b"10"]]], + [b"b", [[b"b1", 10.0]]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_bzmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] - assert r.bzmpop(1, "2", ["b", "a"], min=True, count=2) == res + assert_resp_response( + r, + r.bzmpop(1, "2", ["b", "a"], min=True, count=2), + [b"a", [[b"a1", b"1"], [b"a2", b"2"]]], + [b"a", [[b"a1", 1.0], [b"a2", 2.0]]], + ) with pytest.raises(redis.DataError): r.bzmpop(1, "2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - res = [b"b", [[b"b1", b"10"]]] - assert r.bzmpop(0, "2", ["b", "a"], max=True) == res + assert_resp_response( + r, + r.bzmpop(0, "2", ["b", "a"], max=True), + [b"b", [[b"b1", b"10"]]], + [b"b", [[b"b1", 10.0]]], + ) assert r.bzmpop(1, "2", ["foo", "bar"], max=True) is None def test_zrange(self, r): @@ -2555,18 +2575,24 @@ def test_zrange(self, r): assert r.zrange("a", 0, 2, desc=True) == [b"a3", b"a2", b"a1"] # withscores - if is_resp2_connection(r): - assert r.zrange("a", 0, 1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] - assert r.zrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a3", 3.0)] + assert_resp_response( + r, + r.zrange("a", 0, 1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0)], + [[b"a1", 1.0], [b"a2", 2.0]], + ) + assert_resp_response( + r, + r.zrange("a", 1, 2, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) - # custom score function - assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] - else: - assert r.zrange("a", 0, 1, withscores=True) == [[b"a1", 1.0], [b"a2", 2.0]] - assert r.zrange("a", 1, 2, withscores=True) == [[b"a2", 2.0], [b"a3", 3.0]] + # # custom score function + # assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a1", 1), + # (b"a2", 2), + # ] def test_zrange_errors(self, r): with pytest.raises(exceptions.DataError): @@ -2598,25 +2624,20 @@ def test_zrange_params(self, r): b"a3", b"a2", ] - if is_resp2_connection(r): - assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrange( - "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] - - else: - assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ - [b"a2", 2.0], - [b"a3", 3.0], - [b"a4", 4.0], - ] - assert r.zrange( + assert_resp_response( + r, + r.zrange("a", 2, 4, byscore=True, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) + assert_resp_response( + r, + r.zrange( "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int - ) == [[b"a4", 4], [b"a3", 3], [b"a2", 2]] + ), + [(b"a4", 4), (b"a3", 3), (b"a2", 2)], + [[b"a4", 4], [b"a3", 3], [b"a2", 2]], + ) # rev assert r.zrange("a", 0, 1, desc=True) == [b"a5", b"a4"] @@ -2629,10 +2650,12 @@ def test_zrangestore(self, r): assert r.zrange("b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("b", "a", 1, 2) assert r.zrange("b", 0, -1) == [b"a2", b"a3"] - if is_resp2_connection(r): - assert r.zrange("b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] - else: - assert r.zrange("b", 0, -1, withscores=True) == [[b"a2", 2], [b"a3", 3]] + assert_resp_response( + r, + r.zrange("b", 0, -1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2], [b"a3", 3]], + ) # reversed order assert r.zrangestore("b", "a", 1, 2, desc=True) assert r.zrange("b", 0, -1) == [b"a1", b"a2"] @@ -2667,28 +2690,18 @@ def test_zrangebyscore(self, r): # slicing with start/num assert r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - if is_resp2_connection(r): - assert r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] - else: - assert r.zrangebyscore("a", 2, 4, withscores=True) == [ - [b"a2", 2.0], - [b"a3", 3.0], - [b"a4", 4.0], - ] - assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ - [b"a2", 2], - [b"a3", 3], - [b"a4", 4], - ] + assert_resp_response( + r, + r.zrangebyscore("a", 2, 4, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) + assert_resp_response( + r, + r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int), + [(b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) def test_zrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2735,32 +2748,25 @@ def test_zrevrange(self, r): assert r.zrevrange("a", 0, 1) == [b"a3", b"a2"] assert r.zrevrange("a", 1, 2) == [b"a2", b"a1"] - if is_resp2_connection(r): - # withscores - assert r.zrevrange("a", 0, 1, withscores=True) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] - assert r.zrevrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 1.0), - ] + # withscores + assert_resp_response( + r, + r.zrevrange("a", 0, 1, withscores=True), + [(b"a3", 3.0), (b"a2", 2.0)], + [[b"a3", 3.0], [b"a2", 2.0]], + ) + assert_resp_response( + r, + r.zrevrange("a", 1, 2, withscores=True), + [(b"a2", 2.0), (b"a1", 1.0)], + [[b"a2", 2.0], [b"a1", 1.0]], + ) - # custom score function - assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] - else: - # withscores - assert r.zrevrange("a", 0, 1, withscores=True) == [ - [b"a3", 3.0], - [b"a2", 2.0], - ] - assert r.zrevrange("a", 1, 2, withscores=True) == [ - [b"a2", 2.0], - [b"a1", 1.0], - ] + # # custom score function + # assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a3", 3.0), + # (b"a2", 2.0), + # ] def test_zrevrangebyscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2768,28 +2774,20 @@ def test_zrevrangebyscore(self, r): # slicing with start/num assert r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] - if is_resp2_connection(r): - # withscores - assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] - # custom score function - assert r.zrevrangebyscore( - "a", 4, 2, withscores=True, score_cast_func=int - ) == [ - (b"a4", 4), - (b"a3", 3), - (b"a2", 2), - ] - else: - # withscores - assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - [b"a4", 4.0], - [b"a3", 3.0], - [b"a2", 2.0], - ] + # withscores + assert_resp_response( + r, + r.zrevrangebyscore("a", 4, 2, withscores=True), + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) + # custom score function + assert_resp_response( + r, + r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int), + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) def test_zrevrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2811,63 +2809,33 @@ def test_zunion(self, r): r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["a", "b", "c"]) == [b"a2", b"a4", b"a3", b"a1"] - - if is_resp2_connection(r): - assert r.zunion(["a", "b", "c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] - # max - assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] - # min - assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a2", 1), - (b"a3", 1), - (b"a4", 4), - ] - # with weight - assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] - else: - assert r.zunion(["a", "b", "c"], withscores=True) == [ - [b"a2", 3], - [b"a4", 4], - [b"a3", 8], - [b"a1", 9], - ] - # max - assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - [b"a2", 2], - [b"a4", 4], - [b"a3", 5], - [b"a1", 6], - ] - # min - assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - [b"a1", 1], - [b"a2", 1], - [b"a3", 1], - [b"a4", 4], - ] - # with weight - assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - [b"a2", 5], - [b"a4", 12], - [b"a3", 20], - [b"a1", 23], - ] + assert_resp_response( + r, + r.zunion(["a", "b", "c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3], [b"a4", 4], [b"a3", 8], [b"a1", 9]], + ) + # max + assert_resp_response( + r, + r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2], [b"a4", 4], [b"a3", 5], [b"a1", 6]], + ) + # min + assert_resp_response( + r, + r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1], [b"a2", 1], [b"a3", 1], [b"a4", 4]], + ) + # with weight + assert_resp_response( + r, + r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5], [b"a4", 12], [b"a3", 20], [b"a1", 23]], + ) @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r): @@ -2875,21 +2843,12 @@ def test_zunionstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"]) == 4 - - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] - else: - assert r.zrange("d", 0, -1, withscores=True) == [ - [b"a2", 3], - [b"a4", 4], - [b"a3", 8], - [b"a1", 9], - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3], [b"a4", 4], [b"a3", 8], [b"a1", 9]], + ) @pytest.mark.onlynoncluster def test_zunionstore_max(self, r): @@ -2897,20 +2856,12 @@ def test_zunionstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] - else: - assert r.zrange("d", 0, -1, withscores=True) == [ - [b"a2", 2], - [b"a4", 4], - [b"a3", 5], - [b"a1", 6], - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2], [b"a4", 4], [b"a3", 5], [b"a1", 6]], + ) @pytest.mark.onlynoncluster def test_zunionstore_min(self, r): @@ -2918,20 +2869,12 @@ def test_zunionstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] - else: - assert r.zrange("d", 0, -1, withscores=True) == [ - [b"a1", 1], - [b"a2", 2], - [b"a3", 3], - [b"a4", 4], - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1], [b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) @pytest.mark.onlynoncluster def test_zunionstore_with_weight(self, r): @@ -2939,20 +2882,12 @@ def test_zunionstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - if is_resp2_connection(r): - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] - else: - assert r.zrange("d", 0, -1, withscores=True) == [ - [b"a2", 5], - [b"a4", 12], - [b"a3", 20], - [b"a1", 23], - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5], [b"a4", 12], [b"a3", 20], [b"a1", 23]], + ) @skip_if_server_version_lt("6.1.240") def test_zmscore(self, r): @@ -4020,7 +3955,7 @@ def test_xadd_explicit_ms(self, r: redis.Redis): ms = message_id[: message_id.index(b"-")] assert ms == b"9999999999999999999" - @skip_if_server_version_lt("6.2.0") + @skip_if_server_version_lt("7.0.0") def test_xautoclaim(self, r): stream = "stream" group = "group" @@ -4035,7 +3970,7 @@ def test_xautoclaim(self, r): # trying to claim a message that isn't already pending doesn't # do anything response = r.xautoclaim(stream, group, consumer2, min_idle_time=0) - assert response == [b"0-0", []] + assert response == [b"0-0", [], []] # read the group as consumer1 to initially claim the messages r.xreadgroup(group, consumer1, streams={stream: ">"}) @@ -4327,10 +4262,12 @@ def test_xinfo_stream_full(self, r): info = r.xinfo_stream(stream, full=True) assert info["length"] == 1 - if is_resp2_connection(r): - assert m1 in info["entries"] - else: - assert m1 in info["entries"][0] + assert_resp_response_in( + r, + m1, + info["entries"], + info["entries"].keys(), + ) assert len(info["groups"]) == 1 @skip_if_server_version_lt("5.0.0") @@ -4471,40 +4408,39 @@ def test_xread(self, r): m1 = r.xadd(stream, {"foo": "bar"}) m2 = r.xadd(stream, {"bing": "baz"}) - strem_name = stream.encode() + stream_name = stream.encode() expected_entries = [ get_stream_message(r, stream, m1), get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - res = r.xread(streams={stream: 0}) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xread(streams={stream: 0}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) expected_entries = [get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - res = r.xread(streams={stream: 0}, count=1) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xread(streams={stream: 0}, count=1), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) expected_entries = [get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - res = r.xread(streams={stream: m1}) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xread(streams={stream: m1}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) # xread starting at the last message returns an empty list - res = r.xread(streams={stream: m2}) - if is_resp2_connection(r): - assert res == [] - else: - assert res == {} + assert_resp_response(r, r.xread(streams={stream: m2}), [], {}) @skip_if_server_version_lt("5.0.0") def test_xreadgroup(self, r): @@ -4515,18 +4451,19 @@ def test_xreadgroup(self, r): m2 = r.xadd(stream, {"bing": "baz"}) r.xgroup_create(stream, group, 0) - strem_name = stream.encode() + stream_name = stream.encode() expected_entries = [ get_stream_message(r, stream, m1), get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - res = r.xreadgroup(group, consumer, streams={stream: ">"}) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: ">"}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, 0) @@ -4534,11 +4471,12 @@ def test_xreadgroup(self, r): expected_entries = [get_stream_message(r, stream, m1)] # xread with count=1 returns only the first message - res = r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: ">"}, count=1), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) r.xgroup_destroy(stream, group) @@ -4547,10 +4485,9 @@ def test_xreadgroup(self, r): r.xgroup_create(stream, group, "$") # xread starting after the last message returns an empty message list - if is_resp2_connection(r): - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == [] - else: - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == {} + assert_resp_response( + r, r.xreadgroup(group, consumer, streams={stream: ">"}), [], {} + ) # xreadgroup with noack does not have any items in the PEL r.xgroup_destroy(stream, group) @@ -4562,9 +4499,9 @@ def test_xreadgroup(self, r): # now there should be nothing pending assert len(empty_res[0][1]) == 0 else: - assert len(res[strem_name][0]) == 2 + assert len(res[stream_name][0]) == 2 # now there should be nothing pending - assert len(empty_res[strem_name][0]) == 0 + assert len(empty_res[stream_name][0]) == 0 r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") @@ -4572,11 +4509,12 @@ def test_xreadgroup(self, r): expected_entries = [(m1, {}), (m2, {})] r.xreadgroup(group, consumer, streams={stream: ">"}) r.xtrim(stream, 0) - res = r.xreadgroup(group, consumer, streams={stream: "0"}) - if is_resp2_connection(r): - assert res == [[strem_name, expected_entries]] - else: - assert res == {strem_name: [expected_entries]} + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: "0"}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) @skip_if_server_version_lt("5.0.0") def test_xrevrange(self, r): @@ -4869,12 +4807,17 @@ def test_command(self, r): @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() def test_command_getkeysandflags(self, r: redis.Redis): - res = [ - [b"mylist1", [b"RW", b"access", b"delete"]], - [b"mylist2", [b"RW", b"insert"]], - ] - assert res == r.command_getkeysandflags( - "LMOVE", "mylist1", "mylist2", "left", "left" + assert_resp_response( + r, + r.command_getkeysandflags("LMOVE", "mylist1", "mylist2", "left", "left"), + [ + [b"mylist1", [b"RW", b"access", b"delete"]], + [b"mylist2", [b"RW", b"insert"]], + ], + [ + [b"mylist1", {b"RW", b"access", b"delete"}], + [b"mylist2", {b"RW", b"insert"}], + ], ) @pytest.mark.onlynoncluster diff --git a/tests/test_function.py b/tests/test_function.py index 7ce66a38e6..bb32fdf27c 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -2,7 +2,7 @@ from redis.exceptions import ResponseError -from .conftest import skip_if_server_version_lt +from .conftest import assert_resp_response, skip_if_server_version_lt engine = "lua" lib = "mylib" @@ -64,12 +64,22 @@ def test_function_list(self, r): [[b"name", b"myfunc", b"description", None, b"flags", [b"no-writes"]]], ] ] - assert r.function_list() == res - assert r.function_list(library="*lib") == res - assert ( - r.function_list(withcode=True)[0][7] - == f"#!{engine} name={lib} \n {function}".encode() + resp3_res = [ + { + b"library_name": b"mylib", + b"engine": b"LUA", + b"functions": [ + {b"name": b"myfunc", b"description": None, b"flags": {b"no-writes"}} + ], + } + ] + assert_resp_response(r, r.function_list(), res, resp3_res) + assert_resp_response(r, r.function_list(library="*lib"), res, resp3_res) + res[0].extend( + [b"library_code", f"#!{engine} name={lib} \n {function}".encode()] ) + resp3_res[0][b"library_code"] = f"#!{engine} name={lib} \n {function}".encode() + assert_resp_response(r, r.function_list(withcode=True), res, resp3_res) @pytest.mark.onlycluster def test_function_list_on_cluster(self, r): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 2f6b4bad80..fc98966d74 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -608,6 +608,19 @@ def test_push_handler(self, r): assert wait_for_message(p) is None assert self.message == ["my handler", [b"message", b"foo", b"test message"]] + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + @skip_if_server_version_lt("7.0.0") + def test_push_handler_sharded_pubsub(self, r): + if is_resp2_connection(r): + return + p = r.pubsub(push_handler_func=self.my_handler) + p.ssubscribe("foo") + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == ["my handler", [b"ssubscribe", b"foo", 1]] + assert r.spublish("foo", "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == ["my handler", [b"smessage", b"foo", b"test message"]] + class TestPubSubAutoDecoding: "These tests only validate that we get unicode values back"