diff --git a/dpctl/tensor/_slicing.pxi b/dpctl/tensor/_slicing.pxi index 10b5c58395..361dd906c3 100644 --- a/dpctl/tensor/_slicing.pxi +++ b/dpctl/tensor/_slicing.pxi @@ -1,6 +1,6 @@ # Data Parallel Control (dpctl) # -# Copyright 2020-2022 Intel Corporation +# Copyright 2020-2023 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,11 @@ # limitations under the License. import numbers +from cpython.buffer cimport PyObject_CheckBuffer + + +cdef bint _is_buffer(object o): + return PyObject_CheckBuffer(o) cdef Py_ssize_t _slice_len( @@ -36,14 +41,23 @@ cdef Py_ssize_t _slice_len( cdef bint _is_integral(object x) except *: """Gives True if x is an integral slice spec""" - if isinstance(x, (int, numbers.Integral)): - return True if isinstance(x, usm_ndarray): if x.ndim > 0: return False if x.dtype.kind not in "ui": return False return True + if isinstance(x, bool): + return False + if isinstance(x, int): + return True + if _is_buffer(x): + mbuf = memoryview(x) + if mbuf.ndim == 0: + f = mbuf.format + return f in "bBhHiIlLqQ" + else: + return False if callable(getattr(x, "__index__", None)): try: x.__index__() @@ -53,6 +67,34 @@ cdef bint _is_integral(object x) except *: return False +cdef bint _is_boolean(object x) except *: + """Gives True if x is an integral slice spec""" + if isinstance(x, usm_ndarray): + if x.ndim > 0: + return False + if x.dtype.kind not in "b": + return False + return True + if isinstance(x, bool): + return True + if isinstance(x, int): + return False + if _is_buffer(x): + mbuf = memoryview(x) + if mbuf.ndim == 0: + f = mbuf.format + return f in "?" + else: + return False + if callable(getattr(x, "__bool__", None)): + try: + x.__bool__() + except (TypeError, ValueError): + return False + return True + return False + + def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): """ Give basic slicing index `ind` and array layout information produce @@ -82,6 +124,11 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): _no_advanced_ind, _no_advanced_pos ) + elif _is_boolean(ind): + if ind: + return ((1,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos) + else: + return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos) elif _is_integral(ind): ind = ind.__index__() if 0 <= ind < shape[0]: @@ -117,6 +164,10 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): axes_referenced += 1 if array_streak_started: array_streak_interrupted = True + elif _is_boolean(i): + newaxis_count += 1 + if array_streak_started: + array_streak_interrupted = True elif _is_integral(i): explicit_index += 1 axes_referenced += 1 @@ -133,9 +184,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): "separated by basic slicing specs." ) dt_k = i.dtype.kind - if dt_k == "b": + if dt_k == "b" and i.ndim > 0: axes_referenced += i.ndim - elif dt_k in "ui": + elif dt_k in "ui" and i.ndim > 0: axes_referenced += 1 else: raise IndexError( @@ -186,6 +237,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): if sh_i == 0: is_empty = True k = k_new + elif _is_boolean(ind_i): + new_shape.append(1 if ind_i else 0) + new_strides.append(0) elif _is_integral(ind_i): ind_i = ind_i.__index__() if 0 <= ind_i < shape[k]: diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index a57ea83cea..41688075e0 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -455,6 +455,32 @@ def test_integer_strided_indexing(): assert (dpt.asnumpy(y) == dpt.asnumpy(yc)).all() +def test_TrueFalse_indexing(): + get_queue_or_skip() + n0, n1 = 2, 3 + x = dpt.ones((n0, n1)) + for ind in [True, dpt.asarray(True)]: + y1 = x[ind] + assert y1.shape == (1, n0, n1) + assert y1._pointer == x._pointer + y2 = x[:, ind] + assert y2.shape == (n0, 1, n1) + assert y2._pointer == x._pointer + y3 = x[..., ind] + assert y3.shape == (n0, n1, 1) + assert y3._pointer == x._pointer + for ind in [False, dpt.asarray(False)]: + y1 = x[ind] + assert y1.shape == (0, n0, n1) + assert y1._pointer == x._pointer + y2 = x[:, ind] + assert y2.shape == (n0, 0, n1) + assert y2._pointer == x._pointer + y3 = x[..., ind] + assert y3.shape == (n0, n1, 0) + assert y3._pointer == x._pointer + + @pytest.mark.parametrize( "data_dt", _all_dtypes,