From ee2c1d3fddc32cca75353a31d751c2e0f11734b6 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 6 Jan 2025 20:47:40 +0000 Subject: [PATCH 1/3] DOC: clarify `at()` patterns/antipatterns --- src/array_api_extra/_funcs.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 80ad3c2d..24101e75 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -649,22 +649,39 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 Warnings -------- - (a) When you omit the ``copy`` parameter, you should always immediately overwrite - the parameter array:: + (a) When you omit the ``copy`` parameter, you should never reuse the parameter + array later on; ideally, you should dereference it immediately:: >>> import array_api_extra as xpx >>> x = xpx.at(x, 0).set(2) - The anti-pattern below must be avoided, as it will result in different - behaviour on read-only versus writeable arrays:: + The above best practice pattern ensures that the behaviour won't change depending + on whether ``x`` is writeable or not, as the original ``x`` object is dereferenced + as soon as ``xpx.at`` returns; this way there is no risk to accidentally update it + twice. + + On the reverse, the anti-pattern below must be avoided, as it will result in + different behaviour on read-only versus writeable arrays:: >>> x = xp.asarray([0, 0, 0]) >>> y = xpx.at(x, 0).set(2) >>> z = xpx.at(x, 1).set(3) - In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]`` - when ``x`` is read-only, whereas ``x == y == z == [2, 3, 0]`` when ``x`` is - writeable! + In the above example, both calls to ``xpx.at`` update ``x`` in place *if possible*. + This causes the behaviour to diverge depending on whether ``x`` is writeable or not: + + - If ``x`` is writeable, then after the snippet above you'll have + ``x == y == z == [2, 3, 0]`` + - If ``x`` is read-only, then you'll end up with + ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``. + + The correct pattern to use if you want diverging outputs from the same input is + to enforce copies:: + + >>> x = xp.asarray([0, 0, 0]) + >>> y = xpx.at(x, 0).set(2, copy=True) # Never updates x + >>> z = xpx.at(x, 1).set(3) # May or may not update x in place + >>> del x # avoid accidental reuse of x as we don't know its state anymore (b) The array API standard does not support integer array indices. The behaviour of update methods when the index is an array of integers is From ad74076a3306b86609f7d1be877b331ffb2c9384 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 7 Jan 2025 00:47:34 +0000 Subject: [PATCH 2/3] Update src/array_api_extra/_funcs.py --- src/array_api_extra/_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 24101e75..0860903e 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -650,7 +650,7 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 Warnings -------- (a) When you omit the ``copy`` parameter, you should never reuse the parameter - array later on; ideally, you should dereference it immediately:: + array later on; ideally, you should reassign it immediately:: >>> import array_api_extra as xpx >>> x = xpx.at(x, 0).set(2) From 21fd3a0cac895dda60d83e980ebc20059ab02d02 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 7 Jan 2025 00:47:43 +0000 Subject: [PATCH 3/3] Update src/array_api_extra/_funcs.py Co-authored-by: Lucas Colley --- src/array_api_extra/_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 0860903e..db0a1af1 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -673,7 +673,7 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 - If ``x`` is writeable, then after the snippet above you'll have ``x == y == z == [2, 3, 0]`` - If ``x`` is read-only, then you'll end up with - ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``. + ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and ``z == [0, 3, 0]``. The correct pattern to use if you want diverging outputs from the same input is to enforce copies::