diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 4062c56f..bb9c9cab 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -133,7 +133,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) - avg = _mean(m, axis=1, xp=xp) + avg = _utils.mean(m, axis=1, xp=xp) fact = m.shape[1] - 1 if fact <= 0: @@ -199,26 +199,6 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array: return xp.reshape(diag, (n, n)) -def _mean( - x: Array, - /, - *, - axis: int | tuple[int, ...] | None = None, - keepdims: bool = False, - xp: ModuleType, -) -> Array: - """ - Complex mean, https://github.com/data-apis/array-api/issues/846. - """ - if xp.isdtype(x.dtype, "complex floating"): - x_real = xp.real(x) - x_imag = xp.imag(x) - mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims) - mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims) - return mean_real + (mean_imag * xp.asarray(1j)) - return xp.mean(x, axis=axis, keepdims=keepdims) - - def expand_dims( a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType ) -> Array: diff --git a/src/array_api_extra/_lib/_utils.py b/src/array_api_extra/_lib/_utils.py index bf65340e..ddc5778d 100644 --- a/src/array_api_extra/_lib/_utils.py +++ b/src/array_api_extra/_lib/_utils.py @@ -7,7 +7,7 @@ from . import _compat -__all__ = ["in1d"] +__all__ = ["in1d", "mean"] def in1d( @@ -63,3 +63,23 @@ def in1d( if assume_unique: return ret[: x1.shape[0]] return xp.take(ret, rev_idx, axis=0) + + +def mean( + x: Array, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + xp: ModuleType, +) -> Array: + """ + Complex mean, https://github.com/data-apis/array-api/issues/846. + """ + if xp.isdtype(x.dtype, "complex floating"): + x_real = xp.real(x) + x_imag = xp.imag(x) + mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims) + mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims) + return mean_real + (mean_imag * xp.asarray(1j)) + return xp.mean(x, axis=axis, keepdims=keepdims) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 36411958..4599740d 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -157,6 +157,55 @@ def test_2d(self): create_diagonal(xp.asarray([[1]]), xp=xp) +class TestExpandDims: + def test_functionality(self): + def _squeeze_all(b: Array) -> Array: + """Mimics `np.squeeze(b)`. `xpx.squeeze`?""" + for axis in range(b.ndim): + with contextlib.suppress(ValueError): + b = xp.squeeze(b, axis=axis) + return b + + s = (2, 3, 4, 5) + a = xp.empty(s) + for axis in range(-5, 4): + b = expand_dims(a, axis=axis, xp=xp) + assert b.shape[axis] == 1 + assert _squeeze_all(b).shape == s + + def test_axis_tuple(self): + a = xp.empty((3, 3, 3)) + assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3) + assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1) + assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1) + assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3) + + def test_axis_out_of_range(self): + s = (2, 3, 4, 5) + a = xp.empty(s) + with pytest.raises(IndexError, match="out of bounds"): + expand_dims(a, axis=-6, xp=xp) + with pytest.raises(IndexError, match="out of bounds"): + expand_dims(a, axis=5, xp=xp) + + a = xp.empty((3, 3, 3)) + with pytest.raises(IndexError, match="out of bounds"): + expand_dims(a, axis=(0, -6), xp=xp) + with pytest.raises(IndexError, match="out of bounds"): + expand_dims(a, axis=(0, 5), xp=xp) + + def test_repeated_axis(self): + a = xp.empty((3, 3, 3)) + with pytest.raises(ValueError, match="Duplicate dimensions"): + expand_dims(a, axis=(1, 1), xp=xp) + + def test_positive_negative_repeated(self): + # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817 + a = xp.empty((2, 3, 4, 5)) + with pytest.raises(ValueError, match="Duplicate dimensions"): + expand_dims(a, axis=(3, -3), xp=xp) + + class TestKron: def test_basic(self): # Using 0-dimensional array @@ -222,55 +271,6 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]): assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron") -class TestExpandDims: - def test_functionality(self): - def _squeeze_all(b: Array) -> Array: - """Mimics `np.squeeze(b)`. `xpx.squeeze`?""" - for axis in range(b.ndim): - with contextlib.suppress(ValueError): - b = xp.squeeze(b, axis=axis) - return b - - s = (2, 3, 4, 5) - a = xp.empty(s) - for axis in range(-5, 4): - b = expand_dims(a, axis=axis, xp=xp) - assert b.shape[axis] == 1 - assert _squeeze_all(b).shape == s - - def test_axis_tuple(self): - a = xp.empty((3, 3, 3)) - assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3) - assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1) - assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1) - assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3) - - def test_axis_out_of_range(self): - s = (2, 3, 4, 5) - a = xp.empty(s) - with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=-6, xp=xp) - with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=5, xp=xp) - - a = xp.empty((3, 3, 3)) - with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=(0, -6), xp=xp) - with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=(0, 5), xp=xp) - - def test_repeated_axis(self): - a = xp.empty((3, 3, 3)) - with pytest.raises(ValueError, match="Duplicate dimensions"): - expand_dims(a, axis=(1, 1), xp=xp) - - def test_positive_negative_repeated(self): - # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817 - a = xp.empty((2, 3, 4, 5)) - with pytest.raises(ValueError, match="Duplicate dimensions"): - expand_dims(a, axis=(3, -3), xp=xp) - - class TestSetDiff1D: def test_setdiff1d(self): x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4])