diff --git a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp index 216c1102ab..4384f39cef 100644 --- a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp @@ -177,7 +177,7 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask, } const py::ssize_t *shape = mask.get_shape_raw(); - const py::ssize_t *strides = mask.get_strides_raw(); + auto const &strides_vector = mask.get_strides_vector(); using shT = std::vector; shT simplified_shape; @@ -187,13 +187,9 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask, int mask_nd = mask.get_ndim(); int nd = mask_nd; - constexpr py::ssize_t itemsize = 1; // in elements - bool is_c_contig = mask.is_c_contiguous(); - bool is_f_contig = mask.is_f_contiguous(); - dpctl::tensor::py_internal::simplify_iteration_space_1( - nd, shape, strides, itemsize, is_c_contig, is_f_contig, - simplified_shape, simplified_strides, offset); + nd, shape, strides_vector, simplified_shape, simplified_strides, + offset); if (nd == 1 && simplified_strides[0] == 1) { auto fn = mask_positions_contig_dispatch_vector[mask_typeid]; @@ -463,19 +459,13 @@ py_extract(dpctl::tensor::usm_ndarray src, std::vector simplified_ortho_dst_strides; const py::ssize_t *_shape = ortho_src_shape.data(); - const py::ssize_t *_src_strides = ortho_src_strides.data(); - const py::ssize_t *_dst_strides = ortho_dst_strides.data(); - constexpr py::ssize_t _itemsize = 1; // in elements - - constexpr bool is_c_contig = false; - constexpr bool is_f_contig = false; py::ssize_t ortho_src_offset(0); py::ssize_t ortho_dst_offset(0); dpctl::tensor::py_internal::simplify_iteration_space( - ortho_nd, _shape, _src_strides, _itemsize, is_c_contig, is_f_contig, - _dst_strides, _itemsize, is_c_contig, is_f_contig, + ortho_nd, _shape, ortho_src_strides, ortho_dst_strides, + // output simplified_ortho_shape, simplified_ortho_src_strides, simplified_ortho_dst_strides, ortho_src_offset, ortho_dst_offset); @@ -775,19 +765,12 @@ py_place(dpctl::tensor::usm_ndarray dst, std::vector simplified_ortho_rhs_strides; const py::ssize_t *_shape = ortho_dst_shape.data(); - const py::ssize_t *_dst_strides = ortho_dst_strides.data(); - const py::ssize_t *_rhs_strides = ortho_rhs_strides.data(); - constexpr py::ssize_t _itemsize = 1; // in elements - - constexpr bool is_c_contig = false; - constexpr bool is_f_contig = false; py::ssize_t ortho_dst_offset(0); py::ssize_t ortho_rhs_offset(0); dpctl::tensor::py_internal::simplify_iteration_space( - ortho_nd, _shape, _dst_strides, _itemsize, is_c_contig, is_f_contig, - _rhs_strides, _itemsize, is_c_contig, is_f_contig, + ortho_nd, _shape, ortho_dst_strides, ortho_rhs_strides, simplified_ortho_shape, simplified_ortho_dst_strides, simplified_ortho_rhs_strides, ortho_dst_offset, ortho_rhs_offset); diff --git a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp index db92e2a18e..d8692c1098 100644 --- a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp +++ b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp @@ -167,8 +167,8 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src, copy_ev); } - const py::ssize_t *src_strides = src.get_strides_raw(); - const py::ssize_t *dst_strides = dst.get_strides_raw(); + auto const &src_strides = src.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); using shT = std::vector; shT simplified_shape; @@ -180,25 +180,20 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src, int nd = src_nd; const py::ssize_t *shape = src_shape; - constexpr py::ssize_t src_itemsize = 1; // in elements - constexpr py::ssize_t dst_itemsize = 1; // in elements - - // all args except itemsizes and is_?_contig bools can be modified by - // reference + // nd, simplified_* and *_offset are modified by reference dpctl::tensor::py_internal::simplify_iteration_space( - nd, shape, src_strides, src_itemsize, is_src_c_contig, is_src_f_contig, - dst_strides, dst_itemsize, is_dst_c_contig, is_dst_f_contig, + nd, shape, src_strides, dst_strides, + // output simplified_shape, simplified_src_strides, simplified_dst_strides, src_offset, dst_offset); if (nd < 2) { if (nd == 1) { - std::array shape_arr = {shape[0]}; - // strides may be null + std::array shape_arr = {simplified_shape[0]}; std::array src_strides_arr = { - (src_strides ? src_strides[0] : 1)}; + simplified_src_strides[0]}; std::array dst_strides_arr = { - (dst_strides ? dst_strides[0] : 1)}; + simplified_dst_strides[0]}; sycl::event copy_and_cast_1d_event; if ((src_strides_arr[0] == 1) && (dst_strides_arr[0] == 1) && diff --git a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp index 3c2e84ec0e..3b02225f01 100644 --- a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp +++ b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include +#include #include #include "dpctl4pybind11.hpp" @@ -143,10 +144,8 @@ void copy_numpy_ndarray_into_usm_ndarray( } } - const py::ssize_t *src_strides = - npy_src.strides(); // N.B.: strides in bytes - const py::ssize_t *dst_strides = - dst.get_strides_raw(); // N.B.: strides in elements + auto const &dst_strides = + dst.get_strides_vector(); // N.B.: strides in elements using shT = std::vector; shT simplified_shape; @@ -155,23 +154,42 @@ void copy_numpy_ndarray_into_usm_ndarray( py::ssize_t src_offset(0); py::ssize_t dst_offset(0); - py::ssize_t src_itemsize = npy_src.itemsize(); // item size in bytes - constexpr py::ssize_t dst_itemsize = 1; // item size in elements - int nd = src_ndim; const py::ssize_t *shape = src_shape; + const py::ssize_t *src_strides_p = + npy_src.strides(); // N.B.: strides in bytes + py::ssize_t src_itemsize = npy_src.itemsize(); // item size in bytes + bool is_src_c_contig = ((src_flags & py::array::c_style) != 0); bool is_src_f_contig = ((src_flags & py::array::f_style) != 0); - bool is_dst_c_contig = dst.is_c_contiguous(); - bool is_dst_f_contig = dst.is_f_contiguous(); + shT src_strides_in_elems; + if (src_strides_p) { + src_strides_in_elems.resize(nd); + // copy and convert strides from bytes to elements + std::transform( + src_strides_p, src_strides_p + nd, std::begin(src_strides_in_elems), + [src_itemsize](py::ssize_t el) { return el / src_itemsize; }); + } + else { + if (is_src_c_contig) { + src_strides_in_elems = + dpctl::tensor::c_contiguous_strides(nd, src_shape); + } + else if (is_src_f_contig) { + src_strides_in_elems = + dpctl::tensor::f_contiguous_strides(nd, src_shape); + } + else { + throw py::value_error("NumPy source array has null strides but is " + "neither C- nor F-contiguous."); + } + } - // all args except itemsizes and is_?_contig bools can be modified by - // reference - simplify_iteration_space(nd, shape, src_strides, src_itemsize, - is_src_c_contig, is_src_f_contig, dst_strides, - dst_itemsize, is_dst_c_contig, is_dst_f_contig, + // nd, simplified_* vectors and offsets are modified by reference + simplify_iteration_space(nd, shape, src_strides_in_elems, dst_strides, + // outputs simplified_shape, simplified_src_strides, simplified_dst_strides, src_offset, dst_offset); @@ -186,18 +204,16 @@ void copy_numpy_ndarray_into_usm_ndarray( simplified_shape.push_back(1); simplified_src_strides.reserve(nd); - simplified_src_strides.push_back(src_itemsize); + simplified_src_strides.push_back(1); simplified_dst_strides.reserve(nd); - simplified_dst_strides.push_back(dst_itemsize); + simplified_dst_strides.push_back(1); } // Minumum and maximum element offsets for source np.ndarray py::ssize_t npy_src_min_nelem_offset(0); py::ssize_t npy_src_max_nelem_offset(0); for (int i = 0; i < nd; ++i) { - // convert source strides from bytes to elements - simplified_src_strides[i] = simplified_src_strides[i] / src_itemsize; if (simplified_src_strides[i] < 0) { npy_src_min_nelem_offset += simplified_src_strides[i] * (simplified_shape[i] - 1); diff --git a/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp b/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp index d6bd259942..e11495204a 100644 --- a/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp +++ b/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp @@ -41,11 +41,9 @@ using dpctl::tensor::c_contiguous_strides; using dpctl::tensor::f_contiguous_strides; void simplify_iteration_space_1(int &nd, - const py::ssize_t *&shape, - const py::ssize_t *&strides, - py::ssize_t itemsize, - bool is_c_contig, - bool is_f_contig, + const py::ssize_t *const &shape, + std::vector const &strides, + // output std::vector &simplified_shape, std::vector &simplified_strides, py::ssize_t &offset) @@ -55,29 +53,11 @@ void simplify_iteration_space_1(int &nd, // Simplify iteration space to reduce dimensionality // and improve access pattern simplified_shape.reserve(nd); - for (int i = 0; i < nd; ++i) { - simplified_shape.push_back(shape[i]); - } + simplified_shape.insert(std::end(simplified_shape), shape, shape + nd); simplified_strides.reserve(nd); - if (strides == nullptr) { - if (is_c_contig) { - simplified_strides = c_contiguous_strides(nd, shape, itemsize); - } - else if (is_f_contig) { - simplified_strides = f_contiguous_strides(nd, shape, itemsize); - } - else { - throw std::runtime_error( - "Array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_strides.push_back(strides[i]); - } - } + simplified_strides.insert(std::end(simplified_strides), + std::begin(strides), std::end(strides)); assert(simplified_shape.size() == static_cast(nd)); assert(simplified_strides.size() == static_cast(nd)); @@ -96,41 +76,18 @@ void simplify_iteration_space_1(int &nd, simplified_shape.push_back(shape[0]); simplified_strides.reserve(nd); - - if (strides == nullptr) { - if (is_c_contig) { - simplified_strides.push_back(itemsize); - } - else if (is_f_contig) { - simplified_strides.push_back(itemsize); - } - else { - throw std::runtime_error( - "Array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_strides.push_back(strides[0]); - } + simplified_strides.push_back(strides[0]); assert(simplified_shape.size() == static_cast(nd)); assert(simplified_strides.size() == static_cast(nd)); } - shape = const_cast(simplified_shape.data()); - strides = const_cast(simplified_strides.data()); } void simplify_iteration_space(int &nd, - const py::ssize_t *&shape, - const py::ssize_t *&src_strides, - py::ssize_t src_itemsize, - bool is_src_c_contig, - bool is_src_f_contig, - const py::ssize_t *&dst_strides, - py::ssize_t dst_itemsize, - bool is_dst_c_contig, - bool is_dst_f_contig, + const py::ssize_t *const &shape, + std::vector const &src_strides, + std::vector const &dst_strides, + // output std::vector &simplified_shape, std::vector &simplified_src_strides, std::vector &simplified_dst_strides, @@ -142,56 +99,22 @@ void simplify_iteration_space(int &nd, // Simplify iteration space to reduce dimensionality // and improve access pattern simplified_shape.reserve(nd); - for (int i = 0; i < nd; ++i) { - simplified_shape.push_back(shape[i]); - } + simplified_shape.insert(std::begin(simplified_shape), shape, + shape + nd); + assert(simplified_shape.size() == static_cast(nd)); simplified_src_strides.reserve(nd); - simplified_dst_strides.reserve(nd); - if (src_strides == nullptr) { - if (is_src_c_contig) { - simplified_src_strides = - c_contiguous_strides(nd, shape, src_itemsize); - } - else if (is_src_f_contig) { - simplified_src_strides = - f_contiguous_strides(nd, shape, src_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_src_strides.push_back(src_strides[i]); - } - } - if (dst_strides == nullptr) { - if (is_dst_c_contig) { - simplified_dst_strides = - c_contiguous_strides(nd, shape, dst_itemsize); - } - else if (is_dst_f_contig) { - simplified_dst_strides = - f_contiguous_strides(nd, shape, dst_itemsize); - } - else { - throw std::runtime_error( - "Destination array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_dst_strides.push_back(dst_strides[i]); - } - } - - assert(simplified_shape.size() == static_cast(nd)); + simplified_src_strides.insert(std::end(simplified_src_strides), + std::begin(src_strides), + std::end(src_strides)); assert(simplified_src_strides.size() == static_cast(nd)); + + simplified_dst_strides.reserve(nd); + simplified_dst_strides.insert(std::end(simplified_dst_strides), + std::begin(dst_strides), + std::end(dst_strides)); assert(simplified_dst_strides.size() == static_cast(nd)); + int contracted_nd = simplify_iteration_two_strides( nd, simplified_shape.data(), simplified_src_strides.data(), simplified_dst_strides.data(), @@ -208,72 +131,27 @@ void simplify_iteration_space(int &nd, // Populate vectors simplified_shape.reserve(nd); simplified_shape.push_back(shape[0]); + assert(simplified_shape.size() == static_cast(nd)); simplified_src_strides.reserve(nd); - simplified_dst_strides.reserve(nd); - - if (src_strides == nullptr) { - if (is_src_c_contig) { - simplified_src_strides.push_back(src_itemsize); - } - else if (is_src_f_contig) { - simplified_src_strides.push_back(src_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_src_strides.push_back(src_strides[0]); - } - if (dst_strides == nullptr) { - if (is_dst_c_contig) { - simplified_dst_strides.push_back(dst_itemsize); - } - else if (is_dst_f_contig) { - simplified_dst_strides.push_back(dst_itemsize); - } - else { - throw std::runtime_error( - "Destination array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_dst_strides.push_back(dst_strides[0]); - } - - assert(simplified_shape.size() == static_cast(nd)); + simplified_src_strides.push_back(src_strides[0]); assert(simplified_src_strides.size() == static_cast(nd)); + + simplified_dst_strides.reserve(nd); + simplified_dst_strides.push_back(dst_strides[0]); assert(simplified_dst_strides.size() == static_cast(nd)); } - shape = const_cast(simplified_shape.data()); - src_strides = - const_cast(simplified_src_strides.data()); - dst_strides = - const_cast(simplified_dst_strides.data()); } void simplify_iteration_space_3( int &nd, - const py::ssize_t *&shape, + const py::ssize_t *const &shape, // src1 - const py::ssize_t *&src1_strides, - py::ssize_t src1_itemsize, - bool is_src1_c_contig, - bool is_src1_f_contig, + std::vector const &src1_strides, // src2 - const py::ssize_t *&src2_strides, - py::ssize_t src2_itemsize, - bool is_src2_c_contig, - bool is_src2_f_contig, + std::vector const &src2_strides, // dst - const py::ssize_t *&dst_strides, - py::ssize_t dst_itemsize, - bool is_dst_c_contig, - bool is_dst_f_contig, + std::vector const &dst_strides, // output std::vector &simplified_shape, std::vector &simplified_src1_strides, @@ -288,78 +166,27 @@ void simplify_iteration_space_3( // Simplify iteration space to reduce dimensionality // and improve access pattern simplified_shape.reserve(nd); - for (int i = 0; i < nd; ++i) { - simplified_shape.push_back(shape[i]); - } + simplified_shape.insert(std::end(simplified_shape), shape, shape + nd); + assert(simplified_shape.size() == static_cast(nd)); simplified_src1_strides.reserve(nd); - simplified_src2_strides.reserve(nd); - simplified_dst_strides.reserve(nd); - if (src1_strides == nullptr) { - if (is_src1_c_contig) { - simplified_src1_strides = - c_contiguous_strides(nd, shape, src1_itemsize); - } - else if (is_src1_f_contig) { - simplified_src1_strides = - f_contiguous_strides(nd, shape, src1_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_src1_strides.push_back(src1_strides[i]); - } - } - if (src2_strides == nullptr) { - if (is_src2_c_contig) { - simplified_src2_strides = - c_contiguous_strides(nd, shape, src2_itemsize); - } - else if (is_src2_f_contig) { - simplified_src2_strides = - f_contiguous_strides(nd, shape, src2_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_src2_strides.push_back(src2_strides[i]); - } - } - if (dst_strides == nullptr) { - if (is_dst_c_contig) { - simplified_dst_strides = - c_contiguous_strides(nd, shape, dst_itemsize); - } - else if (is_dst_f_contig) { - simplified_dst_strides = - f_contiguous_strides(nd, shape, dst_itemsize); - } - else { - throw std::runtime_error( - "Destination array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_dst_strides.push_back(dst_strides[i]); - } - } - - assert(simplified_shape.size() == static_cast(nd)); + simplified_src1_strides.insert(std::end(simplified_src1_strides), + std::begin(src1_strides), + std::end(src1_strides)); assert(simplified_src1_strides.size() == static_cast(nd)); + + simplified_src2_strides.reserve(nd); + simplified_src2_strides.insert(std::end(simplified_src2_strides), + std::begin(src2_strides), + std::end(src2_strides)); assert(simplified_src2_strides.size() == static_cast(nd)); + + simplified_dst_strides.reserve(nd); + simplified_dst_strides.insert(std::end(simplified_dst_strides), + std::begin(dst_strides), + std::end(dst_strides)); assert(simplified_dst_strides.size() == static_cast(nd)); + int contracted_nd = simplify_iteration_three_strides( nd, simplified_shape.data(), simplified_src1_strides.data(), simplified_src2_strides.data(), simplified_dst_strides.data(), @@ -378,97 +205,33 @@ void simplify_iteration_space_3( // Populate vectors simplified_shape.reserve(nd); simplified_shape.push_back(shape[0]); + assert(simplified_shape.size() == static_cast(nd)); simplified_src1_strides.reserve(nd); - simplified_src2_strides.reserve(nd); - simplified_dst_strides.reserve(nd); - - if (src1_strides == nullptr) { - if (is_src1_c_contig) { - simplified_src1_strides.push_back(src1_itemsize); - } - else if (is_src1_f_contig) { - simplified_src1_strides.push_back(src1_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_src1_strides.push_back(src1_strides[0]); - } - if (src2_strides == nullptr) { - if (is_src2_c_contig) { - simplified_src2_strides.push_back(src2_itemsize); - } - else if (is_src2_f_contig) { - simplified_src2_strides.push_back(src2_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_src2_strides.push_back(src2_strides[0]); - } - if (dst_strides == nullptr) { - if (is_dst_c_contig) { - simplified_dst_strides.push_back(dst_itemsize); - } - else if (is_dst_f_contig) { - simplified_dst_strides.push_back(dst_itemsize); - } - else { - throw std::runtime_error( - "Destination array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_dst_strides.push_back(dst_strides[0]); - } - - assert(simplified_shape.size() == static_cast(nd)); + simplified_src1_strides.push_back(src1_strides[0]); assert(simplified_src1_strides.size() == static_cast(nd)); + + simplified_src2_strides.reserve(nd); + simplified_src2_strides.push_back(src2_strides[0]); assert(simplified_src2_strides.size() == static_cast(nd)); + + simplified_dst_strides.reserve(nd); + simplified_dst_strides.push_back(dst_strides[0]); assert(simplified_dst_strides.size() == static_cast(nd)); } - shape = const_cast(simplified_shape.data()); - src1_strides = - const_cast(simplified_src1_strides.data()); - src2_strides = - const_cast(simplified_src2_strides.data()); - dst_strides = - const_cast(simplified_dst_strides.data()); } void simplify_iteration_space_4( int &nd, - const py::ssize_t *&shape, + const py::ssize_t *const &shape, // src1 - const py::ssize_t *&src1_strides, - py::ssize_t src1_itemsize, - bool is_src1_c_contig, - bool is_src1_f_contig, + std::vector const &src1_strides, // src2 - const py::ssize_t *&src2_strides, - py::ssize_t src2_itemsize, - bool is_src2_c_contig, - bool is_src2_f_contig, + std::vector const &src2_strides, // src3 - const py::ssize_t *&src3_strides, - py::ssize_t src3_itemsize, - bool is_src3_c_contig, - bool is_src3_f_contig, + std::vector const &src3_strides, // dst - const py::ssize_t *&dst_strides, - py::ssize_t dst_itemsize, - bool is_dst_c_contig, - bool is_dst_f_contig, + std::vector const &dst_strides, // output std::vector &simplified_shape, std::vector &simplified_src1_strides, @@ -485,100 +248,33 @@ void simplify_iteration_space_4( // Simplify iteration space to reduce dimensionality // and improve access pattern simplified_shape.reserve(nd); - for (int i = 0; i < nd; ++i) { - simplified_shape.push_back(shape[i]); - } + simplified_shape.insert(std::end(simplified_shape), shape, shape + nd); + assert(simplified_shape.size() == static_cast(nd)); simplified_src1_strides.reserve(nd); - simplified_src2_strides.reserve(nd); - simplified_src3_strides.reserve(nd); - simplified_dst_strides.reserve(nd); - if (src1_strides == nullptr) { - if (is_src1_c_contig) { - simplified_src1_strides = - c_contiguous_strides(nd, shape, src1_itemsize); - } - else if (is_src1_f_contig) { - simplified_src1_strides = - f_contiguous_strides(nd, shape, src1_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_src1_strides.push_back(src1_strides[i]); - } - } - if (src2_strides == nullptr) { - if (is_src2_c_contig) { - simplified_src2_strides = - c_contiguous_strides(nd, shape, src2_itemsize); - } - else if (is_src2_f_contig) { - simplified_src2_strides = - f_contiguous_strides(nd, shape, src2_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_src2_strides.push_back(src2_strides[i]); - } - } - if (src3_strides == nullptr) { - if (is_src3_c_contig) { - simplified_src3_strides = - c_contiguous_strides(nd, shape, src3_itemsize); - } - else if (is_src3_f_contig) { - simplified_src3_strides = - f_contiguous_strides(nd, shape, src3_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_src3_strides.push_back(src3_strides[i]); - } - } - if (dst_strides == nullptr) { - if (is_dst_c_contig) { - simplified_dst_strides = - c_contiguous_strides(nd, shape, dst_itemsize); - } - else if (is_dst_f_contig) { - simplified_dst_strides = - f_contiguous_strides(nd, shape, dst_itemsize); - } - else { - throw std::runtime_error( - "Destination array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - for (int i = 0; i < nd; ++i) { - simplified_dst_strides.push_back(dst_strides[i]); - } - } - - assert(simplified_shape.size() == static_cast(nd)); + simplified_src1_strides.insert(std::end(simplified_src1_strides), + std::begin(src1_strides), + std::end(src1_strides)); assert(simplified_src1_strides.size() == static_cast(nd)); + + simplified_src2_strides.reserve(nd); + simplified_src2_strides.insert(std::end(simplified_src2_strides), + std::begin(src2_strides), + std::end(src2_strides)); assert(simplified_src2_strides.size() == static_cast(nd)); + + simplified_src3_strides.reserve(nd); + simplified_src3_strides.insert(std::end(simplified_src3_strides), + std::begin(src3_strides), + std::end(src3_strides)); assert(simplified_src3_strides.size() == static_cast(nd)); + + simplified_dst_strides.reserve(nd); + simplified_dst_strides.insert(std::end(simplified_dst_strides), + std::begin(dst_strides), + std::end(dst_strides)); assert(simplified_dst_strides.size() == static_cast(nd)); + int contracted_nd = simplify_iteration_four_strides( nd, simplified_shape.data(), simplified_src1_strides.data(), simplified_src2_strides.data(), simplified_src3_strides.data(), @@ -600,92 +296,24 @@ void simplify_iteration_space_4( // Populate vectors simplified_shape.reserve(nd); simplified_shape.push_back(shape[0]); + assert(simplified_shape.size() == static_cast(nd)); simplified_src1_strides.reserve(nd); - simplified_src2_strides.reserve(nd); - simplified_src3_strides.reserve(nd); - simplified_dst_strides.reserve(nd); - - if (src1_strides == nullptr) { - if (is_src1_c_contig) { - simplified_src1_strides.push_back(src1_itemsize); - } - else if (is_src1_f_contig) { - simplified_src1_strides.push_back(src1_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_src1_strides.push_back(src1_strides[0]); - } - if (src2_strides == nullptr) { - if (is_src2_c_contig) { - simplified_src2_strides.push_back(src2_itemsize); - } - else if (is_src2_f_contig) { - simplified_src2_strides.push_back(src2_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_src2_strides.push_back(src2_strides[0]); - } - if (src3_strides == nullptr) { - if (is_src3_c_contig) { - simplified_src3_strides.push_back(src3_itemsize); - } - else if (is_src3_f_contig) { - simplified_src3_strides.push_back(src3_itemsize); - } - else { - throw std::runtime_error( - "Source array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_src3_strides.push_back(src3_strides[0]); - } - if (dst_strides == nullptr) { - if (is_dst_c_contig) { - simplified_dst_strides.push_back(dst_itemsize); - } - else if (is_dst_f_contig) { - simplified_dst_strides.push_back(dst_itemsize); - } - else { - throw std::runtime_error( - "Destination array has null strides " - "but has neither C- nor F- contiguous flag set"); - } - } - else { - simplified_dst_strides.push_back(dst_strides[0]); - } - - assert(simplified_shape.size() == static_cast(nd)); + simplified_src1_strides.push_back(src1_strides[0]); assert(simplified_src1_strides.size() == static_cast(nd)); + + simplified_src2_strides.reserve(nd); + simplified_src2_strides.push_back(src2_strides[0]); assert(simplified_src2_strides.size() == static_cast(nd)); + + simplified_src3_strides.reserve(nd); + simplified_src3_strides.push_back(src3_strides[0]); assert(simplified_src3_strides.size() == static_cast(nd)); + + simplified_dst_strides.reserve(nd); + simplified_dst_strides.push_back(dst_strides[0]); assert(simplified_dst_strides.size() == static_cast(nd)); } - shape = const_cast(simplified_shape.data()); - src1_strides = - const_cast(simplified_src1_strides.data()); - src2_strides = - const_cast(simplified_src2_strides.data()); - src3_strides = - const_cast(simplified_src3_strides.data()); - dst_strides = - const_cast(simplified_dst_strides.data()); } } // namespace py_internal diff --git a/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp b/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp index 5ce46a57a2..9a60830d1d 100644 --- a/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp +++ b/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp @@ -36,25 +36,16 @@ namespace py_internal namespace py = pybind11; void simplify_iteration_space_1(int &, - const py::ssize_t *&, - const py::ssize_t *&, - py::ssize_t, - bool, - bool, + const py::ssize_t *const &, + std::vector const &, std::vector &, std::vector &, py::ssize_t &); void simplify_iteration_space(int &, - const py::ssize_t *&, - const py::ssize_t *&, - py::ssize_t, - bool, - bool, - const py::ssize_t *&, - py::ssize_t, - bool, - bool, + const py::ssize_t *const &, + std::vector const &, + std::vector const &, std::vector &, std::vector &, std::vector &, @@ -62,22 +53,13 @@ void simplify_iteration_space(int &, py::ssize_t &); void simplify_iteration_space_3(int &, - const py::ssize_t *&, + const py::ssize_t *const &, // src1 - const py::ssize_t *&, - py::ssize_t, - bool, - bool, + std::vector const &, // src2 - const py::ssize_t *&, - py::ssize_t, - bool, - bool, + std::vector const &, // dst - const py::ssize_t *&, - py::ssize_t, - bool, - bool, + std::vector const &, // output std::vector &, std::vector &, @@ -88,27 +70,15 @@ void simplify_iteration_space_3(int &, py::ssize_t &); void simplify_iteration_space_4(int &, - const py::ssize_t *&, + const py::ssize_t *const &, // src1 - const py::ssize_t *&, - py::ssize_t, - bool, - bool, + std::vector const &, // src2 - const py::ssize_t *&, - py::ssize_t, - bool, - bool, + std::vector const &, // src3 - const py::ssize_t *&, - py::ssize_t, - bool, - bool, + std::vector const &, // dst - const py::ssize_t *&, - py::ssize_t, - bool, - bool, + std::vector const &, // output std::vector &, std::vector &, diff --git a/dpctl/tensor/libtensor/source/triul_ctor.cpp b/dpctl/tensor/libtensor/source/triul_ctor.cpp index 47fad15698..b9cf4543f9 100644 --- a/dpctl/tensor/libtensor/source/triul_ctor.cpp +++ b/dpctl/tensor/libtensor/source/triul_ctor.cpp @@ -117,12 +117,6 @@ usm_ndarray_triul(sycl::queue exec_q, "Execution queue context is not the same as allocation contexts"); } - bool is_src_c_contig = src.is_c_contiguous(); - bool is_src_f_contig = src.is_f_contiguous(); - - bool is_dst_c_contig = dst.is_c_contiguous(); - bool is_dst_f_contig = dst.is_f_contiguous(); - auto src_strides = src.get_strides_vector(); auto dst_strides = dst.get_strides_vector(); @@ -133,17 +127,16 @@ usm_ndarray_triul(sycl::queue exec_q, py::ssize_t src_offset(0); py::ssize_t dst_offset(0); - constexpr py::ssize_t src_itemsize = 1; // item size in elements - constexpr py::ssize_t dst_itemsize = 1; // item size in elements - int nd = src_nd - 2; const py::ssize_t *shape = src_shape; - const py::ssize_t *p_src_strides = src_strides.data(); - const py::ssize_t *p_dst_strides = dst_strides.data(); - simplify_iteration_space(nd, shape, p_src_strides, src_itemsize, - is_src_c_contig, is_src_f_contig, p_dst_strides, - dst_itemsize, is_dst_c_contig, is_dst_f_contig, + const shT iter_src_strides(std::begin(src_strides), + std::begin(src_strides) + nd); + const shT iter_dst_strides(std::begin(dst_strides), + std::begin(dst_strides) + nd); + + simplify_iteration_space(nd, shape, iter_src_strides, iter_dst_strides, + // output simplified_shape, simplified_src_strides, simplified_dst_strides, src_offset, dst_offset); diff --git a/dpctl/tensor/libtensor/source/where.cpp b/dpctl/tensor/libtensor/source/where.cpp index b3843844bd..8f47381cf6 100644 --- a/dpctl/tensor/libtensor/source/where.cpp +++ b/dpctl/tensor/libtensor/source/where.cpp @@ -177,10 +177,10 @@ py_where(dpctl::tensor::usm_ndarray condition, return std::make_pair(ht_ev, where_ev); } - const py::ssize_t *cond_strides = condition.get_strides_raw(); - const py::ssize_t *x1_strides = x1.get_strides_raw(); - const py::ssize_t *x2_strides = x2.get_strides_raw(); - const py::ssize_t *dst_strides = dst.get_strides_raw(); + auto const &cond_strides = condition.get_strides_vector(); + auto const &x1_strides = x1.get_strides_vector(); + auto const &x2_strides = x2.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); using shT = std::vector; shT simplified_shape; @@ -193,17 +193,12 @@ py_where(dpctl::tensor::usm_ndarray condition, py::ssize_t x2_offset(0); py::ssize_t dst_offset(0); - const py::ssize_t *_shape = x1_shape; - - constexpr py::ssize_t _itemsize = 1; - dpctl::tensor::py_internal::simplify_iteration_space_4( - nd, _shape, cond_strides, _itemsize, is_cond_c_contig, is_cond_f_contig, - x1_strides, _itemsize, is_x1_c_contig, is_x1_f_contig, x2_strides, - _itemsize, is_x2_c_contig, is_x2_f_contig, dst_strides, _itemsize, - is_dst_c_contig, is_dst_f_contig, simplified_shape, - simplified_cond_strides, simplified_x1_strides, simplified_x2_strides, - simplified_dst_strides, cond_offset, x1_offset, x2_offset, dst_offset); + nd, x1_shape, cond_strides, x1_strides, x2_strides, dst_strides, + // outputs + simplified_shape, simplified_cond_strides, simplified_x1_strides, + simplified_x2_strides, simplified_dst_strides, cond_offset, x1_offset, + x2_offset, dst_offset); auto fn = where_strided_dispatch_table[x1_typeid][cond_typeid];