diff --git a/dpctl/tensor/_stride_utils.pxi b/dpctl/tensor/_stride_utils.pxi index 37d5a366b7..c019cd8ddd 100644 --- a/dpctl/tensor/_stride_utils.pxi +++ b/dpctl/tensor/_stride_utils.pxi @@ -61,6 +61,7 @@ cdef int _from_input_shape_strides( Otherwise they are set to NULL """ cdef int i + cdef int j cdef int all_incr = 1 cdef int all_decr = 1 cdef Py_ssize_t elem_count = 1 @@ -115,6 +116,15 @@ cdef int _from_input_shape_strides( contig[0] = USM_ARRAY_C_CONTIGUOUS else: contig[0] = USM_ARRAY_F_CONTIGUOUS + if nd == 1: + contig[0] = USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS + else: + j = 0 + for i in range(nd): + if shape_arr[i] > 1: + j = j + 1 + if j < 2: + contig[0] = USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS min_disp[0] = 0 max_disp[0] = (elem_count - 1) strides_ptr[0] = (0) @@ -137,26 +147,42 @@ cdef int _from_input_shape_strides( min_disp[0] = min_shift max_disp[0] = max_shift if max_shift == min_shift + (elem_count - 1): + if elem_count == 1: + contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS) + return 0 if nd == 1: if strides_arr[0] == 1: contig[0] = USM_ARRAY_C_CONTIGUOUS else: contig[0] = 0 return 0 - for i in range(0, nd - 1): - if all_incr: - all_incr = ( - (strides_arr[i] > 0) and - (strides_arr[i+1] > 0) and - (strides_arr[i] <= strides_arr[i + 1]) - ) - if all_decr: - all_decr = ( - (strides_arr[i] > 0) and - (strides_arr[i+1] > 0) and - (strides_arr[i] >= strides_arr[i + 1]) - ) - if all_incr: + i = 0 + while i < nd: + if shape_arr[i] == 1: + i = i + 1 + continue + j = i + 1 + while (j < nd and shape_arr[j] == 1): + j = j + 1 + if j < nd: + if all_incr: + all_incr = ( + (strides_arr[i] > 0) and + (strides_arr[j] > 0) and + (strides_arr[i] <= strides_arr[j]) + ) + if all_decr: + all_decr = ( + (strides_arr[i] > 0) and + (strides_arr[j] > 0) and + (strides_arr[i] >= strides_arr[j]) + ) + i = j + else: + break + if all_incr and all_decr: + contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS) + elif all_incr: contig[0] = USM_ARRAY_F_CONTIGUOUS elif all_decr: contig[0] = USM_ARRAY_C_CONTIGUOUS diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 7a33e8d87c..b4d2ed7872 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -59,6 +59,16 @@ def test_allocate_usm_ndarray(shape, usm_type): assert X.shape == X.__sycl_usm_array_interface__["shape"] +def test_usm_ndarray_flags(): + assert dpt.usm_ndarray((5,)).flags == 3 + assert dpt.usm_ndarray((5, 2)).flags == 1 + assert dpt.usm_ndarray((5, 2), order="F").flags == 2 + assert dpt.usm_ndarray((5, 1, 2), order="F").flags == 2 + assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags == 1 + assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags == 2 + assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags == 3 + + @pytest.mark.parametrize( "dtype", [ @@ -703,11 +713,10 @@ def relaxed_strides_equal(st1, st2, sh): 5, ) X = dpt.usm_ndarray(sh_s, dtype="d") - expected_flags = X.flags X.shape = sh_f assert X.shape == sh_f assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f) - assert X.flags == expected_flags + assert X.flags & 1, "reshaped array expected to be C-contiguous" sh_s = ( 2, @@ -842,6 +851,10 @@ def test_reshape(): W = dpt.reshape(Z, (-1,), order="C") assert W.shape == (Z.size,) + X = dpt.usm_ndarray((1,)) + Y = dpt.reshape(X, X.shape) + assert Y.flags == X.flags + def test_transpose(): n, m = 2, 3