diff --git a/.dockerignore b/.dockerignore index 7b9bc9b2df..4dc633c1a6 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,5 +1,4 @@ **/__pycache__ **/*.pyc -.tox .coverage .coverage.* diff --git a/.flake8 b/.flake8 index 6c663473e4..73b4a96bb6 100644 --- a/.flake8 +++ b/.flake8 @@ -4,7 +4,6 @@ exclude = *.egg-info, *.pyc, .git, - .tox, .venv*, build, docs/*, diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..1af2323fe9 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +doctests/* @dmaier-redislabs diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a3b0b0e4e7..722906b048 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,7 +2,7 @@ _Please make sure to review and check all of these items:_ -- [ ] Does `$ tox` pass with this change (including linting)? +- [ ] Do tests and lints pass with this change? - [ ] Do the CI tests pass with this change (enable it first in your forked repo and wait for the github action build to finish)? - [ ] Is the new or changed code fully tested? - [ ] Is a documentation update included (if this change modifies existing APIs, or introduces new ones)? diff --git a/.github/wordlist.txt b/.github/wordlist.txt index be16c437ff..22ae767e46 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -132,7 +132,6 @@ thevalueofmykey timeseries toctree topk -tox triaging txt un diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index e82e7e1530..61da2fce55 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -36,7 +36,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 61ec76e9f8..56f16fa2b0 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -25,7 +25,7 @@ jobs: name: Build docs runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: 3.9 diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 7e0fea2e41..207d58fac7 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -29,7 +29,7 @@ jobs: name: Dependency audit runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: pypa/gh-action-pip-audit@v1.0.8 with: inputs: requirements.txt dev_requirements.txt @@ -40,7 +40,7 @@ jobs: name: Code linters runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: 3.9 @@ -57,14 +57,14 @@ jobs: max-parallel: 15 fail-fast: false matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9'] + python-version: ['3.8', '3.9', '3.10', '3.11', 'pypy-3.8', 'pypy-3.9'] test-type: ['standalone', 'cluster'] connection-type: ['hiredis', 'plain'] env: ACTIONS_ALLOW_UNSECURE_COMMANDS: true name: Python ${{ matrix.python-version }} ${{matrix.test-type}}-${{matrix.connection-type}} tests steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -111,7 +111,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.7', '3.11'] + python-version: ['3.8', '3.11'] test-type: ['standalone', 'cluster'] connection-type: ['hiredis', 'plain'] protocol: ['3'] @@ -119,7 +119,7 @@ jobs: ACTIONS_ALLOW_UNSECURE_COMMANDS: true name: RESP3 [${{ matrix.python-version }} ${{matrix.test-type}}-${{matrix.connection-type}}] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -146,7 +146,7 @@ jobs: matrix: extension: ['tar.gz', 'whl'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: 3.9 @@ -160,9 +160,9 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9'] + python-version: ['3.8', '3.9', '3.10', '3.11', 'pypy-3.8', 'pypy-3.9'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/pypi-publish.yaml b/.github/workflows/pypi-publish.yaml index 50332c1995..4f8833372f 100644 --- a/.github/workflows/pypi-publish.yaml +++ b/.github/workflows/pypi-publish.yaml @@ -12,7 +12,7 @@ jobs: build_and_package: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: install python uses: actions/setup-python@v4 with: diff --git a/.github/workflows/spellcheck.yml b/.github/workflows/spellcheck.yml index e152841553..a48781aa84 100644 --- a/.github/workflows/spellcheck.yml +++ b/.github/workflows/spellcheck.yml @@ -6,9 +6,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Check Spelling - uses: rojopolis/spellcheck-github-actions@0.33.1 + uses: rojopolis/spellcheck-github-actions@0.35.0 with: config_path: .github/spellcheck-settings.yml task_name: Markdown diff --git a/.gitignore b/.gitignore index b392a2d748..3baa34034f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,6 @@ redis.egg-info build/ dist/ dump.rdb -/.tox _build vagrant/.vagrant .python-version diff --git a/CHANGES b/CHANGES index 7b3b4c5ac2..9e0db230c6 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,7 @@ + * Fix return types for `get`, `set_path` and `strappend` in JSONCommands + * Connection.register_connect_callback() is made public. + * Fix async `read_response` to use `disable_decoding`. + * Add 'aclose()' methods to async classes, deprecate async close(). * Fix #2831, add auto_close_connection_pool=True arg to asyncio.Redis.from_url() * Fix incorrect redis.asyncio.Cluster type hint for `retry_on_error` * Fix dead weakref in sentinel connection causing ReferenceError (#2767) @@ -54,6 +58,8 @@ * Fix for Unhandled exception related to self.host with unix socket (#2496) * Improve error output for master discovery * Make `ClusterCommandsProtocol` an actual Protocol + * Add `sum` to DUPLICATE_POLICY documentation of `TS.CREATE`, `TS.ADD` and `TS.ALTER` + * Prevent async ClusterPipeline instances from becoming "false-y" in case of empty command stack (#3061) * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1081f4cb46..4da55c737c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -78,16 +78,9 @@ It is possible to run only Redis client tests (with cluster mode disabled) by using `invoke standalone-tests`; similarly, RedisCluster tests can be run by using `invoke cluster-tests`. -Each run of tox starts and stops the various dockers required. Sometimes +Each run of tests starts and stops the various dockers required. Sometimes things get stuck, an `invoke clean` can help. -Continuous Integration uses these same wrappers to run all of these -tests against multiple versions of python. Feel free to test your -changes against all the python versions supported, as declared by the -tox.ini file (eg: tox -e py39). If you have the various python versions -on your desktop, you can run *tox* by itself, to test all supported -versions. - ### Docker Tips Following are a few tips that can help you work with the Docker-based @@ -97,10 +90,6 @@ To get a bash shell inside of a container: `$ docker run -it /bin/bash` -**Note**: The term \"service\" refers to the \"services\" defined in the -`tox.ini` file at the top of the repo: \"master\", \"replicaof\", -\"sentinel_1\", \"sentinel_2\", \"sentinel_3\". - Containers run a minimal Debian image that probably lacks tools you want to use. To install packages, first get a bash session (see previous tip) and then run: @@ -111,23 +100,6 @@ You can see the logging output of a containers like this: `$ docker logs -f ` -The command make test runs all tests in all tested Python -environments. To run the tests in a single environment, like Python 3.9, -use a command like this: - -`$ docker-compose run test tox -e py39 -- --redis-url=redis://master:6379/9` - -Here, the flag `-e py39` runs tests against the Python 3.9 tox -environment. And note from the example that whenever you run tests like -this, instead of using make test, you need to pass -`-- --redis-url=redis://master:6379/9`. This points the tests at the -\"master\" container. - -Our test suite uses `pytest`. You can run a specific test suite against -a specific Python version like this: - -`$ docker-compose run test tox -e py37 -- --redis-url=redis://master:6379/9 tests/test_commands.py` - ### Troubleshooting If you get any errors when running `make dev` or `make test`, make sure diff --git a/README.md b/README.md index f8c3a78ae7..2097e87bba 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,20 @@ The Python interface to the Redis key-value store. --------------------------------------------- +## How do I Redis? + +[Learn for free at Redis University](https://university.redis.com/) + +[Build faster with the Redis Launchpad](https://launchpad.redis.com/) + +[Try the Redis Cloud](https://redis.com/try-free/) + +[Dive in developer tutorials](https://developer.redis.com/) + +[Join the Redis community](https://redis.com/community/) + +[Work at Redis](https://redis.com/company/careers/jobs/) + ## Installation Start a redis via docker: @@ -42,7 +56,7 @@ Looking for a high-level library to handle object mapping? See [redis-om-python] ## Supported Redis Versions -The most recent version of this library supports redis version [5.0](https://github.com/redis/redis/blob/5.0/00-RELEASENOTES), [6.0](https://github.com/redis/redis/blob/6.0/00-RELEASENOTES), [6.2](https://github.com/redis/redis/blob/6.2/00-RELEASENOTES), and [7.0](https://github.com/redis/redis/blob/7.0/00-RELEASENOTES). +The most recent version of this library supports redis version [5.0](https://github.com/redis/redis/blob/5.0/00-RELEASENOTES), [6.0](https://github.com/redis/redis/blob/6.0/00-RELEASENOTES), [6.2](https://github.com/redis/redis/blob/6.2/00-RELEASENOTES), [7.0](https://github.com/redis/redis/blob/7.0/00-RELEASENOTES) and [7.2](https://github.com/redis/redis/blob/7.2/00-RELEASENOTES). The table below highlights version compatibility of the most-recent library versions and redis versions. diff --git a/dev_requirements.txt b/dev_requirements.txt index cdb3774ab6..3715599af0 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -8,7 +8,6 @@ packaging>=20.4 pytest==7.2.0 pytest-timeout==2.1.0 pytest-asyncio>=0.20.2 -tox==3.27.1 invoke==1.7.3 pytest-cov>=4.0.0 vulture>=2.3.0 diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index fd29d2f684..de645bd764 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -346,7 +346,7 @@ running. The third option runs an event loop in a separate thread. pubsub.run_in_thread() creates a new thread and starts the event loop. -The thread object is returned to the caller of [un_in_thread(). The +The thread object is returned to the caller of run_in_thread(). The caller can use the thread.stop() method to shut down the event loop and thread. Behind the scenes, this is simply a wrapper around get_message() that runs in a separate thread, essentially creating a tiny non-blocking diff --git a/docs/clustering.rst b/docs/clustering.rst index 9b4dee1c9f..f8320e4e59 100644 --- a/docs/clustering.rst +++ b/docs/clustering.rst @@ -92,7 +92,7 @@ The ‘target_nodes’ parameter is explained in the following section, >>> # target-node: default-node >>> rc.ping() -Specfiying Target Nodes +Specifying Target Nodes ----------------------- As mentioned above, all non key-based RedisCluster commands accept the diff --git a/docs/conf.py b/docs/conf.py index 8849752404..a201da2fc0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -86,7 +86,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ["_build", "**.ipynb_checkponts"] +exclude_patterns = ["_build", "**.ipynb_checkpoints"] # The reST default role (used for this markup: `text`) to use for all # documents. diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb index f7e67e2ca7..5eab4db1f7 100644 --- a/docs/examples/asyncio_examples.ipynb +++ b/docs/examples/asyncio_examples.ipynb @@ -15,7 +15,7 @@ "\n", "## Connecting and Disconnecting\n", "\n", - "Utilizing asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, a connection pool is created on `redis.Redis()` and attached to this `Redis` instance. The connection pool closes automatically on the call to `Redis.close` which disconnects all connections." + "Utilizing asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, a connection pool is created on `redis.Redis()` and attached to this `Redis` instance. The connection pool closes automatically on the call to `Redis.aclose` which disconnects all connections." ] }, { @@ -39,9 +39,29 @@ "source": [ "import redis.asyncio as redis\n", "\n", - "connection = redis.Redis()\n", - "print(f\"Ping successful: {await connection.ping()}\")\n", - "await connection.close()" + "client = redis.Redis()\n", + "print(f\"Ping successful: {await client.ping()}\")\n", + "await client.aclose()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you create custom `ConnectionPool` for the `Redis` instance to use alone, use the `from_pool` class method to create it. This will cause the pool to be disconnected along with the Redis instance. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "pool = redis.ConnectionPool.from_url(\"redis://localhost\")\n", + "client = redis.Redis.from_pool(pool)\n", + "await client.close()" ] }, { @@ -53,7 +73,8 @@ } }, "source": [ - "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + "\n", + "However, If you supply a `ConnectionPool` that is shared several `Redis` instances, you may want to disconnect the connection pool explicitly. use the `connection_pool` argument in that case." ] }, { @@ -69,10 +90,12 @@ "source": [ "import redis.asyncio as redis\n", "\n", - "connection = redis.Redis(auto_close_connection_pool=False)\n", - "await connection.close()\n", - "# Or: await connection.close(close_connection_pool=False)\n", - "await connection.connection_pool.disconnect()" + "pool = redis.ConnectionPool.from_url(\"redis://localhost\")\n", + "client1 = redis.Redis(connection_pool=pool)\n", + "client2 = redis.Redis(connection_pool=pool)\n", + "await client1.aclose()\n", + "await client2.aclose()\n", + "await pool.aclose()" ] }, { @@ -90,9 +113,9 @@ "source": [ "import redis.asyncio as redis\n", "\n", - "connection = redis.Redis(protocol=3)\n", - "await connection.close()\n", - "await connection.ping()" + "client = redis.Redis(protocol=3)\n", + "await client.aclose()\n", + "await client.ping()" ] }, { diff --git a/docs/examples/pipeline_examples.ipynb b/docs/examples/pipeline_examples.ipynb index 4e20375bfa..36ce31d708 100644 --- a/docs/examples/pipeline_examples.ipynb +++ b/docs/examples/pipeline_examples.ipynb @@ -123,7 +123,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The responses of the three commands are stored in a list. In the above example, the two first boolean indicates that the `set` commands were successfull and the last element of the list is the result of the `get(\"a\")` comand." + "The responses of the three commands are stored in a list. In the above example, the two first boolean indicates that the `set` commands were successful and the last element of the list is the result of the `get(\"a\")` comand." ] }, { diff --git a/docs/examples/redis-stream-example.ipynb b/docs/examples/redis-stream-example.ipynb index a84bf19cb6..eb1f2e9a20 100644 --- a/docs/examples/redis-stream-example.ipynb +++ b/docs/examples/redis-stream-example.ipynb @@ -652,7 +652,7 @@ "metadata": {}, "source": [ "## delete all\n", - "To remove the messages with need to remote them explicitly with `xdel`." + "To remove the messages with need to remove them explicitly with `xdel`." ] }, { diff --git a/docs/examples/search_json_examples.ipynb b/docs/examples/search_json_examples.ipynb index b66e3361c7..9ce1efc0ec 100644 --- a/docs/examples/search_json_examples.ipynb +++ b/docs/examples/search_json_examples.ipynb @@ -34,7 +34,7 @@ "from redis.commands.search.query import NumericFilter, Query\n", "\n", "\n", - "r = redis.Redis(host='localhost', port=36379)\n", + "r = redis.Redis(host='localhost', port=6379)\n", "user1 = {\n", " \"user\":{\n", " \"name\": \"Paul John\",\n", @@ -59,9 +59,19 @@ " \"city\": \"Tel Aviv\"\n", " }\n", "}\n", + "\n", + "user4 = {\n", + " \"user\":{\n", + " \"name\": \"Sarah Zamir\",\n", + " \"email\": \"sarah.zamir@example.com\",\n", + " \"age\": 30,\n", + " \"city\": \"Paris\"\n", + " }\n", + "}\n", "r.json().set(\"user:1\", Path.root_path(), user1)\n", "r.json().set(\"user:2\", Path.root_path(), user2)\n", "r.json().set(\"user:3\", Path.root_path(), user3)\n", + "r.json().set(\"user:4\", Path.root_path(), user4)\n", "\n", "schema = (TextField(\"$.user.name\", as_name=\"name\"),TagField(\"$.user.city\", as_name=\"city\"), NumericField(\"$.user.age\", as_name=\"age\"))\n", "r.ft().create_index(schema, definition=IndexDefinition(prefix=[\"user:\"], index_type=IndexType.JSON))" @@ -102,6 +112,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -133,13 +144,72 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Projecting using JSON Path expressions " + "### Paginating and Ordering search Results" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Result{4 total, docs: [Document {'id': 'user:1', 'payload': None, 'age': '42', 'json': '{\"user\":{\"name\":\"Paul John\",\"email\":\"paul.john@example.com\",\"age\":42,\"city\":\"London\"}}'}, Document {'id': 'user:3', 'payload': None, 'age': '35', 'json': '{\"user\":{\"name\":\"Paul Zamir\",\"email\":\"paul.zamir@example.com\",\"age\":35,\"city\":\"Tel Aviv\"}}'}]}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Search for all users, returning 2 users at a time and sorting by age in descending order\n", + "offset = 0\n", + "num = 2\n", + "q = Query(\"*\").paging(offset, num).sort_by(\"age\", asc=False) # pass asc=True to sort in ascending order\n", + "r.ft().search(q)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Counting the total number of Items" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q = Query(\"*\").paging(0, 0)\n", + "r.ft().search(q).total" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Projecting using JSON Path expressions " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, "outputs": [ { "data": { @@ -148,7 +218,7 @@ " Document {'id': 'user:3', 'payload': None, 'city': 'Tel Aviv'}]" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -166,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -175,7 +245,7 @@ "[[b'age', b'35'], [b'age', b'42']]" ] }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -184,6 +254,36 @@ "req = aggregations.AggregateRequest(\"Paul\").sort_by(\"@age\")\n", "r.ft().aggregate(req).rows" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Count the total number of Items" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[b'total', b'4']]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The group_by expects a string or list of strings to group the results before applying the aggregation function to\n", + "# each group. Passing an empty list here acts as `GROUPBY 0` which applies the aggregation function to the whole results\n", + "req = aggregations.AggregateRequest(\"*\").group_by([], reducers.count().alias(\"total\"))\n", + "r.ft().aggregate(req).rows" + ] } ], "metadata": { @@ -191,9 +291,9 @@ "hash": "d45c99ba0feda92868abafa8257cbb4709c97f1a0b5dc62bbeebdf89d4fad7fe" }, "kernelspec": { - "display_name": "Python 3.8.12 64-bit ('venv': venv)", + "display_name": "redis-py", "language": "python", - "name": "python3" + "name": "redis-py" }, "language_info": { "codemirror_mode": { @@ -205,10 +305,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" - }, - "orig_nbformat": 4 + "version": "3.11.3" + } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index f77296df6a..8e59249bef 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -46,7 +46,6 @@ class BaseParser(ABC): - EXCEPTION_CLASSES = { "ERR": { "max number of clients reached": ConnectionError, @@ -138,12 +137,6 @@ def __init__(self, socket_read_size: int): self._stream: Optional[StreamReader] = None self._read_size = socket_read_size - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - async def can_read_destructive(self) -> bool: raise NotImplementedError() diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py index fb5da831fe..bdd749a5bc 100644 --- a/redis/_parsers/helpers.py +++ b/redis/_parsers/helpers.py @@ -322,7 +322,7 @@ def float_or_none(response): return float(response) -def bool_ok(response): +def bool_ok(response, **options): return str_if_bytes(response) == "OK" diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index b3247b71ec..a52dbbd013 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -1,15 +1,13 @@ import asyncio import socket import sys -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, TypedDict, Union if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout else: from async_timeout import timeout as async_timeout -from redis.compat import TypedDict - from ..exceptions import ConnectionError, InvalidResponse, RedisError from ..typing import EncodableT from ..utils import HIREDIS_AVAILABLE @@ -198,10 +196,16 @@ async def read_response( if not self._connected: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - response = self._reader.gets() + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() while response is False: await self.read_from_socket() - response = self._reader.gets() + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() # if the response is a ConnectionError or the response is a list and # the first item is a ConnectionError, raise it as something bad diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 1275686710..13aa1ffccb 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -6,15 +6,18 @@ from .base import _AsyncRESPBase, _RESPBase from .socket import SERVER_CLOSED_CONNECTION_ERROR +_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] + class _RESP3Parser(_RESPBase): """RESP3 protocol implementation""" def __init__(self, socket_read_size): super().__init__(socket_read_size) - self.push_handler_func = self.handle_push_response + self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.invalidations_push_handler_func = None - def handle_push_response(self, response): + def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response @@ -96,8 +99,9 @@ def _read_response(self, disable_decoding=False, push_request=False): pass # map response elif byte == b"%": - # we use this approach and not dict comprehension here - # because this dict comprehension fails in python 3.7 + # We cannot use a dict-comprehension to parse stream. + # Evaluation order of key:val expression in dict comprehension only + # became defined to be left-right in version 3.8 resp_dict = {} for _ in range(int(response)): key = self._read_response(disable_decoding=disable_decoding) @@ -113,13 +117,7 @@ def _read_response(self, disable_decoding=False, push_request=False): ) for _ in range(int(response)) ] - res = self.push_handler_func(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + self.handle_push_response(response, disable_decoding, push_request) else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -127,16 +125,32 @@ def _read_response(self, disable_decoding=False, push_request=False): response = self.encoder.decode(response) return response - def set_push_handler(self, push_handler_func): - self.push_handler_func = push_handler_func + def handle_push_response(self, response, disable_decoding, push_request): + if response[0] in _INVALIDATION_MESSAGE: + res = self.invalidation_push_handler_func(response) + else: + res = self.pubsub_push_handler_func(response) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + + def set_pubsub_push_handler(self, pubsub_push_handler_func): + self.pubsub_push_handler_func = pubsub_push_handler_func + + def set_invalidation_push_handler(self, invalidations_push_handler_func): + self.invalidation_push_handler_func = invalidations_push_handler_func class _AsyncRESP3Parser(_AsyncRESPBase): def __init__(self, socket_read_size): super().__init__(socket_read_size) - self.push_handler_func = self.handle_push_response + self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.invalidations_push_handler_func = None - def handle_push_response(self, response): + def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response @@ -225,12 +239,16 @@ async def _read_response( pass # map response elif byte == b"%": - response = { - (await self._read_response(disable_decoding=disable_decoding)): ( - await self._read_response(disable_decoding=disable_decoding) + # We cannot use a dict-comprehension to parse stream. + # Evaluation order of key:val expression in dict comprehension only + # became defined to be left-right in version 3.8 + resp_dict = {} + for _ in range(int(response)): + key = await self._read_response(disable_decoding=disable_decoding) + resp_dict[key] = await self._read_response( + disable_decoding=disable_decoding, push_request=push_request ) - for _ in range(int(response)) - } + response = resp_dict # push response elif byte == b">": response = [ @@ -241,15 +259,7 @@ async def _read_response( ) for _ in range(int(response)) ] - res = self.push_handler_func(response) - if not push_request: - return await ( - self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - ) - else: - return res + await self.handle_push_response(response, disable_decoding, push_request) else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -257,5 +267,20 @@ async def _read_response( response = self.encoder.decode(response) return response - def set_push_handler(self, push_handler_func): - self.push_handler_func = push_handler_func + async def handle_push_response(self, response, disable_decoding, push_request): + if response[0] in _INVALIDATION_MESSAGE: + res = self.invalidation_push_handler_func(response) + else: + res = self.pubsub_push_handler_func(response) + if not push_request: + return await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + + def set_pubsub_push_handler(self, pubsub_push_handler_func): + self.pubsub_push_handler_func = pubsub_push_handler_func + + def set_invalidation_push_handler(self, invalidations_push_handler_func): + self.invalidation_push_handler_func = invalidations_push_handler_func diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index f0c1ab7536..79689fcb5e 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -14,11 +14,12 @@ List, Mapping, MutableMapping, - NoReturn, Optional, + Protocol, Set, Tuple, Type, + TypedDict, TypeVar, Union, cast, @@ -38,6 +39,12 @@ ) from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry +from redis.cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalCache, +) from redis.client import ( EMPTY_RESPONSE, NEVER_DECODE, @@ -50,7 +57,6 @@ AsyncSentinelCommands, list_or_args, ) -from redis.compat import Protocol, TypedDict from redis.credentials import CredentialProvider from redis.exceptions import ( ConnectionError, @@ -61,10 +67,11 @@ TimeoutError, WatchError, ) -from redis.typing import ChannelT, EncodableT, KeyT +from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT from redis.utils import ( HIREDIS_AVAILABLE, _set_info_logger, + deprecated_function, get_lib_version, safe_str, str_if_bytes, @@ -114,7 +121,7 @@ def from_url( cls, url: str, single_connection_client: bool = False, - auto_close_connection_pool: bool = True, + auto_close_connection_pool: Optional[bool] = None, **kwargs, ): """ @@ -160,12 +167,39 @@ class initializer. In the case of conflicting arguments, querystring """ connection_pool = ConnectionPool.from_url(url, **kwargs) - redis = cls( + client = cls( connection_pool=connection_pool, single_connection_client=single_connection_client, ) - redis.auto_close_connection_pool = auto_close_connection_pool - return redis + if auto_close_connection_pool is not None: + warnings.warn( + DeprecationWarning( + '"auto_close_connection_pool" is deprecated ' + "since version 5.0.1. " + "Please create a ConnectionPool explicitly and " + "provide to the Redis() constructor instead." + ) + ) + else: + auto_close_connection_pool = True + client.auto_close_connection_pool = auto_close_connection_pool + return client + + @classmethod + def from_pool( + cls: Type["Redis"], + connection_pool: ConnectionPool, + ) -> "Redis": + """ + Return a Redis client from the given connection pool. + The Redis client will take ownership of the connection pool and + close it when the Redis client is closed. + """ + client = cls( + connection_pool=connection_pool, + ) + client.auto_close_connection_pool = True + return client def __init__( self, @@ -200,10 +234,17 @@ def __init__( lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, retry: Optional[Retry] = None, - auto_close_connection_pool: bool = True, + auto_close_connection_pool: Optional[bool] = None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + cache_enable: bool = False, + client_cache: Optional[_LocalCache] = None, + cache_max_size: int = 100, + cache_ttl: int = 0, + cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_blacklist: List[str] = DEFAULT_BLACKLIST, + cache_whitelist: List[str] = DEFAULT_WHITELIST, ): """ Initialize a new Redis client. @@ -214,13 +255,22 @@ def __init__( """ kwargs: Dict[str, Any] # auto_close_connection_pool only has an effect if connection_pool is - # None. This is a similar feature to the missing __del__ to resolve #1103, - # but it accounts for whether a user wants to manually close the connection - # pool, as a similar feature to ConnectionPool's __del__. - self.auto_close_connection_pool = ( - auto_close_connection_pool if connection_pool is None else False - ) + # None. It is assumed that if connection_pool is not None, the user + # wants to manage the connection pool themselves. + if auto_close_connection_pool is not None: + warnings.warn( + DeprecationWarning( + '"auto_close_connection_pool" is deprecated ' + "since version 5.0.1. " + "Please create a ConnectionPool explicitly and " + "provide to the Redis() constructor instead." + ) + ) + else: + auto_close_connection_pool = True + if not connection_pool: + # Create internal connection pool, expected to be closed by Redis instance if not retry_on_error: retry_on_error = [] if retry_on_timeout is True: @@ -277,7 +327,13 @@ def __init__( "ssl_check_hostname": ssl_check_hostname, } ) + # This arg only used if no pool is passed in + self.auto_close_connection_pool = auto_close_connection_pool connection_pool = ConnectionPool(**kwargs) + else: + # If a pool is passed in, do not close it + self.auto_close_connection_pool = False + self.connection_pool = connection_pool self.single_connection_client = single_connection_client self.connection: Optional[Connection] = None @@ -294,8 +350,21 @@ def __init__( # on a set of redis commands self._single_conn_lock = asyncio.Lock() + self.client_cache = client_cache + if cache_enable: + self.client_cache = _LocalCache( + cache_max_size, cache_ttl, cache_eviction_policy + ) + if self.client_cache is not None: + self.cache_blacklist = cache_blacklist + self.cache_whitelist = cache_whitelist + self.client_cache_initialized = False + def __repr__(self): - return f"{self.__class__.__name__}<{self.connection_pool!r}>" + return ( + f"<{self.__class__.__module__}.{self.__class__.__name__}" + f"({self.connection_pool!r})>" + ) def __await__(self): return self.initialize().__await__() @@ -305,6 +374,10 @@ async def initialize(self: _RedisT) -> _RedisT: async with self._single_conn_lock: if self.connection is None: self.connection = await self.connection_pool.get_connection("_") + if self.client_cache is not None: + self.connection._parser.set_invalidation_push_handler( + self._cache_invalidation_process + ) return self def set_response_callback(self, command: str, callback: ResponseCallbackT): @@ -486,19 +559,27 @@ async def __aenter__(self: _RedisT) -> _RedisT: return await self.initialize() async def __aexit__(self, exc_type, exc_value, traceback): - await self.close() + await self.aclose() _DEL_MESSAGE = "Unclosed Redis client" - def __del__(self, _warnings: Any = warnings) -> None: + # passing _warnings and _grl as argument default since they may be gone + # by the time __del__ is called at shutdown + def __del__( + self, + _warn: Any = warnings.warn, + _grl: Any = asyncio.get_running_loop, + ) -> None: if hasattr(self, "connection") and (self.connection is not None): - _warnings.warn( - f"Unclosed client session {self!r}", ResourceWarning, source=self - ) - context = {"client": self, "message": self._DEL_MESSAGE} - asyncio.get_running_loop().call_exception_handler(context) + _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + _grl().call_exception_handler(context) + except RuntimeError: + pass + self.connection._close() - async def close(self, close_connection_pool: Optional[bool] = None) -> None: + async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: """ Closes Redis client connection @@ -515,6 +596,15 @@ async def close(self, close_connection_pool: Optional[bool] = None) -> None: close_connection_pool is None and self.auto_close_connection_pool ): await self.connection_pool.disconnect() + if self.client_cache: + self.client_cache.flush() + + @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") + async def close(self, close_connection_pool: Optional[bool] = None) -> None: + """ + Alias for aclose(), for backwards compatibility + """ + await self.aclose(close_connection_pool) async def _send_command_parse_response(self, conn, command_name, *args, **options): """ @@ -536,28 +626,95 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): ): raise error + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ + if data[1] is not None: + for key in data[1]: + self.client_cache.invalidate(str_if_bytes(key)) + else: + self.client_cache.flush() + + async def _get_from_local_cache(self, command: str): + """ + If the command is in the local cache, return the response + """ + if ( + self.client_cache is None + or command[0] in self.cache_blacklist + or command[0] not in self.cache_whitelist + ): + return None + while not self.connection._is_socket_empty(): + await self.connection.read_response(push_request=True) + return self.client_cache.get(command) + + def _add_to_local_cache( + self, command: Tuple[str], response: ResponseT, keys: List[KeysT] + ): + """ + Add the command and response to the local cache if the command + is allowed to be cached + """ + if ( + self.client_cache is not None + and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) + and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + ): + self.client_cache.set(command, response, keys) + + def delete_from_local_cache(self, command: str): + """ + Delete the command from the local cache + """ + try: + self.client_cache.delete(command) + except AttributeError: + pass + # COMMAND EXECUTION AND PROTOCOL PARSING async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() - pool = self.connection_pool command_name = args[0] - conn = self.connection or await pool.get_connection(command_name, **options) + keys = options.pop("keys", None) # keys are used only for client side caching + response_from_cache = await self._get_from_local_cache(args) + if response_from_cache is not None: + return response_from_cache + else: + pool = self.connection_pool + conn = self.connection or await pool.get_connection(command_name, **options) - if self.single_connection_client: - await self._single_conn_lock.acquire() - try: - return await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - finally: if self.single_connection_client: - self._single_conn_lock.release() - if not self.connection: - await pool.release(conn) + await self._single_conn_lock.acquire() + try: + if self.client_cache is not None and not self.client_cache_initialized: + await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, "CLIENT", *("CLIENT", "TRACKING", "ON") + ), + lambda error: self._disconnect_raise(conn, error), + ) + self.client_cache_initialized = True + response = await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + self._add_to_local_cache(args, response, keys) + return response + finally: + if self.single_connection_client: + self._single_conn_lock.release() + if not self.connection: + await pool.release(conn) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options @@ -604,7 +761,7 @@ class Monitor: listen() method yields commands from monitor. """ - monitor_re = re.compile(r"\[(\d+) (.*)\] (.*)") + monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)") command_re = re.compile(r'"(.*?)(? Awaitable[NoReturn]: - # In case a connection property does not yet exist - # (due to a crash earlier in the Redis() constructor), return - # immediately as there is nothing to clean-up. - if not hasattr(self, "connection"): - return - return self.reset() + @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") + async def close(self) -> None: + """Alias for aclose(), for backwards compatibility""" + await self.aclose() + + @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset") + async def reset(self) -> None: + """Alias for aclose(), for backwards compatibility""" + await self.aclose() async def on_connect(self, connection: Connection): """Re-subscribe to any channels and patterns previously subscribed to""" @@ -798,7 +962,7 @@ async def connect(self): else: await self.connection.connect() if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) async def _disconnect_raise_connect(self, conn, error): """ @@ -1191,6 +1355,10 @@ async def reset(self): await self.connection_pool.release(self.connection) self.connection = None + async def aclose(self) -> None: + """Alias for reset(), a standard method name for cleanup""" + await self.reset() + def multi(self): """ Start a transactional block of the pipeline after WATCH commands @@ -1207,6 +1375,7 @@ def multi(self): def execute_command( self, *args, **kwargs ) -> Union["Pipeline", Awaitable["Pipeline"]]: + kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) @@ -1223,14 +1392,14 @@ async def _disconnect_reset_raise(self, conn, error): # valid since this connection has died. raise a WatchError, which # indicates the user should retry this transaction. if self.watching: - await self.reset() + await self.aclose() raise WatchError( "A ConnectionError occurred on while watching one or more keys" ) # if retry_on_timeout is not set, or the error is not # a TimeoutError, raise it if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): - await self.reset() + await self.aclose() raise async def immediate_execute_command(self, *args, **options): diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 84407116ed..6a1753ad19 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -62,7 +62,13 @@ TryAgainError, ) from redis.typing import AnyKeyT, EncodableT, KeyT -from redis.utils import dict_merge, get_lib_version, safe_str, str_if_bytes +from redis.utils import ( + deprecated_function, + dict_merge, + get_lib_version, + safe_str, + str_if_bytes, +) TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -395,12 +401,12 @@ async def initialize(self) -> "RedisCluster": ) self._initialize = False except BaseException: - await self.nodes_manager.close() - await self.nodes_manager.close("startup_nodes") + await self.nodes_manager.aclose() + await self.nodes_manager.aclose("startup_nodes") raise return self - async def close(self) -> None: + async def aclose(self) -> None: """Close all connections & client if initialized.""" if not self._initialize: if not self._lock: @@ -408,28 +414,37 @@ async def close(self) -> None: async with self._lock: if not self._initialize: self._initialize = True - await self.nodes_manager.close() - await self.nodes_manager.close("startup_nodes") + await self.nodes_manager.aclose() + await self.nodes_manager.aclose("startup_nodes") + + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") + async def close(self) -> None: + """alias for aclose() for backwards compatibility""" + await self.aclose() async def __aenter__(self) -> "RedisCluster": return await self.initialize() async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: - await self.close() + await self.aclose() def __await__(self) -> Generator[Any, None, "RedisCluster"]: return self.initialize().__await__() _DEL_MESSAGE = "Unclosed RedisCluster client" - def __del__(self) -> None: + def __del__( + self, + _warn: Any = warnings.warn, + _grl: Any = asyncio.get_running_loop, + ) -> None: if hasattr(self, "_initialize") and not self._initialize: - warnings.warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) + _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) try: context = {"client": self, "message": self._DEL_MESSAGE} - asyncio.get_running_loop().call_exception_handler(context) + _grl().call_exception_handler(context) except RuntimeError: - ... + pass async def on_connect(self, connection: Connection) -> None: await connection.on_connect() @@ -588,13 +603,13 @@ async def _determine_slot(self, command: str, *args: Any) -> int: # EVAL/EVALSHA. # - issue: https://github.com/redis/redis/issues/9493 # - fix: https://github.com/redis/redis/pull/9733 - if command in ("EVAL", "EVALSHA"): + if command.upper() in ("EVAL", "EVALSHA"): # command syntax: EVAL "script body" num_keys ... if len(args) < 2: raise RedisClusterException( f"Invalid args in command: {command, *args}" ) - keys = args[2 : 2 + args[1]] + keys = args[2 : 2 + int(args[1])] # if there are 0 keys, that means the script can be run on any node # so we can just return a random slot if not keys: @@ -604,7 +619,7 @@ async def _determine_slot(self, command: str, *args: Any) -> int: if not keys: # FCALL can call a function with 0 keys, that means the function # can be run on any node so we can just return a random slot - if command in ("FCALL", "FCALL_RO"): + if command.upper() in ("FCALL", "FCALL_RO"): return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) raise RedisClusterException( "No way to dispatch this command to Redis Cluster. " @@ -667,6 +682,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: :raises RedisClusterException: if target_nodes is not provided & the command can't be mapped to a slot """ + kwargs.pop("keys", None) # the keys are used only for client side caching command = args[0] target_nodes = [] target_nodes_specified = False @@ -767,13 +783,13 @@ async def _execute_command( self.nodes_manager.startup_nodes.pop(target_node.name, None) # Hard force of reinitialize of the node/slots setup # and try again with the new setup - await self.close() + await self.aclose() raise except ClusterDownError: # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command - await self.close() + await self.aclose() await asyncio.sleep(0.25) raise except MovedError as e: @@ -790,7 +806,7 @@ async def _execute_command( self.reinitialize_steps and self.reinitialize_counter % self.reinitialize_steps == 0 ): - await self.close() + await self.aclose() # Reset the counter self.reinitialize_counter = 0 else: @@ -958,17 +974,20 @@ def __eq__(self, obj: Any) -> bool: _DEL_MESSAGE = "Unclosed ClusterNode object" - def __del__(self) -> None: + def __del__( + self, + _warn: Any = warnings.warn, + _grl: Any = asyncio.get_running_loop, + ) -> None: for connection in self._connections: if connection.is_connected: - warnings.warn( - f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self - ) + _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) + try: context = {"client": self, "message": self._DEL_MESSAGE} - asyncio.get_running_loop().call_exception_handler(context) + _grl().call_exception_handler(context) except RuntimeError: - ... + pass break async def disconnect(self) -> None: @@ -1117,13 +1136,13 @@ def set_nodes( if remove_old: for name in list(old.keys()): if name not in new: - asyncio.create_task(old.pop(name).disconnect()) + task = asyncio.create_task(old.pop(name).disconnect()) # noqa for name, node in new.items(): if name in old: if old[name] is node: continue - asyncio.create_task(old[name].disconnect()) + task = asyncio.create_task(old[name].disconnect()) # noqa old[name] = node def _update_moved_slots(self) -> None: @@ -1323,7 +1342,7 @@ async def initialize(self) -> None: # If initialize was called after a MovedError, clear it self._moved_exception = None - async def close(self, attr: str = "nodes_cache") -> None: + async def aclose(self, attr: str = "nodes_cache") -> None: self.default_node = None await asyncio.gather( *( @@ -1410,7 +1429,8 @@ def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: self._command_stack = [] def __bool__(self) -> bool: - return bool(self._command_stack) + "Pipeline instances should always evaluate to True on Python 3+" + return True def __len__(self) -> int: return len(self._command_stack) @@ -1429,6 +1449,7 @@ def execute_command( or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - Rest of the kwargs are passed to the Redis connection """ + kwargs.pop("keys", None) # the keys are used only for client side caching self._command_stack.append( PipelineCommand(len(self._command_stack), *args, **kwargs) ) @@ -1471,7 +1492,7 @@ async def execute( if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # Try again with the new cluster setup. exception = e - await self._client.close() + await self._client.aclose() await asyncio.sleep(0.25) else: # All other errors should be raised. diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index c1cc1d310c..df2bd20f9f 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -2,11 +2,10 @@ import copy import enum import inspect -import os import socket import ssl import sys -import threading +import warnings import weakref from abc import abstractmethod from itertools import chain @@ -18,9 +17,11 @@ List, Mapping, Optional, + Protocol, Set, Tuple, Type, + TypedDict, TypeVar, Union, ) @@ -35,13 +36,11 @@ from redis.asyncio.retry import Retry from redis.backoff import NoBackoff -from redis.compat import Protocol, TypedDict from redis.connection import DEFAULT_RESP_VERSION from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - ChildDeadlockedError, ConnectionError, DataError, RedisError, @@ -97,7 +96,6 @@ class AbstractConnection: """Manages communication to and from a Redis server""" __slots__ = ( - "pid", "db", "username", "client_name", @@ -158,7 +156,6 @@ def __init__( "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) - self.pid = os.getpid() self.db = db self.client_name = client_name self.lib_name = lib_name @@ -209,9 +206,27 @@ def __init__( raise ConnectionError("protocol must be either 2 or 3") self.protocol = protocol + def __del__(self, _warnings: Any = warnings): + # For some reason, the individual streams don't get properly garbage + # collected and therefore produce no resource warnings. We add one + # here, in the same style as those from the stdlib. + if getattr(self, "_writer", None): + _warnings.warn( + f"unclosed Connection {self!r}", ResourceWarning, source=self + ) + self._close() + + def _close(self): + """ + Internal method to silently close the connection without waiting + """ + if self._writer: + self._writer.close() + self._writer = self._reader = None + def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) - return f"{self.__class__.__name__}<{repr_args}>" + return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" @abstractmethod def repr_pieces(self): @@ -222,10 +237,27 @@ def is_connected(self): return self._reader is not None and self._writer is not None def register_connect_callback(self, callback): - self._connect_callbacks.append(weakref.WeakMethod(callback)) + """ + Register a callback to be called when the connection is established either + initially or reconnected. This allows listeners to issue commands that + are ephemeral to the connection, for example pub/sub subscription or + key tracking. The callback must be a _method_ and will be kept as + a weak reference. + """ + wm = weakref.WeakMethod(callback) + if wm not in self._connect_callbacks: + self._connect_callbacks.append(wm) - def clear_connect_callbacks(self): - self._connect_callbacks = [] + def deregister_connect_callback(self, callback): + """ + De-register a previously registered callback. It will no-longer receive + notifications on connection events. Calling this is not required when the + listener goes away, since the callbacks are kept as weak methods. + """ + try: + self._connect_callbacks.remove(weakref.WeakMethod(callback)) + except ValueError: + pass def set_parser(self, parser_class: Type[BaseParser]) -> None: """ @@ -268,6 +300,8 @@ async def connect(self): # run any user callbacks. right now the only internal callback # is for pubsub channel/pattern resubscription + # first, remove any dead weakrefs + self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] for ref in self._connect_callbacks: callback = ref() task = callback(self) @@ -381,12 +415,11 @@ async def disconnect(self, nowait: bool = False) -> None: if not self.is_connected: return try: - if os.getpid() == self.pid: - self._writer.close() # type: ignore[union-attr] - # wait for close to finish, except when handling errors and - # forcefully disconnecting. - if not nowait: - await self._writer.wait_closed() # type: ignore[union-attr] + self._writer.close() # type: ignore[union-attr] + # wait for close to finish, except when handling errors and + # forcefully disconnecting. + if not nowait: + await self._writer.wait_closed() # type: ignore[union-attr] except OSError: pass finally: @@ -613,6 +646,10 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] output.append(SYM_EMPTY.join(pieces)) return output + def _is_socket_empty(self): + """Check if the socket is empty""" + return not self._reader.at_eof() + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -860,6 +897,7 @@ def to_bool(value) -> Optional[bool]: "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, + "timeout": float, } ) @@ -1004,125 +1042,49 @@ def __init__( self.connection_kwargs = connection_kwargs self.max_connections = max_connections - # a lock to protect the critical section in _checkpid(). - # this lock is acquired when the process id changes, such as - # after a fork. during this time, multiple threads in the child - # process could attempt to acquire this lock. the first thread - # to acquire the lock will reset the data structures and lock - # object of this pool. subsequent threads acquiring this lock - # will notice the first thread already did the work and simply - # release the lock. - self._fork_lock = threading.Lock() - self._lock = asyncio.Lock() - self._created_connections: int - self._available_connections: List[AbstractConnection] - self._in_use_connections: Set[AbstractConnection] - self.reset() # lgtm [py/init-calls-subclass] + self._available_connections: List[AbstractConnection] = [] + self._in_use_connections: Set[AbstractConnection] = set() self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) def __repr__(self): return ( - f"{self.__class__.__name__}" - f"<{self.connection_class(**self.connection_kwargs)!r}>" + f"<{self.__class__.__module__}.{self.__class__.__name__}" + f"({self.connection_class(**self.connection_kwargs)!r})>" ) def reset(self): - self._lock = asyncio.Lock() - self._created_connections = 0 self._available_connections = [] - self._in_use_connections = set() - - # this must be the last operation in this method. while reset() is - # called when holding _fork_lock, other threads in this process - # can call _checkpid() which compares self.pid and os.getpid() without - # holding any lock (for performance reasons). keeping this assignment - # as the last operation ensures that those other threads will also - # notice a pid difference and block waiting for the first thread to - # release _fork_lock. when each of these threads eventually acquire - # _fork_lock, they will notice that another thread already called - # reset() and they will immediately release _fork_lock and continue on. - self.pid = os.getpid() - - def _checkpid(self): - # _checkpid() attempts to keep ConnectionPool fork-safe on modern - # systems. this is called by all ConnectionPool methods that - # manipulate the pool's state such as get_connection() and release(). - # - # _checkpid() determines whether the process has forked by comparing - # the current process id to the process id saved on the ConnectionPool - # instance. if these values are the same, _checkpid() simply returns. - # - # when the process ids differ, _checkpid() assumes that the process - # has forked and that we're now running in the child process. the child - # process cannot use the parent's file descriptors (e.g., sockets). - # therefore, when _checkpid() sees the process id change, it calls - # reset() in order to reinitialize the child's ConnectionPool. this - # will cause the child to make all new connection objects. - # - # _checkpid() is protected by self._fork_lock to ensure that multiple - # threads in the child process do not call reset() multiple times. - # - # there is an extremely small chance this could fail in the following - # scenario: - # 1. process A calls _checkpid() for the first time and acquires - # self._fork_lock. - # 2. while holding self._fork_lock, process A forks (the fork() - # could happen in a different thread owned by process A) - # 3. process B (the forked child process) inherits the - # ConnectionPool's state from the parent. that state includes - # a locked _fork_lock. process B will not be notified when - # process A releases the _fork_lock and will thus never be - # able to acquire the _fork_lock. - # - # to mitigate this possible deadlock, _checkpid() will only wait 5 - # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in - # that time it is assumed that the child is deadlocked and a - # redis.ChildDeadlockedError error is raised. - if self.pid != os.getpid(): - acquired = self._fork_lock.acquire(timeout=5) - if not acquired: - raise ChildDeadlockedError - # reset() the instance for the new process if another thread - # hasn't already done so - try: - if self.pid != os.getpid(): - self.reset() - finally: - self._fork_lock.release() + self._in_use_connections = weakref.WeakSet() - async def get_connection(self, command_name, *keys, **options): - """Get a connection from the pool""" - self._checkpid() - async with self._lock: - try: - connection = self._available_connections.pop() - except IndexError: - connection = self.make_connection() - self._in_use_connections.add(connection) + def can_get_connection(self) -> bool: + """Return True if a connection can be retrieved from the pool.""" + return ( + self._available_connections + or len(self._in_use_connections) < self.max_connections + ) + async def get_connection(self, command_name, *keys, **options): + """Get a connected connection from the pool""" + connection = self.get_available_connection() try: - # ensure this connection is connected to Redis - await connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the - # pool before all data has been read or the socket has been - # closed. either way, reconnect and verify everything is good. - try: - if await connection.can_read_destructive(): - raise ConnectionError("Connection has data") from None - except (ConnectionError, OSError): - await connection.disconnect() - await connection.connect() - if await connection.can_read_destructive(): - raise ConnectionError("Connection not ready") from None + await self.ensure_connection(connection) except BaseException: - # release the connection back to the pool so that we don't - # leak it await self.release(connection) raise return connection + def get_available_connection(self): + """Get a connection from the pool, without making sure it is connected""" + try: + connection = self._available_connections.pop() + except IndexError: + if len(self._in_use_connections) >= self.max_connections: + raise ConnectionError("Too many connections") from None + connection = self.make_connection() + self._in_use_connections.add(connection) + return connection + def get_encoder(self): """Return an encoder based on encoding settings""" kwargs = self.connection_kwargs @@ -1133,35 +1095,31 @@ def get_encoder(self): ) def make_connection(self): - """Create a new connection""" - if self._created_connections >= self.max_connections: - raise ConnectionError("Too many connections") - self._created_connections += 1 + """Create a new connection. Can be overridden by child classes.""" return self.connection_class(**self.connection_kwargs) + async def ensure_connection(self, connection: AbstractConnection): + """Ensure that the connection object is connected and valid""" + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read_destructive(): + raise ConnectionError("Connection has data") from None + except (ConnectionError, OSError): + await connection.disconnect() + await connection.connect() + if await connection.can_read_destructive(): + raise ConnectionError("Connection not ready") from None + async def release(self, connection: AbstractConnection): """Releases the connection back to the pool""" - self._checkpid() - async with self._lock: - try: - self._in_use_connections.remove(connection) - except KeyError: - # Gracefully fail when a connection is returned to this pool - # that the pool doesn't actually own - pass - - if self.owns_connection(connection): - self._available_connections.append(connection) - else: - # pool doesn't own this connection. do not add it back - # to the pool and decrement the count so that another - # connection can take its place if needed - self._created_connections -= 1 - await connection.disconnect() - return - - def owns_connection(self, connection: AbstractConnection): - return connection.pid == self.pid + # Connections should always be returned to the correct pool, + # not doing so is an error that will cause an exception here. + self._in_use_connections.remove(connection) + self._available_connections.append(connection) async def disconnect(self, inuse_connections: bool = True): """ @@ -1171,21 +1129,23 @@ async def disconnect(self, inuse_connections: bool = True): current in use, potentially by other tasks. Otherwise only disconnect connections that are idle in the pool. """ - self._checkpid() - async with self._lock: - if inuse_connections: - connections: Iterable[AbstractConnection] = chain( - self._available_connections, self._in_use_connections - ) - else: - connections = self._available_connections - resp = await asyncio.gather( - *(connection.disconnect() for connection in connections), - return_exceptions=True, + if inuse_connections: + connections: Iterable[AbstractConnection] = chain( + self._available_connections, self._in_use_connections ) - exc = next((r for r in resp if isinstance(r, BaseException)), None) - if exc: - raise exc + else: + connections = self._available_connections + resp = await asyncio.gather( + *(connection.disconnect() for connection in connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc + + async def aclose(self) -> None: + """Close the pool, disconnecting all connections""" + await self.disconnect() def set_retry(self, retry: "Retry") -> None: for conn in self._available_connections: @@ -1196,21 +1156,21 @@ def set_retry(self, retry: "Retry") -> None: class BlockingConnectionPool(ConnectionPool): """ - Thread-safe blocking connection pool:: + A blocking connection pool:: - >>> from redis.client import Redis - >>> client = Redis(connection_pool=BlockingConnectionPool()) + >>> from redis.asyncio import Redis, BlockingConnectionPool + >>> client = Redis.from_pool(BlockingConnectionPool()) It performs the same function as the default - :py:class:`~redis.ConnectionPool` implementation, in that, + :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, it maintains a pool of reusable connections that can be shared by - multiple redis clients (safely across threads if required). + multiple async redis clients. The difference is that, in the event that a client tries to get a connection from the pool when all of connections are in use, rather than raising a :py:class:`~redis.ConnectionError` (as the default - :py:class:`~redis.ConnectionPool` implementation does), it - makes the client wait ("blocks") for a specified number of seconds until + :py:class:`~redis.asyncio.ConnectionPool` implementation does), it + makes blocks the current `Task` for a specified number of seconds until a connection becomes available. Use ``max_connections`` to increase / decrease the pool size:: @@ -1233,131 +1193,37 @@ def __init__( max_connections: int = 50, timeout: Optional[int] = 20, connection_class: Type[AbstractConnection] = Connection, - queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, + queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated **connection_kwargs, ): - - self.queue_class = queue_class - self.timeout = timeout - self._connections: List[AbstractConnection] super().__init__( connection_class=connection_class, max_connections=max_connections, **connection_kwargs, ) - - def reset(self): - # Create and fill up a thread safe queue with ``None`` values. - self.pool = self.queue_class(self.max_connections) - while True: - try: - self.pool.put_nowait(None) - except asyncio.QueueFull: - break - - # Keep a list of actual connection instances so that we can - # disconnect them later. - self._connections = [] - - # this must be the last operation in this method. while reset() is - # called when holding _fork_lock, other threads in this process - # can call _checkpid() which compares self.pid and os.getpid() without - # holding any lock (for performance reasons). keeping this assignment - # as the last operation ensures that those other threads will also - # notice a pid difference and block waiting for the first thread to - # release _fork_lock. when each of these threads eventually acquire - # _fork_lock, they will notice that another thread already called - # reset() and they will immediately release _fork_lock and continue on. - self.pid = os.getpid() - - def make_connection(self): - """Make a fresh connection.""" - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection + self._condition = asyncio.Condition() + self.timeout = timeout async def get_connection(self, command_name, *keys, **options): - """ - Get a connection, blocking for ``self.timeout`` until a connection - is available from the pool. - - If the connection returned is ``None`` then creates a new connection. - Because we use a last-in first-out queue, the existing connections - (having been returned to the pool after the initial ``None`` values - were added) will be returned before ``None`` values. This means we only - create new connections when we need to, i.e.: the actual number of - connections will only increase in response to demand. - """ - # Make sure we haven't changed process. - self._checkpid() - - # Try and get a connection from the pool. If one isn't available within - # self.timeout then raise a ``ConnectionError``. - connection = None + """Gets a connection from the pool, blocking until one is available""" try: - async with async_timeout(self.timeout): - connection = await self.pool.get() - except (asyncio.QueueEmpty, asyncio.TimeoutError): - # Note that this is not caught by the redis client and will be - # raised unless handled by application code. If you want never to - raise ConnectionError("No connection available.") - - # If the ``connection`` is actually ``None`` then that's a cue to make - # a new connection to add to the pool. - if connection is None: - connection = self.make_connection() - + async with self._condition: + async with async_timeout(self.timeout): + await self._condition.wait_for(self.can_get_connection) + connection = super().get_available_connection() + except asyncio.TimeoutError as err: + raise ConnectionError("No connection available.") from err + + # We now perform the connection check outside of the lock. try: - # ensure this connection is connected to Redis - await connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the - # pool before all data has been read or the socket has been - # closed. either way, reconnect and verify everything is good. - try: - if await connection.can_read_destructive(): - raise ConnectionError("Connection has data") from None - except (ConnectionError, OSError): - await connection.disconnect() - await connection.connect() - if await connection.can_read_destructive(): - raise ConnectionError("Connection not ready") from None + await self.ensure_connection(connection) + return connection except BaseException: - # release the connection back to the pool so that we don't leak it await self.release(connection) raise - return connection - async def release(self, connection: AbstractConnection): """Releases the connection back to the pool.""" - # Make sure we haven't changed process. - self._checkpid() - if not self.owns_connection(connection): - # pool doesn't own this connection. do not add it back - # to the pool. instead add a None value which is a placeholder - # that will cause the pool to recreate the connection if - # its needed. - await connection.disconnect() - self.pool.put_nowait(None) - return - - # Put the connection back into the pool. - try: - self.pool.put_nowait(connection) - except asyncio.QueueFull: - # perhaps the pool has been reset() after a fork? regardless, - # we don't want this connection - pass - - async def disconnect(self, inuse_connections: bool = True): - """Disconnects all connections in the pool.""" - self._checkpid() - async with self._lock: - resp = await asyncio.gather( - *(connection.disconnect() for connection in self._connections), - return_exceptions=True, - ) - exc = next((r for r in resp if isinstance(r, BaseException)), None) - if exc: - raise exc + async with self._condition: + await super().release(connection) + self._condition.notify() diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 5ed924096c..d88babc59c 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -30,11 +30,14 @@ def __init__(self, **kwargs): def __repr__(self): pool = self.connection_pool - s = f"{self.__class__.__name__}" + return s + ")>" async def connect_to(self, address): self.host, self.port = address @@ -69,12 +72,14 @@ async def read_response( timeout: Optional[float] = None, *, disconnect_on_error: Optional[float] = True, + push_request: Optional[bool] = False, ): try: return await super().read_response( disable_decoding=disable_decoding, timeout=timeout, disconnect_on_error=disconnect_on_error, + push_request=push_request, ) except ReadOnlyError: if self.connection_pool.is_master: @@ -118,8 +123,8 @@ def __init__(self, service_name, sentinel_manager, **kwargs): def __repr__(self): return ( - f"{self.__class__.__name__}" - f"" + f"<{self.__class__.__module__}.{self.__class__.__name__}" + f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>" ) def reset(self): @@ -218,6 +223,7 @@ async def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ + kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") @@ -239,7 +245,10 @@ def __repr__(self): f"{sentinel.connection_pool.connection_kwargs['host']}:" f"{sentinel.connection_pool.connection_kwargs['port']}" ) - return f"{self.__class__.__name__}" + return ( + f"<{self.__class__}.{self.__class__.__name__}" + f"(sentinels=[{','.join(sentinel_addresses)}])>" + ) def check_master_state(self, state: dict, service_name: str) -> bool: if not state["is_master"] or state["is_sdown"] or state["is_odown"]: @@ -338,12 +347,7 @@ def master_for( connection_pool = connection_pool_class(service_name, self, **connection_kwargs) # The Redis object "owns" the pool - auto_close_connection_pool = True - client = redis_class( - connection_pool=connection_pool, - ) - client.auto_close_connection_pool = auto_close_connection_pool - return client + return redis_class.from_pool(connection_pool) def slave_for( self, @@ -375,9 +379,4 @@ def slave_for( connection_pool = connection_pool_class(service_name, self, **connection_kwargs) # The Redis object "owns" the pool - auto_close_connection_pool = True - client = redis_class( - connection_pool=connection_pool, - ) - client.auto_close_connection_pool = auto_close_connection_pool - return client + return redis_class.from_pool(connection_pool) diff --git a/redis/cache.py b/redis/cache.py new file mode 100644 index 0000000000..d920702339 --- /dev/null +++ b/redis/cache.py @@ -0,0 +1,328 @@ +import random +import time +from collections import OrderedDict, defaultdict +from enum import Enum +from typing import List + +from redis.typing import KeyT, ResponseT + +DEFAULT_EVICTION_POLICY = "lru" + + +DEFAULT_BLACKLIST = [ + "BF.CARD", + "BF.DEBUG", + "BF.EXISTS", + "BF.INFO", + "BF.MEXISTS", + "BF.SCANDUMP", + "CF.COMPACT", + "CF.COUNT", + "CF.DEBUG", + "CF.EXISTS", + "CF.INFO", + "CF.MEXISTS", + "CF.SCANDUMP", + "CMS.INFO", + "CMS.QUERY", + "DUMP", + "EXPIRETIME", + "FT.AGGREGATE", + "FT.ALIASADD", + "FT.ALIASDEL", + "FT.ALIASUPDATE", + "FT.CURSOR", + "FT.EXPLAIN", + "FT.EXPLAINCLI", + "FT.GET", + "FT.INFO", + "FT.MGET", + "FT.PROFILE", + "FT.SEARCH", + "FT.SPELLCHECK", + "FT.SUGGET", + "FT.SUGLEN", + "FT.SYNDUMP", + "FT.TAGVALS", + "FT._ALIASADDIFNX", + "FT._ALIASDELIFX", + "HRANDFIELD", + "JSON.DEBUG", + "PEXPIRETIME", + "PFCOUNT", + "PTTL", + "SRANDMEMBER", + "TDIGEST.BYRANK", + "TDIGEST.BYREVRANK", + "TDIGEST.CDF", + "TDIGEST.INFO", + "TDIGEST.MAX", + "TDIGEST.MIN", + "TDIGEST.QUANTILE", + "TDIGEST.RANK", + "TDIGEST.REVRANK", + "TDIGEST.TRIMMED_MEAN", + "TOPK.INFO", + "TOPK.LIST", + "TOPK.QUERY", + "TOUCH", + "TTL", +] + + +DEFAULT_WHITELIST = [ + "BITCOUNT", + "BITFIELD_RO", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUSBYMEMBER_RO", + "GEORADIUS_RO", + "GEOSEARCH", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "JSON.ARRINDEX", + "JSON.ARRLEN", + "JSON.GET", + "JSON.MGET", + "JSON.OBJKEYS", + "JSON.OBJLEN", + "JSON.RESP", + "JSON.STRLEN", + "JSON.TYPE", + "LCS", + "LINDEX", + "LLEN", + "LPOS", + "LRANGE", + "MGET", + "SCARD", + "SDIFF", + "SINTER", + "SINTERCARD", + "SISMEMBER", + "SMEMBERS", + "SMISMEMBER", + "SORT_RO", + "STRLEN", + "SUBSTR", + "SUNION", + "TS.GET", + "TS.INFO", + "TS.RANGE", + "TS.REVRANGE", + "TYPE", + "XLEN", + "XPENDING", + "XRANGE", + "XREAD", + "XREVRANGE", + "ZCARD", + "ZCOUNT", + "ZDIFF", + "ZINTER", + "ZINTERCARD", + "ZLEXCOUNT", + "ZMSCORE", + "ZRANGE", + "ZRANGEBYLEX", + "ZRANGEBYSCORE", + "ZRANK", + "ZREVRANGE", + "ZREVRANGEBYLEX", + "ZREVRANGEBYSCORE", + "ZREVRANK", + "ZSCORE", + "ZUNION", +] + +_RESPONSE = "response" +_KEYS = "keys" +_CTIME = "ctime" +_ACCESS_COUNT = "access_count" + + +class EvictionPolicy(Enum): + LRU = "lru" + LFU = "lfu" + RANDOM = "random" + + +class _LocalCache: + """ + A caching mechanism for storing redis commands and their responses. + + Args: + max_size (int): The maximum number of commands to be stored in the cache. + ttl (int): The time-to-live for each command in seconds. + eviction_policy (EvictionPolicy): The eviction policy to use for removing commands when the cache is full. + + Attributes: + max_size (int): The maximum number of commands to be stored in the cache. + ttl (int): The time-to-live for each command in seconds. + eviction_policy (EvictionPolicy): The eviction policy used for cache management. + cache (OrderedDict): The ordered dictionary to store commands and their metadata. + key_commands_map (defaultdict): A mapping of keys to the set of commands that use each key. + commands_ttl_list (list): A list to keep track of the commands in the order they were added. # noqa + """ + + def __init__( + self, max_size: int, ttl: int, eviction_policy: EvictionPolicy, **kwargs + ): + self.max_size = max_size + self.ttl = ttl + self.eviction_policy = eviction_policy + self.cache = OrderedDict() + self.key_commands_map = defaultdict(set) + self.commands_ttl_list = [] + + def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): + """ + Set a redis command and its response in the cache. + + Args: + command (str): The redis command. + response (ResponseT): The response associated with the command. + keys_in_command (List[KeyT]): The list of keys used in the command. + """ + if len(self.cache) >= self.max_size: + self._evict() + self.cache[command] = { + _RESPONSE: response, + _KEYS: keys_in_command, + _CTIME: time.monotonic(), + _ACCESS_COUNT: 0, # Used only for LFU + } + self._update_key_commands_map(keys_in_command, command) + self.commands_ttl_list.append(command) + + def get(self, command: str) -> ResponseT: + """ + Get the response for a redis command from the cache. + + Args: + command (str): The redis command. + + Returns: + ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa + """ + if command in self.cache: + if self._is_expired(command): + self.delete(command) + return + self._update_access(command) + return self.cache[command]["response"] + + def delete(self, command: str): + """ + Delete a redis command and its metadata from the cache. + + Args: + command (str): The redis command to be deleted. + """ + if command in self.cache: + keys_in_command = self.cache[command].get("keys") + self._del_key_commands_map(keys_in_command, command) + self.commands_ttl_list.remove(command) + del self.cache[command] + + def delete_many(self, commands): + pass + + def flush(self): + """Clear the entire cache, removing all redis commands and metadata.""" + self.cache.clear() + self.key_commands_map.clear() + self.commands_ttl_list = [] + + def _is_expired(self, command: str) -> bool: + """ + Check if a redis command has expired based on its time-to-live. + + Args: + command (str): The redis command. + + Returns: + bool: True if the command has expired, False otherwise. + """ + if self.ttl == 0: + return False + return time.monotonic() - self.cache[command]["ctime"] > self.ttl + + def _update_access(self, command: str): + """ + Update the access information for a redis command based on the eviction policy. + + Args: + command (str): The redis command. + """ + if self.eviction_policy == EvictionPolicy.LRU.value: + self.cache.move_to_end(command) + elif self.eviction_policy == EvictionPolicy.LFU.value: + self.cache[command]["access_count"] = ( + self.cache.get(command, {}).get("access_count", 0) + 1 + ) + self.cache.move_to_end(command) + elif self.eviction_policy == EvictionPolicy.RANDOM.value: + pass # Random eviction doesn't require updates + + def _evict(self): + """Evict a redis command from the cache based on the eviction policy.""" + if self._is_expired(self.commands_ttl_list[0]): + self.delete(self.commands_ttl_list[0]) + elif self.eviction_policy == EvictionPolicy.LRU.value: + self.cache.popitem(last=False) + elif self.eviction_policy == EvictionPolicy.LFU.value: + min_access_command = min( + self.cache, key=lambda k: self.cache[k].get("access_count", 0) + ) + self.cache.pop(min_access_command) + elif self.eviction_policy == EvictionPolicy.RANDOM.value: + random_command = random.choice(list(self.cache.keys())) + self.cache.pop(random_command) + + def _update_key_commands_map(self, keys: List[KeyT], command: str): + """ + Update the key_commands_map with command that uses the keys. + + Args: + keys (List[KeyT]): The list of keys used in the command. + command (str): The redis command. + """ + for key in keys: + self.key_commands_map[key].add(command) + + def _del_key_commands_map(self, keys: List[KeyT], command: str): + """ + Remove a redis command from the key_commands_map. + + Args: + keys (List[KeyT]): The list of keys used in the redis command. + command (str): The redis command. + """ + for key in keys: + self.key_commands_map[key].remove(command) + + def invalidate(self, key: KeyT): + """ + Invalidate (delete) all redis commands associated with a specific key. + + Args: + key (KeyT): The key to be invalidated. + """ + if key not in self.key_commands_map: + return + commands = list(self.key_commands_map[key]) + for command in commands: + self.delete(command) diff --git a/redis/client.py b/redis/client.py index f695cef534..7f2c8d290d 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,14 +4,21 @@ import time import warnings from itertools import chain -from typing import Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, _RedisCallbacksRESP3, bool_ok, ) +from redis.cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalCache, +) from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -31,6 +38,7 @@ ) from redis.lock import Lock from redis.retry import Retry +from redis.typing import KeysT, ResponseT from redis.utils import ( HIREDIS_AVAILABLE, _set_info_logger, @@ -49,7 +57,7 @@ class CaseInsensitiveDict(dict): "Case insensitive dict implementation. Assumes string keys only." - def __init__(self, data): + def __init__(self, data: Dict[str, str]) -> None: for k, v in data.items(): self[k.upper()] = v @@ -93,7 +101,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ @classmethod - def from_url(cls, url, **kwargs): + def from_url(cls, url: str, **kwargs) -> "Redis": """ Return a Redis client object configured from the given URL @@ -136,10 +144,28 @@ class initializer. In the case of conflicting arguments, querystring """ single_connection_client = kwargs.pop("single_connection_client", False) connection_pool = ConnectionPool.from_url(url, **kwargs) - return cls( + client = cls( connection_pool=connection_pool, single_connection_client=single_connection_client, ) + client.auto_close_connection_pool = True + return client + + @classmethod + def from_pool( + cls: Type["Redis"], + connection_pool: ConnectionPool, + ) -> "Redis": + """ + Return a Redis client from the given connection pool. + The Redis client will take ownership of the connection pool and + close it when the Redis client is closed. + """ + client = cls( + connection_pool=connection_pool, + ) + client.auto_close_connection_pool = True + return client def __init__( self, @@ -184,7 +210,14 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - ): + cache_enable: bool = False, + client_cache: Optional[_LocalCache] = None, + cache_max_size: int = 100, + cache_ttl: int = 0, + cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_blacklist: List[str] = DEFAULT_BLACKLIST, + cache_whitelist: List[str] = DEFAULT_WHITELIST, + ) -> None: """ Initialize a new Redis client. To specify a retry policy for specific errors, first set @@ -275,6 +308,10 @@ def __init__( } ) connection_pool = ConnectionPool(**kwargs) + self.auto_close_connection_pool = True + else: + self.auto_close_connection_pool = False + self.connection_pool = connection_pool self.connection = None if single_connection_client: @@ -287,14 +324,30 @@ def __init__( else: self.response_callbacks.update(_RedisCallbacksRESP2) - def __repr__(self): - return f"{type(self).__name__}<{repr(self.connection_pool)}>" + self.client_cache = client_cache + if cache_enable: + self.client_cache = _LocalCache( + cache_max_size, cache_ttl, cache_eviction_policy + ) + if self.client_cache is not None: + self.cache_blacklist = cache_blacklist + self.cache_whitelist = cache_whitelist + self.client_tracking_on() + self.connection._parser.set_invalidation_push_handler( + self._cache_invalidation_process + ) + + def __repr__(self) -> str: + return ( + f"<{type(self).__module__}.{type(self).__name__}" + f"({repr(self.connection_pool)})>" + ) - def get_encoder(self): + def get_encoder(self) -> "Encoder": """Get the connection pool's encoder""" return self.connection_pool.get_encoder() - def get_connection_kwargs(self): + def get_connection_kwargs(self) -> Dict: """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs @@ -305,11 +358,26 @@ def set_retry(self, retry: "Retry") -> None: self.get_connection_kwargs().update({"retry": retry}) self.connection_pool.set_retry(retry) - def set_response_callback(self, command, callback): + def set_response_callback(self, command: str, callback: Callable) -> None: """Set a custom Response Callback""" self.response_callbacks[command] = callback - def load_external_module(self, funcname, func): + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ + if data[1] is not None: + for key in data[1]: + self.client_cache.invalidate(str_if_bytes(key)) + else: + self.client_cache.flush() + + def load_external_module(self, funcname, func) -> None: """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. @@ -332,7 +400,7 @@ def load_external_module(self, funcname, func): """ setattr(self, funcname, func) - def pipeline(self, transaction=True, shard_hint=None): + def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline": """ Return a new pipeline object that can queue multiple commands for later execution. ``transaction`` indicates whether all commands @@ -344,7 +412,9 @@ def pipeline(self, transaction=True, shard_hint=None): self.connection_pool, self.response_callbacks, transaction, shard_hint ) - def transaction(self, func, *watches, **kwargs): + def transaction( + self, func: Callable[["Pipeline"], None], *watches, **kwargs + ) -> None: """ Convenience method for executing the callable `func` as a transaction while watching all keys specified in `watches`. The 'func' callable @@ -368,13 +438,13 @@ def transaction(self, func, *watches, **kwargs): def lock( self, - name, - timeout=None, - sleep=0.1, - blocking=True, - blocking_timeout=None, - lock_class=None, - thread_local=True, + name: str, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Optional[float] = None, + lock_class: Union[None, Any] = None, + thread_local: bool = True, ): """ Return a new Lock object using key ``name`` that mimics @@ -477,6 +547,11 @@ def close(self): self.connection = None self.connection_pool.release(conn) + if self.auto_close_connection_pool: + self.connection_pool.disconnect() + if self.client_cache: + self.client_cache.flush() + def _send_command_parse_response(self, conn, command_name, *args, **options): """ Send a command and parse the response @@ -497,23 +572,67 @@ def _disconnect_raise(self, conn, error): ): raise error + def _get_from_local_cache(self, command: str): + """ + If the command is in the local cache, return the response + """ + if ( + self.client_cache is None + or command[0] in self.cache_blacklist + or command[0] not in self.cache_whitelist + ): + return None + while not self.connection._is_socket_empty(): + self.connection.read_response(push_request=True) + return self.client_cache.get(command) + + def _add_to_local_cache( + self, command: Tuple[str], response: ResponseT, keys: List[KeysT] + ): + """ + Add the command and response to the local cache if the command + is allowed to be cached + """ + if ( + self.client_cache is not None + and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) + and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + ): + self.client_cache.set(command, response, keys) + + def delete_from_local_cache(self, command: str): + """ + Delete the command from the local cache + """ + try: + self.client_cache.delete(command) + except AttributeError: + pass + # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): """Execute a command and return a parsed response""" - pool = self.connection_pool command_name = args[0] - conn = self.connection or pool.get_connection(command_name, **options) + keys = options.pop("keys", None) + response_from_cache = self._get_from_local_cache(args) + if response_from_cache is not None: + return response_from_cache + else: + pool = self.connection_pool + conn = self.connection or pool.get_connection(command_name, **options) - try: - return conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - finally: - if not self.connection: - pool.release(conn) + try: + response = conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + self._add_to_local_cache(args, response, keys) + return response + finally: + if not self.connection: + pool.release(conn) def parse_response(self, connection, command_name, **options): """Parses a response from the Redis server""" @@ -546,7 +665,7 @@ class Monitor: listen() method yields commands from monitor. """ - monitor_re = re.compile(r"\[(\d+) (.*)\] (.*)") + monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)") command_re = re.compile(r'"(.*?)(? "PubSub": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: self.reset() - def __del__(self): + def __del__(self) -> None: try: # if this object went out of scope prior to shutting down # subscriptions, close the connection manually before @@ -662,10 +781,10 @@ def __del__(self): except Exception: pass - def reset(self): + def reset(self) -> None: if self.connection: self.connection.disconnect() - self.connection.clear_connect_callbacks() + self.connection.deregister_connect_callback(self.on_connect) self.connection_pool.release(self.connection) self.connection = None self.health_check_response_counter = 0 @@ -677,10 +796,10 @@ def reset(self): self.pending_unsubscribe_patterns = set() self.subscribed_event.clear() - def close(self): + def close(self) -> None: self.reset() - def on_connect(self, connection): + def on_connect(self, connection) -> None: "Re-subscribe to any channels and patterns previously subscribed to" # NOTE: for python3, we can't pass bytestrings as keyword arguments # so we need to decode channel/pattern names back to unicode strings @@ -706,7 +825,7 @@ def on_connect(self, connection): self.ssubscribe(**shard_channels) @property - def subscribed(self): + def subscribed(self) -> bool: """Indicates if there are subscriptions to any channels or patterns""" return self.subscribed_event.is_set() @@ -725,14 +844,14 @@ def execute_command(self, *args): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: self.clean_health_check_responses() self._execute(connection, connection.send_command, *args, **kwargs) - def clean_health_check_responses(self): + def clean_health_check_responses(self) -> None: """ If any health check responses are present, clean them """ @@ -750,7 +869,7 @@ def clean_health_check_responses(self): ) ttl -= 1 - def _disconnect_raise_connect(self, conn, error): + def _disconnect_raise_connect(self, conn, error) -> None: """ Close the connection and raise an exception if retry_on_timeout is not set or the error @@ -801,7 +920,7 @@ def try_read(): return None return response - def is_health_check_response(self, response): + def is_health_check_response(self, response) -> bool: """ Check if the response is a health check response. If there are no subscriptions redis responds to PING command with a @@ -812,7 +931,7 @@ def is_health_check_response(self, response): self.health_check_response_b, # If there wasn't ] - def check_health(self): + def check_health(self) -> None: conn = self.connection if conn is None: raise RuntimeError( @@ -824,7 +943,7 @@ def check_health(self): conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) self.health_check_response_counter += 1 - def _normalize_keys(self, data): + def _normalize_keys(self, data) -> Dict: """ normalize channel/pattern names to be either bytes or strings based on whether responses are automatically decoded. this saves us @@ -958,7 +1077,9 @@ def listen(self): if response is not None: yield response - def get_message(self, ignore_subscribe_messages=False, timeout=0.0): + def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): """ Get the next message if one is available, otherwise None. @@ -987,7 +1108,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0): get_sharded_message = get_message - def ping(self, message=None): + def ping(self, message: Union[str, None] = None) -> bool: """ Ping the Redis server """ @@ -1068,7 +1189,12 @@ def handle_message(self, response, ignore_subscribe_messages=False): return message - def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): + def run_in_thread( + self, + sleep_time: float = 0.0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: raise PubSubError(f"Channel: '{channel}' has no handler registered") @@ -1089,7 +1215,15 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): class PubSubWorkerThread(threading.Thread): - def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None): + def __init__( + self, + pubsub, + sleep_time: float, + daemon: bool = False, + exception_handler: Union[ + Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None + ] = None, + ): super().__init__() self.daemon = daemon self.pubsub = pubsub @@ -1097,7 +1231,7 @@ def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None): self.exception_handler = exception_handler self._running = threading.Event() - def run(self): + def run(self) -> None: if self._running.is_set(): return self._running.set() @@ -1112,7 +1246,7 @@ def run(self): self.exception_handler(e, pubsub, self) pubsub.close() - def stop(self): + def stop(self) -> None: # trip the flag so the run loop exits. the run loop will # close the pubsub connection, which disconnects the socket # and returns the connection to the pool. @@ -1150,7 +1284,7 @@ def __init__(self, connection_pool, response_callbacks, transaction, shard_hint) self.watching = False self.reset() - def __enter__(self): + def __enter__(self) -> "Pipeline": return self def __exit__(self, exc_type, exc_value, traceback): @@ -1162,14 +1296,14 @@ def __del__(self): except Exception: pass - def __len__(self): + def __len__(self) -> int: return len(self.command_stack) - def __bool__(self): + def __bool__(self) -> bool: """Pipeline instances should always evaluate to True""" return True - def reset(self): + def reset(self) -> None: self.command_stack = [] self.scripts = set() # make sure to reset the connection state in the event that we were @@ -1192,7 +1326,11 @@ def reset(self): self.connection_pool.release(self.connection) self.connection = None - def multi(self): + def close(self) -> None: + """Close the pipeline""" + self.reset() + + def multi(self) -> None: """ Start a transactional block of the pipeline after WATCH commands are issued. End the transactional block with `execute`. @@ -1206,11 +1344,12 @@ def multi(self): self.explicit_transaction = True def execute_command(self, *args, **kwargs): + kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) - def _disconnect_reset_raise(self, conn, error): + def _disconnect_reset_raise(self, conn, error) -> None: """ Close the connection, reset watching state and raise an exception if we were watching, @@ -1253,7 +1392,7 @@ def immediate_execute_command(self, *args, **options): lambda error: self._disconnect_reset_raise(conn, error), ) - def pipeline_execute_command(self, *args, **options): + def pipeline_execute_command(self, *args, **options) -> "Pipeline": """ Stage a command to be executed when execute() is next called @@ -1268,7 +1407,7 @@ def pipeline_execute_command(self, *args, **options): self.command_stack.append((args, options)) return self - def _execute_transaction(self, connection, commands, raise_on_error): + def _execute_transaction(self, connection, commands, raise_on_error) -> List: cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options] @@ -1386,7 +1525,7 @@ def load_scripts(self): if not exist: s.sha = immediate("SCRIPT LOAD", s.script) - def _disconnect_raise_reset(self, conn, error): + def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None: """ Close the connection, raise an exception if we were watching, and raise an exception if TimeoutError is not part of retry_on_error, @@ -1448,6 +1587,6 @@ def watch(self, *names): raise RedisError("Cannot issue a WATCH after a MULTI") return self.execute_command("WATCH", *names) - def unwatch(self): + def unwatch(self) -> bool: """Unwatches all previously specified keys""" return self.watching and self.execute_command("UNWATCH") or True diff --git a/redis/cluster.py b/redis/cluster.py index cba62de077..8032173e66 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -295,6 +295,10 @@ class AbstractRedisCluster: "LATENCY HISTORY", "LATENCY LATEST", "LATENCY RESET", + "MODULE LIST", + "MODULE LOAD", + "MODULE UNLOAD", + "MODULE LOADEX", ], DEFAULT_NODE, ), @@ -967,11 +971,11 @@ def determine_slot(self, *args): # redis server to parse the keys. Besides, there is a bug in redis<7.0 # where `self._get_command_keys()` fails anyway. So, we special case # EVAL/EVALSHA. - if command in ("EVAL", "EVALSHA"): + if command.upper() in ("EVAL", "EVALSHA"): # command syntax: EVAL "script body" num_keys ... if len(args) <= 2: raise RedisClusterException(f"Invalid args in command: {args}") - num_actual_keys = args[2] + num_actual_keys = int(args[2]) eval_keys = args[3 : 3 + num_actual_keys] # if there are 0 keys, that means the script can be run on any node # so we can just return a random slot @@ -983,7 +987,7 @@ def determine_slot(self, *args): if keys is None or len(keys) == 0: # FCALL can call a function with 0 keys, that means the function # can be run on any node so we can just return a random slot - if command in ("FCALL", "FCALL_RO"): + if command.upper() in ("FCALL", "FCALL_RO"): return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) raise RedisClusterException( "No way to dispatch this command to Redis Cluster. " @@ -1056,6 +1060,7 @@ def execute_command(self, *args, **kwargs): list dict """ + kwargs.pop("keys", None) # the keys are used only for client side caching target_nodes_specified = False is_default_node = False target_nodes = None @@ -1773,7 +1778,7 @@ def execute_command(self, *args): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) connection = self.connection self._execute(connection, connection.send_command, *args) @@ -1958,6 +1963,7 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ + kwargs.pop("keys", None) # the keys are used only for client side caching return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): @@ -2192,7 +2198,7 @@ def _send_cluster_commands( ) if attempt and allow_redirections: # RETRY MAGIC HAPPENS HERE! - # send these remaing commands one at a time using `execute_command` + # send these remaining commands one at a time using `execute_command` # in the main client. This keeps our retry logic # in one place mostly, # and allows us to be more confident in correctness of behavior. @@ -2453,7 +2459,6 @@ def read(self): """ """ connection = self.connection for c in self.commands: - # if there is a result on this command, # it means we ran into an exception # like a connection error. Trying to parse diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index 691cab3def..8dd463ed18 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -7,13 +7,13 @@ Iterable, Iterator, List, + Literal, Mapping, NoReturn, Optional, Union, ) -from redis.compat import Literal from redis.crc import key_slot from redis.exceptions import RedisClusterException, RedisError from redis.typing import ( @@ -23,6 +23,7 @@ KeysT, KeyT, PatternT, + ResponseT, ) from .core import ( @@ -32,13 +33,14 @@ AsyncFunctionCommands, AsyncGearsCommands, AsyncManagementCommands, + AsyncModuleCommands, AsyncScriptCommands, DataAccessCommands, FunctionCommands, GearsCommands, ManagementCommands, + ModuleCommands, PubSubCommands, - ResponseT, ScriptCommands, ) from .helpers import list_or_args @@ -223,7 +225,7 @@ def delete(self, *keys: KeyT) -> ResponseT: The keys are first split up into slots and then an DEL command is sent for every slot - Non-existant keys are ignored. + Non-existent keys are ignored. Returns the number of keys that were deleted. For more information see https://redis.io/commands/del @@ -238,7 +240,7 @@ def touch(self, *keys: KeyT) -> ResponseT: The keys are first split up into slots and then an TOUCH command is sent for every slot - Non-existant keys are ignored. + Non-existent keys are ignored. Returns the number of keys that were touched. For more information see https://redis.io/commands/touch @@ -252,7 +254,7 @@ def unlink(self, *keys: KeyT) -> ResponseT: The keys are first split up into slots and then an TOUCH command is sent for every slot - Non-existant keys are ignored. + Non-existent keys are ignored. Returns the number of keys that were unlinked. For more information see https://redis.io/commands/unlink @@ -873,6 +875,7 @@ class RedisClusterCommands( ScriptCommands, FunctionCommands, GearsCommands, + ModuleCommands, RedisModuleCommands, ): """ @@ -903,6 +906,7 @@ class AsyncRedisClusterCommands( AsyncScriptCommands, AsyncFunctionCommands, AsyncGearsCommands, + AsyncModuleCommands, ): """ A class for all Redis Cluster commands diff --git a/redis/commands/core.py b/redis/commands/core.py index 031781d75d..fbeeb88789 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5,7 +5,6 @@ import warnings from typing import ( TYPE_CHECKING, - Any, AsyncIterator, Awaitable, Callable, @@ -13,6 +12,7 @@ Iterable, Iterator, List, + Literal, Mapping, Optional, Sequence, @@ -21,7 +21,6 @@ Union, ) -from redis.compat import Literal from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError from redis.typing import ( AbsExpiryT, @@ -37,6 +36,7 @@ KeysT, KeyT, PatternT, + ResponseT, ScriptTextT, StreamIdT, TimeoutSecT, @@ -49,8 +49,6 @@ from redis.asyncio.client import Redis as AsyncRedis from redis.client import Redis -ResponseT = Union[Awaitable, Any] - class ACLCommands(CommandsProtocol): """ @@ -403,7 +401,7 @@ class ManagementCommands(CommandsProtocol): Redis management commands """ - def auth(self, password, username=None, **kwargs): + def auth(self, password: str, username: Optional[str] = None, **kwargs): """ Authenticates the user. If you do not pass username, Redis will try to authenticate for the "default" user. If you do pass username, it will @@ -1590,7 +1588,7 @@ def bitcount( raise DataError("Both start and end must be specified") if mode is not None: params.append(mode) - return self.execute_command("BITCOUNT", *params) + return self.execute_command("BITCOUNT", *params, keys=[key]) def bitfield( self: Union["Redis", "AsyncRedis"], @@ -1626,7 +1624,7 @@ def bitfield_ro( items = items or [] for encoding, offset in items: params.extend(["GET", encoding, offset]) - return self.execute_command("BITFIELD_RO", *params) + return self.execute_command("BITFIELD_RO", *params, keys=[key]) def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseT: """ @@ -1666,7 +1664,7 @@ def bitpos( if mode is not None: params.append(mode) - return self.execute_command("BITPOS", *params) + return self.execute_command("BITPOS", *params, keys=[key]) def copy( self, @@ -1733,7 +1731,7 @@ def exists(self, *names: KeyT) -> ResponseT: For more information see https://redis.io/commands/exists """ - return self.execute_command("EXISTS", *names) + return self.execute_command("EXISTS", *names, keys=names) __contains__ = exists @@ -1826,7 +1824,7 @@ def get(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/get """ - return self.execute_command("GET", name) + return self.execute_command("GET", name, keys=[name]) def getdel(self, name: KeyT) -> ResponseT: """ @@ -1920,7 +1918,7 @@ def getbit(self, name: KeyT, offset: int) -> ResponseT: For more information see https://redis.io/commands/getbit """ - return self.execute_command("GETBIT", name, offset) + return self.execute_command("GETBIT", name, offset, keys=[name]) def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ @@ -1929,7 +1927,7 @@ def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: For more information see https://redis.io/commands/getrange """ - return self.execute_command("GETRANGE", key, start, end) + return self.execute_command("GETRANGE", key, start, end, keys=[key]) def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ @@ -2012,6 +2010,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: options = {} if not args: options[EMPTY_RESPONSE] = [] + options["keys"] = keys return self.execute_command("MGET", *args, **options) def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: @@ -2458,14 +2457,14 @@ def strlen(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/strlen """ - return self.execute_command("STRLEN", name) + return self.execute_command("STRLEN", name, keys=[name]) def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ - return self.execute_command("SUBSTR", name, start, end) + return self.execute_command("SUBSTR", name, start, end, keys=[name]) def touch(self, *args: KeyT) -> ResponseT: """ @@ -2490,7 +2489,7 @@ def type(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/type """ - return self.execute_command("TYPE", name) + return self.execute_command("TYPE", name, keys=[name]) def watch(self, *names: KeyT) -> None: """ @@ -2543,7 +2542,7 @@ def lcs( pieces.extend(["MINMATCHLEN", minmatchlen]) if withmatchlen: pieces.append("WITHMATCHLEN") - return self.execute_command("LCS", *pieces) + return self.execute_command("LCS", *pieces, keys=[key1, key2]) class AsyncBasicKeyCommands(BasicKeyCommands): @@ -2682,7 +2681,7 @@ def lindex( For more information see https://redis.io/commands/lindex """ - return self.execute_command("LINDEX", name, index) + return self.execute_command("LINDEX", name, index, keys=[name]) def linsert( self, name: str, where: str, refvalue: str, value: str @@ -2704,7 +2703,7 @@ def llen(self, name: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/llen """ - return self.execute_command("LLEN", name) + return self.execute_command("LLEN", name, keys=[name]) def lpop( self, @@ -2751,7 +2750,7 @@ def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list For more information see https://redis.io/commands/lrange """ - return self.execute_command("LRANGE", name, start, end) + return self.execute_command("LRANGE", name, start, end, keys=[name]) def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: """ @@ -2823,13 +2822,13 @@ def rpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("RPUSH", name, *values) - def rpushx(self, name: str, value: str) -> Union[Awaitable[int], int]: + def rpushx(self, name: str, *values: str) -> Union[Awaitable[int], int]: """ Push ``value`` onto the tail of the list ``name`` if ``name`` exists For more information see https://redis.io/commands/rpushx """ - return self.execute_command("RPUSHX", name, value) + return self.execute_command("RPUSHX", name, *values) def lpos( self, @@ -2874,7 +2873,7 @@ def lpos( if maxlen is not None: pieces.extend(["MAXLEN", maxlen]) - return self.execute_command("LPOS", *pieces) + return self.execute_command("LPOS", *pieces, keys=[name]) def sort( self, @@ -2946,6 +2945,7 @@ def sort( ) options = {"groups": len(get) if groups else None} + options["keys"] = [name] return self.execute_command("SORT", *pieces, **options) def sort_ro( @@ -3319,7 +3319,7 @@ def scard(self, name: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/scard """ - return self.execute_command("SCARD", name) + return self.execute_command("SCARD", name, keys=[name]) def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: """ @@ -3328,7 +3328,7 @@ def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: For more information see https://redis.io/commands/sdiff """ args = list_or_args(keys, args) - return self.execute_command("SDIFF", *args) + return self.execute_command("SDIFF", *args, keys=args) def sdiffstore( self, dest: str, keys: List, *args: List @@ -3349,7 +3349,7 @@ def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: For more information see https://redis.io/commands/sinter """ args = list_or_args(keys, args) - return self.execute_command("SINTER", *args) + return self.execute_command("SINTER", *args, keys=args) def sintercard( self, numkeys: int, keys: List[str], limit: int = 0 @@ -3364,7 +3364,7 @@ def sintercard( For more information see https://redis.io/commands/sintercard """ args = [numkeys, *keys, "LIMIT", limit] - return self.execute_command("SINTERCARD", *args) + return self.execute_command("SINTERCARD", *args, keys=keys) def sinterstore( self, dest: str, keys: List, *args: List @@ -3388,7 +3388,7 @@ def sismember( For more information see https://redis.io/commands/sismember """ - return self.execute_command("SISMEMBER", name, value) + return self.execute_command("SISMEMBER", name, value, keys=[name]) def smembers(self, name: str) -> Union[Awaitable[Set], Set]: """ @@ -3396,7 +3396,7 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: For more information see https://redis.io/commands/smembers """ - return self.execute_command("SMEMBERS", name) + return self.execute_command("SMEMBERS", name, keys=[name]) def smismember( self, name: str, values: List, *args: List @@ -3413,7 +3413,7 @@ def smismember( For more information see https://redis.io/commands/smismember """ args = list_or_args(values, args) - return self.execute_command("SMISMEMBER", name, *args) + return self.execute_command("SMISMEMBER", name, *args, keys=[name]) def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: """ @@ -3462,7 +3462,7 @@ def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: For more information see https://redis.io/commands/sunion """ args = list_or_args(keys, args) - return self.execute_command("SUNION", *args) + return self.execute_command("SUNION", *args, keys=args) def sunionstore( self, dest: str, keys: List, *args: List @@ -3820,7 +3820,7 @@ def xlen(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/xlen """ - return self.execute_command("XLEN", name) + return self.execute_command("XLEN", name, keys=[name]) def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: """ @@ -3830,7 +3830,7 @@ def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: For more information see https://redis.io/commands/xpending """ - return self.execute_command("XPENDING", name, groupname) + return self.execute_command("XPENDING", name, groupname, keys=[name]) def xpending_range( self, @@ -3919,7 +3919,7 @@ def xrange( pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XRANGE", name, *pieces) + return self.execute_command("XRANGE", name, *pieces, keys=[name]) def xread( self, @@ -3957,7 +3957,7 @@ def xread( keys, values = zip(*streams.items()) pieces.extend(keys) pieces.extend(values) - return self.execute_command("XREAD", *pieces) + return self.execute_command("XREAD", *pieces, keys=keys) def xreadgroup( self, @@ -4036,7 +4036,7 @@ def xrevrange( pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XREVRANGE", name, *pieces) + return self.execute_command("XREVRANGE", name, *pieces, keys=[name]) def xtrim( self, @@ -4175,7 +4175,7 @@ def zcard(self, name: KeyT) -> ResponseT: For more information see https://redis.io/commands/zcard """ - return self.execute_command("ZCARD", name) + return self.execute_command("ZCARD", name, keys=[name]) def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ @@ -4184,7 +4184,7 @@ def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: For more information see https://redis.io/commands/zcount """ - return self.execute_command("ZCOUNT", name, min, max) + return self.execute_command("ZCOUNT", name, min, max, keys=[name]) def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: """ @@ -4196,7 +4196,7 @@ def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: pieces = [len(keys), *keys] if withscores: pieces.append("WITHSCORES") - return self.execute_command("ZDIFF", *pieces) + return self.execute_command("ZDIFF", *pieces, keys=keys) def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: """ @@ -4264,7 +4264,7 @@ def zintercard( For more information see https://redis.io/commands/zintercard """ args = [numkeys, *keys, "LIMIT", limit] - return self.execute_command("ZINTERCARD", *args) + return self.execute_command("ZINTERCARD", *args, keys=keys) def zlexcount(self, name, min, max): """ @@ -4273,7 +4273,7 @@ def zlexcount(self, name, min, max): For more information see https://redis.io/commands/zlexcount """ - return self.execute_command("ZLEXCOUNT", name, min, max) + return self.execute_command("ZLEXCOUNT", name, min, max, keys=[name]) def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: """ @@ -4456,6 +4456,7 @@ def _zrange( if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = [name] return self.execute_command(*pieces, **options) def zrange( @@ -4544,6 +4545,7 @@ def zrevrange( if withscores: pieces.append(b"WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = name return self.execute_command(*pieces, **options) def zrangestore( @@ -4618,7 +4620,7 @@ def zrangebylex( pieces = ["ZRANGEBYLEX", name, min, max] if start is not None and num is not None: pieces.extend([b"LIMIT", start, num]) - return self.execute_command(*pieces) + return self.execute_command(*pieces, keys=[name]) def zrevrangebylex( self, @@ -4642,7 +4644,7 @@ def zrevrangebylex( pieces = ["ZREVRANGEBYLEX", name, max, min] if start is not None and num is not None: pieces.extend(["LIMIT", start, num]) - return self.execute_command(*pieces) + return self.execute_command(*pieces, keys=[name]) def zrangebyscore( self, @@ -4676,6 +4678,7 @@ def zrangebyscore( if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = [name] return self.execute_command(*pieces, **options) def zrevrangebyscore( @@ -4710,6 +4713,7 @@ def zrevrangebyscore( if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} + options["keys"] = [name] return self.execute_command(*pieces, **options) def zrank( @@ -4727,8 +4731,8 @@ def zrank( For more information see https://redis.io/commands/zrank """ if withscore: - return self.execute_command("ZRANK", name, value, "WITHSCORE") - return self.execute_command("ZRANK", name, value) + return self.execute_command("ZRANK", name, value, "WITHSCORE", keys=[name]) + return self.execute_command("ZRANK", name, value, keys=[name]) def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -4786,8 +4790,10 @@ def zrevrank( For more information see https://redis.io/commands/zrevrank """ if withscore: - return self.execute_command("ZREVRANK", name, value, "WITHSCORE") - return self.execute_command("ZREVRANK", name, value) + return self.execute_command( + "ZREVRANK", name, value, "WITHSCORE", keys=[name] + ) + return self.execute_command("ZREVRANK", name, value, keys=[name]) def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: """ @@ -4795,7 +4801,7 @@ def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: For more information see https://redis.io/commands/zscore """ - return self.execute_command("ZSCORE", name, value) + return self.execute_command("ZSCORE", name, value, keys=[name]) def zunion( self, @@ -4842,7 +4848,7 @@ def zmscore(self, key: KeyT, members: List[str]) -> ResponseT: if not members: raise DataError("ZMSCORE members must be a non-empty list") pieces = [key] + members - return self.execute_command("ZMSCORE", *pieces) + return self.execute_command("ZMSCORE", *pieces, keys=[key]) def _zaggregate( self, @@ -4872,6 +4878,7 @@ def _zaggregate( raise DataError("aggregate can be sum, min or max.") if options.get("withscores", False): pieces.append(b"WITHSCORES") + options["keys"] = keys return self.execute_command(*pieces, **options) @@ -4933,7 +4940,7 @@ def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: For more information see https://redis.io/commands/hexists """ - return self.execute_command("HEXISTS", name, key) + return self.execute_command("HEXISTS", name, key, keys=[name]) def hget( self, name: str, key: str @@ -4943,7 +4950,7 @@ def hget( For more information see https://redis.io/commands/hget """ - return self.execute_command("HGET", name, key) + return self.execute_command("HGET", name, key, keys=[name]) def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: """ @@ -4951,7 +4958,7 @@ def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: For more information see https://redis.io/commands/hgetall """ - return self.execute_command("HGETALL", name) + return self.execute_command("HGETALL", name, keys=[name]) def hincrby( self, name: str, key: str, amount: int = 1 @@ -4979,7 +4986,7 @@ def hkeys(self, name: str) -> Union[Awaitable[List], List]: For more information see https://redis.io/commands/hkeys """ - return self.execute_command("HKEYS", name) + return self.execute_command("HKEYS", name, keys=[name]) def hlen(self, name: str) -> Union[Awaitable[int], int]: """ @@ -4987,7 +4994,7 @@ def hlen(self, name: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/hlen """ - return self.execute_command("HLEN", name) + return self.execute_command("HLEN", name, keys=[name]) def hset( self, @@ -5054,7 +5061,7 @@ def hmget(self, name: str, keys: List, *args: List) -> Union[Awaitable[List], Li For more information see https://redis.io/commands/hmget """ args = list_or_args(keys, args) - return self.execute_command("HMGET", name, *args) + return self.execute_command("HMGET", name, *args, keys=[name]) def hvals(self, name: str) -> Union[Awaitable[List], List]: """ @@ -5062,7 +5069,7 @@ def hvals(self, name: str) -> Union[Awaitable[List], List]: For more information see https://redis.io/commands/hvals """ - return self.execute_command("HVALS", name) + return self.execute_command("HVALS", name, keys=[name]) def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: """ @@ -5071,7 +5078,7 @@ def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: For more information see https://redis.io/commands/hstrlen """ - return self.execute_command("HSTRLEN", name, key) + return self.execute_command("HSTRLEN", name, key, keys=[name]) AsyncHashCommands = HashCommands @@ -5464,7 +5471,7 @@ def geodist( raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) - return self.execute_command("GEODIST", *pieces) + return self.execute_command("GEODIST", *pieces, keys=[name]) def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -5473,7 +5480,7 @@ def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: For more information see https://redis.io/commands/geohash """ - return self.execute_command("GEOHASH", name, *values) + return self.execute_command("GEOHASH", name, *values, keys=[name]) def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -5483,7 +5490,7 @@ def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: For more information see https://redis.io/commands/geopos """ - return self.execute_command("GEOPOS", name, *values) + return self.execute_command("GEOPOS", name, *values, keys=[name]) def georadius( self, @@ -5823,6 +5830,8 @@ def _geosearchgeneric( if kwargs[arg_name]: pieces.append(byte_repr) + kwargs["keys"] = [args[0] if command == "GEOSEARCH" else args[1]] + return self.execute_command(command, *pieces, **kwargs) diff --git a/redis/commands/graph/__init__.py b/redis/commands/graph/__init__.py index a882dd514d..ffaf1fb4ff 100644 --- a/redis/commands/graph/__init__.py +++ b/redis/commands/graph/__init__.py @@ -1,3 +1,5 @@ +import warnings + from ..helpers import quote_string, random_string, stringify_param_value from .commands import AsyncGraphCommands, GraphCommands from .edge import Edge # noqa @@ -18,6 +20,12 @@ def __init__(self, client, name=random_string()): """ Create a new graph. """ + warnings.warn( + DeprecationWarning( + "RedisGraph support is deprecated as of Redis Stack 7.2 \ + (https://redis.com/blog/redisgraph-eol/)" + ) + ) self.NAME = name # Graph key self.client = client self.execute_command = client.execute_command diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index 324d981d66..127141f650 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -64,6 +64,11 @@ def parse_list_to_dict(response): for i in range(0, len(response), 2): if isinstance(response[i], list): res["Child iterators"].append(parse_list_to_dict(response[i])) + try: + if isinstance(response[i + 1], list): + res["Child iterators"].append(parse_list_to_dict(response[i + 1])) + except IndexError: + pass elif isinstance(response[i + 1], list): res["Child iterators"] = [parse_list_to_dict(response[i + 1])] else: diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 3abe155796..4c2e58369c 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -49,7 +49,7 @@ def arrindex( if stop is not None: pieces.append(stop) - return self.execute_command("JSON.ARRINDEX", *pieces) + return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name]) def arrinsert( self, name: str, path: str, index: int, *args: List[JsonType] @@ -72,7 +72,7 @@ def arrlen( For more information see `JSON.ARRLEN `_. """ # noqa - return self.execute_command("JSON.ARRLEN", name, str(path)) + return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name]) def arrpop( self, @@ -80,7 +80,6 @@ def arrpop( path: Optional[str] = Path.root_path(), index: Optional[int] = -1, ) -> List[Union[str, None]]: - """Pop the element at ``index`` in the array JSON value under ``path`` at key ``name``. @@ -103,14 +102,14 @@ def type(self, name: str, path: Optional[str] = Path.root_path()) -> List[str]: For more information see `JSON.TYPE `_. """ # noqa - return self.execute_command("JSON.TYPE", name, str(path)) + return self.execute_command("JSON.TYPE", name, str(path), keys=[name]) def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List: """Return the JSON value under ``path`` at key ``name``. For more information see `JSON.RESP `_. """ # noqa - return self.execute_command("JSON.RESP", name, str(path)) + return self.execute_command("JSON.RESP", name, str(path), keys=[name]) def objkeys( self, name: str, path: Optional[str] = Path.root_path() @@ -120,7 +119,7 @@ def objkeys( For more information see `JSON.OBJKEYS `_. """ # noqa - return self.execute_command("JSON.OBJKEYS", name, str(path)) + return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name]) def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int: """Return the length of the dictionary JSON value under ``path`` at key @@ -128,7 +127,7 @@ def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int: For more information see `JSON.OBJLEN `_. """ # noqa - return self.execute_command("JSON.OBJLEN", name, str(path)) + return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name]) def numincrby(self, name: str, path: str, number: int) -> str: """Increment the numeric (integer or floating point) JSON value under @@ -174,7 +173,7 @@ def delete(self, key: str, path: Optional[str] = Path.root_path()) -> int: def get( self, name: str, *args, no_escape: Optional[bool] = False - ) -> List[JsonType]: + ) -> Optional[List[JsonType]]: """ Get the object stored as a JSON value at key ``name``. @@ -198,7 +197,7 @@ def get( # Handle case where key doesn't exist. The JSONDecoder would raise a # TypeError exception since it can't decode None try: - return self.execute_command("JSON.GET", *pieces) + return self.execute_command("JSON.GET", *pieces, keys=[name]) except TypeError: return None @@ -212,7 +211,7 @@ def mget(self, keys: List[str], path: str) -> List[JsonType]: pieces = [] pieces += keys pieces.append(str(path)) - return self.execute_command("JSON.MGET", *pieces) + return self.execute_command("JSON.MGET", *pieces, keys=keys) def set( self, @@ -325,7 +324,7 @@ def set_path( nx: Optional[bool] = False, xx: Optional[bool] = False, decode_keys: Optional[bool] = False, - ) -> List[Dict[str, bool]]: + ) -> Dict[str, bool]: """ Iterate over ``root_folder`` and set each JSON file to a value under ``json_path`` with the file name as the key. @@ -365,7 +364,7 @@ def strlen(self, name: str, path: Optional[str] = None) -> List[Union[int, None] pieces = [name] if path is not None: pieces.append(str(path)) - return self.execute_command("JSON.STRLEN", *pieces) + return self.execute_command("JSON.STRLEN", *pieces, keys=[name]) def toggle( self, name: str, path: Optional[str] = Path.root_path() @@ -378,7 +377,7 @@ def toggle( return self.execute_command("JSON.TOGGLE", name, str(path)) def strappend( - self, name: str, value: str, path: Optional[int] = Path.root_path() + self, name: str, value: str, path: Optional[str] = Path.root_path() ) -> Union[int, List[Optional[int]]]: """Append to the string JSON value. If two options are specified after the key name, the path is determined to be the first. If a single diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index e635f91e99..a2bb23b76d 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -27,7 +27,6 @@ class BatchIndexer: """ def __init__(self, client, chunk_size=1000): - self.client = client self.execute_command = client.execute_command self._pipeline = client.pipeline(transaction=False, shard_hint=None) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 93a3d9273b..50d18f476a 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,8 +1,10 @@ +from typing import List, Union + FIELDNAME = object() class Limit: - def __init__(self, offset=0, count=0): + def __init__(self, offset: int = 0, count: int = 0) -> None: self.offset = offset self.count = count @@ -22,12 +24,12 @@ class Reducer: NAME = None - def __init__(self, *args): + def __init__(self, *args: List[str]) -> None: self._args = args self._field = None self._alias = None - def alias(self, alias): + def alias(self, alias: str) -> "Reducer": """ Set the alias for this reducer. @@ -51,7 +53,7 @@ def alias(self, alias): return self @property - def args(self): + def args(self) -> List[str]: return self._args @@ -62,7 +64,7 @@ class SortDirection: DIRSTRING = None - def __init__(self, field): + def __init__(self, field: str) -> None: self.field = field @@ -87,7 +89,7 @@ class AggregateRequest: Aggregation request which can be passed to `Client.aggregate`. """ - def __init__(self, query="*"): + def __init__(self, query: str = "*") -> None: """ Create an aggregation request. This request may then be passed to `client.aggregate()`. @@ -110,7 +112,7 @@ def __init__(self, query="*"): self._cursor = [] self._dialect = None - def load(self, *fields): + def load(self, *fields: List[str]) -> "AggregateRequest": """ Indicate the fields to be returned in the response. These fields are returned in addition to any others implicitly specified. @@ -126,7 +128,9 @@ def load(self, *fields): self._loadall = True return self - def group_by(self, fields, *reducers): + def group_by( + self, fields: List[str], *reducers: Union[Reducer, List[Reducer]] + ) -> "AggregateRequest": """ Specify by which fields to group the aggregation. @@ -151,7 +155,7 @@ def group_by(self, fields, *reducers): self._aggregateplan.extend(ret) return self - def apply(self, **kwexpr): + def apply(self, **kwexpr) -> "AggregateRequest": """ Specify one or more projection expressions to add to each result @@ -169,7 +173,7 @@ def apply(self, **kwexpr): return self - def limit(self, offset, num): + def limit(self, offset: int, num: int) -> "AggregateRequest": """ Sets the limit for the most recent group or query. @@ -215,7 +219,7 @@ def limit(self, offset, num): self._aggregateplan.extend(_limit.build_args()) return self - def sort_by(self, *fields, **kwargs): + def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": """ Indicate how the results should be sorted. This can also be used for *top-N* style queries @@ -262,7 +266,7 @@ def sort_by(self, *fields, **kwargs): self._aggregateplan.extend(ret) return self - def filter(self, expressions): + def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest": """ Specify filter for post-query results using predicates relating to values in the result set. @@ -280,7 +284,7 @@ def filter(self, expressions): return self - def with_schema(self): + def with_schema(self) -> "AggregateRequest": """ If set, the `schema` property will contain a list of `[field, type]` entries in the result object. @@ -288,11 +292,11 @@ def with_schema(self): self._with_schema = True return self - def verbatim(self): + def verbatim(self) -> "AggregateRequest": self._verbatim = True return self - def cursor(self, count=0, max_idle=0.0): + def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest": args = ["WITHCURSOR"] if count: args += ["COUNT", str(count)] @@ -301,7 +305,7 @@ def cursor(self, count=0, max_idle=0.0): self._cursor = args return self - def build_args(self): + def build_args(self) -> List[str]: # @foo:bar ... ret = [self._query] @@ -329,7 +333,7 @@ def build_args(self): return ret - def dialect(self, dialect): + def dialect(self, dialect: int) -> "AggregateRequest": """ Add a dialect field to the aggregate command. @@ -340,7 +344,7 @@ def dialect(self, dialect): class Cursor: - def __init__(self, cid): + def __init__(self, cid: int) -> None: self.cid = cid self.max_idle = 0 self.count = 0 @@ -355,12 +359,12 @@ def build_args(self): class AggregateResult: - def __init__(self, rows, cursor, schema): + def __init__(self, rows, cursor: Cursor, schema) -> None: self.rows = rows self.cursor = cursor self.schema = schema - def __repr__(self): + def __repr__(self) -> (str, str): cid = self.cursor.cid if self.cursor else -1 return ( f"<{self.__class__.__name__} at 0x{id(self):x} " diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 87b572195c..2df2b5a754 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -1,11 +1,11 @@ import itertools import time -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union from redis.client import Pipeline from redis.utils import deprecated_function -from ..helpers import parse_to_dict +from ..helpers import get_protocol_version, parse_to_dict from ._util import to_string from .aggregation import AggregateRequest, AggregateResult, Cursor from .document import Document @@ -64,7 +64,7 @@ class SearchCommands: """Search commands.""" def _parse_results(self, cmd, res, **kwargs): - if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + if get_protocol_version(self.client) in ["3", 3]: return res else: return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs) @@ -220,7 +220,7 @@ def create_index( return self.execute_command(*args) - def alter_schema_add(self, fields): + def alter_schema_add(self, fields: List[str]): """ Alter the existing search index by adding new fields. The index must already exist. @@ -240,7 +240,7 @@ def alter_schema_add(self, fields): return self.execute_command(*args) - def dropindex(self, delete_documents=False): + def dropindex(self, delete_documents: bool = False): """ Drop the index if it exists. Replaced `drop_index` in RediSearch 2.0. @@ -322,15 +322,15 @@ def _add_document_hash( ) def add_document( self, - doc_id, - nosave=False, - score=1.0, - payload=None, - replace=False, - partial=False, - language=None, - no_create=False, - **fields, + doc_id: str, + nosave: bool = False, + score: float = 1.0, + payload: bool = None, + replace: bool = False, + partial: bool = False, + language: Optional[str] = None, + no_create: str = False, + **fields: List[str], ): """ Add a single document to the index. @@ -554,7 +554,9 @@ def aggregate( AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor ) - def _get_aggregate_result(self, raw, query, has_cursor): + def _get_aggregate_result( + self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool + ): if has_cursor: if isinstance(query, Cursor): query.cid = raw[1] @@ -642,7 +644,7 @@ def spellcheck(self, query, distance=None, include=None, exclude=None): return self._parse_results(SPELLCHECK_CMD, res) - def dict_add(self, name, *terms): + def dict_add(self, name: str, *terms: List[str]): """Adds terms to a dictionary. ### Parameters @@ -656,7 +658,7 @@ def dict_add(self, name, *terms): cmd.extend(terms) return self.execute_command(*cmd) - def dict_del(self, name, *terms): + def dict_del(self, name: str, *terms: List[str]): """Deletes terms from a dictionary. ### Parameters @@ -670,7 +672,7 @@ def dict_del(self, name, *terms): cmd.extend(terms) return self.execute_command(*cmd) - def dict_dump(self, name): + def dict_dump(self, name: str): """Dumps all terms in the given dictionary. ### Parameters @@ -682,7 +684,7 @@ def dict_dump(self, name): cmd = [DICT_DUMP_CMD, name] return self.execute_command(*cmd) - def config_set(self, option, value): + def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. ### Parameters @@ -696,7 +698,7 @@ def config_set(self, option, value): raw = self.execute_command(*cmd) return raw == "OK" - def config_get(self, option): + def config_get(self, option: str) -> str: """Get runtime configuration option value. ### Parameters @@ -709,7 +711,7 @@ def config_get(self, option): res = self.execute_command(*cmd) return self._parse_results(CONFIG_CMD, res) - def tagvals(self, tagfield): + def tagvals(self, tagfield: str): """ Return a list of all possible tag values @@ -722,7 +724,7 @@ def tagvals(self, tagfield): return self.execute_command(TAGVALS_CMD, self.index_name, tagfield) - def aliasadd(self, alias): + def aliasadd(self, alias: str): """ Alias a search index - will fail if alias already exists @@ -735,7 +737,7 @@ def aliasadd(self, alias): return self.execute_command(ALIAS_ADD_CMD, alias, self.index_name) - def aliasupdate(self, alias): + def aliasupdate(self, alias: str): """ Updates an alias - will fail if alias does not already exist @@ -748,7 +750,7 @@ def aliasupdate(self, alias): return self.execute_command(ALIAS_UPDATE_CMD, alias, self.index_name) - def aliasdel(self, alias): + def aliasdel(self, alias: str): """ Removes an alias to a search index @@ -783,7 +785,7 @@ def sugadd(self, key, *suggestions, **kwargs): return pipe.execute()[-1] - def suglen(self, key): + def suglen(self, key: str) -> int: """ Return the number of entries in the AutoCompleter index. @@ -791,7 +793,7 @@ def suglen(self, key): """ # noqa return self.execute_command(SUGLEN_COMMAND, key) - def sugdel(self, key, string): + def sugdel(self, key: str, string: str) -> int: """ Delete a string from the AutoCompleter index. Returns 1 if the string was found and deleted, 0 otherwise. @@ -801,8 +803,14 @@ def sugdel(self, key, string): return self.execute_command(SUGDEL_COMMAND, key, string) def sugget( - self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False - ): + self, + key: str, + prefix: str, + fuzzy: bool = False, + num: int = 10, + with_scores: bool = False, + with_payloads: bool = False, + ) -> List[SuggestionParser]: """ Get a list of suggestions from the AutoCompleter, for a given prefix. @@ -850,7 +858,7 @@ def sugget( parser = SuggestionParser(with_scores, with_payloads, res) return [s for s in parser] - def synupdate(self, groupid, skipinitial=False, *terms): + def synupdate(self, groupid: str, skipinitial: bool = False, *terms: List[str]): """ Updates a synonym group. The command is used to create or update a synonym group with @@ -986,7 +994,7 @@ async def spellcheck(self, query, distance=None, include=None, exclude=None): return self._parse_results(SPELLCHECK_CMD, res) - async def config_set(self, option, value): + async def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. ### Parameters @@ -1000,7 +1008,7 @@ async def config_set(self, option, value): raw = await self.execute_command(*cmd) return raw == "OK" - async def config_get(self, option): + async def config_get(self, option: str) -> str: """Get runtime configuration option value. ### Parameters @@ -1053,8 +1061,14 @@ async def sugadd(self, key, *suggestions, **kwargs): return (await pipe.execute())[-1] async def sugget( - self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False - ): + self, + key: str, + prefix: str, + fuzzy: bool = False, + num: int = 10, + with_scores: bool = False, + with_payloads: bool = False, + ) -> List[SuggestionParser]: """ Get a list of suggestions from the AutoCompleter, for a given prefix. diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index 6f31ce1fc2..f316ed9f14 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -4,7 +4,6 @@ class Field: - NUMERIC = "NUMERIC" TEXT = "TEXT" WEIGHT = "WEIGHT" @@ -14,6 +13,7 @@ class Field: SORTABLE = "SORTABLE" NOINDEX = "NOINDEX" AS = "AS" + GEOSHAPE = "GEOSHAPE" def __init__( self, @@ -92,6 +92,21 @@ def __init__(self, name: str, **kwargs): Field.__init__(self, name, args=[Field.NUMERIC], **kwargs) +class GeoShapeField(Field): + """ + GeoShapeField is used to enable within/contain indexing/searching + """ + + SPHERICAL = "SPHERICAL" + FLAT = "FLAT" + + def __init__(self, name: str, coord_system=None, **kwargs): + args = [Field.GEOSHAPE] + if coord_system: + args.append(coord_system) + Field.__init__(self, name, args=args, **kwargs) + + class GeoField(Field): """ GeoField is used to define a geo-indexing field in a schema definition diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index 5071cfabf2..113ddf9da8 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,3 +1,6 @@ +from typing import List, Optional, Union + + class Query: """ Query is used to build complex queries that have more parameters than just @@ -8,52 +11,52 @@ class Query: i.e. `Query("foo").verbatim().filter(...)` etc. """ - def __init__(self, query_string): + def __init__(self, query_string: str) -> None: """ Create a new query object. The query string is set in the constructor, and other options have setter functions. """ - self._query_string = query_string - self._offset = 0 - self._num = 10 - self._no_content = False - self._no_stopwords = False - self._fields = None - self._verbatim = False - self._with_payloads = False - self._with_scores = False - self._scorer = False - self._filters = list() - self._ids = None - self._slop = -1 - self._timeout = None - self._in_order = False - self._sortby = None - self._return_fields = [] - self._summarize_fields = [] - self._highlight_fields = [] - self._language = None - self._expander = None - self._dialect = None - - def query_string(self): + self._query_string: str = query_string + self._offset: int = 0 + self._num: int = 10 + self._no_content: bool = False + self._no_stopwords: bool = False + self._fields: Optional[List[str]] = None + self._verbatim: bool = False + self._with_payloads: bool = False + self._with_scores: bool = False + self._scorer: Optional[str] = None + self._filters: List = list() + self._ids: Optional[List[str]] = None + self._slop: int = -1 + self._timeout: Optional[float] = None + self._in_order: bool = False + self._sortby: Optional[SortbyField] = None + self._return_fields: List = [] + self._summarize_fields: List = [] + self._highlight_fields: List = [] + self._language: Optional[str] = None + self._expander: Optional[str] = None + self._dialect: Optional[int] = None + + def query_string(self) -> str: """Return the query string of this query only.""" return self._query_string - def limit_ids(self, *ids): + def limit_ids(self, *ids) -> "Query": """Limit the results to a specific set of pre-known document ids of any length.""" self._ids = ids return self - def return_fields(self, *fields): + def return_fields(self, *fields) -> "Query": """Add fields to return fields.""" self._return_fields += fields return self - def return_field(self, field, as_field=None): + def return_field(self, field: str, as_field: Optional[str] = None) -> "Query": """Add field to return fields (Optional: add 'AS' name to the field).""" self._return_fields.append(field) @@ -61,12 +64,18 @@ def return_field(self, field, as_field=None): self._return_fields += ("AS", as_field) return self - def _mk_field_list(self, fields): + def _mk_field_list(self, fields: List[str]) -> List: if not fields: return [] return [fields] if isinstance(fields, str) else list(fields) - def summarize(self, fields=None, context_len=None, num_frags=None, sep=None): + def summarize( + self, + fields: Optional[List] = None, + context_len: Optional[int] = None, + num_frags: Optional[int] = None, + sep: Optional[str] = None, + ) -> "Query": """ Return an abridged format of the field, containing only the segments of the field which contain the matching term(s). @@ -98,7 +107,9 @@ def summarize(self, fields=None, context_len=None, num_frags=None, sep=None): self._summarize_fields = args return self - def highlight(self, fields=None, tags=None): + def highlight( + self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None + ) -> None: """ Apply specified markup to matched term(s) within the returned field(s). @@ -116,7 +127,7 @@ def highlight(self, fields=None, tags=None): self._highlight_fields = args return self - def language(self, language): + def language(self, language: str) -> "Query": """ Analyze the query as being in the specified language. @@ -125,19 +136,19 @@ def language(self, language): self._language = language return self - def slop(self, slop): + def slop(self, slop: int) -> "Query": """Allow a maximum of N intervening non matched terms between phrase terms (0 means exact phrase). """ self._slop = slop return self - def timeout(self, timeout): + def timeout(self, timeout: float) -> "Query": """overrides the timeout parameter of the module""" self._timeout = timeout return self - def in_order(self): + def in_order(self) -> "Query": """ Match only documents where the query terms appear in the same order in the document. @@ -146,7 +157,7 @@ def in_order(self): self._in_order = True return self - def scorer(self, scorer): + def scorer(self, scorer: str) -> "Query": """ Use a different scoring function to evaluate document relevance. Default is `TFIDF`. @@ -157,7 +168,7 @@ def scorer(self, scorer): self._scorer = scorer return self - def get_args(self): + def get_args(self) -> List[str]: """Format the redis arguments for this query and return them.""" args = [self._query_string] args += self._get_args_tags() @@ -165,7 +176,7 @@ def get_args(self): args += ["LIMIT", self._offset, self._num] return args - def _get_args_tags(self): + def _get_args_tags(self) -> List[str]: args = [] if self._no_content: args.append("NOCONTENT") @@ -194,7 +205,7 @@ def _get_args_tags(self): args += self._ids if self._slop >= 0: args += ["SLOP", self._slop] - if self._timeout: + if self._timeout is not None: args += ["TIMEOUT", self._timeout] if self._in_order: args.append("INORDER") @@ -216,7 +227,7 @@ def _get_args_tags(self): return args - def paging(self, offset, num): + def paging(self, offset: int, num: int) -> "Query": """ Set the paging for the query (defaults to 0..10). @@ -227,19 +238,19 @@ def paging(self, offset, num): self._num = num return self - def verbatim(self): + def verbatim(self) -> "Query": """Set the query to be verbatim, i.e. use no query expansion or stemming. """ self._verbatim = True return self - def no_content(self): + def no_content(self) -> "Query": """Set the query to only return ids and not the document content.""" self._no_content = True return self - def no_stopwords(self): + def no_stopwords(self) -> "Query": """ Prevent the query from being filtered for stopwords. Only useful in very big queries that you are certain contain @@ -248,17 +259,17 @@ def no_stopwords(self): self._no_stopwords = True return self - def with_payloads(self): + def with_payloads(self) -> "Query": """Ask the engine to return document payloads.""" self._with_payloads = True return self - def with_scores(self): + def with_scores(self) -> "Query": """Ask the engine to return document search scores.""" self._with_scores = True return self - def limit_fields(self, *fields): + def limit_fields(self, *fields: List[str]) -> "Query": """ Limit the search to specific TEXT fields only. @@ -268,7 +279,7 @@ def limit_fields(self, *fields): self._fields = fields return self - def add_filter(self, flt): + def add_filter(self, flt: "Filter") -> "Query": """ Add a numeric or geo filter to the query. **Currently only one of each filter is supported by the engine** @@ -280,7 +291,7 @@ def add_filter(self, flt): self._filters.append(flt) return self - def sort_by(self, field, asc=True): + def sort_by(self, field: str, asc: bool = True) -> "Query": """ Add a sortby field to the query. @@ -290,7 +301,7 @@ def sort_by(self, field, asc=True): self._sortby = SortbyField(field, asc) return self - def expander(self, expander): + def expander(self, expander: str) -> "Query": """ Add a expander field to the query. @@ -310,7 +321,7 @@ def dialect(self, dialect: int) -> "Query": class Filter: - def __init__(self, keyword, field, *args): + def __init__(self, keyword: str, field: str, *args: List[str]) -> None: self.args = [keyword, field] + list(args) @@ -318,7 +329,14 @@ class NumericFilter(Filter): INF = "+inf" NEG_INF = "-inf" - def __init__(self, field, minval, maxval, minExclusive=False, maxExclusive=False): + def __init__( + self, + field: str, + minval: Union[int, str], + maxval: Union[int, str], + minExclusive: bool = False, + maxExclusive: bool = False, + ) -> None: args = [ minval if not minExclusive else f"({minval}", maxval if not maxExclusive else f"({maxval}", @@ -333,10 +351,12 @@ class GeoFilter(Filter): FEET = "ft" MILES = "mi" - def __init__(self, field, lon, lat, radius, unit=KILOMETERS): + def __init__( + self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS + ) -> None: Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit) class SortbyField: - def __init__(self, field, asc=True): + def __init__(self, field: str, asc=True) -> None: self.args = [field, "ASC" if asc else "DESC"] diff --git a/redis/commands/search/reducers.py b/redis/commands/search/reducers.py index 41ed11a238..8b60f23283 100644 --- a/redis/commands/search/reducers.py +++ b/redis/commands/search/reducers.py @@ -1,8 +1,12 @@ -from .aggregation import Reducer, SortDirection +from typing import Union + +from .aggregation import Asc, Desc, Reducer, SortDirection class FieldOnlyReducer(Reducer): - def __init__(self, field): + """See https://redis.io/docs/interact/search-and-query/search/aggregations/""" + + def __init__(self, field: str) -> None: super().__init__(field) self._field = field @@ -14,7 +18,7 @@ class count(Reducer): NAME = "COUNT" - def __init__(self): + def __init__(self) -> None: super().__init__() @@ -25,7 +29,7 @@ class sum(FieldOnlyReducer): NAME = "SUM" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -36,7 +40,7 @@ class min(FieldOnlyReducer): NAME = "MIN" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -47,7 +51,7 @@ class max(FieldOnlyReducer): NAME = "MAX" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -58,7 +62,7 @@ class avg(FieldOnlyReducer): NAME = "AVG" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -69,7 +73,7 @@ class tolist(FieldOnlyReducer): NAME = "TOLIST" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -81,7 +85,7 @@ class count_distinct(FieldOnlyReducer): NAME = "COUNT_DISTINCT" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -103,7 +107,7 @@ class quantile(Reducer): NAME = "QUANTILE" - def __init__(self, field, pct): + def __init__(self, field: str, pct: float) -> None: super().__init__(field, str(pct)) self._field = field @@ -115,7 +119,7 @@ class stddev(FieldOnlyReducer): NAME = "STDDEV" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -126,7 +130,7 @@ class first_value(Reducer): NAME = "FIRST_VALUE" - def __init__(self, field, *byfields): + def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None: """ Selects the first value of the given field within the group. @@ -166,7 +170,7 @@ class random_sample(Reducer): NAME = "RANDOM_SAMPLE" - def __init__(self, field, size): + def __init__(self, field: str, size: int) -> None: """ ### Parameter diff --git a/redis/commands/search/result.py b/redis/commands/search/result.py index 451bf89bb7..5b19e6faa4 100644 --- a/redis/commands/search/result.py +++ b/redis/commands/search/result.py @@ -69,5 +69,5 @@ def __init__( ) self.docs.append(doc) - def __repr__(self): + def __repr__(self) -> str: return f"Result{{{self.total} total, docs: {self.docs}}}" diff --git a/redis/commands/search/suggestion.py b/redis/commands/search/suggestion.py index 5d1eba64b8..499c8d917e 100644 --- a/redis/commands/search/suggestion.py +++ b/redis/commands/search/suggestion.py @@ -1,3 +1,5 @@ +from typing import Optional + from ._util import to_string @@ -7,12 +9,14 @@ class Suggestion: autocomplete server """ - def __init__(self, string, score=1.0, payload=None): + def __init__( + self, string: str, score: float = 1.0, payload: Optional[str] = None + ) -> None: self.string = to_string(string) self.payload = to_string(payload) self.score = score - def __repr__(self): + def __repr__(self) -> str: return self.string @@ -23,7 +27,7 @@ class SuggestionParser: the return value depending on what objects were requested """ - def __init__(self, with_scores, with_payloads, ret): + def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None: self.with_scores = with_scores self.with_payloads = with_payloads diff --git a/redis/commands/timeseries/commands.py b/redis/commands/timeseries/commands.py index 13e3cdf498..208ddfb09f 100644 --- a/redis/commands/timeseries/commands.py +++ b/redis/commands/timeseries/commands.py @@ -59,6 +59,9 @@ def create( - 'last': override with latest value. - 'min': only override if the value is lower than the existing value. - 'max': only override if the value is higher than the existing value. + - 'sum': If a previous sample exists, add the new sample to it so that \ + the updated value is equal to (previous + new). If no previous sample \ + exists, set the updated value equal to the new value. For more information: https://redis.io/commands/ts.create/ """ # noqa @@ -103,6 +106,9 @@ def alter( - 'last': override with latest value. - 'min': only override if the value is lower than the existing value. - 'max': only override if the value is higher than the existing value. + - 'sum': If a previous sample exists, add the new sample to it so that \ + the updated value is equal to (previous + new). If no previous sample \ + exists, set the updated value equal to the new value. For more information: https://redis.io/commands/ts.alter/ """ # noqa @@ -154,6 +160,9 @@ def add( - 'last': override with latest value. - 'min': only override if the value is lower than the existing value. - 'max': only override if the value is higher than the existing value. + - 'sum': If a previous sample exists, add the new sample to it so that \ + the updated value is equal to (previous + new). If no previous sample \ + exists, set the updated value equal to the new value. For more information: https://redis.io/commands/ts.add/ """ # noqa @@ -425,7 +434,7 @@ def range( bucket_timestamp, empty, ) - return self.execute_command(RANGE_CMD, *params) + return self.execute_command(RANGE_CMD, *params, keys=[key]) def revrange( self, @@ -497,7 +506,7 @@ def revrange( bucket_timestamp, empty, ) - return self.execute_command(REVRANGE_CMD, *params) + return self.execute_command(REVRANGE_CMD, *params, keys=[key]) def __mrange_params( self, @@ -721,7 +730,7 @@ def get(self, key: KeyT, latest: Optional[bool] = False): """ # noqa params = [key] self._append_latest(params, latest) - return self.execute_command(GET_CMD, *params) + return self.execute_command(GET_CMD, *params, keys=[key]) def mget( self, @@ -761,7 +770,7 @@ def info(self, key: KeyT): For more information: https://redis.io/commands/ts.info/ """ # noqa - return self.execute_command(INFO_CMD, key) + return self.execute_command(INFO_CMD, key, keys=[key]) def queryindex(self, filters: List[str]): """# noqa diff --git a/redis/commands/timeseries/utils.py b/redis/commands/timeseries/utils.py index c49b040271..12ed656277 100644 --- a/redis/commands/timeseries/utils.py +++ b/redis/commands/timeseries/utils.py @@ -5,7 +5,7 @@ def list_to_dict(aList): return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))} -def parse_range(response): +def parse_range(response, **kwargs): """Parse range response. Used by TS.RANGE and TS.REVRANGE.""" return [tuple((r[0], float(r[1]))) for r in response] diff --git a/redis/compat.py b/redis/compat.py deleted file mode 100644 index e478493467..0000000000 --- a/redis/compat.py +++ /dev/null @@ -1,6 +0,0 @@ -# flake8: noqa -try: - from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import] -except ImportError: - from typing_extensions import Literal # lgtm [py/unused-import] - from typing_extensions import Protocol, TypedDict diff --git a/redis/connection.py b/redis/connection.py index 00d293a238..35a4ff4a37 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,5 +1,6 @@ import copy import os +import select import socket import ssl import sys @@ -9,7 +10,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Optional, Type, Union +from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser @@ -55,7 +56,7 @@ class HiredisRespSerializer: - def pack(self, *args): + def pack(self, *args: List): """Pack a series of arguments into the Redis protocol""" output = [] @@ -128,27 +129,27 @@ class AbstractConnection: def __init__( self, - db=0, - password=None, - socket_timeout=None, - socket_connect_timeout=None, - retry_on_timeout=False, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + retry_on_timeout: bool = False, retry_on_error=SENTINEL, - encoding="utf-8", - encoding_errors="strict", - decode_responses=False, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, parser_class=DefaultParser, - socket_read_size=65536, - health_check_interval=0, - client_name=None, - lib_name="redis-py", - lib_version=get_lib_version(), - username=None, - retry=None, - redis_connect_func=None, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Union[Any, None] = None, + redis_connect_func: Optional[Callable[[], None]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - command_packer=None, + command_packer: Optional[Callable[[], None]] = None, ): """ Initialize a new Connection. @@ -217,7 +218,7 @@ def __init__( def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) - return f"{self.__class__.__name__}<{repr_args}>" + return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" @abstractmethod def repr_pieces(self): @@ -238,10 +239,27 @@ def _construct_command_packer(self, packer): return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) def register_connect_callback(self, callback): - self._connect_callbacks.append(weakref.WeakMethod(callback)) + """ + Register a callback to be called when the connection is established either + initially or reconnected. This allows listeners to issue commands that + are ephemeral to the connection, for example pub/sub subscription or + key tracking. The callback must be a _method_ and will be kept as + a weak reference. + """ + wm = weakref.WeakMethod(callback) + if wm not in self._connect_callbacks: + self._connect_callbacks.append(wm) - def clear_connect_callbacks(self): - self._connect_callbacks = [] + def deregister_connect_callback(self, callback): + """ + De-register a previously registered callback. It will no-longer receive + notifications on connection events. Calling this is not required when the + listener goes away, since the callbacks are kept as weak methods. + """ + try: + self._connect_callbacks.remove(weakref.WeakMethod(callback)) + except ValueError: + pass def set_parser(self, parser_class): """ @@ -279,6 +297,8 @@ def connect(self): # run any user callbacks. right now the only internal callback # is for pubsub channel/pattern resubscription + # first, remove any dead weakrefs + self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] for ref in self._connect_callbacks: callback = ref() if callback: @@ -513,7 +533,10 @@ def read_response( self.next_health_check = time() + self.health_check_interval if isinstance(response, ResponseError): - raise response + try: + raise response + finally: + del response # avoid creating ref cycles return response def pack_command(self, *args): @@ -550,6 +573,11 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output + def _is_socket_empty(self): + """Check if the socket is empty""" + r, _, _ = select.select([self._sock], [], [], 0) + return not bool(r) + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -843,6 +871,7 @@ def to_bool(value): "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, + "timeout": float, } @@ -967,7 +996,10 @@ class initializer. In the case of conflicting arguments, querystring return cls(**kwargs) def __init__( - self, connection_class=Connection, max_connections=None, **connection_kwargs + self, + connection_class=Connection, + max_connections: Optional[int] = None, + **connection_kwargs, ): max_connections = max_connections or 2**31 if not isinstance(max_connections, int) or max_connections < 0: @@ -988,13 +1020,13 @@ def __init__( self._fork_lock = threading.Lock() self.reset() - def __repr__(self): + def __repr__(self) -> (str, str): return ( - f"{type(self).__name__}" - f"<{repr(self.connection_class(**self.connection_kwargs))}>" + f"<{type(self).__module__}.{type(self).__name__}" + f"({repr(self.connection_class(**self.connection_kwargs))})>" ) - def reset(self): + def reset(self) -> None: self._lock = threading.Lock() self._created_connections = 0 self._available_connections = [] @@ -1011,7 +1043,7 @@ def reset(self): # reset() and they will immediately release _fork_lock and continue on. self.pid = os.getpid() - def _checkpid(self): + def _checkpid(self) -> None: # _checkpid() attempts to keep ConnectionPool fork-safe on modern # systems. this is called by all ConnectionPool methods that # manipulate the pool's state such as get_connection() and release(). @@ -1058,7 +1090,7 @@ def _checkpid(self): finally: self._fork_lock.release() - def get_connection(self, command_name, *keys, **options): + def get_connection(self, command_name: str, *keys, **options) -> "Connection": "Get a connection from the pool" self._checkpid() with self._lock: @@ -1091,7 +1123,7 @@ def get_connection(self, command_name, *keys, **options): return connection - def get_encoder(self): + def get_encoder(self) -> Encoder: "Return an encoder based on encoding settings" kwargs = self.connection_kwargs return Encoder( @@ -1100,14 +1132,14 @@ def get_encoder(self): decode_responses=kwargs.get("decode_responses", False), ) - def make_connection(self): + def make_connection(self) -> "Connection": "Create a new connection" if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") self._created_connections += 1 return self.connection_class(**self.connection_kwargs) - def release(self, connection): + def release(self, connection: "Connection") -> None: "Releases the connection back to the pool" self._checkpid() with self._lock: @@ -1128,10 +1160,10 @@ def release(self, connection): connection.disconnect() return - def owns_connection(self, connection): + def owns_connection(self, connection: "Connection") -> int: return connection.pid == self.pid - def disconnect(self, inuse_connections=True): + def disconnect(self, inuse_connections: bool = True) -> None: """ Disconnects connections in the pool @@ -1151,6 +1183,10 @@ def disconnect(self, inuse_connections=True): for connection in connections: connection.disconnect() + def close(self) -> None: + """Close the pool, disconnecting all connections""" + self.disconnect() + def set_retry(self, retry: "Retry") -> None: self.connection_kwargs.update({"retry": retry}) for conn in self._available_connections: @@ -1201,7 +1237,6 @@ def __init__( queue_class=LifoQueue, **connection_kwargs, ): - self.queue_class = queue_class self.timeout = timeout super().__init__( diff --git a/redis/ocsp.py b/redis/ocsp.py index b0420b4711..8819848fa9 100644 --- a/redis/ocsp.py +++ b/redis/ocsp.py @@ -61,7 +61,7 @@ def _check_certificate(issuer_cert, ocsp_bytes, validate=True): ) else: raise ConnectionError( - "failed to retrieve a sucessful response from the ocsp responder" + "failed to retrieve a successful response from the ocsp responder" ) if ocsp_response.this_update >= datetime.datetime.now(): @@ -139,7 +139,7 @@ def _get_pubkey_hash(certificate): def ocsp_staple_verifier(con, ocsp_bytes, expected=None): - """An implemention of a function for set_ocsp_client_callback in PyOpenSSL. + """An implementation of a function for set_ocsp_client_callback in PyOpenSSL. This function validates that the provide ocsp_bytes response is valid, and matches the expected, stapled responses. @@ -266,7 +266,7 @@ def build_certificate_url(self, server, cert, issuer_cert): return url def check_certificate(self, server, cert, issuer_url): - """Checks the validitity of an ocsp server for an issuer""" + """Checks the validity of an ocsp server for an issuer""" r = requests.get(issuer_url) if not r.ok: diff --git a/redis/sentinel.py b/redis/sentinel.py index 836e781e7f..dfcd8ff64b 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -24,7 +24,10 @@ def __init__(self, **kwargs): def __repr__(self): pool = self.connection_pool - s = f"{type(self).__name__}" + s = ( + f"<{type(self).__module__}.{type(self).__name__}" + f"(service={pool.service_name}%s)>" + ) if self.host: host_info = f",host={self.host},port={self.port}" s = s % host_info @@ -162,7 +165,10 @@ def __init__(self, service_name, sentinel_manager, **kwargs): def __repr__(self): role = "master" if self.is_master else "slave" - return f"{type(self).__name__}" + ) def reset(self): super().reset() @@ -244,6 +250,7 @@ def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ + kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") @@ -261,7 +268,10 @@ def __repr__(self): sentinel_addresses.append( "{host}:{port}".format_map(sentinel.connection_pool.connection_kwargs) ) - return f'{type(self).__name__}' + return ( + f"<{type(self).__module__}.{type(self).__name__}" + f'(sentinels=[{",".join(sentinel_addresses)}])>' + ) def check_master_state(self, state, service_name): if not state["is_master"] or state["is_sdown"] or state["is_odown"]: @@ -353,10 +363,8 @@ def master_for( kwargs["is_master"] = True connection_kwargs = dict(self.connection_kwargs) connection_kwargs.update(kwargs) - return redis_class( - connection_pool=connection_pool_class( - service_name, self, **connection_kwargs - ) + return redis_class.from_pool( + connection_pool_class(service_name, self, **connection_kwargs) ) def slave_for( @@ -386,8 +394,6 @@ def slave_for( kwargs["is_master"] = False connection_kwargs = dict(self.connection_kwargs) connection_kwargs.update(kwargs) - return redis_class( - connection_pool=connection_pool_class( - service_name, self, **connection_kwargs - ) + return redis_class.from_pool( + connection_pool_class(service_name, self, **connection_kwargs) ) diff --git a/redis/typing.py b/redis/typing.py index 56a1e99ba7..a5d1369d63 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -7,13 +7,12 @@ Awaitable, Iterable, Mapping, + Protocol, Type, TypeVar, Union, ) -from redis.compat import Protocol - if TYPE_CHECKING: from redis._parsers import Encoder from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool @@ -33,6 +32,7 @@ PatternT = _StringLikeT # Patterns matched against keys, fields etc FieldT = EncodableT # Fields within hash tables, streams and geo commands KeysT = Union[KeyT, Iterable[KeyT]] +ResponseT = Union[Awaitable, Any] ChannelT = _StringLikeT GroupT = _StringLikeT # Consumer group ConsumerT = _StringLikeT # Consumer name diff --git a/setup.py b/setup.py index 475e3565fa..89aa2e6658 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.0.0", + version="5.1.0a1", packages=find_packages( include=[ "redis", @@ -34,10 +34,8 @@ }, author="Redis Inc.", author_email="oss@redis.com", - python_requires=">=3.7", + python_requires=">=3.8", install_requires=[ - 'importlib-metadata >= 1.0; python_version < "3.8"', - 'typing-extensions; python_version<"3.8"', 'async-timeout>=4.0.2; python_full_version<="3.11.2"', ], classifiers=[ @@ -49,7 +47,6 @@ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/tests/conftest.py b/tests/conftest.py index 16f3fbb9db..bad9f43e42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import random import time from typing import Callable, TypeVar +from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse @@ -9,7 +10,7 @@ import redis from packaging.version import Version from redis.backoff import NoBackoff -from redis.connection import parse_url +from redis.connection import Connection, parse_url from redis.exceptions import RedisClusterException from redis.retry import Retry @@ -39,7 +40,6 @@ def __init__( help=None, metavar=None, ): - _option_strings = [] for option_string in option_strings: _option_strings.append(option_string) @@ -72,7 +72,6 @@ def format_usage(self): def pytest_addoption(parser): - parser.addoption( "--redis-url", default=default_redis_url, @@ -354,23 +353,23 @@ def sslclient(request): def _gen_cluster_mock_resp(r, response): - connection = Mock() + connection = Mock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - r.connection = connection - return r + with mock.patch.object(r, "connection", connection): + yield r @pytest.fixture() def mock_cluster_resp_ok(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, "OK") + yield from _gen_cluster_mock_resp(r, "OK") @pytest.fixture() def mock_cluster_resp_int(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, 2) + yield from _gen_cluster_mock_resp(r, 2) @pytest.fixture() @@ -384,7 +383,7 @@ def mock_cluster_resp_info(request, **kwargs): "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" "cluster_stats_messages_received:105653\r\n" ) - return _gen_cluster_mock_resp(r, response) + yield from _gen_cluster_mock_resp(r, response) @pytest.fixture() @@ -408,7 +407,7 @@ def mock_cluster_resp_nodes(request, **kwargs): "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " "master,fail - 1447829446956 1447829444948 1 disconnected\n" ) - return _gen_cluster_mock_resp(r, response) + yield from _gen_cluster_mock_resp(r, response) @pytest.fixture() @@ -419,7 +418,7 @@ def mock_cluster_resp_slaves(request, **kwargs): "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " "1447836789290 3 connected']" ) - return _gen_cluster_mock_resp(r, response) + yield from _gen_cluster_mock_resp(r, response) @pytest.fixture(scope="session") diff --git a/tests/test_asyncio/compat.py b/tests/test_asyncio/compat.py index 5edcd4ae54..4a9778b70a 100644 --- a/tests/test_asyncio/compat.py +++ b/tests/test_asyncio/compat.py @@ -6,6 +6,18 @@ except AttributeError: import mock +try: + from contextlib import aclosing +except ImportError: + import contextlib + + @contextlib.asynccontextmanager + async def aclosing(thing): + try: + yield thing + finally: + await thing.aclose() + def create_task(coroutine): return asyncio.create_task(coroutine) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index c837f284f7..5d9e0b4f2e 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -8,7 +8,7 @@ from packaging.version import Version from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser from redis.asyncio.client import Monitor -from redis.asyncio.connection import parse_url +from redis.asyncio.connection import Connection, parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.utils import HIREDIS_AVAILABLE @@ -100,7 +100,7 @@ async def teardown(): # handle cases where a test disconnected a client # just manually retry the flushdb await client.flushdb() - await client.close() + await client.aclose() await client.connection_pool.disconnect() else: if flushdb: @@ -110,7 +110,7 @@ async def teardown(): # handle cases where a test disconnected a client # just manually retry the flushdb await client.flushdb(target_nodes="primaries") - await client.close() + await client.aclose() teardown_clients.append(teardown) return client @@ -138,23 +138,25 @@ async def decoded_r(create_redis): def _gen_cluster_mock_resp(r, response): - connection = mock.AsyncMock() + connection = mock.AsyncMock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - r.connection = connection - return r + with mock.patch.object(r, "connection", connection): + yield r @pytest_asyncio.fixture() async def mock_cluster_resp_ok(create_redis, **kwargs): r = await create_redis(**kwargs) - return _gen_cluster_mock_resp(r, "OK") + for mocked in _gen_cluster_mock_resp(r, "OK"): + yield mocked @pytest_asyncio.fixture() async def mock_cluster_resp_int(create_redis, **kwargs): r = await create_redis(**kwargs) - return _gen_cluster_mock_resp(r, 2) + for mocked in _gen_cluster_mock_resp(r, 2): + yield mocked @pytest_asyncio.fixture() @@ -168,7 +170,8 @@ async def mock_cluster_resp_info(create_redis, **kwargs): "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" "cluster_stats_messages_received:105653\r\n" ) - return _gen_cluster_mock_resp(r, response) + for mocked in _gen_cluster_mock_resp(r, response): + yield mocked @pytest_asyncio.fixture() @@ -192,7 +195,8 @@ async def mock_cluster_resp_nodes(create_redis, **kwargs): "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " "master,fail - 1447829446956 1447829444948 1 disconnected\n" ) - return _gen_cluster_mock_resp(r, response) + for mocked in _gen_cluster_mock_resp(r, response): + yield mocked @pytest_asyncio.fixture() @@ -203,7 +207,8 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " "1447836789290 3 connected']" ) - return _gen_cluster_mock_resp(r, response) + for mocked in _gen_cluster_mock_resp(r, response): + yield mocked async def wait_for_command( diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py new file mode 100644 index 0000000000..c837acfed1 --- /dev/null +++ b/tests/test_asyncio/test_cache.py @@ -0,0 +1,129 @@ +import time + +import pytest +import redis.asyncio as redis +from redis.utils import HIREDIS_AVAILABLE + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_get_from_cache(): + r = redis.Redis(cache_enable=True, single_connection_client=True, protocol=3) + r2 = redis.Redis(protocol=3) + # add key to redis + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + await r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == b"barbar" + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_max_size(): + r = redis.Redis( + cache_enable=True, cache_max_size=3, single_connection_client=True, protocol=3 + ) + # add 3 keys to redis + await r.set("foo", "bar") + await r.set("foo2", "bar2") + await r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert await r.get("foo") == b"bar" + assert await r.get("foo2") == b"bar2" + assert await r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) == b"bar2" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + await r.set("foo4", "bar4") + assert await r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_ttl(): + r = redis.Redis( + cache_enable=True, cache_ttl=1, single_connection_client=True, protocol=3 + ) + # add key to redis + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_lfu_eviction(): + r = redis.Redis( + cache_enable=True, + cache_max_size=3, + cache_eviction_policy="lfu", + single_connection_client=True, + protocol=3, + ) + # add 3 keys to redis + await r.set("foo", "bar") + await r.set("foo2", "bar2") + await r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert await r.get("foo") == b"bar" + assert await r.get("foo2") == b"bar2" + assert await r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + await r.set("foo4", "bar4") + assert await r.get("foo4") == b"bar4" + # test the eviction policy + assert len(r.client_cache.cache) == 3 + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) is None + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_decode_response(): + r = redis.Redis( + decode_responses=True, + cache_enable=True, + single_connection_client=True, + protocol=3, + ) + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == "bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + await r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == "barbar" + + await r.aclose() diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 1cb1fa5195..e6cf2e4ce7 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -38,7 +38,7 @@ ) from ..ssl_utils import get_ssl_filename -from .compat import mock +from .compat import aclosing, mock pytestmark = pytest.mark.onlycluster @@ -175,7 +175,7 @@ def cmd_init_mock(self, r: ClusterNode) -> None: def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: - connection = mock.AsyncMock() + connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.return_value = response while node._free: @@ -185,7 +185,7 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode: - connection = mock.AsyncMock() + connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.side_effect = exc while node._free: @@ -270,7 +270,38 @@ async def test_host_port_startup_node(self) -> None: cluster = await get_mocked_redis_client(host=default_host, port=default_port) assert cluster.get_node(host=default_host, port=default_port) is not None - await cluster.close() + await cluster.aclose() + + async def test_aclosing(self) -> None: + cluster = await get_mocked_redis_client(host=default_host, port=default_port) + called = 0 + + async def mock_aclose(): + nonlocal called + called += 1 + + with mock.patch.object(cluster, "aclose", mock_aclose): + async with aclosing(cluster): + pass + assert called == 1 + await cluster.aclose() + + async def test_close_is_aclose(self) -> None: + """ + Test that it is possible to use host & port arguments as startup node + args + """ + cluster = await get_mocked_redis_client(host=default_host, port=default_port) + called = 0 + + async def mock_aclose(): + nonlocal called + called += 1 + + with mock.patch.object(cluster, "aclose", mock_aclose): + await cluster.close() + assert called == 1 + await cluster.aclose() async def test_startup_nodes(self) -> None: """ @@ -289,7 +320,7 @@ async def test_startup_nodes(self) -> None: and cluster.get_node(host=default_host, port=port_2) is not None ) - await cluster.close() + await cluster.aclose() startup_node = ClusterNode("127.0.0.1", 16379) async with RedisCluster(startup_nodes=[startup_node], client_name="test") as rc: @@ -417,7 +448,7 @@ async def read_response_mocked(*args: Any, **kwargs: Any) -> None: ) ) - await rc.close() + await rc.aclose() async def test_execute_command_errors(self, r: RedisCluster) -> None: """ @@ -461,7 +492,7 @@ async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None conn = primary._free.pop() assert conn.read_response.called is not True - await r.close() + await r.aclose() async def test_execute_command_node_flag_all_nodes(self, r: RedisCluster) -> None: """ @@ -690,7 +721,7 @@ def execute_command_mock_third(self, *args, **options): await read_cluster.get("foo") mocks["send_command"].assert_has_calls([mock.call("READONLY")]) - await read_cluster.close() + await read_cluster.aclose() async def test_keyslot(self, r: RedisCluster) -> None: """ @@ -762,7 +793,7 @@ def raise_error(target_node, *args, **kwargs): await rc.get("bar") assert execute_command.failed_calls == rc.cluster_error_retry_attempts - await rc.close() + await rc.aclose() async def test_set_default_node_success(self, r: RedisCluster) -> None: """ @@ -843,7 +874,7 @@ async def test_can_run_concurrent_commands(self, request: FixtureRequest) -> Non *(rc.echo("i", target_nodes=RedisCluster.ALL_NODES) for i in range(100)) ) ) - await rc.close() + await rc.aclose() def test_replace_cluster_node(self, r: RedisCluster) -> None: prev_default_node = r.get_default_node() @@ -901,7 +932,7 @@ def address_remap(address): assert await r.set("byte_string", b"giraffe") assert await r.get("byte_string") == b"giraffe" finally: - await r.close() + await r.aclose() finally: await asyncio.gather(*[p.aclose() for p in proxies]) @@ -1002,7 +1033,7 @@ async def test_initialize_before_execute_multi_key_command( url = request.config.getoption("--redis-url") r = RedisCluster.from_url(url) assert 0 == await r.exists("a", "b", "c") - await r.close() + await r.aclose() @skip_if_redis_enterprise() async def test_cluster_myid(self, r: RedisCluster) -> None: @@ -1065,7 +1096,7 @@ async def test_cluster_delslots(self) -> None: assert node0._free.pop().read_response.called assert node1._free.pop().read_response.called - await r.close() + await r.aclose() @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() @@ -1076,7 +1107,7 @@ async def test_cluster_delslotsrange(self): await r.cluster_addslots(node, 1, 2, 3, 4, 5) assert await r.cluster_delslotsrange(1, 5) assert node._free.pop().read_response.called - await r.close() + await r.aclose() @skip_if_redis_enterprise() async def test_cluster_failover(self, r: RedisCluster) -> None: @@ -1286,7 +1317,7 @@ async def test_readonly(self) -> None: for replica in r.get_replicas(): assert replica._free.pop().read_response.called - await r.close() + await r.aclose() @skip_if_redis_enterprise() async def test_readwrite(self) -> None: @@ -1299,7 +1330,7 @@ async def test_readwrite(self) -> None: for replica in r.get_replicas(): assert replica._free.pop().read_response.called - await r.close() + await r.aclose() @skip_if_redis_enterprise() async def test_bgsave(self, r: RedisCluster) -> None: @@ -1524,7 +1555,7 @@ async def test_client_kill( ] assert len(clients) == 1 assert clients[0].get("name") == "redis-py-c1" - await r2.close() + await r2.aclose() @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_not_empty_string(self, r: RedisCluster) -> None: @@ -2302,7 +2333,7 @@ async def test_acl_log( await r.acl_deluser(username, target_nodes="primaries") - await user_client.close() + await user_client.aclose() class TestNodesManager: @@ -2359,7 +2390,7 @@ async def test_init_slots_cache_not_all_slots_covered(self) -> None: cluster_slots=cluster_slots, require_full_coverage=True, ) - await rc.close() + await rc.aclose() assert str(ex.value).startswith( "All slots are not covered after query all startup_nodes." ) @@ -2385,7 +2416,7 @@ async def test_init_slots_cache_not_require_full_coverage_success(self) -> None: assert 5460 not in rc.nodes_manager.slots_cache - await rc.close() + await rc.aclose() async def test_init_slots_cache(self) -> None: """ @@ -2416,7 +2447,7 @@ async def test_init_slots_cache(self) -> None: assert len(n_manager.nodes_cache) == 6 - await rc.close() + await rc.aclose() async def test_init_slots_cache_cluster_mode_disabled(self) -> None: """ @@ -2427,7 +2458,7 @@ async def test_init_slots_cache_cluster_mode_disabled(self) -> None: rc = await get_mocked_redis_client( host=default_host, port=default_port, cluster_enabled=False ) - await rc.close() + await rc.aclose() assert "Cluster mode is not enabled on this node" in str(e.value) async def test_empty_startup_nodes(self) -> None: @@ -2514,7 +2545,7 @@ async def test_cluster_one_instance(self) -> None: for i in range(0, REDIS_CLUSTER_HASH_SLOTS): assert n.slots_cache[i] == [n_node] - await rc.close() + await rc.aclose() async def test_init_with_down_node(self) -> None: """ diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 6a3a2eca59..35b9f2a29f 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -370,10 +370,12 @@ async def test_client_setinfo(self, r: redis.Redis): info = await r2.client_info() assert info["lib-name"] == "test2" assert info["lib-ver"] == "1234" + await r2.aclose() r3 = redis.asyncio.Redis(lib_name=None, lib_version=None) info = await r3.client_info() assert info["lib-name"] == "" assert info["lib-ver"] == "" + await r3.aclose() @skip_if_server_version_lt("2.6.9") @pytest.mark.onlynoncluster @@ -3213,7 +3215,6 @@ async def test_memory_usage(self, r: redis.Redis): assert isinstance(await r.memory_usage("foo"), int) @skip_if_server_version_lt("4.0.0") - @pytest.mark.onlynoncluster async def test_module_list(self, r: redis.Redis): assert isinstance(await r.module_list(), list) for x in await r.module_list(): diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index bead7208f5..5e6b120fb3 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -62,6 +62,7 @@ async def test_tcp_ssl_connect(tcp_address): socket_timeout=10, ) await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + await conn.disconnect() async def _assert_connect(conn, server_address, certfile=None, keyfile=None): @@ -72,6 +73,8 @@ async def _handler(reader, writer): try: return await _redis_request_handler(reader, writer, stop_event) finally: + writer.close() + await writer.wait_closed() finished.set() if isinstance(server_address, str): diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index d1aad796e7..55a1c3a2f6 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -11,7 +11,7 @@ _AsyncRESP3Parser, _AsyncRESPBase, ) -from redis.asyncio import Redis +from redis.asyncio import ConnectionPool, Redis from redis.asyncio.connection import Connection, UnixDomainSocketConnection, parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff @@ -85,6 +85,8 @@ async def get_conn(_): assert init_call_count == 1 assert command_call_count == 2 + r.connection = None # it was a Mock + await r.aclose() @skip_if_server_version_lt("4.0.0") @@ -143,6 +145,7 @@ async def mock_connect(): conn._connect.side_effect = mock_connect await conn.connect() assert conn._connect.call_count == 3 + await conn.disconnect() async def test_connect_without_retry_on_os_error(): @@ -194,6 +197,7 @@ async def test_connection_parse_response_resume(r: redis.Redis): pytest.fail("didn't receive a response") assert response assert i > 0 + await conn.disconnect() @pytest.mark.onlynoncluster @@ -254,9 +258,8 @@ async def do_close(): async def do_read(): return await conn.read_response() - reader = mock.AsyncMock() - writer = mock.AsyncMock() - writer.transport = mock.Mock() + reader = mock.Mock(spec=asyncio.StreamReader) + writer = mock.Mock(spec=asyncio.StreamWriter) writer.transport.get_extra_info.side_effect = None # for HiredisParser @@ -289,7 +292,7 @@ def test_create_single_connection_client_from_url(): assert client.single_connection_client is True -@pytest.mark.parametrize("from_url", (True, False)) +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) async def test_pool_auto_close(request, from_url): """Verify that basic Redis instances have auto_close_connection_pool set to True""" @@ -303,23 +306,187 @@ async def get_redis_connection(): r1 = await get_redis_connection() assert r1.auto_close_connection_pool is True - await r1.close() + await r1.aclose() -@pytest.mark.parametrize("from_url", (True, False)) -async def test_pool_auto_close_disable(request, from_url): - """Verify that auto_close_connection_pool can be disabled""" +async def test_close_is_aclose(request): + """Verify close() calls aclose()""" + calls = 0 + + async def mock_aclose(self): + nonlocal calls + calls += 1 + + url: str = request.config.getoption("--redis-url") + r1 = await Redis.from_url(url) + with patch.object(r1, "aclose", mock_aclose): + with pytest.deprecated_call(): + await r1.close() + assert calls == 1 + + with pytest.deprecated_call(): + await r1.close() + + +async def test_pool_from_url_deprecation(request): + url: str = request.config.getoption("--redis-url") + + with pytest.deprecated_call(): + return Redis.from_url(url, auto_close_connection_pool=False) + + +async def test_pool_auto_close_disable(request): + """Verify that auto_close_connection_pool can be disabled (deprecated)""" url: str = request.config.getoption("--redis-url") url_args = parse_url(url) async def get_redis_connection(): - if from_url: - return Redis.from_url(url, auto_close_connection_pool=False) url_args["auto_close_connection_pool"] = False - return Redis(**url_args) + with pytest.deprecated_call(): + return Redis(**url_args) r1 = await get_redis_connection() assert r1.auto_close_connection_pool is False await r1.connection_pool.disconnect() - await r1.close() + await r1.aclose() + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +async def test_redis_connection_pool(request, from_url): + """Verify that basic Redis instances using `connection_pool` + have auto_close_connection_pool set to False""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + pool = None + + async def get_redis_connection(): + nonlocal pool + if from_url: + pool = ConnectionPool.from_url(url) + else: + pool = ConnectionPool(**url_args) + return Redis(connection_pool=pool) + + called = 0 + + async def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + async with await get_redis_connection() as r1: + assert r1.auto_close_connection_pool is False + + assert called == 0 + await pool.disconnect() + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +async def test_redis_from_pool(request, from_url): + """Verify that basic Redis instances created using `from_pool()` + have auto_close_connection_pool set to True""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + pool = None + + async def get_redis_connection(): + nonlocal pool + if from_url: + pool = ConnectionPool.from_url(url) + else: + pool = ConnectionPool(**url_args) + return Redis.from_pool(pool) + + called = 0 + + async def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + async with await get_redis_connection() as r1: + assert r1.auto_close_connection_pool is True + + assert called == 1 + await pool.disconnect() + + +@pytest.mark.parametrize("auto_close", (True, False)) +async def test_redis_pool_auto_close_arg(request, auto_close): + """test that redis instance where pool is provided have + auto_close_connection_pool set to False, regardless of arg""" + + url: str = request.config.getoption("--redis-url") + pool = ConnectionPool.from_url(url) + + async def get_redis_connection(): + with pytest.deprecated_call(): + client = Redis(connection_pool=pool, auto_close_connection_pool=auto_close) + return client + + called = 0 + + async def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + async with await get_redis_connection() as r1: + assert r1.auto_close_connection_pool is False + + assert called == 0 + await pool.disconnect() + + +async def test_client_garbage_collection(request): + """ + Test that a Redis client will call _close() on any + connection that it holds at time of destruction + """ + + url: str = request.config.getoption("--redis-url") + pool = ConnectionPool.from_url(url) + + # create a client with a connection from the pool + client = Redis(connection_pool=pool, single_connection_client=True) + await client.initialize() + with mock.patch.object(client, "connection") as a: + # we cannot, in unittests, or from asyncio, reliably trigger garbage collection + # so we must just invoke the handler + with pytest.warns(ResourceWarning): + client.__del__() + assert a._close.called + + await client.aclose() + await pool.aclose() + + +async def test_connection_garbage_collection(request): + """ + Test that a Connection object will call close() on the + stream that it holds. + """ + + url: str = request.config.getoption("--redis-url") + pool = ConnectionPool.from_url(url) + + # create a client with a connection from the pool + client = Redis(connection_pool=pool, single_connection_client=True) + await client.initialize() + conn = client.connection + + with mock.patch.object(conn, "_reader"): + with mock.patch.object(conn, "_writer") as a: + # we cannot, in unittests, or from asyncio, reliably trigger + # garbage collection so we must just invoke the handler + with pytest.warns(ResourceWarning): + conn.__del__() + assert a.close.called + + await client.aclose() + await pool.aclose() diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 7672dc74b4..5e4d3f206f 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -1,5 +1,4 @@ import asyncio -import os import re import pytest @@ -8,7 +7,7 @@ from redis.asyncio.connection import Connection, to_bool from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt -from .compat import mock +from .compat import aclosing, mock from .conftest import asynccontextmanager from .test_pubsub import wait_for_message @@ -43,7 +42,7 @@ async def test_auto_disconnect_redis_created_pool(self, r: redis.Redis): new_conn = await self.create_two_conn(r) assert new_conn != r.connection assert self.get_total_connected_connections(r.connection_pool) == 2 - await r.close() + await r.aclose() assert self.has_no_connected_connections(r.connection_pool) async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): @@ -53,7 +52,7 @@ async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): ) new_conn = await self.create_two_conn(r2) assert self.get_total_connected_connections(r2.connection_pool) == 2 - await r2.close() + await r2.aclose() assert r2.connection_pool._in_use_connections == {new_conn} assert new_conn.is_connected assert len(r2.connection_pool._available_connections) == 1 @@ -62,7 +61,7 @@ async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): async def test_auto_release_override_true_manual_created_pool(self, r: redis.Redis): assert r.auto_close_connection_pool is True, "This is from the class fixture" await self.create_two_conn(r) - await r.close() + await r.aclose() assert self.get_total_connected_connections(r.connection_pool) == 2, ( "The connection pool should not be disconnected as a manually created " "connection pool was passed in in conftest.py" @@ -73,7 +72,7 @@ async def test_auto_release_override_true_manual_created_pool(self, r: redis.Red async def test_close_override(self, r: redis.Redis, auto_close_conn_pool): r.auto_close_connection_pool = auto_close_conn_pool await self.create_two_conn(r) - await r.close(close_connection_pool=True) + await r.aclose(close_connection_pool=True) assert self.has_no_connected_connections(r.connection_pool) @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) @@ -82,7 +81,7 @@ async def test_negate_auto_close_client_pool( ): r.auto_close_connection_pool = auto_close_conn_pool new_conn = await self.create_two_conn(r) - await r.close(close_connection_pool=False) + await r.aclose(close_connection_pool=False) assert not self.has_no_connected_connections(r.connection_pool) assert r.connection_pool._in_use_connections == {new_conn} assert r.connection_pool._available_connections[0].is_connected @@ -94,7 +93,9 @@ class DummyConnection(Connection): def __init__(self, **kwargs): self.kwargs = kwargs - self.pid = os.getpid() + + def repr_pieces(self): + return [("id", id(self)), ("kwargs", self.kwargs)] async def connect(self): pass @@ -134,6 +135,16 @@ async def test_connection_creation(self): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs + async def test_aclosing(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = redis.ConnectionPool( + connection_class=DummyConnection, + max_connections=None, + **connection_kwargs, + ) + async with aclosing(pool): + pass + async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: @@ -169,11 +180,8 @@ async def test_repr_contains_db_info_tcp(self): async with self.get_pool( connection_kwargs=connection_kwargs, connection_class=redis.Connection ) as pool: - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + expected = "host=localhost,port=6379,db=1,client_name=test-client" + assert expected in repr(pool) async def test_repr_contains_db_info_unix(self): connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"} @@ -181,11 +189,8 @@ async def test_repr_contains_db_info_unix(self): connection_kwargs=connection_kwargs, connection_class=redis.UnixDomainSocketConnection, ) as pool: - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + expected = "path=/abc,db=1,client_name=test-client" + assert expected in repr(pool) class TestBlockingConnectionPool: @@ -282,11 +287,8 @@ def test_repr_contains_db_info_tcp(self): pool = redis.ConnectionPool( host="localhost", port=6379, client_name="test-client" ) - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + expected = "host=localhost,port=6379,db=0,client_name=test-client" + assert expected in repr(pool) def test_repr_contains_db_info_unix(self): pool = redis.ConnectionPool( @@ -294,11 +296,8 @@ def test_repr_contains_db_info_unix(self): path="abc", client_name="test-client", ) - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + expected = "path=abc,db=0,client_name=test-client" + assert expected in repr(pool) class TestConnectionPoolURLParsing: @@ -443,6 +442,31 @@ def test_invalid_scheme_raises_error(self): ) +class TestBlockingConnectionPoolURLParsing: + def test_extra_typed_querystring_options(self): + pool = redis.BlockingConnectionPool.from_url( + "redis://localhost/2?socket_timeout=20&socket_connect_timeout=10" + "&socket_keepalive=&retry_on_timeout=Yes&max_connections=10&timeout=13.37" + ) + + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + "socket_timeout": 20.0, + "socket_connect_timeout": 10.0, + "retry_on_timeout": True, + } + assert pool.max_connections == 10 + assert pool.timeout == 13.37 + + def test_invalid_extra_typed_querystring_options(self): + with pytest.raises(ValueError): + redis.BlockingConnectionPool.from_url( + "redis://localhost/2?timeout=_not_a_float_" + ) + + class TestConnectionPoolUnixSocketURLParsing: def test_defaults(self): pool = redis.ConnectionPool.from_url("unix:///socket") @@ -623,7 +647,10 @@ def test_connect_from_url_tcp(self): connection = redis.Redis.from_url("redis://localhost") pool = connection.connection_pool - assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + print(repr(pool)) + assert re.match( + r"< .*?([^\.]+) \( < .*?([^\.]+) \( (.+) \) > \) >", repr(pool), re.VERBOSE + ).groups() == ( "ConnectionPool", "Connection", "host=localhost,port=6379,db=0", @@ -633,7 +660,9 @@ def test_connect_from_url_unix(self): connection = redis.Redis.from_url("unix:///path/to/socket") pool = connection.connection_pool - assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + assert re.match( + r"< .*?([^\.]+) \( < .*?([^\.]+) \( (.+) \) > \) >", repr(pool), re.VERBOSE + ).groups() == ( "ConnectionPool", "UnixDomainSocketConnection", "path=/path/to/socket,db=0", diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index ff588861e4..df46cabc43 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -15,6 +15,8 @@ def __init__(self, addr, redis_addr, delay: float = 0.0): self.send_event = asyncio.Event() self.server = None self.task = None + self.cond = asyncio.Condition() + self.running = 0 async def __aenter__(self): await self.start() @@ -49,24 +51,24 @@ def set_delay(self, delay: float = 0.0): async def handle(self, reader, writer): # establish connection to redis redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) - try: - pipe1 = asyncio.create_task( - self.pipe(reader, redis_writer, "to redis:", self.send_event) - ) - pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:")) - await asyncio.gather(pipe1, pipe2) - finally: - redis_writer.close() + pipe1 = asyncio.create_task( + self.pipe(reader, redis_writer, "to redis:", self.send_event) + ) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:")) + await asyncio.gather(pipe1, pipe2) async def stop(self): - # clean up enough so that we can reuse the looper + # shutdown the server self.task.cancel() try: await self.task except asyncio.CancelledError: pass - loop = self.server.get_loop() - await loop.shutdown_asyncgens() + await self.server.wait_closed() + # Server does not wait for all spawned tasks. We must do that also to ensure + # that all sockets are closed. + async with self.cond: + await self.cond.wait_for(lambda: self.running == 0) async def pipe( self, @@ -75,32 +77,43 @@ async def pipe( name="", event: asyncio.Event = None, ): - while True: - data = await reader.read(1000) - if not data: - break - # print(f"{name} read {len(data)} delay {self.delay}") - if event: - event.set() - await asyncio.sleep(self.delay) - writer.write(data) - await writer.drain() + self.running += 1 + try: + while True: + data = await reader.read(1000) + if not data: + break + # print(f"{name} read {len(data)} delay {self.delay}") + if event: + event.set() + await asyncio.sleep(self.delay) + writer.write(data) + await writer.drain() + finally: + try: + writer.close() + await writer.wait_closed() + except RuntimeError: + # ignore errors on close pertaining to no event loop. Don't want + # to clutter the test output with errors if being garbage collected + pass + async with self.cond: + self.running -= 1 + if self.running == 0: + self.cond.notify_all() @pytest.mark.onlynoncluster @pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) async def test_standalone(delay, master_host): - # create a tcp socket proxy that relays data to Redis and back, # inserting 0.1 seconds of delay async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=master_host) as dp: - for b in [True, False]: # note that we connect to proxy, rather than to Redis directly async with Redis( host="127.0.0.1", port=5380, single_connection_client=b ) as r: - await r.set("foo", "foo") await r.set("bar", "bar") @@ -180,7 +193,6 @@ async def op(pipe): @pytest.mark.onlycluster async def test_cluster(master_host): - delay = 0.1 cluster_port = 16379 remap_base = 7372 @@ -204,8 +216,9 @@ def all_clear(): p.send_event.clear() async def wait_for_send(): - asyncio.wait( - [p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED + await asyncio.wait( + [asyncio.Task(p.send_event.wait()) for p in proxies], + return_when=asyncio.FIRST_COMPLETED, ) @contextlib.contextmanager @@ -219,11 +232,10 @@ def set_delay(delay: float): for p in proxies: await stack.enter_async_context(p) - with contextlib.closing( - RedisCluster.from_url( - f"redis://127.0.0.1:{remap_base}", address_remap=remap - ) - ) as r: + r = RedisCluster.from_url( + f"redis://127.0.0.1:{remap_base}", address_remap=remap + ) + try: await r.initialize() await r.set("foo", "foo") await r.set("bar", "bar") @@ -241,10 +253,12 @@ async def op(r): with pytest.raises(asyncio.CancelledError): await t - # try a number of requests to excercise all the connections + # try a number of requests to exercise all the connections async def doit(): assert await r.get("bar") == b"bar" assert await r.ping() assert await r.get("foo") == b"foo" await asyncio.gather(*[doit() for _ in range(10)]) + finally: + await r.close() diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index ed651cd903..a35bd4795f 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -490,7 +490,6 @@ async def test_json_mget_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_numby_commands_dollar(decoded_r: redis.Redis): - # Test NUMINCRBY await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} @@ -546,7 +545,6 @@ async def test_numby_commands_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_strappend_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) @@ -578,7 +576,6 @@ async def test_strappend_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_strlen_dollar(decoded_r: redis.Redis): - # Test multi await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} @@ -713,7 +710,6 @@ async def test_arrinsert_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_arrlen_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", @@ -802,7 +798,6 @@ async def test_arrpop_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_arrtrim_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", @@ -960,7 +955,6 @@ async def test_type_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_clear_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 75484a2791..c052eae2a0 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -234,7 +234,6 @@ class TestLockClassSelection: def test_lock_class_argument(self, r): class MyLock: def __init__(self, *args, **kwargs): - pass lock = r.lock("foo", lock_class=MyLock) diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index edd2f6d147..4b29360d72 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -2,11 +2,11 @@ import redis from tests.conftest import skip_if_server_version_lt +from .compat import aclosing, mock from .conftest import wait_for_command class TestPipeline: - @pytest.mark.onlynoncluster async def test_pipeline_is_true(self, r): """Ensure pipeline instances are not false-y""" async with r.pipeline() as pipe: @@ -286,6 +286,24 @@ async def test_watch_reset_unwatch(self, r): assert unwatch_command is not None assert unwatch_command["command"] == "UNWATCH" + @pytest.mark.onlynoncluster + async def test_aclose_is_reset(self, r): + async with r.pipeline() as pipe: + called = 0 + + async def mock_reset(): + nonlocal called + called += 1 + + with mock.patch.object(pipe, "reset", mock_reset): + await pipe.aclose() + assert called == 1 + + @pytest.mark.onlynoncluster + async def test_aclosing(self, r): + async with aclosing(r.pipeline()): + pass + @pytest.mark.onlynoncluster async def test_transaction_callable(self, r): await r.set("a", 1) @@ -377,7 +395,6 @@ async def test_pipeline_get(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.0.0") async def test_pipeline_discard(self, r): - # empty pipeline should raise an error async with r.pipeline() as pipe: pipe.set("key", "someval") diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 858576584f..19d4b1c650 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -20,7 +20,7 @@ from redis.utils import HIREDIS_AVAILABLE from tests.conftest import get_protocol_version, skip_if_server_version_lt -from .compat import create_task, mock +from .compat import aclosing, create_task, mock def with_timeout(t): @@ -84,9 +84,8 @@ def make_subscribe_test_data(pubsub, type): @pytest_asyncio.fixture() async def pubsub(r: redis.Redis): - p = r.pubsub() - yield p - await p.close() + async with r.pubsub() as p: + yield p @pytest.mark.onlynoncluster @@ -122,7 +121,6 @@ async def test_pattern_subscribe_unsubscribe(self, pubsub): async def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - for key in keys: assert await sub_func(key) is None @@ -164,7 +162,6 @@ async def test_resubscribe_to_patterns_on_reconnection(self, pubsub): async def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - assert p.subscribed is False await sub_func(keys[0]) # we're now subscribed even though we haven't processed the @@ -217,6 +214,46 @@ async def test_subscribe_property_with_patterns(self, pubsub): kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_subscribed_property(**kwargs) + async def test_aclosing(self, r: redis.Redis): + p = r.pubsub() + async with aclosing(p): + assert p.subscribed is False + await p.subscribe("foo") + assert p.subscribed is True + assert p.subscribed is False + + async def test_context_manager(self, r: redis.Redis): + p = r.pubsub() + async with p: + assert p.subscribed is False + await p.subscribe("foo") + assert p.subscribed is True + assert p.subscribed is False + + async def test_close_is_aclose(self, r: redis.Redis): + """ + Test backwards compatible close method + """ + p = r.pubsub() + assert p.subscribed is False + await p.subscribe("foo") + assert p.subscribed is True + with pytest.deprecated_call(): + await p.close() + assert p.subscribed is False + + async def test_reset_is_aclose(self, r: redis.Redis): + """ + Test backwards compatible reset method + """ + p = r.pubsub() + assert p.subscribed is False + await p.subscribe("foo") + assert p.subscribed is True + with pytest.deprecated_call(): + await p.reset() + assert p.subscribed is False + async def test_ignore_all_subscribe_messages(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) @@ -233,7 +270,7 @@ async def test_ignore_all_subscribe_messages(self, r: redis.Redis): assert p.subscribed is True assert await wait_for_message(p) is None assert p.subscribed is False - await p.close() + await p.aclose() async def test_ignore_individual_subscribe_messages(self, pubsub): p = pubsub @@ -350,7 +387,7 @@ async def test_channel_message_handler(self, r: redis.Redis): assert await r.publish("foo", "test message") == 1 assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") - await p.close() + await p.aclose() async def test_channel_async_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -359,7 +396,7 @@ async def test_channel_async_message_handler(self, r): assert await r.publish("foo", "test message") == 1 assert await wait_for_message(p) is None assert self.async_message == make_message("message", "foo", "test message") - await p.close() + await p.aclose() async def test_channel_sync_async_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -371,7 +408,7 @@ async def test_channel_sync_async_message_handler(self, r): assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") assert self.async_message == make_message("message", "bar", "test message 2") - await p.close() + await p.aclose() @pytest.mark.onlynoncluster async def test_pattern_message_handler(self, r: redis.Redis): @@ -383,7 +420,7 @@ async def test_pattern_message_handler(self, r: redis.Redis): assert self.message == make_message( "pmessage", "foo", "test message", pattern="f*" ) - await p.close() + await p.aclose() async def test_unicode_channel_message_handler(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) @@ -394,7 +431,7 @@ async def test_unicode_channel_message_handler(self, r: redis.Redis): assert await r.publish(channel, "test message") == 1 assert await wait_for_message(p) is None assert self.message == make_message("message", channel, "test message") - await p.close() + await p.aclose() @pytest.mark.onlynoncluster # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html @@ -410,7 +447,7 @@ async def test_unicode_pattern_message_handler(self, r: redis.Redis): assert self.message == make_message( "pmessage", channel, "test message", pattern=pattern ) - await p.close() + await p.aclose() async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): p = pubsub @@ -524,7 +561,7 @@ async def test_channel_message_handler(self, r: redis.Redis): await r.publish(self.channel, new_data) assert await wait_for_message(p) is None assert self.message == self.make_message("message", self.channel, new_data) - await p.close() + await p.aclose() async def test_pattern_message_handler(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) @@ -546,7 +583,7 @@ async def test_pattern_message_handler(self, r: redis.Redis): assert self.message == self.make_message( "pmessage", self.channel, new_data, pattern=self.pattern ) - await p.close() + await p.aclose() async def test_context_manager(self, r: redis.Redis): async with r.pubsub() as pubsub: @@ -556,7 +593,7 @@ async def test_context_manager(self, r: redis.Redis): assert pubsub.connection is None assert pubsub.channels == {} assert pubsub.patterns == {} - await pubsub.close() + await pubsub.aclose() @pytest.mark.onlynoncluster @@ -597,9 +634,9 @@ async def test_pubsub_numsub(self, r: redis.Redis): channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] assert await r.pubsub_numsub("foo", "bar", "baz") == channels - await p1.close() - await p2.close() - await p3.close() + await p1.aclose() + await p2.aclose() + await p3.aclose() @skip_if_server_version_lt("2.8.0") async def test_pubsub_numpat(self, r: redis.Redis): @@ -608,7 +645,7 @@ async def test_pubsub_numpat(self, r: redis.Redis): for i in range(3): assert (await wait_for_message(p))["type"] == "psubscribe" assert await r.pubsub_numpat() == 3 - await p.close() + await p.aclose() @pytest.mark.onlynoncluster @@ -621,7 +658,7 @@ async def test_send_pubsub_ping(self, r: redis.Redis): assert await wait_for_message(p) == make_message( type="pong", channel=None, data="", pattern=None ) - await p.close() + await p.aclose() @skip_if_server_version_lt("3.0.0") async def test_send_pubsub_ping_message(self, r: redis.Redis): @@ -631,7 +668,7 @@ async def test_send_pubsub_ping_message(self, r: redis.Redis): assert await wait_for_message(p) == make_message( type="pong", channel=None, data="hello world", pattern=None ) - await p.close() + await p.aclose() @pytest.mark.onlynoncluster diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py index 2912ca786c..8bc71c1479 100644 --- a/tests/test_asyncio/test_retry.py +++ b/tests/test_asyncio/test_retry.py @@ -131,5 +131,8 @@ async def test_get_set_retry_object(self, request): assert r.get_retry()._retries == new_retry_policy._retries assert isinstance(r.get_retry()._backoff, ExponentialBackoff) assert exiting_conn.retry._retries == new_retry_policy._retries + await r.connection_pool.release(exiting_conn) new_conn = await r.connection_pool.get_connection("_") assert new_conn.retry._retries == new_retry_policy._retries + await r.connection_pool.release(new_conn) + await r.aclose() diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index e46de39c70..1f1931e28a 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -77,7 +77,6 @@ async def createIndex(decoded_r, num_docs=100, definition=None): r = csv.reader(bzfp, delimiter=";") for n, line in enumerate(r): - play, chapter, _, text = line[1], line[2], line[4], line[5] key = f"{play}:{chapter}".lower() @@ -163,10 +162,8 @@ async def test_client(decoded_r: redis.Redis): ) ).total both_total = ( - await ( - decoded_r.ft().search( - Query("henry").no_content().limit_fields("play", "txt") - ) + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play", "txt") ) ).total assert 129 == txt_total @@ -370,18 +367,14 @@ async def test_stopwords(decoded_r: redis.Redis): @pytest.mark.redismod async def test_filters(decoded_r: redis.Redis): - await ( - decoded_r.ft().create_index( - (TextField("txt"), NumericField("num"), GeoField("loc")) - ) + await decoded_r.ft().create_index( + (TextField("txt"), NumericField("num"), GeoField("loc")) ) - await ( - decoded_r.hset( - "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} - ) + await decoded_r.hset( + "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} ) - await ( - decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) + await decoded_r.hset( + "doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"} ) await waitForIndex(decoded_r, "idx") @@ -432,10 +425,8 @@ async def test_filters(decoded_r: redis.Redis): @pytest.mark.redismod async def test_sort_by(decoded_r: redis.Redis): - await ( - decoded_r.ft().create_index( - (TextField("txt"), NumericField("num", sortable=True)) - ) + await decoded_r.ft().create_index( + (TextField("txt"), NumericField("num", sortable=True)) ) await decoded_r.hset("doc1", mapping={"txt": "foo bar", "num": 1}) await decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2}) @@ -488,8 +479,8 @@ async def test_drop_index(decoded_r: redis.Redis): @pytest.mark.redismod async def test_example(decoded_r: redis.Redis): # Creating the index definition and schema - await ( - decoded_r.ft().create_index((TextField("title", weight=5.0), TextField("body"))) + await decoded_r.ft().create_index( + (TextField("title", weight=5.0), TextField("body")) ) # Indexing a document @@ -550,8 +541,8 @@ async def test_auto_complete(decoded_r: redis.Redis): await decoded_r.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) await decoded_r.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) - sugs = await ( - decoded_r.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) + sugs = await decoded_r.ft().sugget( + "ac", "pay", with_payloads=True, with_scores=True ) assert 3 == len(sugs) for sug in sugs: @@ -639,8 +630,8 @@ async def test_no_index(decoded_r: redis.Redis): @pytest.mark.redismod async def test_explain(decoded_r: redis.Redis): - await ( - decoded_r.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) + await decoded_r.ft().create_index( + (TextField("f1"), TextField("f2"), TextField("f3")) ) res = await decoded_r.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") assert res @@ -903,10 +894,8 @@ async def test_alter_schema_add(decoded_r: redis.Redis): async def test_spell_check(decoded_r: redis.Redis): await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - await ( - decoded_r.hset( - "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} - ) + await decoded_r.hset( + "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} ) await decoded_r.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) await waitForIndex(decoded_r, "idx") @@ -1042,8 +1031,8 @@ async def test_scorer(decoded_r: redis.Redis): assert 1.0 == res.docs[0].score res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) assert 1.0 == res.docs[0].score - res = await ( - decoded_r.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() ) assert 0.14285714285714285 == res.docs[0].score res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) @@ -1514,7 +1503,7 @@ async def test_withsuffixtrie(decoded_r: redis.Redis): assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] assert await decoded_r.ft().dropindex("idx") - # create withsuffixtrie index (text fiels) + # create withsuffixtrie index (text fields) assert await decoded_r.ft().create_index((TextField("t", withsuffixtrie=True))) waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) info = await decoded_r.ft().info() diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index a2d52f17b7..51e59d69d0 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -72,7 +72,6 @@ def client(self, host, port, **kwargs): @pytest_asyncio.fixture() async def cluster(master_ip): - cluster = SentinelTestCluster(ip=master_ip) saved_Redis = redis.asyncio.sentinel.Redis redis.asyncio.sentinel.Redis = cluster.client @@ -184,13 +183,13 @@ async def test_discover_slaves(cluster, sentinel): @pytest.mark.onlynoncluster async def test_master_for(cluster, sentinel, master_ip): - master = sentinel.master_for("mymaster", db=9) - assert await master.ping() - assert master.connection_pool.master_address == (master_ip, 6379) + async with sentinel.master_for("mymaster", db=9) as master: + assert await master.ping() + assert master.connection_pool.master_address == (master_ip, 6379) # Use internal connection check - master = sentinel.master_for("mymaster", db=9, check_connection=True) - assert await master.ping() + async with sentinel.master_for("mymaster", db=9, check_connection=True) as master: + assert await master.ping() @pytest.mark.onlynoncluster @@ -198,16 +197,16 @@ async def test_slave_for(cluster, sentinel): cluster.slaves = [ {"ip": "127.0.0.1", "port": 6379, "is_odown": False, "is_sdown": False} ] - slave = sentinel.slave_for("mymaster", db=9) - assert await slave.ping() + async with sentinel.slave_for("mymaster", db=9) as slave: + assert await slave.ping() @pytest.mark.onlynoncluster async def test_slave_for_slave_not_found_error(cluster, sentinel): cluster.master["is_odown"] = True - slave = sentinel.slave_for("mymaster", db=9) - with pytest.raises(SlaveNotFoundError): - await slave.ping() + async with sentinel.slave_for("mymaster", db=9) as slave: + with pytest.raises(SlaveNotFoundError): + await slave.ping() @pytest.mark.onlynoncluster @@ -261,7 +260,7 @@ async def mock_disconnect(): calls += 1 with mock.patch.object(pool, "disconnect", mock_disconnect): - await client.close() + await client.aclose() assert calls == 1 await pool.disconnect() diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index e784690c77..cae4b9581f 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -10,11 +10,11 @@ pytestmark = pytest.mark.asyncio -async def test_connect_retry_on_timeout_error(): +async def test_connect_retry_on_timeout_error(connect_args): """Test that the _connect function is retried in case of a timeout""" connection_pool = mock.AsyncMock() connection_pool.get_master_address = mock.AsyncMock( - return_value=("localhost", 6379) + return_value=(connect_args["host"], connect_args["port"]) ) conn = SentinelManagedConnection( retry_on_timeout=True, @@ -34,3 +34,4 @@ async def mock_connect(): conn._connect.side_effect = mock_connect await conn.connect() assert conn._connect.call_count == 3 + await conn.disconnect() diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index 48ffdfd889..91c15c3db2 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -108,7 +108,6 @@ async def test_add(decoded_r: redis.Redis): @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") async def test_add_duplicate_policy(r: redis.Redis): - # Test for duplicate policy BLOCK assert 1 == await r.ts().add("time-serie-add-ooo-block", 1, 5.0) with pytest.raises(Exception): diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000000..45621fe77e --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,119 @@ +import time + +import pytest +import redis +from redis.utils import HIREDIS_AVAILABLE + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_get_from_cache(): + r = redis.Redis(cache_enable=True, single_connection_client=True, protocol=3) + r2 = redis.Redis(protocol=3) + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == b"barbar" + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_max_size(): + r = redis.Redis( + cache_enable=True, cache_max_size=3, single_connection_client=True, protocol=3 + ) + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) == b"bar2" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_ttl(): + r = redis.Redis( + cache_enable=True, cache_ttl=1, single_connection_client=True, protocol=3 + ) + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_lfu_eviction(): + r = redis.Redis( + cache_enable=True, + cache_max_size=3, + cache_eviction_policy="lfu", + single_connection_client=True, + protocol=3, + ) + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # test the eviction policy + assert len(r.client_cache.cache) == 3 + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) is None + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_decode_response(): + r = redis.Redis( + decode_responses=True, + cache_enable=True, + single_connection_client=True, + protocol=3, + ) + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == "bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == "barbar" diff --git a/tests/test_commands.py b/tests/test_commands.py index 114fb6b686..6660c2c6b0 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -516,7 +516,6 @@ def test_client_trackinginfo(self, r): @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() def test_client_tracking(self, r, r2): - # simple case assert r.client_tracking_on() assert r.client_tracking_off() @@ -4909,7 +4908,6 @@ def test_latency_latest(self, r: redis.Redis): def test_latency_reset(self, r: redis.Redis): assert r.latency_reset() == 0 - @pytest.mark.onlynoncluster @skip_if_server_version_lt("4.0.0") @skip_if_redis_enterprise() def test_module_list(self, r): @@ -5012,7 +5010,6 @@ def test_module_loadex(self, r: redis.Redis): @skip_if_server_version_lt("2.6.0") def test_restore(self, r): - # standard restore key = "foo" r.set(key, "bar") diff --git a/tests/test_connect.py b/tests/test_connect.py index f07750dc80..696e69ceea 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -115,7 +115,7 @@ def get_request(self): return connstream, fromaddr -if hasattr(socket, "UnixStreamServer"): +if hasattr(socketserver, "UnixStreamServer"): class _RedisUDSServer(socketserver.UnixStreamServer): def __init__(self, *args, **kw) -> None: diff --git a/tests/test_connection.py b/tests/test_connection.py index 760b23c9c1..bff249559e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,9 +5,15 @@ import pytest import redis +from redis import ConnectionPool, Redis from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff -from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection +from redis.connection import ( + Connection, + SSLConnection, + UnixDomainSocketConnection, + parse_url, +) from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -209,3 +215,84 @@ def test_create_single_connection_client_from_url(): "redis://localhost:6379/0?", single_connection_client=True ) assert client.connection is not None + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +def test_pool_auto_close(request, from_url): + """Verify that basic Redis instances have auto_close_connection_pool set to True""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + def get_redis_connection(): + if from_url: + return Redis.from_url(url) + return Redis(**url_args) + + r1 = get_redis_connection() + assert r1.auto_close_connection_pool is True + r1.close() + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +def test_redis_connection_pool(request, from_url): + """Verify that basic Redis instances using `connection_pool` + have auto_close_connection_pool set to False""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + pool = None + + def get_redis_connection(): + nonlocal pool + if from_url: + pool = ConnectionPool.from_url(url) + else: + pool = ConnectionPool(**url_args) + return Redis(connection_pool=pool) + + called = 0 + + def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + with get_redis_connection() as r1: + assert r1.auto_close_connection_pool is False + + assert called == 0 + pool.disconnect() + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +def test_redis_from_pool(request, from_url): + """Verify that basic Redis instances created using `from_pool()` + have auto_close_connection_pool set to True""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + pool = None + + def get_redis_connection(): + nonlocal pool + if from_url: + pool = ConnectionPool.from_url(url) + else: + pool = ConnectionPool(**url_args) + return Redis.from_pool(pool) + + called = 0 + + def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + with get_redis_connection() as r1: + assert r1.auto_close_connection_pool is True + + assert called == 1 + pool.disconnect() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index ab0fc6be98..dee7c554d3 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,6 +1,7 @@ import os import re import time +from contextlib import closing from threading import Thread from unittest import mock @@ -51,6 +52,16 @@ def test_connection_creation(self): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs + def test_closing(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = redis.ConnectionPool( + connection_class=DummyConnection, + max_connections=None, + **connection_kwargs, + ) + with closing(pool): + pass + def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) @@ -84,11 +95,8 @@ def test_repr_contains_db_info_tcp(self): pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=redis.Connection ) - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + expected = "host=localhost,port=6379,db=1,client_name=test-client" + assert expected in repr(pool) def test_repr_contains_db_info_unix(self): connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"} @@ -96,11 +104,8 @@ def test_repr_contains_db_info_unix(self): connection_kwargs=connection_kwargs, connection_class=redis.UnixDomainSocketConnection, ) - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + expected = "path=/abc,db=1,client_name=test-client" + assert expected in repr(pool) class TestBlockingConnectionPool: @@ -179,11 +184,8 @@ def test_repr_contains_db_info_tcp(self): pool = redis.ConnectionPool( host="localhost", port=6379, client_name="test-client" ) - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + expected = "host=localhost,port=6379,db=0,client_name=test-client" + assert expected in repr(pool) def test_repr_contains_db_info_unix(self): pool = redis.ConnectionPool( @@ -191,11 +193,8 @@ def test_repr_contains_db_info_unix(self): path="abc", client_name="test-client", ) - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + expected = "path=abc,db=0,client_name=test-client" + assert expected in repr(pool) class TestConnectionPoolURLParsing: @@ -348,6 +347,31 @@ def test_invalid_scheme_raises_error_when_double_slash_missing(self): ) +class TestBlockingConnectionPoolURLParsing: + def test_extra_typed_querystring_options(self): + pool = redis.BlockingConnectionPool.from_url( + "redis://localhost/2?socket_timeout=20&socket_connect_timeout=10" + "&socket_keepalive=&retry_on_timeout=Yes&max_connections=10&timeout=42" + ) + + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + "socket_timeout": 20.0, + "socket_connect_timeout": 10.0, + "retry_on_timeout": True, + } + assert pool.max_connections == 10 + assert pool.timeout == 42.0 + + def test_invalid_extra_typed_querystring_options(self): + with pytest.raises(ValueError): + redis.BlockingConnectionPool.from_url( + "redis://localhost/2?timeout=_not_a_float_" + ) + + class TestConnectionPoolUnixSocketURLParsing: def test_defaults(self): pool = redis.ConnectionPool.from_url("unix:///socket") @@ -543,7 +567,9 @@ def test_connect_from_url_tcp(self): connection = redis.Redis.from_url("redis://localhost") pool = connection.connection_pool - assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + assert re.match( + r"< .*?([^\.]+) \( < .*?([^\.]+) \( (.+) \) > \) >", repr(pool), re.VERBOSE + ).groups() == ( "ConnectionPool", "Connection", "host=localhost,port=6379,db=0", @@ -553,7 +579,9 @@ def test_connect_from_url_unix(self): connection = redis.Redis.from_url("unix:///path/to/socket") pool = connection.connection_pool - assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + assert re.match( + r"< .*?([^\.]+) \( < .*?([^\.]+) \( (.+) \) > \) >", repr(pool), re.VERBOSE + ).groups() == ( "ConnectionPool", "UnixDomainSocketConnection", "path=/path/to/socket,db=0", diff --git a/tests/test_graph.py b/tests/test_graph.py index 6fa9977d98..c6d128908e 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -404,7 +404,7 @@ def test_cache_sync(client): # Client B will try to get Client A out of sync by: # 1. deleting the graph # 2. reconstructing the graph in a different order, this will casuse - # a differance in the current mapping between string IDs and the + # a difference in the current mapping between string IDs and the # mapping Client A is aware of # # Client A should pick up on the changes by comparing graph versions diff --git a/tests/test_graph_utils/test_edge.py b/tests/test_graph_utils/test_edge.py index 581ebfab5d..1918a6ff44 100644 --- a/tests/test_graph_utils/test_edge.py +++ b/tests/test_graph_utils/test_edge.py @@ -4,7 +4,6 @@ @pytest.mark.redismod def test_init(): - with pytest.raises(AssertionError): edge.Edge(None, None, None) edge.Edge(node.Node(), None, None) @@ -62,7 +61,7 @@ def test_stringify(): @pytest.mark.redismod -def test_comparision(): +def test_comparison(): node1 = node.Node(node_id=1) node2 = node.Node(node_id=2) node3 = node.Node(node_id=3) diff --git a/tests/test_graph_utils/test_node.py b/tests/test_graph_utils/test_node.py index c3b34ac6ff..22e6d59414 100644 --- a/tests/test_graph_utils/test_node.py +++ b/tests/test_graph_utils/test_node.py @@ -33,7 +33,7 @@ def test_stringify(fixture): @pytest.mark.redismod -def test_comparision(fixture): +def test_comparison(fixture): no_args, no_props, props_only, no_label, multi_label = fixture assert node.Node() == node.Node() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 57a94d2f45..66ee1c5390 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -41,6 +41,7 @@ def test_parse_to_dict(): "Child iterators", ["Type", "bar", "Time", "0.0729", "Counter", 3], ["Type", "barbar", "Time", "0.058", "Counter", 3], + ["Type", "barbarbar", "Time", "0.0234", "Counter", 3], ], ], ] @@ -49,6 +50,7 @@ def test_parse_to_dict(): "Child iterators": [ {"Counter": 3.0, "Time": 0.0729, "Type": "bar"}, {"Counter": 3.0, "Time": 0.058, "Type": "barbar"}, + {"Counter": 3.0, "Time": 0.0234, "Type": "barbarbar"}, ], "Counter": 3.0, "Time": 0.2089, diff --git a/tests/test_json.py b/tests/test_json.py index be347f6677..73d72b8cc9 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -464,7 +464,6 @@ def test_json_mget_dollar(client): def test_numby_commands_dollar(client): - # Test NUMINCRBY client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) # Test multi @@ -508,7 +507,6 @@ def test_numby_commands_dollar(client): def test_strappend_dollar(client): - client.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) @@ -539,7 +537,6 @@ def test_strappend_dollar(client): def test_strlen_dollar(client): - # Test multi client.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} @@ -672,7 +669,6 @@ def test_arrinsert_dollar(client): def test_arrlen_dollar(client): - client.json().set( "doc1", "$", @@ -762,7 +758,6 @@ def test_arrpop_dollar(client): def test_arrtrim_dollar(client): - client.json().set( "doc1", "$", @@ -1015,7 +1010,6 @@ def test_toggle_dollar(client): def test_resp_dollar(client): - data = { "L1": { "a": { @@ -1244,7 +1238,6 @@ def test_resp_dollar(client): def test_arrindex_dollar(client): - client.json().set( "store", "$", diff --git a/tests/test_lock.py b/tests/test_lock.py index b34f7f0159..72af87fa81 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -247,7 +247,6 @@ class TestLockClassSelection: def test_lock_class_argument(self, r): class MyLock: def __init__(self, *args, **kwargs): - pass lock = r.lock("foo", lock_class=MyLock) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7b048eec01..7f10fcad4f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,3 +1,6 @@ +from contextlib import closing +from unittest import mock + import pytest import redis @@ -284,6 +287,24 @@ def test_watch_reset_unwatch(self, r): assert unwatch_command is not None assert unwatch_command["command"] == "UNWATCH" + @pytest.mark.onlynoncluster + def test_close_is_reset(self, r): + with r.pipeline() as pipe: + called = 0 + + def mock_reset(): + nonlocal called + called += 1 + + with mock.patch.object(pipe, "reset", mock_reset): + pipe.close() + assert called == 1 + + @pytest.mark.onlynoncluster + def test_closing(self, r): + with closing(r.pipeline()): + pass + @pytest.mark.onlynoncluster def test_transaction_callable(self, r): r["a"] = 1 @@ -369,7 +390,6 @@ def test_pipeline_with_bitfield(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.0.0") def test_pipeline_discard(self, r): - # empty pipeline should raise an error with r.pipeline() as pipe: pipe.set("key", "someval") diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index ba097e3194..fb46772af3 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -152,7 +152,6 @@ def test_shard_channel_subscribe_unsubscribe_cluster(self, r): def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - for key in keys: assert sub_func(key) is None @@ -201,7 +200,6 @@ def test_resubscribe_to_shard_channels_on_reconnection(self, r): def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - assert p.subscribed is False sub_func(keys[0]) # we're now subscribed even though we haven't processed the diff --git a/tests/test_search.py b/tests/test_search.py index f3f9619d92..bfe204254c 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -13,6 +13,7 @@ from redis.commands.search import Search from redis.commands.search.field import ( GeoField, + GeoShapeField, NumericField, TagField, TextField, @@ -86,7 +87,6 @@ def createIndex(client, num_docs=100, definition=None): r = csv.reader(bzfp, delimiter=";") for n, line in enumerate(r): - play, chapter, _, text = line[1], line[2], line[4], line[5] key = f"{play}:{chapter}".lower() @@ -820,7 +820,6 @@ def test_spell_check(client): waitForIndex(client, getattr(client.ft(), "index_name", "idx")) if is_resp2_connection(client): - # test spellcheck res = client.ft().spellcheck("impornant") assert "important" == res["impornant"][0]["suggestion"] @@ -2100,7 +2099,6 @@ def test_numeric_params(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") def test_geo_params(client): - client.ft().create_index((GeoField("g"))) client.hset("doc1", mapping={"g": "29.69465, 34.95126"}) client.hset("doc2", mapping={"g": "29.69350, 34.94737"}) @@ -2229,7 +2227,7 @@ def test_withsuffixtrie(client: redis.Redis): assert "WITHSUFFIXTRIE" not in info["attributes"][0] assert client.ft().dropindex("idx") - # create withsuffixtrie index (text fiels) + # create withsuffixtrie index (text fields) assert client.ft().create_index((TextField("t", withsuffixtrie=True))) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) info = client.ft().info() @@ -2246,7 +2244,7 @@ def test_withsuffixtrie(client: redis.Redis): assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] assert client.ft().dropindex("idx") - # create withsuffixtrie index (text fiels) + # create withsuffixtrie index (text fields) assert client.ft().create_index((TextField("t", withsuffixtrie=True))) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) info = client.ft().info() @@ -2264,6 +2262,25 @@ def test_withsuffixtrie(client: redis.Redis): def test_query_timeout(r: redis.Redis): q1 = Query("foo").timeout(5000) assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10] + q1 = Query("foo").timeout(0) + assert q1.get_args() == ["foo", "TIMEOUT", 0, "LIMIT", 0, 10] q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): r.ft().search(q2) + + +@pytest.mark.redismod +def test_geoshape(client: redis.Redis): + client.ft().create_index((GeoShapeField("geom", GeoShapeField.FLAT))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + client.hset("small", "geom", "POLYGON((1 1, 1 100, 100 100, 100 1, 1 1))") + client.hset("large", "geom", "POLYGON((1 1, 1 200, 200 200, 200 1, 1 1))") + q1 = Query("@geom:[WITHIN $poly]").dialect(3) + qp1 = {"poly": "POLYGON((0 0, 0 150, 150 150, 150 0, 0 0))"} + q2 = Query("@geom:[CONTAINS $poly]").dialect(3) + qp2 = {"poly": "POLYGON((2 2, 2 50, 50 50, 50 2, 2 2))"} + result = client.ft().search(q1, query_params=qp1) + assert len(result.docs) == 1 + assert result.docs[0]["id"] == "small" + result = client.ft().search(q2, query_params=qp2) + assert len(result.docs) == 2 diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index b7bcc27de2..54b9647098 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -1,4 +1,5 @@ import socket +from unittest import mock import pytest import redis.sentinel @@ -240,3 +241,28 @@ def test_flushconfig(cluster, sentinel): def test_reset(cluster, sentinel): cluster.master["is_odown"] = True assert sentinel.sentinel_reset("mymaster") + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("method_name", ["master_for", "slave_for"]) +def test_auto_close_pool(cluster, sentinel, method_name): + """ + Check that the connection pool created by the sentinel client is + automatically closed + """ + + method = getattr(sentinel, method_name) + client = method("mymaster", db=9) + pool = client.connection_pool + assert client.auto_close_connection_pool is True + calls = 0 + + def mock_disconnect(): + nonlocal calls + calls += 1 + + with mock.patch.object(pool, "disconnect", mock_disconnect): + client.close() + + assert calls == 1 + pool.disconnect() diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 4ab86cd56e..6b59967f3c 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -104,7 +104,6 @@ def test_add(client): @skip_ifmodversion_lt("1.4.0", "timeseries") def test_add_duplicate_policy(client): - # Test for duplicate policy BLOCK assert 1 == client.ts().add("time-serie-add-ooo-block", 1, 5.0) with pytest.raises(Exception):