From e5c75522002102fbf279dc7512ccc32642c99898 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 23 Jan 2023 16:10:34 -0600 Subject: [PATCH] Added _contract_iter3 utility to simplify iteration space over 3 arrays ``` In [1]: import dpctl.tensor as dpt, dpctl.tensor._tensor_impl as ti, dpctl In [4]: import itertools In [5]: ti._contract_iter2((2, 5, 3), (15, -3, 1), (0,0,1)) Out[5]: ([10, 3], [3, 1], -12, [0, 1], 0) In [6]: or_s = set( (15*i0 - 3*i1 + i2, i2, 15*i0 - 3*i1 + i2) for i0,i1,i2 in itertools.product(range(2), range(5), range(3)) ) In [7]: alt_s = set( (3*i0 + i1 - 12, i1, 3*i0 + i1 - 12) for i0,i1 in itertools.product(range(10), range(3)) ) In [8]: or_s == alt_s Out[8]: True ``` --- .../libtensor/include/utils/strided_iters.hpp | 154 +++++++++++++++++- dpctl/tensor/libtensor/source/tensor_py.cpp | 10 ++ 2 files changed, 159 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/libtensor/include/utils/strided_iters.hpp b/dpctl/tensor/libtensor/include/utils/strided_iters.hpp index 0abd7f4f2a..e2fe6051bc 100644 --- a/dpctl/tensor/libtensor/include/utils/strided_iters.hpp +++ b/dpctl/tensor/libtensor/include/utils/strided_iters.hpp @@ -408,9 +408,8 @@ int simplify_iteration_stride(const int nd, The new shape and new strides, as well as the offset `(new_shape, new_strides1, disp1, new_stride2, disp2)` are such that - iterating over them will traverse the same pairs of elements, possibly in - different order. - + iterating over them will traverse the same set of pairs of elements, + possibly in a different order. */ template int simplify_iteration_two_strides(const int nd, @@ -447,7 +446,7 @@ int simplify_iteration_two_strides(const int nd, auto str1_p = strides1[p]; auto str2_p = strides2[p]; shape_w.push_back(sh_p); - if (str1_p < 0 && str2_p < 0) { + if (str1_p <= 0 && str2_p <= 0 && std::min(str1_p, str2_p) < 0) { disp1 += str1_p * (sh_p - 1); str1_p = -str1_p; disp2 += str2_p * (sh_p - 1); @@ -468,7 +467,7 @@ int simplify_iteration_two_strides(const int nd, StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1; StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2; - if (jump1 == str1 and jump2 == str2) { + if (jump1 == str1 && jump2 == str2) { changed = true; shape_w[i] *= shape_w[i + 1]; for (int j = i; j < nd_; ++j) { @@ -540,3 +539,148 @@ contract_iter2(vecT shape, vecT strides1, vecT strides2) out_strides2.resize(nd); return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2); } + +/* + For purposes of iterating over pairs of elements of three arrays + with `shape` and strides `strides1`, `strides2`, `strides3` given as + pointers `simplify_iteration_three_strides(nd, shape_ptr, strides1_ptr, + strides2_ptr, strides3_ptr, disp1, disp2, disp3)` + may modify memory and returns new length of these arrays. + + The new shape and new strides, as well as the offset + `(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3)` + are such that iterating over them will traverse the same set of tuples of + elements, possibly in a different order. + */ +template +int simplify_iteration_three_strides(const int nd, + ShapeTy *shape, + StridesTy *strides1, + StridesTy *strides2, + StridesTy *strides3, + StridesTy &disp1, + StridesTy &disp2, + StridesTy &disp3) +{ + disp1 = std::ptrdiff_t(0); + disp2 = std::ptrdiff_t(0); + if (nd < 2) + return nd; + + std::vector pos(nd); + std::iota(pos.begin(), pos.end(), 0); + + std::stable_sort( + pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) { + auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1]; + auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2]; + return (abs_str1 > abs_str2) || + (abs_str1 == abs_str2 && shape[i1] > shape[i2]); + }); + + std::vector shape_w; + std::vector strides1_w; + std::vector strides2_w; + std::vector strides3_w; + + bool contractable = true; + for (int i = 0; i < nd; ++i) { + auto p = pos[i]; + auto sh_p = shape[p]; + auto str1_p = strides1[p]; + auto str2_p = strides2[p]; + auto str3_p = strides3[p]; + shape_w.push_back(sh_p); + if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 && + std::min(std::min(str1_p, str2_p), str3_p) < 0) + { + disp1 += str1_p * (sh_p - 1); + str1_p = -str1_p; + disp2 += str2_p * (sh_p - 1); + str2_p = -str2_p; + disp3 += str3_p * (sh_p - 1); + str3_p = -str3_p; + } + if (str1_p < 0 || str2_p < 0 || str3_p < 0) { + contractable = false; + } + strides1_w.push_back(str1_p); + strides2_w.push_back(str2_p); + strides3_w.push_back(str3_p); + } + int nd_ = nd; + while (contractable) { + bool changed = false; + for (int i = 0; i + 1 < nd_; ++i) { + StridesTy str1 = strides1_w[i + 1]; + StridesTy str2 = strides2_w[i + 1]; + StridesTy str3 = strides3_w[i + 1]; + StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1; + StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2; + StridesTy jump3 = strides3_w[i] - (shape_w[i + 1] - 1) * str3; + + if (jump1 == str1 && jump2 == str2 && jump3 == str3) { + changed = true; + shape_w[i] *= shape_w[i + 1]; + for (int j = i; j < nd_; ++j) { + strides1_w[j] = strides1_w[j + 1]; + } + for (int j = i; j < nd_; ++j) { + strides2_w[j] = strides2_w[j + 1]; + } + for (int j = i; j < nd_; ++j) { + strides3_w[j] = strides3_w[j + 1]; + } + for (int j = i + 1; j + 1 < nd_; ++j) { + shape_w[j] = shape_w[j + 1]; + } + --nd_; + break; + } + } + if (!changed) + break; + } + for (int i = 0; i < nd_; ++i) { + shape[i] = shape_w[i]; + } + for (int i = 0; i < nd_; ++i) { + strides1[i] = strides1_w[i]; + } + for (int i = 0; i < nd_; ++i) { + strides2[i] = strides2_w[i]; + } + for (int i = 0; i < nd_; ++i) { + strides3[i] = strides3_w[i]; + } + + return nd_; +} + +template > +std::tuple +contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3) +{ + const size_t dim = shape.size(); + if (dim != strides1.size() || dim != strides2.size() || + dim != strides3.size()) { + throw Error("Shape and strides must be of equal size."); + } + vecT out_shape = shape; + vecT out_strides1 = strides1; + vecT out_strides2 = strides2; + vecT out_strides3 = strides3; + T disp1(0); + T disp2(0); + T disp3(0); + + int nd = simplify_iteration_three_strides( + dim, out_shape.data(), out_strides1.data(), out_strides2.data(), + out_strides3.data(), disp1, disp2, disp3); + out_shape.resize(nd); + out_strides1.resize(nd); + out_strides2.resize(nd); + out_strides3.resize(nd); + return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2, + out_strides3, disp3); +} diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 460855e76e..aa8634ecf4 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -133,6 +133,16 @@ PYBIND11_MODULE(_tensor_impl, m) "as the original " "iterator, possibly in a different order."); + m.def( + "_contract_iter3", &contract_iter3, + "Simplifies iteration over elements of 3-tuple of arrays of given " + "shape " + "with strides stride1, stride2, and stride3. Returns " + "a 7-tuple: shape, stride and offset for the new iterator of possible " + "smaller dimension for each array, which traverses the same elements " + "as the original " + "iterator, possibly in a different order."); + m.def("_copy_usm_ndarray_for_reshape", ©_usm_ndarray_for_reshape, "Copies from usm_ndarray `src` into usm_ndarray `dst` with the same " "number of elements using underlying 'C'-contiguous order for flat "