diff --git a/dpctl/tensor/libtensor/include/utils/strided_iters.hpp b/dpctl/tensor/libtensor/include/utils/strided_iters.hpp index e2fe6051bc..e7a7b1d75f 100644 --- a/dpctl/tensor/libtensor/include/utils/strided_iters.hpp +++ b/dpctl/tensor/libtensor/include/utils/strided_iters.hpp @@ -428,11 +428,19 @@ int simplify_iteration_two_strides(const int nd, std::iota(pos.begin(), pos.end(), 0); std::stable_sort( - pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) { - auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1]; - auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2]; - return (abs_str1 > abs_str2) || - (abs_str1 == abs_str2 && shape[i1] > shape[i2]); + pos.begin(), pos.end(), [&strides1, &strides2, &shape](int i1, int i2) { + auto abs_str1_i1 = + (strides1[i1] < 0) ? -strides1[i1] : strides1[i1]; + auto abs_str1_i2 = + (strides1[i2] < 0) ? -strides1[i2] : strides1[i2]; + auto abs_str2_i1 = + (strides2[i1] < 0) ? -strides2[i1] : strides2[i1]; + auto abs_str2_i2 = + (strides2[i2] < 0) ? -strides2[i2] : strides2[i2]; + return (abs_str1_i1 > abs_str1_i2) || + (abs_str1_i1 == abs_str1_i2 && + (abs_str2_i1 > abs_str2_i2 || + (abs_str2_i1 == abs_str2_i2 && shape[i1] > shape[i2]))); }); std::vector shape_w; @@ -458,6 +466,7 @@ int simplify_iteration_two_strides(const int nd, strides1_w.push_back(str1_p); strides2_w.push_back(str2_p); } + int nd_ = nd; while (contractable) { bool changed = false; @@ -570,13 +579,28 @@ int simplify_iteration_three_strides(const int nd, std::vector pos(nd); std::iota(pos.begin(), pos.end(), 0); - std::stable_sort( - pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) { - auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1]; - auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2]; - return (abs_str1 > abs_str2) || - (abs_str1 == abs_str2 && shape[i1] > shape[i2]); - }); + std::stable_sort(pos.begin(), pos.end(), + [&strides1, &strides2, &strides3, &shape](int i1, int i2) { + auto abs_str1_i1 = + (strides1[i1] < 0) ? -strides1[i1] : strides1[i1]; + auto abs_str1_i2 = + (strides1[i2] < 0) ? -strides1[i2] : strides1[i2]; + auto abs_str2_i1 = + (strides2[i1] < 0) ? -strides2[i1] : strides2[i1]; + auto abs_str2_i2 = + (strides2[i2] < 0) ? -strides2[i2] : strides2[i2]; + auto abs_str3_i1 = + (strides3[i1] < 0) ? -strides3[i1] : strides3[i1]; + auto abs_str3_i2 = + (strides3[i2] < 0) ? -strides3[i2] : strides3[i2]; + return (abs_str1_i1 > abs_str1_i2) || + ((abs_str1_i1 == abs_str1_i2) && + ((abs_str2_i1 > abs_str2_i2) || + ((abs_str2_i1 == abs_str2_i2) && + ((abs_str3_i1 > abs_str3_i2) || + ((abs_str3_i1 == abs_str3_i2) && + (shape[i1] > shape[i2])))))); + }); std::vector shape_w; std::vector strides1_w;