diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 052fe28c60..8ed840b8f6 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -97,23 +97,22 @@ sycl::event single_reduction_for_gemm(sycl::queue &exec_q, { sycl::event red_ev; if (reduction_nelems < wg) { - red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; + const ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + const ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; - ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); sycl::range<1> iter_range{iter_nelems}; @@ -128,23 +127,22 @@ sycl::event single_reduction_for_gemm(sycl::queue &exec_q, }); } else { - red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; + const ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + const ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; - ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); if (iter_nelems == 1) { // increase GPU occupancy @@ -194,21 +192,20 @@ single_reduction_for_gemm_contig(sycl::queue &exec_q, { sycl::event red_ev; if (reduction_nelems < wg) { - red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; + constexpr InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); sycl::range<1> iter_range{iter_nelems}; @@ -223,21 +220,20 @@ single_reduction_for_gemm_contig(sycl::queue &exec_q, }); } else { - red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; + constexpr InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); if (iter_nelems == 1) { // increase GPU occupancy @@ -288,9 +284,7 @@ sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, const ssize_t *res_shape_strides, const std::vector &depends) { - - const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler - &cgh) { + sycl::event first_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; @@ -304,9 +298,9 @@ sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, // iter_shape_and_strides are going to be accessed by // inp_indexer - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ + constexpr InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{ 0, /* size */ static_cast(reduction_nelems), /* step */ static_cast(iter_nelems)}; @@ -348,14 +342,15 @@ sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, InputIndexerT, ResIndexerT>; using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - InputIndexerT inp_indexer{0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; - ResIndexerT res_iter_indexer{}; + const InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + constexpr ResIndexerT res_iter_indexer{}; - InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, - res_iter_indexer}; + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; - ReductionIndexerT reduction_indexer{}; + constexpr ReductionIndexerT reduction_indexer{}; auto globalRange = sycl::range<1>{iter_nelems * reduction_groups_ * wg}; @@ -391,15 +386,15 @@ sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, InputIndexerT, ResIndexerT>; using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - InputIndexerT inp_indexer{ + const InputIndexerT inp_indexer{ 0, static_cast(iter_nelems), static_cast(remaining_reduction_nelems)}; - ResIndexerT res_iter_indexer{res_nd, static_cast(res_offset), - res_shape_strides}; + const ResIndexerT res_iter_indexer{ + res_nd, static_cast(res_offset), res_shape_strides}; - InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, - res_iter_indexer}; - ReductionIndexerT reduction_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + constexpr ReductionIndexerT reduction_indexer{}; wg = max_wg; reductions_per_wi = @@ -462,9 +457,9 @@ tree_reduction_for_gemm_contig(sycl::queue &exec_q, // iter_shape_and_strides are going to be accessed by // inp_indexer - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ + constexpr InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{ 0, /* size */ static_cast(reduction_nelems), /* step */ static_cast(iter_nelems)}; @@ -509,14 +504,15 @@ tree_reduction_for_gemm_contig(sycl::queue &exec_q, // n * m = iter_nelems because essentially, this process // creates a stack of reduction_nelems 2D matrices and we reduce // along the stack axis - InputIndexerT inp_indexer{0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; - ResIndexerT res_iter_indexer{}; + const InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + constexpr ResIndexerT res_iter_indexer{}; - InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, - res_iter_indexer}; + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; - ReductionIndexerT reduction_indexer{}; + constexpr ReductionIndexerT reduction_indexer{}; auto globalRange = sycl::range<1>{iter_nelems * reduction_groups_ * wg}; @@ -551,14 +547,14 @@ tree_reduction_for_gemm_contig(sycl::queue &exec_q, InputIndexerT, ResIndexerT>; using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - InputIndexerT inp_indexer{ + const InputIndexerT inp_indexer{ 0, static_cast(iter_nelems), static_cast(remaining_reduction_nelems)}; - ResIndexerT res_iter_indexer{}; + constexpr ResIndexerT res_iter_indexer{}; - InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, - res_iter_indexer}; - ReductionIndexerT reduction_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + constexpr ReductionIndexerT reduction_indexer{}; wg = max_wg; reductions_per_wi = @@ -592,9 +588,10 @@ template -class GemmFunctorThreadNM +class GemmBatchFunctorThreadNM { private: const lhsT *lhs = nullptr; @@ -610,31 +607,36 @@ class GemmFunctorThreadNM size_t m = 0; size_t m_blocks = 0; size_t wg_delta_m = 0; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - OuterInnerDimsIndexerT res_indexer; + size_t batch_nelems; + const BatchDimsIndexerT batch_indexer; + const OuterInnerDimsIndexerT lhs_indexer; + const OuterInnerDimsIndexerT rhs_indexer; + const OuterInnerDimsIndexerT res_indexer; public: - GemmFunctorThreadNM(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT1 local_A_block_, - LocAccT2 local_B_block_, - size_t n_, - size_t wg_delta_n_, - size_t k_, - size_t k_blocks_, - size_t wi_delta_k_, - size_t m_, - size_t m_blocks_, - size_t wg_delta_m_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - OuterInnerDimsIndexerT res_indexer_) + GemmBatchFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + const BatchDimsIndexerT &batch_indexer_, + const OuterInnerDimsIndexerT &lhs_indexer_, + const OuterInnerDimsIndexerT &rhs_indexer_, + const OuterInnerDimsIndexerT &res_indexer_) : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) { @@ -642,25 +644,35 @@ class GemmFunctorThreadNM void operator()(sycl::nd_item<1> it) const { - const size_t gr_id = it.get_group_linear_id(); + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < - // k_blocks - const size_t block_i = gr_id / (m_blocks * k_blocks); - const size_t block_r = gr_id - block_i * (m_blocks * k_blocks); - const size_t block_j = block_r / k_blocks; - const size_t block_s = block_r - block_j * k_blocks; + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, + // 0 <= block_s < k_blocks - const size_t lid = it.get_local_linear_id(); - const size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n - const size_t local_j = - lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m // load A block and B blocks into SLM size_t i = block_i * wi_delta_n * wg_delta_n; size_t j = block_j * wi_delta_m * wg_delta_m; - const size_t s = block_s * wi_delta_k; + size_t s = block_s * wi_delta_k; const std::int64_t a_st0 = k; const std::int64_t a_st1 = 1; @@ -674,34 +686,35 @@ class GemmFunctorThreadNM size_t lws = it.get_local_range(0); for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { - const size_t v_i = - vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n - const size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k - const size_t g_i = i + v_i; - const size_t g_s = s + v_s; + size_t g_i = i + v_i; + size_t g_s = s + v_s; local_A_block[vid] = (g_i < n && g_s < k) ? static_cast( - lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) : resT(0); } using slmB_t = typename LocAccT2::value_type; for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { - const size_t v_j = vid / wi_delta_k; // 0 <= v_i < wg_delta_m - const size_t v_s = vid - v_j * wi_delta_k; // 0 <= v_s < wi_delta_k + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k - const size_t g_j = j + v_j * wi_delta_m; - const size_t g_s = s + v_s; + size_t g_j = j + v_j * wi_delta_m; + size_t g_s = s + v_s; if constexpr (wi_delta_m == 1 && std::is_same_v) { local_B_block[vid] = (g_j < m && g_s < k) ? static_cast( - rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) : resT(0); } else { @@ -713,7 +726,8 @@ class GemmFunctorThreadNM vec[lane_id] = (g_j1 < m && g_s < k) ? static_cast( - rhs[rhs_indexer(g_s * b_st0 + g_j1 * b_st1)]) + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j1 * b_st1)]) : resT(0); } @@ -726,10 +740,11 @@ class GemmFunctorThreadNM i += local_i * wi_delta_n; j += local_j * wi_delta_m; - const size_t a_offset = local_i * wi_delta_k * wi_delta_n; - const size_t b_offset = local_j * wi_delta_k; + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { size_t a_pr_offset = private_i * wi_delta_k; @@ -740,7 +755,7 @@ class GemmFunctorThreadNM local_B_block[b_offset + private_s]); } - const size_t gl_i = i + private_i; + size_t gl_i = i + private_i; if constexpr (wi_delta_m == 1 && std::is_same_v) { const size_t gl_j = j; @@ -748,7 +763,8 @@ class GemmFunctorThreadNM sycl::atomic_ref - aout(res[res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + aout(res[res_offset + + res_indexer(gl_i * c_st0 + gl_j * c_st1)]); aout += local_sum; } @@ -764,7 +780,8 @@ class GemmFunctorThreadNM resT, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::global_space> - aout(res[res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + aout(res[res_offset + + res_indexer(gl_i * c_st0 + gl_j * c_st1)]); aout += local_sum[lane_id]; } @@ -779,8 +796,9 @@ template -class GemmFunctorThreadK +class GemmBatchFunctorThreadK { private: const lhsT *lhs = nullptr; @@ -796,54 +814,76 @@ class GemmFunctorThreadK size_t delta_k = 0; size_t n_wi = 0; size_t m = 0; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - OuterInnerDimsIndexerT res_indexer; + size_t batch_nelems = 0; + const BatchDimsIndexerT batch_indexer; + const OuterInnerDimsIndexerT lhs_indexer; + const OuterInnerDimsIndexerT rhs_indexer; + const OuterInnerDimsIndexerT res_indexer; public: - GemmFunctorThreadK(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT workspace_, - LocAccT local_B_block_, - size_t n_, - size_t n_blocks_, - size_t delta_n_, - size_t k_, - size_t k_blocks_, - size_t delta_k_, - size_t n_wi_, - size_t m_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - OuterInnerDimsIndexerT res_indexer_) + GemmBatchFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + const BatchDimsIndexerT &batch_indexer_, + const OuterInnerDimsIndexerT &lhs_indexer_, + const OuterInnerDimsIndexerT &rhs_indexer_, + const OuterInnerDimsIndexerT &res_indexer_) : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), - n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) { } void operator()(sycl::nd_item<1> it) const { - size_t gr_id = it.get_group_linear_id(); - size_t lid = it.get_local_linear_id(); + // for batching: + // (current matrix in batch) m_id = global_id / (global_range / + // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = + // m_id + // * (k * m) for res, offset = m_id * (n * m) + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + const size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); // lift gr_id -> (block_i, block_j, block_s) // block_i moves fastest, then block_s, then block_j - size_t block_j = - gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks - size_t block_r = - gr_id - block_j * (n_blocks * - k_blocks); // 0 <= block_r < n_blocks * k_blocks - size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks - size_t block_i = - block_r - block_s * n_blocks; // 0 <= block_i < n_blocks - - size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n - size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + const size_t r_size = (n_blocks * k_blocks); + // 0 <= block_j < m_blocks, + const size_t block_j = gr_id / r_size; + // 0 <= block_r < n_blocks * k_blocks + const size_t block_r = gr_id - block_j * r_size; + // 0 <= block_s < k_blocks + const size_t block_s = block_r / n_blocks; + // 0 <= block_i < n_blocks + const size_t block_i = block_r - block_s * n_blocks; + + // 0 <= local_i < delta_n + const size_t local_i = lid / (delta_k); + // 0 <= local_s < delta_k + const size_t local_s = lid - local_i * (delta_k); size_t i = block_i * delta_n + local_i; size_t j = m_groups * block_j; @@ -854,13 +894,14 @@ class GemmFunctorThreadK constexpr resT identity_ = resT(0); if (local_i == 0) { for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { - size_t sq = s + q; - size_t sqmj = sq * m + j; + const size_t sq = s + q; + const size_t sqmj = sq * m + j; if constexpr (m_groups == 1 && std::is_same_v) { local_B_block[local_s + q] = (sq < k && j < m) - ? static_cast(rhs[rhs_indexer(sqmj)]) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj)]) : identity_; } else { @@ -870,7 +911,8 @@ class GemmFunctorThreadK local_B_vec[vec_idx] = (sq < k && j + vec_idx < m) ? static_cast( - rhs[rhs_indexer(sqmj + vec_idx)]) + rhs[rhs_offset + + rhs_indexer(sqmj + vec_idx)]) : identity_; } local_B_block[local_s + q] = local_B_vec; @@ -886,11 +928,12 @@ class GemmFunctorThreadK accV_t private_sum(identity_); constexpr accV_t vec_identity_(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - private_sum += ((i < n) && (t + t_shift < k)) - ? (static_cast( - lhs[lhs_indexer(global_s_offset + t)]) * - local_B_block[t]) - : vec_identity_; + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; } size_t workspace_i_shift = local_i * delta_k; @@ -907,7 +950,7 @@ class GemmFunctorThreadK sycl::atomic_ref - aout0(res[res_indexer(i * m + j)]); + aout0(res[res_offset + res_indexer(i * m + j)]); if constexpr (m_groups == 1 && std::is_same_v) { aout0 += local_sum; @@ -922,7 +965,8 @@ class GemmFunctorThreadK resT, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::global_space> - aout1(res[res_indexer(i * m + j + vec_id)]); + aout1(res[res_offset + + res_indexer(i * m + j + vec_id)]); aout1 += local_sum[vec_id]; } @@ -940,562 +984,495 @@ class gemm_k_krn; template class gemm_nm_krn; -typedef sycl::event (*gemm_impl_fn_ptr_t)( - sycl::queue &, - const char *, // lhs - const char *, // rhs - char *, // res - size_t, // lhs_outer_nelems (n) - size_t, // inner_nelems (k) - size_t, // rhs_outer_nelems (m) - int, // inner nd - int, // lhs outer nd - const ssize_t *, // lhs shape and strides - int, // rhs outer nd - const ssize_t *, // rhs shape and strides - int, // res outer nd - const ssize_t *, // res shape and strides - std::vector const &); +template +class gemm_batch_k_krn; -template -sycl::event gemm_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t n, - size_t k, - size_t m, - int inner_nd, - int lhs_outer_nd, - const ssize_t *lhs_shape_strides, - int rhs_outer_nd, - const ssize_t *rhs_shape_strides, - int res_outer_nd, - const ssize_t *res_shape_strides, - std::vector const &depends = {}) -{ - const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); - const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); - resTy *res_tp = reinterpret_cast(res_cp); +template +class gemm_batch_nm_krn; - sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); +namespace gemm_detail +{ - using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; - IndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - using InitKernelName = class gemm_init_krn; - cgh.parallel_for( - sycl::range<1>(n * m), [=](sycl::id<1> id) { - auto res_offset = res_indexer(id[0]); - res_tp[res_offset] = resTy(0); - }); - }); +template +sycl::event _gemm_old_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const size_t batch_nelems, + const size_t n, + const size_t k, + const size_t m, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + const std::vector &depends) +{ + static_assert(std::is_same_v); + static_assert(std::is_same_v); - if (k == 0) { - return res_init_ev; - } + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = dev.get_info(); const size_t reserved_slm_size = 512; - using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_shape_strides); - OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_shape_strides); - OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - - if (m < 4) { - constexpr size_t m_groups = 1; - const size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + const size_t lws = wg_delta_n * wg_delta_m; - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + const size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + const size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + const size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t lws = delta_n * delta_k; + auto gwsRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_batch_nm_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, std::move(local_A_block), + std::move(local_B_block), n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); + }); + return gemm_ev; +} - using KernelName = class gemm_k_krn; - cgh.parallel_for( - ndRange, GemmFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, lhs_indexer, rhs_indexer, res_indexer)); - }); - return gemm_ev; - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - const size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); +template +sycl::event _gemm_old_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const size_t batch_nelems, + const size_t n, + const size_t k, + const size_t m, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + const std::vector &depends) +{ + constexpr size_t m_groups = 4; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + static_assert(std::is_same_v); + static_assert(std::is_same_v); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; - using LocAccT = sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - using KernelName = class gemm_k_krn; - cgh.parallel_for( - ndRange, GemmFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, lhs_indexer, rhs_indexer, res_indexer)); - }); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - return gemm_ev; - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t lws = delta_n * delta_k; - auto gwsRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_nm_krn; - cgh.parallel_for( - ndRange, - GemmFunctorThreadNM( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - }); - return gemm_ev; - } + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, std::move(workspace), + std::move(local_B_block), n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; } -typedef sycl::event (*gemm_contig_impl_fn_ptr_t)( - sycl::queue &, - const char *, // lhs - const char *, // rhs - char *, // res - size_t, // n - size_t, // k - size_t, // m - std::vector const &); - -template -sycl::event gemm_contig_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t n, - size_t k, - size_t m, - std::vector const &depends = {}) +template +sycl::event _gemm_old_small_m_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const size_t batch_nelems, + const size_t n, + const size_t k, + const size_t m, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + const std::vector &depends) { - const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); - const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); - resTy *res_tp = reinterpret_cast(res_cp); - - sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - cgh.fill(res_tp, resTy(0), n * m); - }); + constexpr size_t m_groups = 1; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - if (k == 0) { - return res_init_ev; - } + static_assert(std::is_same_v); + static_assert(std::is_same_v); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = dev.get_info(); const size_t reserved_slm_size = 512; - using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerIndexerT lhs_indexer{}; - OuterInnerIndexerT rhs_indexer{}; - OuterInnerIndexerT res_indexer{}; - - if (m < 4) { - constexpr size_t m_groups = 1; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_k_krn; - cgh.parallel_for( - ndRange, GemmFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, lhs_indexer, rhs_indexer, res_indexer)); - }); - - return gemm_ev; - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); - - using LocAccT = sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - using KernelName = class gemm_k_krn; - cgh.parallel_for( - ndRange, GemmFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, lhs_indexer, rhs_indexer, res_indexer)); - }); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - return gemm_ev; - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t lws = delta_n * delta_k; - auto gwsRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_nm_krn; - cgh.parallel_for( - ndRange, - GemmFunctorThreadNM( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - }); + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, std::move(workspace), + std::move(local_B_block), n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); - return gemm_ev; - } + return gemm_ev; } +} // end of namespace gemm_detail + template -class GemmNoAtomicFunctorThreadNM + std::uint32_t wi_delta_n, + std::uint32_t wi_delta_m_vecs, + std::uint32_t m_vec_size> +class GemmBatchFunctorThreadNM_vecm { private: const lhsT *lhs = nullptr; const rhsT *rhs = nullptr; resT *res = nullptr; - LocAccT1 local_A_block; - LocAccT2 local_B_block; + LocAccT1 local_lhs_block; + LocAccT2 local_rhs_block; + size_t batch_nelems; size_t n = 0; - size_t wg_delta_n = 0; size_t k = 0; - size_t k_blocks = 0; - size_t wi_delta_k = 0; size_t m = 0; - size_t m_blocks = 0; - size_t wg_delta_m = 0; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - ResIndexerT res_indexer; + size_t n_groups = 0; + std::uint32_t wg_delta_n = 0; + std::uint32_t wg_delta_m = 0; + std::uint32_t wi_delta_k = 0; + const BatchDimsIndexerT batch_indexer; + const LhsIndexerT lhs_indexer; + const RhsIndexerT rhs_indexer; + const ResIndexerT res_indexer; public: - GemmNoAtomicFunctorThreadNM(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT1 local_A_block_, - LocAccT2 local_B_block_, - size_t n_, - size_t wg_delta_n_, - size_t k_, - size_t k_blocks_, - size_t wi_delta_k_, - size_t m_, - size_t m_blocks_, - size_t wg_delta_m_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - ResIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), - local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), - k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), - m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), - lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), - res_indexer(res_indexer_) + /*! @brief */ + GemmBatchFunctorThreadNM_vecm(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_lhs_block_, + LocAccT2 local_rhs_block_, + size_t batch_nelems_, + size_t n_, + size_t k_, + size_t m_, + size_t n_groups_, + size_t wg_delta_n_, + size_t wg_delta_m_, + size_t wi_delta_k_, + const BatchDimsIndexerT &batch_indexer_, + const LhsIndexerT &lhs_indexer_, + const RhsIndexerT &rhs_indexer_, + const ResIndexerT &res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_lhs_block(local_lhs_block_), + local_rhs_block(local_rhs_block_), batch_nelems(batch_nelems_), n(n_), + k(k_), m(m_), n_groups(n_groups_), wg_delta_n(wg_delta_n_), + wg_delta_m(wg_delta_m_), wi_delta_k(wi_delta_k_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) { } void operator()(sycl::nd_item<1> it) const { - size_t gr_id = it.get_group_linear_id(); - // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < - // k_blocks - size_t block_i = gr_id / (m_blocks * k_blocks); - size_t block_r = gr_id - block_i * (m_blocks * k_blocks); - size_t block_j = block_r / k_blocks; - size_t block_s = block_r - block_j * k_blocks; - - size_t lid = it.get_local_linear_id(); - size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n - size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m - - // load A block and B blocks into SLM - - size_t i = block_i * wi_delta_n * wg_delta_n; - size_t j = block_j * wi_delta_m * wg_delta_m; - size_t s = block_s * wi_delta_k; + constexpr resT zero_(0); + constexpr std::uint32_t wi_total_delta_m = wi_delta_m_vecs * m_vec_size; - const std::int64_t a_st0 = k; - const std::int64_t a_st1 = 1; - - const std::int64_t b_st0 = m; - const std::int64_t b_st1 = 1; + const size_t gws_per_batch = it.get_group_range(0) / batch_nelems; + const size_t batch_id = it.get_group_linear_id() / gws_per_batch; + const size_t gr_id = + it.get_group_linear_id() - batch_id * gws_per_batch; - const std::int64_t c_st0 = m; - const std::int64_t c_st1 = 1; + const auto &three_offsets_ = + batch_indexer(static_cast(batch_id)); - size_t lws = it.get_local_range(0); + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); - for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { - size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n - size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + // 0 <= block_j < m_groups + const size_t block_j = gr_id / n_groups; + // 0 <= block_i < n_groups + const size_t block_i = gr_id - block_j * n_groups; - size_t g_i = i + v_i; - size_t g_s = s + v_s; + // Assumption: lws == wg_delta_n * wg_delta_m + const std::uint32_t lid = it.get_local_linear_id(); + // 0 <= local_j < (lws / wg_delta_n == wg_delta_m) + const std::uint32_t local_j = lid / wg_delta_n; + // sub-group lanes map to adjacent local_i + const std::uint32_t local_i = lid - local_j * wg_delta_n; - local_A_block[vid] = - (g_i < n && g_s < k) - ? static_cast( - lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) - : resT(0); - } + // Coordinates of the block of C the work-group works on + size_t i = block_i * wg_delta_n * wi_delta_n; + size_t j = block_j * wg_delta_m * wi_total_delta_m; using slmB_t = typename LocAccT2::value_type; - for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { - size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m - size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_j = j + v_j * wi_delta_m; - size_t g_s = s + v_s; - - if constexpr (wi_delta_m == 1 && std::is_same_v) { - local_B_block[vid] = - (g_j < m && g_s < k) + const size_t a_st0 = k; + const size_t a_st1 = 1; + + const size_t b_st0 = m; + const size_t b_st1 = 1; + + const size_t c_st0 = m; + const size_t c_st1 = 1; + + // allocate/initialize private matrix C + // size ( wi_total_delta_n, wi_total_delta_m ) + constexpr std::uint32_t C_size = wi_delta_n * wi_delta_m_vecs; + std::array private_C{slmB_t{zero_}}; + + for (size_t s = 0; s < k; s += wi_delta_k) { + // populate local_lhs_block ( wg_delta_n * wi_delta_n, + // wi_delta_k) + for (std::uint32_t vid = lid; vid < local_lhs_block.size(); + vid += it.get_local_range()[0]) + { + // 0 <= v_i < wg_delta_n * wi_delta_n + const std::uint32_t v_i = vid / wi_delta_k; + // 0 <= v_s < wi_delta_k + const std::uint32_t v_s = vid - v_i * wi_delta_k; + + const size_t g_i = i + v_i; + const size_t g_s = s + v_s; + + const std::uint32_t mapped_vid = + wg_delta_n * wi_delta_n * v_s + v_i; + local_lhs_block[mapped_vid] = + (g_i < n && g_s < k) ? static_cast( - rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) - : resT(0); + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : zero_; } - else { - slmB_t vec{}; -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) - { - size_t g_j1 = g_j + lane_id; - vec[lane_id] = - (g_j1 < m && g_s < k) + + // populate local_rhs_block> ( wg_delta_m * + // wi_delta_m_vecs, wi_delta_k ) + for (std::uint32_t vid = lid; vid < local_rhs_block.size(); + vid += it.get_local_range()[0]) + { + // 0 <= v_j < wg_delta_m * wi_delta_m_vecs + const std::uint32_t v_j = vid / wi_delta_k; + // 0 <= v_s < wi_delta_k + const std::uint32_t v_s = vid - v_j * wi_delta_k; + + const size_t g_j = j + v_j * m_vec_size; + const size_t g_s = s + v_s; + const std::uint32_t mapped_vid = + wg_delta_m * wi_delta_m_vecs * v_s + v_j; + + if constexpr (m_vec_size == 1) { + local_rhs_block[mapped_vid] = + (g_j < m && g_s < k) ? static_cast( - rhs[rhs_indexer(g_s * b_st0 + g_j1 * b_st1)]) - : resT(0); + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : zero_; + } + else { + slmB_t vec{}; +#pragma unroll + for (std::uint32_t lane_id = 0; lane_id < m_vec_size; + ++lane_id) { + const size_t g_j1 = g_j + lane_id; + vec[lane_id] = (g_j1 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + + g_j1 * b_st1)]) + : zero_; + }; + + local_rhs_block[mapped_vid] = vec; } - - local_B_block[vid] = vec; } - } - - it.barrier(sycl::access::fence_space::local_space); - - i += local_i * wi_delta_n; - j += local_j * wi_delta_m; - - const size_t a_offset = local_i * wi_delta_k * wi_delta_n; - const size_t b_offset = local_j * wi_delta_k; - - constexpr resT identity_(0); - for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { - const size_t a_pr_offset = private_i * wi_delta_k; + it.barrier(sycl::access::fence_space::local_space); - slmB_t local_sum(identity_); - for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { - local_sum = local_sum + - (local_A_block[a_offset + a_pr_offset + private_s] * - local_B_block[b_offset + private_s]); + const std::uint32_t lo_lhs_st_k = (wg_delta_n * wi_delta_n); + const std::uint32_t lo_rhs_rk_k = (wg_delta_m * wi_delta_m_vecs); + for (std::uint32_t pr_k = 0; pr_k < wi_delta_k; ++pr_k) { +#pragma unroll + for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) { +#pragma unroll + for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) + { + private_C[pr_i * wi_delta_m_vecs + pr_j] += + local_lhs_block[pr_k * lo_lhs_st_k + + (local_i + pr_i * wg_delta_n)] * + local_rhs_block[pr_k * lo_rhs_rk_k + + (local_j + pr_j * wg_delta_m)]; + } + } } - size_t gl_i = i + private_i; + it.barrier(sycl::access::fence_space::local_space); + } - if constexpr (wi_delta_m == 1 && std::is_same_v) { - const size_t gl_j = j; - if (gl_i < n && gl_j < m) { - res[res_indexer(gl_i * c_st0 + gl_j * c_st1 + - block_s * n * m)] = local_sum; + if constexpr (m_vec_size == 1) { +#pragma unroll + for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) { + size_t out_i = i + local_i + pr_i * wg_delta_n; + if (out_i < n) { +#pragma unroll + for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) + { + const size_t out_j = + j + (local_j + pr_j * wg_delta_m) * m_vec_size; + const size_t out_flat_id = + out_i * c_st0 + out_j * c_st1; + if (out_j < m) { + res[res_offset + res_indexer(out_flat_id)] = + private_C[pr_i * wi_delta_m_vecs + pr_j]; + } + } } } - else { + } + else { #pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) - { - const size_t gl_j = j + lane_id; - - if (gl_i < n && gl_j < m) { - res[res_indexer(gl_i * c_st0 + gl_j * c_st1 + - block_s * n * m)] = local_sum[lane_id]; + for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) { + size_t out_i = i + local_i + pr_i * wg_delta_n; + if (out_i < n) { + // could be unrolled + for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) + { + size_t out_j = + j + (local_j + pr_j * wg_delta_m) * m_vec_size; +#pragma unroll + for (std::uint32_t lane_id = 0; lane_id < m_vec_size; + ++lane_id) { + const size_t out_flat_id = + out_i * c_st0 + (out_j + lane_id) * c_st1; + if (out_j + lane_id < m) { + res[res_offset + res_indexer(out_flat_id)] = + private_C[pr_i * wi_delta_m_vecs + pr_j] + [lane_id]; + } + } } } } @@ -1503,156 +1480,73 @@ class GemmNoAtomicFunctorThreadNM } }; -template -class GemmNoAtomicFunctorThreadK +struct GemmBatchFunctorThreadNM_vecm_HyperParameters { private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT workspace; - LocAccT local_B_block; - size_t n = 0; - size_t n_blocks = 0; - size_t delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t delta_k = 0; - size_t n_wi = 0; - size_t m = 0; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - ResIndexerT res_indexer; + std::uint32_t wi_delta_n = 2; + std::uint32_t wi_delta_m_vecs = 4; + std::uint32_t m_vec_size = 1; public: - GemmNoAtomicFunctorThreadK(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT workspace_, - LocAccT local_B_block_, - size_t n_, - size_t n_blocks_, - size_t delta_n_, - size_t k_, - size_t k_blocks_, - size_t delta_k_, - size_t n_wi_, - size_t m_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - ResIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), - local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), - delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), - n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), - rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters(); + constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters( + std::uint32_t wi_delta_n_, + std::uint32_t wi_delta_m_vecs_, + std::uint32_t m_vec_size_) + : wi_delta_n(wi_delta_n_), wi_delta_m_vecs(wi_delta_m_vecs_), + m_vec_size(m_vec_size_) { } - void operator()(sycl::nd_item<1> it) const + constexpr std::uint32_t get_wi_delta_n() const { - size_t gr_id = it.get_group_linear_id(); - size_t lid = it.get_local_linear_id(); - - // lift gr_id -> (block_i, block_j, block_s) - // block_i moves fastest, then block_s, then block_j - - size_t block_j = - gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks - size_t block_r = - gr_id - block_j * (n_blocks * - k_blocks); // 0 <= block_r < n_blocks * k_blocks - size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks - size_t block_i = - block_r - block_s * n_blocks; // 0 <= block_i < n_blocks - - size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n - size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k - - size_t i = block_i * delta_n + local_i; - size_t j = m_groups * block_j; - size_t s = block_s * delta_k * n_wi + local_s; - - using accV_t = typename LocAccT::value_type; + return wi_delta_n; + } + constexpr std::uint32_t get_wi_delta_m_vecs() const + { + return wi_delta_m_vecs; + } + constexpr std::uint32_t get_m_vec_size() const + { + return m_vec_size; + } +}; - constexpr resT identity_ = resT(0); - if (local_i == 0) { - for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { - size_t sq = s + q; - size_t sqmj = sq * m + j; +template +struct GemmBatchFunctorThreadNM_vecm_HyperParametersSelector +{ + constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector() {} - if constexpr (m_groups == 1 && std::is_same_v) { - local_B_block[local_s + q] = - (sq < k && j < m) - ? static_cast(rhs[rhs_indexer(sqmj)]) - : identity_; - ; - } - else { - accV_t local_B_vec; -#pragma unroll - for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { - local_B_vec[vec_idx] = - (sq < k && j + vec_idx < m) - ? static_cast( - rhs[rhs_indexer(sqmj + vec_idx)]) - : identity_; - } - local_B_block[local_s + q] = local_B_vec; - } - } + constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters get() const + { + if constexpr (sizeof(resT) == 1) { + // 1 * 8 * 2 * 4 == 64 + return GemmBatchFunctorThreadNM_vecm_HyperParameters(8, 2, 4); } - - it.barrier(sycl::access::fence_space::local_space); - - size_t t_shift = block_s * delta_k * n_wi; - size_t global_s_offset = i * k + t_shift; - - accV_t private_sum(identity_); - constexpr accV_t vec_identity_(identity_); - for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - private_sum += ((i < n) && (t + t_shift < k)) - ? (static_cast( - lhs[lhs_indexer(global_s_offset + t)]) * - local_B_block[t]) - : vec_identity_; + else if constexpr (sizeof(resT) == 2) { + // 2 * 4 * 2 * 4 == 64 + return GemmBatchFunctorThreadNM_vecm_HyperParameters(4, 2, 4); } - - size_t workspace_i_shift = local_i * delta_k; - workspace[workspace_i_shift + local_s] = private_sum; - - it.barrier(sycl::access::fence_space::local_space); - - if (local_s == 0 && i < n) { - accV_t local_sum(workspace[workspace_i_shift]); - for (size_t t = 1; t < delta_k; ++t) { - local_sum += workspace[workspace_i_shift + t]; - } - - const size_t res_offset = (block_s * n * m); - - if constexpr (m_groups == 1 && std::is_same_v) { - res[res_indexer(i * m + j) + res_offset] = local_sum; + else if constexpr (sizeof(resT) == 4) { + // 4 * 4 * 1 * 4 == 64 + return GemmBatchFunctorThreadNM_vecm_HyperParameters(4, 1, 4); + } + else if constexpr (sizeof(resT) == 8) { + // 8 * 2 * 1 * 4 == 64 + if constexpr (std::is_same_v>) { + return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 4, 1); } else { - static_assert(m_groups >= 1); - res[res_indexer(i * m + j) + res_offset] = local_sum[0]; - -#pragma unroll - for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { - if (j + vec_id < m) { - res[res_indexer(i * m + j + vec_id) + res_offset] = - local_sum[vec_id]; - } - } + return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 1, 4); } } + else if constexpr (std::is_same_v>) { + // 16 * 2 * 2 * 1 == 64 + return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 2, 1); + } + else { + return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 2, 1); + } } }; @@ -1661,1444 +1555,579 @@ template -class gemm_tree_nm_krn; + typename T6, + typename T7, + std::uint32_t p1, + std::uint32_t p2, + std::uint32_t p3> +class gemm_batch_nm_vecm_krn; -template -class gemm_tree_k_krn; +namespace gemm_detail +{ -template -sycl::event gemm_tree_k_impl(sycl::queue &exec_q, - const lhsTy *lhs_tp, - const rhsTy *rhs_tp, - resTy *res_tp, - size_t n, - size_t k, - size_t m, - int inner_nd, - int lhs_outer_nd, - const ssize_t *lhs_outer_inner_shapes_strides, - int rhs_outer_nd, - const ssize_t *rhs_outer_inner_shapes_strides, - int res_nd, - const ssize_t *res_shapes_strides, - const std::vector &depends) +template +std::tuple +get_wg_delta_m_and_wi_delta_k(const size_t slm_byte_size, + const std::uint32_t wg_delta_n, + const std::uint32_t suggested_wg_delta_m) { - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + std::uint32_t wg_delta_m = suggested_wg_delta_m; - const sycl::device &dev = exec_q.get_device(); - const size_t local_mem_size = - dev.get_info(); - const size_t reserved_slm_size = 512; + const size_t slm_max_rows = + slm_byte_size / + ((wg_delta_n * wi_delta_n + wg_delta_m * wi_delta_m) * sizeof(T)); - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + std::uint32_t wi_delta_k = + (slm_max_rows >= 64) + ? 64 + : 32 * static_cast(slm_max_rows / 32); - sycl::event gemm_ev; - if (k <= (delta_k * n_wi)) { - gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + for (std::uint32_t it = 0; !wi_delta_k && (it < 4); ++it) { + wg_delta_m /= 2; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); - return gemm_ev; + const size_t slm_max_rows = + slm_byte_size / + ((wg_delta_n * wi_delta_n + wg_delta_m * wi_delta_m) * sizeof(T)); + + wi_delta_k = + (slm_max_rows >= 64) + ? 64 + : ((slm_max_rows >= 32) + ? 32 + : (slm_max_rows >= 16 ? 16 + : 8 * static_cast( + slm_max_rows / 8))); } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - size_t iter_nelems = n * m; - size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + if (!wi_delta_k) { + throw std::runtime_error("Insufficient resources"); + } - // more than one work-groups is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + return std::make_tuple(wg_delta_m, wi_delta_k); +} - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); +template +sycl::event _gemm_batch_new_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const size_t batch_nelems, + const size_t n, + const size_t k, + const size_t m, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + std::vector const &depends) +{ + constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector + selector{}; + constexpr auto hyper_params = selector.get(); - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); + constexpr std::uint32_t wi_delta_n = hyper_params.get_wi_delta_n(); + constexpr std::uint32_t wi_delta_m_vecs = + hyper_params.get_wi_delta_m_vecs(); + constexpr std::uint32_t m_vec_size = hyper_params.get_m_vec_size(); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); + constexpr std::uint32_t wi_total_delta_m = wi_delta_m_vecs * m_vec_size; - if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - if (!tmp) { - throw std::runtime_error("Unable to allocate device memory"); - } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using KernelName = + class gemm_batch_nm_vecm_krn; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); + const auto &kernel_id = sycl::get_kernel_id(); - sycl::event red_ev = single_reduction_for_gemm( - exec_q, tmp, res_tp, identity_val, iter_nelems, - reduction_nelems, reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, - res_shapes_strides, {gemm_ev}); + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + auto krn = kb.get_kernel(kernel_id); - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - else { - assert(reduction_groups > 1); + const std::uint32_t max_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + // Limit work-group size + constexpr size_t wg_sz_limit(2048); + const size_t max_wg_sz = std::min( + dev.get_info(), wg_sz_limit); + const std::uint32_t max_subgroups_per_wg = + static_cast(max_wg_sz / max_sg_size); - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + const size_t reserved_slm_byte_size = 512; + const size_t slm_byte_size = + dev.get_info(); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + const std::uint32_t wg_delta_n = max_sg_size; + std::uint32_t wg_delta_m = 0; + std::uint32_t wi_delta_k = 0; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - } - }); - // tree_reduction_for_gemm returns sycl::event for reduction - sycl::event red_ev = tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, - wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, - res_nd, 0, res_shapes_strides, {gemm_ev}); + std::tie(wg_delta_m, wi_delta_k) = + get_wg_delta_m_and_wi_delta_k( + slm_byte_size - reserved_slm_byte_size, wg_delta_n, + max_subgroups_per_wg); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + const std::uint32_t lws = wg_delta_n * wg_delta_m; - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); + const size_t n_groups = + (n + wg_delta_n * wi_delta_n - 1) / (wg_delta_n * wi_delta_n); + const size_t m_groups = (m + wg_delta_m * wi_total_delta_m - 1) / + (wg_delta_m * wi_total_delta_m); - return cleanup_host_task_event; - } - } + const size_t gws = lws * batch_nelems * n_groups * m_groups; + + sycl::range<1> lRange(lws); + sycl::range<1> gRange(gws); + sycl::nd_range<1> ndRange(gRange, lRange); + + using slmB_t = + typename std::conditional>::type; + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block(wg_delta_n * wi_delta_n * wi_delta_k, cgh); + + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wg_delta_m * wi_delta_m_vecs * wi_delta_k, cgh); + + using Impl_FunctorT = GemmBatchFunctorThreadNM_vecm< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, BatchIndexerT, LhsIndexerT, + RhsIndexerT, ResIndexerT, wi_delta_n, wi_delta_m_vecs, m_vec_size>; + + cgh.parallel_for( + ndRange, Impl_FunctorT( + lhs_tp, rhs_tp, res_tp, std::move(local_A_block), + std::move(local_B_block), batch_nelems, n, k, m, + n_groups, wg_delta_n, wg_delta_m, wi_delta_k, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); + }); + return gemm_ev; } -template -sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, - const lhsTy *lhs_tp, - const rhsTy *rhs_tp, - resTy *res_tp, - size_t n, - size_t k, - size_t m, - int inner_nd, - int lhs_outer_nd, - const ssize_t *lhs_outer_inner_shapes_strides, - int rhs_outer_nd, - const ssize_t *rhs_outer_inner_shapes_strides, - int res_nd, - const ssize_t *res_shapes_strides, - const std::vector &depends) -{ - constexpr int wi_delta_n = 2; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - const sycl::device &dev = exec_q.get_device(); - const size_t local_mem_size = - dev.get_info(); - const size_t reserved_slm_size = 512; +} // namespace gemm_detail - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); +typedef sycl::event (*gemm_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // lhs_outer_nelems (n) + size_t, // inner_nelems (k) + size_t, // rhs_outer_nelems (m) + int, // inner nd + int, // lhs outer nd + const ssize_t *, // lhs shape and strides + int, // rhs outer nd + const ssize_t *, // rhs shape and strides + int, // res outer nd + const ssize_t *, // res shape and strides + std::vector const &); - // each group processes delta_k items in a column, - // so no need to allocate temp memory if one group needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); +template +sycl::event gemm_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_shape_strides, + int rhs_outer_nd, + const ssize_t *rhs_shape_strides, + int res_outer_nd, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - }); - return gemm_ev; + using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_shape_strides); + const OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_shape_strides); + const OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + constexpr BatchIndexerT batch_indexer{}; + + constexpr size_t single_batch_nelems = 1; + + const size_t min_nm = std::min(n, m); + const size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_detail::_gemm_batch_new_nm_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-groups is needed, requires a temporary - // wi_delta_k elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); - - if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - if (!tmp) { - throw std::runtime_error("Unable to allocate device memory"); - } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, - wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, - n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, - wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, - n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - }); - - sycl::event red_ev = single_reduction_for_gemm( - exec_q, tmp, res_tp, identity_val, iter_nelems, - reduction_nelems, reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, - res_shapes_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - else { - assert(reduction_groups > 1); - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, - wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, - m_blocks, wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, - wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, - m_blocks, wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); - } + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const IndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + using InitKernelName = class gemm_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); }); + }); - sycl::event red_ev = tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, - wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, - res_nd, 0, res_shapes_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); + if (k == 0) { + return res_init_ev; + } - return cleanup_host_task_event; + if ((max_nm < 64)) { + if (m < 4) { + return gemm_detail::_gemm_old_small_m_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); } + return gemm_detail::_gemm_old_k_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); } + + return gemm_detail::_gemm_old_nm_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, {res_init_ev}); } -template class gemm_tree_empty_krn; +typedef sycl::event (*gemm_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // n + size_t, // k + size_t, // m + std::vector const &); template -sycl::event gemm_tree_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t n, - size_t k, - size_t m, - int inner_nd, - int lhs_outer_nd, - const ssize_t *lhs_outer_inner_shapes_strides, - int rhs_outer_nd, - const ssize_t *rhs_outer_inner_shapes_strides, - int res_nd, - const ssize_t *res_shapes_strides, - std::vector const &depends = {}) +sycl::event gemm_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + std::vector const &depends = {}) { const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); resTy *res_tp = reinterpret_cast(res_cp); - if (k == 0) { - sycl::event gemm_no_reduction_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerIndexerT lhs_indexer{}; + constexpr OuterInnerIndexerT rhs_indexer{}; + constexpr OuterInnerIndexerT res_indexer{}; + + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + constexpr BatchIndexerT batch_indexer{}; + + constexpr size_t single_batch_nelems = 1; + + const size_t min_nm = std::min(n, m); + const size_t max_nm = std::max(n, m); + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_detail::_gemm_batch_new_nm_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + } - using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; - IndexerT res_indexer(res_nd, 0, res_shapes_strides); - using InitKernelName = - class gemm_tree_empty_krn; - cgh.parallel_for( - sycl::range<1>(n * m), [=](sycl::id<1> id) { - auto res_offset = res_indexer(id[0]); - res_tp[res_offset] = resTy(0); - }); - }); - return gemm_no_reduction_ev; + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m); + }); + + if (k == 0) { + return res_init_ev; } - if ((k > n && k > m) || m < 4) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m < 4) { - return gemm_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, - lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); - } - else { - return gemm_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, - lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); - } - } - else { - return gemm_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); + if (max_nm < 64) { + if (m < 4) { + return gemm_detail::_gemm_old_small_m_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); } + return gemm_detail::_gemm_old_k_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); - } - else { - return gemm_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); - } - } -} - -template -sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, - const lhsTy *lhs_tp, - const rhsTy *rhs_tp, - resTy *res_tp, - size_t n, - size_t k, - size_t m, - std::vector const &depends) -{ - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - const sycl::device &dev = exec_q.get_device(); - const size_t local_mem_size = - dev.get_info(); - const size_t reserved_slm_size = 512; - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - sycl::event gemm_ev; - if (k <= (delta_k * n_wi)) { - gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-groups is needed, requires a - // temporary delta_k * n_wi elements processed along k, - // so if more to process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); - if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - if (!tmp) { - throw std::runtime_error("Unable to allocate device memory"); - } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); - - sycl::event red_ev = - single_reduction_for_gemm_contig( - exec_q, tmp, res_tp, identity_val, iter_nelems, - reduction_nelems, reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - else { - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - } - }); - // tree_reduction_for_gemm_contig returns sycl::event - // for reduction - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, preferred_reductions_per_wi, - reductions_per_wi, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } + return gemm_detail::_gemm_old_nm_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, {res_init_ev}); } -template -sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, - const lhsTy *lhs_tp, - const rhsTy *rhs_tp, - resTy *res_tp, - size_t n, - size_t k, - size_t m, - std::vector const &depends) -{ - constexpr int wi_delta_n = 2; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - const sycl::device &dev = exec_q.get_device(); - const size_t local_mem_size = - dev.get_info(); - const size_t reserved_slm_size = 512; - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - // each group processes delta_k items in a column, - // so no need to allocate temp memory if one group needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-groups is needed, requires a temporary - // wi_delta_k elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); +template class gemm_batch_init_krn; - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); +typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // batch nelems + size_t, // lhs outer nelems (n) + size_t, // inner nelems (k) + size_t, // rhs outer nelems (m) + int, // batching nd + const ssize_t *, // batch shape strides + ssize_t, // lhs batch offset + ssize_t, // rhs batch offset + ssize_t, // res batch offset + int, // inner dims + int, // lhs outer dims + const ssize_t *, // lhs outer and inner shape and strides + int, // rhs outer dims + const ssize_t *, // rhs outer and inner shape and strides + int, // res outer dims + const ssize_t *, // res outer and inner shape and strides + const ssize_t *, // res full shape and strides + std::vector const &); - if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - if (!tmp) { - throw std::runtime_error("Unable to allocate device memory"); - } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); +template +sycl::event gemm_batch_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, - n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, - n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - }); + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + const BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + const size_t min_nm = std::min(n, m); + const size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_detail::_gemm_batch_new_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + } - sycl::event red_ev = - single_reduction_for_gemm_contig( - exec_q, tmp, res_tp, identity_val, iter_nelems, - reduction_nelems, reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides); + using InitKernelName = class gemm_batch_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - else { - assert(reduction_groups > 1); + if (k == 0) { + return res_init_ev; + } - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + if (m < 4) { + return gemm_detail::_gemm_old_small_m_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + else if (k > n && k > m) { + return gemm_detail::_gemm_old_k_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + else { + return gemm_detail::_gemm_old_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } +} - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } +typedef sycl::event (*gemm_batch_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // batch nelems + size_t, // n + size_t, // k + size_t, // m + ssize_t, // lhs batch offset + ssize_t, // rhs batch offset + ssize_t, // res batch offset + std::vector const &); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); +template +sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT res_indexer{}; - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, preferred_reductions_per_wi, - reductions_per_wi, {gemm_ev}); + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); + const size_t min_nm = std::min(n, m); + const size_t max_nm = std::max(n, m); - return cleanup_host_task_event; - } + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_detail::_gemm_batch_new_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); } -} -template -sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t n, - size_t k, - size_t m, - std::vector const &depends = {}) -{ - const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); - const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); - resTy *res_tp = reinterpret_cast(res_cp); + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); if (k == 0) { - sycl::event gemm_no_reduction_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - cgh.fill(res_tp, resTy(0), n * m); - }); - return gemm_no_reduction_ev; + return res_init_ev; } - if ((k > n && k > m) || m < 4) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m < 4) { - return gemm_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); - } - else { - return gemm_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); - } - } - else { - return gemm_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); - } - } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); - } - else { - return gemm_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + if (max_nm < 64) { + if (m < 4) { + return gemm_detail::_gemm_old_small_m_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); } + return gemm_detail::_gemm_old_k_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); } + + return gemm_detail::_gemm_old_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, {res_init_ev}); } +// ========== Gemm Tree + template -class GemmBatchFunctorThreadNM +class GemmBatchNoAtomicFunctorThreadNM { private: const lhsT *lhs = nullptr; @@ -3115,30 +2144,30 @@ class GemmBatchFunctorThreadNM size_t m_blocks = 0; size_t wg_delta_m = 0; size_t batch_nelems; - BatchDimsIndexerT batch_indexer; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - OuterInnerDimsIndexerT res_indexer; + const BatchDimsIndexerT batch_indexer; + const OuterInnerDimsIndexerT lhs_indexer; + const OuterInnerDimsIndexerT rhs_indexer; + const ResIndexerT res_indexer; public: - GemmBatchFunctorThreadNM(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT1 local_A_block_, - LocAccT2 local_B_block_, - size_t n_, - size_t wg_delta_n_, - size_t k_, - size_t k_blocks_, - size_t wi_delta_k_, - size_t m_, - size_t m_blocks_, - size_t wg_delta_m_, - size_t batch_nelems_, - BatchDimsIndexerT batch_indexer_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - OuterInnerDimsIndexerT res_indexer_) + GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + const BatchDimsIndexerT batch_indexer_, + const OuterInnerDimsIndexerT lhs_indexer_, + const OuterInnerDimsIndexerT rhs_indexer_, + const ResIndexerT res_indexer_) : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), @@ -3229,7 +2258,7 @@ class GemmBatchFunctorThreadNM #pragma unroll for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { - const size_t g_j1 = g_j + lane_id; + size_t g_j1 = g_j + lane_id; vec[lane_id] = (g_j1 < m && g_s < k) ? static_cast( @@ -3247,13 +2276,13 @@ class GemmBatchFunctorThreadNM i += local_i * wi_delta_n; j += local_j * wi_delta_m; - size_t a_offset = local_i * wi_delta_k * wi_delta_n; - size_t b_offset = local_j * wi_delta_k; + const size_t a_offset = local_i * wi_delta_k * wi_delta_n; + const size_t b_offset = local_j * wi_delta_k; constexpr resT identity_(0); for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { - size_t a_pr_offset = private_i * wi_delta_k; + const size_t a_pr_offset = private_i * wi_delta_k; slmB_t local_sum(identity_); for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { @@ -3262,18 +2291,13 @@ class GemmBatchFunctorThreadNM local_B_block[b_offset + private_s]); } - size_t gl_i = i + private_i; + const size_t gl_i = i + private_i; if constexpr (wi_delta_m == 1 && std::is_same_v) { const size_t gl_j = j; if (gl_i < n && gl_j < m) { - sycl::atomic_ref - aout(res[res_offset + - res_indexer(gl_i * c_st0 + gl_j * c_st1)]); - - aout += local_sum; + res[res_offset + res_indexer(gl_i * c_st0 + gl_j * c_st1) + + (block_s * n * m * batch_nelems)] = local_sum; } } else { @@ -3283,14 +2307,10 @@ class GemmBatchFunctorThreadNM const size_t gl_j = j + lane_id; if (gl_i < n && gl_j < m) { - sycl::atomic_ref< - resT, sycl::memory_order::relaxed, - sycl::memory_scope::device, - sycl::access::address_space::global_space> - aout(res[res_offset + - res_indexer(gl_i * c_st0 + gl_j * c_st1)]); - - aout += local_sum[lane_id]; + res[res_offset + + res_indexer(gl_i * c_st0 + gl_j * c_st1) + + (block_s * n * m * batch_nelems)] = + local_sum[lane_id]; } } } @@ -3303,9 +2323,10 @@ template -class GemmBatchFunctorThreadK +class GemmBatchNoAtomicFunctorThreadK { private: const lhsT *lhs = nullptr; @@ -3322,30 +2343,30 @@ class GemmBatchFunctorThreadK size_t n_wi = 0; size_t m = 0; size_t batch_nelems = 0; - BatchDimsIndexerT batch_indexer; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - OuterInnerDimsIndexerT res_indexer; - -public: - GemmBatchFunctorThreadK(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT workspace_, - LocAccT local_B_block_, - size_t n_, - size_t n_blocks_, - size_t delta_n_, - size_t k_, - size_t k_blocks_, - size_t delta_k_, - size_t n_wi_, - size_t m_, - size_t batch_nelems_, - BatchDimsIndexerT batch_indexer_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - OuterInnerDimsIndexerT res_indexer_) + const BatchDimsIndexerT batch_indexer; + const OuterInnerDimsIndexerT lhs_indexer; + const OuterInnerDimsIndexerT rhs_indexer; + const ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + const BatchDimsIndexerT &batch_indexer_, + const OuterInnerDimsIndexerT &lhs_indexer_, + const OuterInnerDimsIndexerT &rhs_indexer_, + const ResIndexerT &res_indexer_) : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), @@ -3357,19 +2378,13 @@ class GemmBatchFunctorThreadK void operator()(sycl::nd_item<1> it) const { - // for batching: - // (current matrix in batch) m_id = global_id / (global_range / - // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = - // m_id - // * (k * m) for res, offset = m_id * (n * m) const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; const size_t gr_id = it.get_group_linear_id() - m_id * n_groups_per_batch; - const size_t lid = it.get_local_linear_id(); + size_t lid = it.get_local_linear_id(); const auto &three_offsets_ = batch_indexer(static_cast(m_id)); - const auto &lhs_offset = three_offsets_.get_first_offset(); const auto &rhs_offset = three_offsets_.get_second_offset(); const auto &res_offset = three_offsets_.get_third_offset(); @@ -3377,18 +2392,18 @@ class GemmBatchFunctorThreadK // lift gr_id -> (block_i, block_j, block_s) // block_i moves fastest, then block_s, then block_j - const size_t block_j = - gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks - const size_t block_r = - gr_id - block_j * (n_blocks * - k_blocks); // 0 <= block_r < n_blocks * k_blocks - const size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks - const size_t block_i = - block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + const size_t r_size = (n_blocks * k_blocks); + // 0 <= block_j < m_blocks + size_t block_j = gr_id / r_size; + // 0 <= block_r < n_blocks * k_blocks + size_t block_r = gr_id - block_j * r_size; + // 0 <= block_s < k_blocks + size_t block_s = block_r / n_blocks; + // 0 <= block_i < n_blocks + size_t block_i = block_r - block_s * n_blocks; - const size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n - const size_t local_s = - lid - local_i * (delta_k); // 0 <= local_s < delta_k + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k size_t i = block_i * delta_n + local_i; size_t j = m_groups * block_j; @@ -3399,8 +2414,8 @@ class GemmBatchFunctorThreadK constexpr resT identity_ = resT(0); if (local_i == 0) { for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { - const size_t sq = s + q; - const size_t sqmj = sq * m + j; + size_t sq = s + q; + size_t sqmj = sq * m + j; if constexpr (m_groups == 1 && std::is_same_v) { local_B_block[local_s + q] = @@ -3452,923 +2467,1407 @@ class GemmBatchFunctorThreadK local_sum += workspace[workspace_i_shift + t]; } - sycl::atomic_ref - aout0(res[res_offset + res_indexer(i * m + j)]); + const size_t total_offset = + res_offset + (block_s * n * m * batch_nelems); if constexpr (m_groups == 1 && std::is_same_v) { - aout0 += local_sum; + res[total_offset + res_indexer(i * m + j)] = local_sum; } else { - aout0 += local_sum[0]; + res[total_offset + res_indexer(i * m + j)] = local_sum[0]; #pragma unroll for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { if (j + vec_id < m) { - sycl::atomic_ref< - resT, sycl::memory_order::relaxed, - sycl::memory_scope::device, - sycl::access::address_space::global_space> - aout1(res[res_offset + - res_indexer(i * m + j + vec_id)]); + res[total_offset + res_indexer(i * m + j + vec_id)] = + local_sum[1]; + } + } + } + } + } +}; + +template +class gemm_batch_tree_k_krn; + +template +class gemm_batch_tree_nm_krn; + +namespace gemm_detail +{ + +template +sycl::event _gemm_tree_k_step(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const size_t batch_nelems, + const size_t n, + const size_t k, + const size_t m, + const size_t delta_n, + const size_t n_wi, + const size_t delta_k, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + const std::vector &depends) +{ + static_assert(std::is_same_v); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const size_t n_blocks = (n + delta_n - 1) / delta_n; + const size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + const size_t m_blocks = (m + m_groups - 1) / m_groups; + + const size_t lws = delta_n * delta_k; + const size_t gws = batch_nelems * n_blocks * m_blocks * k_blocks * lws; + + auto gRange = sycl::range<1>(gws); + auto lRange = sycl::range<1>(lws); + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using slmB_t = + typename std::conditional>::type; + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_tree_k_krn; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, res_tp, std::move(workspace), + std::move(local_B_block), n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; +} + +} // end of namespace gemm_detail + +template +sycl::event +gemm_batch_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends) +{ + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + const BatchDimsIndexerT batch_indexer( + batch_nd, lhs_batch_offset, rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + return gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, delta_n, + n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + depends); + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + constexpr TmpIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + const UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + const Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + const BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, tmp, batch_nelems, n, k, m, delta_n, + n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + constexpr TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + const StridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides + 2 * batch_nd); + const Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + const BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, batch_nelems, n, + k, m, delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer, depends); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +namespace gemm_detail +{ + +template +sycl::event _gemm_tree_nm_step(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const size_t batch_nelems, + const size_t n, + const size_t k, + const size_t m, + const std::uint32_t wg_delta_n, + const std::uint32_t wg_delta_m, + const std::uint32_t wi_delta_k, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + const std::vector &depends) +{ + static_assert(std::is_same_v); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const size_t lws = wg_delta_n * wg_delta_m; + + const size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + const size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + const size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - aout1 += local_sum[vec_id]; - } - } - } - } - } -}; + const size_t gws = batch_nelems * n_blocks * m_blocks * k_blocks * lws; -template class gemm_batch_init_krn; + auto gwsRange = sycl::range<1>(gws); + auto lwsRange = sycl::range<1>(lws); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); -template -class gemm_batch_k_krn; + using slmB_t = + typename std::conditional>::type; + using LocAccT1 = sycl::local_accessor; + using LocAccT2 = sycl::local_accessor; -template -class gemm_batch_nm_krn; + const sycl::range<1> local_A_size((wi_delta_n * wg_delta_n) * + wi_delta_k); + const sycl::range<1> local_B_size(wi_delta_k * wg_delta_m); -typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( - sycl::queue &, - const char *, // lhs - const char *, // rhs - char *, // res - size_t, // batch nelems - size_t, // lhs outer nelems (n) - size_t, // inner nelems (k) - size_t, // rhs outer nelems (m) - int, // batching nd - const ssize_t *, // batch shape strides - ssize_t, // lhs batch offset - ssize_t, // rhs batch offset - ssize_t, // res batch offset - int, // inner dims - int, // lhs outer dims - const ssize_t *, // lhs outer and inner shape and strides - int, // rhs outer dims - const ssize_t *, // rhs outer and inner shape and strides - int, // res outer dims - const ssize_t *, // res outer and inner shape and strides - const ssize_t *, // res full shape and strides - std::vector const &); + LocAccT1 local_A_block(local_A_size, cgh); + LocAccT2 local_B_block(local_B_size, cgh); -template -sycl::event gemm_batch_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - int batch_nd, - const ssize_t *batch_shape_strides, - ssize_t lhs_batch_offset, - ssize_t rhs_batch_offset, - ssize_t res_batch_offset, - int inner_nd, - int lhs_outer_nd, - const ssize_t *lhs_outer_inner_shapes_strides, - int rhs_outer_nd, - const ssize_t *rhs_outer_inner_shapes_strides, - int res_outer_nd, - const ssize_t *res_outer_shapes_strides, - const ssize_t *res_shape_strides, - std::vector const &depends = {}) + using KernelName = + class gemm_batch_tree_nm_krn; + cgh.parallel_for( + ndRange, GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, LhsIndexerT, + ResIndexerT, BatchIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, std::move(local_A_block), + std::move(local_B_block), n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); + }); + return gemm_ev; +} + +} // end namespace gemm_detail + +template +sycl::event +gemm_batch_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends) { - const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); - const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); - resTy *res_tp = reinterpret_cast(res_cp); + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = dev.get_info(); const size_t reserved_slm_size = 512; - sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; - IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, - res_shape_strides); - using InitKernelName = class gemm_batch_init_krn; - cgh.parallel_for( - sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { - auto res_offset = res_indexer(id[0]); - res_tp[res_offset] = resTy(0); - }); - }); + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - if (k == 0) { - return res_init_ev; + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + const BatchDimsIndexerT batch_indexer( + batch_nd, lhs_batch_offset, rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + return gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, - res_outer_shapes_strides); - using BatchDimsIndexerT = - dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; - BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, - batch_shape_strides); + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - if (m < 4) { - constexpr size_t m_groups = 1; - const size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); - size_t lws = delta_n * delta_k; + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); - auto gRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } - auto ndRange = sycl::nd_range<1>(gRange, lRange); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + constexpr TmpIndexerT res_indexer{}; - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + const UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + const Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + const BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, wi_delta_n, wi_delta_m>( + exec_q, lhs_tp, rhs_tp, tmp, batch_nelems, n, k, m, wg_delta_n, + wg_delta_m, wi_delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); - using KernelName = - class gemm_batch_k_krn; - cgh.parallel_for( - ndRange, GemmBatchFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - return gemm_ev; + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + constexpr TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + const UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + const Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + const BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, wi_delta_n, wi_delta_m>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, batch_nelems, n, + k, m, wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - const size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); +} - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); +template +sycl::event +gemm_batch_new_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); - size_t lws = delta_n * delta_k; + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + const BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); - auto gRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + return gemm_ev; +} - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); +template +class gemm_batch_tree_empty_krn; - using LocAccT = sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); +template +sycl::event gemm_batch_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); - using KernelName = - class gemm_batch_k_krn; - cgh.parallel_for( - ndRange, GemmBatchFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); - return gemm_ev; + const size_t min_nm = std::min(n, m); + const size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_batch_new_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - auto gwsRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const IndexerT res_indexer(batch_nd + res_outer_nd, + res_batch_offset, res_shape_strides); + using InitKernelName = + class gemm_batch_tree_empty_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + return gemm_batch_no_reduction_ev; + } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); - - using KernelName = - class gemm_batch_nm_krn; - cgh.parallel_for( - ndRange, - GemmBatchFunctorThreadNM( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); - return gemm_ev; + if (max_nm < 64) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + constexpr std::uint32_t m_groups_one = 1; + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { + constexpr std::uint32_t m_groups_four = 4; + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } + else { + constexpr std::uint32_t m_groups_one = 1; + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + constexpr std::uint32_t m_groups_four = 4; + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { // m > 1, n > k or m > k, resTy complex + constexpr std::uint32_t m_groups_one = 1; + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } } } -typedef sycl::event (*gemm_batch_contig_impl_fn_ptr_t)( - sycl::queue &, - const char *, // lhs - const char *, // rhs - char *, // res - size_t, // batch nelems - size_t, // n - size_t, // k - size_t, // m - ssize_t, // lhs batch offset - ssize_t, // rhs batch offset - ssize_t, // res batch offset - std::vector const &); - -template -sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - ssize_t lhs_batch_offset, - ssize_t rhs_batch_offset, - ssize_t res_batch_offset, - std::vector const &depends = {}) +template +sycl::event +gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + std::vector const &depends) { - const lhsTy *lhs_tp = - reinterpret_cast(lhs_cp) + lhs_batch_offset; - const rhsTy *rhs_tp = - reinterpret_cast(rhs_cp) + rhs_batch_offset; - resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = dev.get_info(); const size_t reserved_slm_size = 512; - sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - cgh.fill(res_tp, resTy(0), n * m * batch_nelems); - }); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - if (k == 0) { - return res_init_ev; + if (k <= (delta_k * n_wi)) { + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + return gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, delta_n, + n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + depends); } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; - using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); - - if (m < 4) { - constexpr size_t m_groups = 1; - const size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - size_t lws = delta_n * delta_k; + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); - auto gRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } - using KernelName = - class gemm_batch_k_krn; - cgh.parallel_for( - ndRange, GemmBatchFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; - return gemm_ev; - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - const size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, tmp, batch_nelems, n, k, m, delta_n, + n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + tmp_indexer, depends); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); - size_t lws = delta_n * delta_k; + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - auto gRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - using LocAccT = sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; - using KernelName = - class gemm_batch_k_krn; - cgh.parallel_for( - ndRange, GemmBatchFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); - return gemm_ev; - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, batch_nelems, n, + k, m, delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer, depends); - auto gwsRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); - - using KernelName = - class gemm_batch_nm_krn; - cgh.parallel_for( - ndRange, - GemmBatchFunctorThreadNM( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); - return gemm_ev; + return cleanup_host_task_event; + } } } -template -class GemmBatchNoAtomicFunctorThreadNM +template +sycl::event +gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + std::vector const &depends) { -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT1 local_A_block; - LocAccT2 local_B_block; - size_t n = 0; - size_t wg_delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t wi_delta_k = 0; - size_t m = 0; - size_t m_blocks = 0; - size_t wg_delta_m = 0; - size_t batch_nelems; - BatchDimsIndexerT batch_indexer; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - ResIndexerT res_indexer; + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI -public: - GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT1 local_A_block_, - LocAccT2 local_B_block_, - size_t n_, - size_t wg_delta_n_, - size_t k_, - size_t k_blocks_, - size_t wi_delta_k_, - size_t m_, - size_t m_blocks_, - size_t wg_delta_m_, - size_t batch_nelems_, - BatchDimsIndexerT batch_indexer_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - ResIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), - local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), - k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), - m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), - batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), - lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), - res_indexer(res_indexer_) - { - } + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; - void operator()(sycl::nd_item<1> it) const - { - const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; - const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; - const size_t gr_id = - it.get_group_linear_id() - m_id * n_groups_per_batch; + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + return gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s - // < k_blocks + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - const auto &lhs_offset = three_offsets_.get_first_offset(); - const auto &rhs_offset = three_offsets_.get_second_offset(); - const auto &res_offset = three_offsets_.get_third_offset(); + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); - size_t block_i = gr_id / (m_blocks * k_blocks); - size_t block_r = gr_id - block_i * (m_blocks * k_blocks); - size_t block_j = block_r / k_blocks; - size_t block_s = block_r - block_j * k_blocks; + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); - size_t lid = it.get_local_linear_id(); - size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n - size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); - // load A block and B blocks into SLM + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); - size_t i = block_i * wi_delta_n * wg_delta_n; - size_t j = block_j * wi_delta_m * wg_delta_m; - size_t s = block_s * wi_delta_k; + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } - const std::int64_t a_st0 = k; - const std::int64_t a_st1 = 1; + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT tmp_indexer{}; - const std::int64_t b_st0 = m; - const std::int64_t b_st1 = 1; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; - const std::int64_t c_st0 = m; - const std::int64_t c_st1 = 1; + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); - size_t lws = it.get_local_range(0); + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, tmp, batch_nelems, n, k, m, + wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer, depends); - for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { - size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n - size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); - size_t g_i = i + v_i; - size_t g_s = s + v_s; + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - local_A_block[vid] = - (g_i < n && g_s < k) - ? static_cast( - lhs[lhs_offset + - lhs_indexer(g_i * a_st0 + g_s * a_st1)]) - : resT(0); + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; } + else { + assert(reduction_groups > 1); - using slmB_t = typename LocAccT2::value_type; - - for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { - size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m - size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k - - size_t g_j = j + v_j * wi_delta_m; - size_t g_s = s + v_s; + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - if constexpr (wi_delta_m == 1 && std::is_same_v) { - local_B_block[vid] = - (g_j < m && g_s < k) - ? static_cast( - rhs[rhs_offset + - rhs_indexer(g_s * b_st0 + g_j * b_st1)]) - : resT(0); + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); } else { - slmB_t vec{}; -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) - { - size_t g_j1 = g_j + lane_id; - vec[lane_id] = - (g_j1 < m && g_s < k) - ? static_cast( - rhs[rhs_offset + - rhs_indexer(g_s * b_st0 + g_j1 * b_st1)]) - : resT(0); - } - - local_B_block[vid] = vec; - } - } - - it.barrier(sycl::access::fence_space::local_space); + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - i += local_i * wi_delta_n; - j += local_j * wi_delta_m; + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT tmp_indexer{}; - const size_t a_offset = local_i * wi_delta_k * wi_delta_n; - const size_t b_offset = local_j * wi_delta_k; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; - constexpr resT identity_(0); + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); - for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { - const size_t a_pr_offset = private_i * wi_delta_k; + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + batch_nelems, n, k, m, wg_delta_n, wg_delta_m, + wi_delta_k, batch_indexer, lhs_indexer, rhs_indexer, + tmp_indexer, depends); - slmB_t local_sum(identity_); - for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { - local_sum = local_sum + - (local_A_block[a_offset + a_pr_offset + private_s] * - local_B_block[b_offset + private_s]); - } + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); - const size_t gl_i = i + private_i; + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - if constexpr (wi_delta_m == 1 && std::is_same_v) { - const size_t gl_j = j; - if (gl_i < n && gl_j < m) { - res[res_offset + res_indexer(gl_i * c_st0 + gl_j * c_st1) + - (block_s * n * m * batch_nelems)] = local_sum; - } - } - else { -#pragma unroll - for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) - { - const size_t gl_j = j + lane_id; + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); - if (gl_i < n && gl_j < m) { - res[res_offset + - res_indexer(gl_i * c_st0 + gl_j * c_st1) + - (block_s * n * m * batch_nelems)] = - local_sum[lane_id]; - } - } - } + return cleanup_host_task_event; } } -}; +} -template -class GemmBatchNoAtomicFunctorThreadK +template +sycl::event gemm_new_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_shape_strides, + int rhs_outer_nd, + const ssize_t *rhs_shape_strides, + int res_outer_nd, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) { -private: - const lhsT *lhs = nullptr; - const rhsT *rhs = nullptr; - resT *res = nullptr; - LocAccT workspace; - LocAccT local_B_block; - size_t n = 0; - size_t n_blocks = 0; - size_t delta_n = 0; - size_t k = 0; - size_t k_blocks = 0; - size_t delta_k = 0; - size_t n_wi = 0; - size_t m = 0; - size_t batch_nelems = 0; - BatchDimsIndexerT batch_indexer; - OuterInnerDimsIndexerT lhs_indexer; - OuterInnerDimsIndexerT rhs_indexer; - ResIndexerT res_indexer; - -public: - GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, - const rhsT *rhs_, - resT *res_, - LocAccT workspace_, - LocAccT local_B_block_, - size_t n_, - size_t n_blocks_, - size_t delta_n_, - size_t k_, - size_t k_blocks_, - size_t delta_k_, - size_t n_wi_, - size_t m_, - size_t batch_nelems_, - BatchDimsIndexerT batch_indexer_, - OuterInnerDimsIndexerT lhs_indexer_, - OuterInnerDimsIndexerT rhs_indexer_, - ResIndexerT res_indexer_) - : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), - local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), - delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), - n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), - batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), - rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) - { - } - - void operator()(sycl::nd_item<1> it) const - { - const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; - const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; - const size_t gr_id = - it.get_group_linear_id() - m_id * n_groups_per_batch; - size_t lid = it.get_local_linear_id(); - - const auto &three_offsets_ = batch_indexer(static_cast(m_id)); - const auto &lhs_offset = three_offsets_.get_first_offset(); - const auto &rhs_offset = three_offsets_.get_second_offset(); - const auto &res_offset = three_offsets_.get_third_offset(); - - // lift gr_id -> (block_i, block_j, block_s) - // block_i moves fastest, then block_s, then block_j + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_shape_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_shape_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_shape_strides); - size_t block_j = - gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks - size_t block_r = - gr_id - block_j * (n_blocks * - k_blocks); // 0 <= block_r < n_blocks * k_blocks - size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks - size_t block_i = - block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + constexpr BatchDimsIndexerT batch_indexer{}; - size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n - size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + constexpr size_t single_batch_nelems = 1; - size_t i = block_i * delta_n + local_i; - size_t j = m_groups * block_j; - size_t s = block_s * delta_k * n_wi + local_s; + sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); - using accV_t = typename LocAccT::value_type; + return gemm_ev; +} - constexpr resT identity_ = resT(0); - if (local_i == 0) { - for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { - size_t sq = s + q; - size_t sqmj = sq * m + j; +template +sycl::event +gemm_batch_new_nm_contig_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const size_t batch_nelems, + const size_t n, + const size_t k, + const size_t m, + std::vector const &depends = {}) +{ + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT res_indexer{}; - if constexpr (m_groups == 1 && std::is_same_v) { - local_B_block[local_s + q] = - (sq < k && j < m) - ? static_cast( - rhs[rhs_offset + rhs_indexer(sqmj)]) - : identity_; - } - else { - accV_t local_B_vec; -#pragma unroll - for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { - local_B_vec[vec_idx] = - (sq < k && j + vec_idx < m) - ? static_cast( - rhs[rhs_offset + - rhs_indexer(sqmj + vec_idx)]) - : identity_; - } - local_B_block[local_s + q] = local_B_vec; - } - } - } + constexpr size_t single_batch_nelems = 1; - it.barrier(sycl::access::fence_space::local_space); + if (batch_nelems == single_batch_nelems) { + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + constexpr BatchDimsIndexerT batch_indexer{}; - size_t t_shift = block_s * delta_k * n_wi; - size_t global_s_offset = i * k + t_shift; + sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); - accV_t private_sum(identity_); - constexpr accV_t vec_identity_(identity_); - for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - private_sum += - ((i < n) && (t + t_shift < k)) - ? (static_cast( - lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * - local_B_block[t]) - : vec_identity_; - } + return gemm_ev; + } + else { + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + + const ssize_t ss_batch_nelems = static_cast(batch_nelems); + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, ss_batch_nelems, static_cast(n * k)}, + Strided1DIndexer{0, ss_batch_nelems, static_cast(k * m)}, + Strided1DIndexer{0, ss_batch_nelems, static_cast(n * m)}); + + sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); - size_t workspace_i_shift = local_i * delta_k; - workspace[workspace_i_shift + local_s] = private_sum; + return gemm_ev; + } +} - it.barrier(sycl::access::fence_space::local_space); +template +sycl::event +gemm_batch_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; - if (local_s == 0 && i < n) { - accV_t local_sum(workspace[workspace_i_shift]); - for (size_t t = 1; t < delta_k; ++t) { - local_sum += workspace[workspace_i_shift + t]; - } + const size_t min_nm = std::min(n, m); + const size_t max_nm = std::max(n, m); - const size_t total_offset = - res_offset + (block_s * n * m * batch_nelems); + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_batch_new_nm_contig_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } - if constexpr (m_groups == 1 && std::is_same_v) { - res[total_offset + res_indexer(i * m + j)] = local_sum; - } - else { - res[total_offset + res_indexer(i * m + j)] = local_sum[0]; + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); + return gemm_batch_no_reduction_ev; + } -#pragma unroll - for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { - if (j + vec_id < m) { - res[total_offset + res_indexer(i * m + j + vec_id)] = - local_sum[1]; - } - } + if (max_nm < 64) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); } } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } } -}; + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + } +} + +// Gemm tree non-batched template -class gemm_batch_tree_k_krn; +class gemm_tree_nm_krn; template -class gemm_batch_tree_nm_krn; +class gemm_tree_k_krn; template -sycl::event -gemm_batch_tree_k_impl(sycl::queue &exec_q, - const lhsTy *lhs_tp, - const rhsTy *rhs_tp, - resTy *res_tp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - int batch_nd, - const ssize_t *batch_shape_strides, - ssize_t lhs_batch_offset, - ssize_t rhs_batch_offset, - ssize_t res_batch_offset, - int inner_nd, - int lhs_outer_nd, - const ssize_t *lhs_outer_inner_shapes_strides, - int rhs_outer_nd, - const ssize_t *rhs_outer_inner_shapes_strides, - int res_outer_nd, - const ssize_t *res_outer_shapes_strides, - const ssize_t *res_shape_strides, - std::vector const &depends) +sycl::event gemm_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const ssize_t *res_shapes_strides, + const std::vector &depends) { size_t delta_k(4); size_t n_wi(64); @@ -4385,94 +3884,44 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, delta_n // modified by reference ); - if (k <= (delta_k * n_wi)) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, - res_outer_shapes_strides); - using BatchDimsIndexerT = - dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; - BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, - batch_shape_strides); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + constexpr BatchIndexerT batch_indexer{}; - auto ndRange = sycl::nd_range<1>(gRange, lRange); + constexpr size_t single_batch_nelems = 1; - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - - const auto &krn_body = GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, - res_indexer); + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); - cgh.parallel_for(ndRange, krn_body); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - } - }); - return gemm_ev; + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + const OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); + + return gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); } else { using ReductionOpT = sycl::plus; constexpr resTy identity_val = sycl::known_identity::value; - size_t iter_nelems = batch_nelems * n * m; + size_t iter_nelems = n * m; size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - // more than one work-group is needed, requires a - // temporary delta_k * n_wi elements processed along k, - // so if more to process use multiple + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - constexpr size_t preferred_reductions_per_wi = 4; + constexpr size_t preferred_reductions_per_wi = 8; size_t reductions_per_wi(preferred_reductions_per_wi); size_t reduction_groups = @@ -4488,96 +3937,26 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { throw std::runtime_error("Unable to allocate device memory"); } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; - StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, - batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer( - lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - } - }); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr ResIndexerT res_indexer{}; + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, ResIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, tmp, single_batch_nelems, n, k, m, + delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); sycl::event red_ev = single_reduction_for_gemm( exec_q, tmp, res_tp, identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, - {gemm_ev}); + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { @@ -4598,100 +3977,29 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, resTy *partially_reduced_tmp2 = nullptr; if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); + throw std::runtime_error("Unable to allocate device memory"); } else { partially_reduced_tmp2 = partially_reduced_tmp + reduction_nelems * iter_nelems; } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr ResIndexerT res_indexer{}; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, - batch_shape_strides); - StridedIndexer rhs_batch_indexer(batch_nd, rhs_batch_offset, - batch_shape_strides + - 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer( - lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>; - - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, ResIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + single_batch_nelems, n, k, m, delta_n, n_wi, delta_k, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + // tree_reduction_for_gemm returns sycl::event for reduction sycl::event red_ev = tree_reduction_for_gemm( exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, - batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, - {gemm_ev}); + res_nd, 0, res_shapes_strides, {gemm_ev}); sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { @@ -4709,29 +4017,21 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, } template -sycl::event -gemm_batch_tree_nm_impl(sycl::queue &exec_q, - const lhsTy *lhs_tp, - const rhsTy *rhs_tp, - resTy *res_tp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - int batch_nd, - const ssize_t *batch_shape_strides, - ssize_t lhs_batch_offset, - ssize_t rhs_batch_offset, - ssize_t res_batch_offset, - int inner_nd, - int lhs_outer_nd, - const ssize_t *lhs_outer_inner_shapes_strides, - int rhs_outer_nd, - const ssize_t *rhs_outer_inner_shapes_strides, - int res_outer_nd, - const ssize_t *res_outer_shapes_strides, - const ssize_t *res_shape_strides, - std::vector const &depends) +sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const ssize_t *res_shapes_strides, + const std::vector &depends) { constexpr int wi_delta_n = 2; size_t wg_delta_n(16); // rows of A processed in WG @@ -4750,106 +4050,45 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, wg_delta_m // modified by reference ); - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + constexpr BatchIndexerT batch_indexer{}; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, - res_outer_shapes_strides); - using BatchDimsIndexerT = - dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; - BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, - batch_shape_strides); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - } - }); - return gemm_ev; + constexpr size_t single_batch_nelems = 1; + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + const OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); + + return gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, + k, m, wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); } else { using ReductionOpT = sycl::plus; constexpr resTy identity_val = sycl::known_identity::value; - size_t iter_nelems = batch_nelems * n * m; + + size_t iter_nelems = n * m; size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to // process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - constexpr size_t preferred_reductions_per_wi = 4; + constexpr size_t preferred_reductions_per_wi = 8; size_t reductions_per_wi(preferred_reductions_per_wi); size_t reduction_groups = @@ -4865,106 +4104,26 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { throw std::runtime_error("Unable to allocate device memory"); } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; - StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, - batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer( - lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, - n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, - n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr ResIndexerT res_indexer{}; + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, wi_delta_m>( + exec_q, lhs_tp, rhs_tp, tmp, single_batch_nelems, n, k, m, + wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer, depends); sycl::event red_ev = single_reduction_for_gemm( exec_q, tmp, res_tp, identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, - {gemm_ev}); + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { @@ -4992,103 +4151,22 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, partially_reduced_tmp + reduction_nelems * iter_nelems; } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr ResIndexerT res_indexer{}; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; - StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, - batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer( - lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - } - }); + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, wi_delta_m>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + single_batch_nelems, n, k, m, wg_delta_n, wg_delta_m, + wi_delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); sycl::event red_ev = tree_reduction_for_gemm( exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, - batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, - {gemm_ev}); + res_nd, 0, res_shapes_strides, {gemm_ev}); sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { @@ -5105,122 +4183,112 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, } } -template -class gemm_batch_tree_empty_krn; +template class gemm_tree_empty_krn; template -sycl::event gemm_batch_tree_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - int batch_nd, - const ssize_t *batch_shape_strides, - ssize_t lhs_batch_offset, - ssize_t rhs_batch_offset, - ssize_t res_batch_offset, - int inner_nd, - int lhs_outer_nd, - const ssize_t *lhs_outer_inner_shapes_strides, - int rhs_outer_nd, - const ssize_t *rhs_outer_inner_shapes_strides, - int res_outer_nd, - const ssize_t *res_outer_shapes_strides, - const ssize_t *res_shape_strides, - std::vector const &depends = {}) +sycl::event gemm_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const ssize_t *res_shapes_strides, + std::vector const &depends = {}) { const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); resTy *res_tp = reinterpret_cast(res_cp); + const size_t min_nm = std::min(n, m); + const size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_new_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + if (k == 0) { - sycl::event gemm_batch_no_reduction_ev = + sycl::event gemm_no_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; - IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, - res_shape_strides); + const IndexerT res_indexer(res_nd, 0, res_shapes_strides); using InitKernelName = - class gemm_batch_tree_empty_krn; + class gemm_tree_empty_krn; cgh.parallel_for( - sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + sycl::range<1>(n * m), [=](sycl::id<1> id) { auto res_offset = res_indexer(id[0]); res_tp[res_offset] = resTy(0); }); }); - return gemm_batch_no_reduction_ev; + return gemm_no_reduction_ev; } - if ((k > n && k > m) || m < 4) { + if (max_nm < 64) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m < 4) { - return gemm_batch_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - batch_nd, batch_shape_strides, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); } else { - return gemm_batch_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - batch_nd, batch_shape_strides, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); } } else { - return gemm_batch_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, - batch_shape_strides, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, inner_nd, lhs_outer_nd, + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); } } else { // m > 1, n > k or m > k using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { - return gemm_batch_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, - batch_shape_strides, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, inner_nd, lhs_outer_nd, + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); } - else { // m > 1, n > k or m > k, resTy complex - return gemm_batch_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, - batch_shape_strides, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, inner_nd, lhs_outer_nd, + else { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); } } } template -sycl::event -gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, - const lhsTy *lhs_tp, - const rhsTy *rhs_tp, - resTy *res_tp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - std::vector const &depends) +sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + std::vector const &depends) { size_t delta_k(4); size_t n_wi(64); @@ -5237,100 +4305,41 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, delta_n // modified by reference ); - if (k <= (delta_k * n_wi)) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - using dpctl::tensor::offset_utils::Strided1DIndexer; - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT res_indexer{}; - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + constexpr BatchIndexerT batch_indexer{}; - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + constexpr size_t single_batch_nelems = 1; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - } - }); - return gemm_ev; + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + return gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); } else { using ReductionOpT = sycl::plus; constexpr resTy identity_val = sycl::known_identity::value; - size_t iter_nelems = batch_nelems * n * m; + size_t iter_nelems = n * m; size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - // more than one work-group is needed, requires a + // more than one work-groups is needed, requires a // temporary delta_k * n_wi elements processed along k, // so if more to process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - constexpr size_t preferred_reductions_per_wi = 4; + constexpr size_t preferred_reductions_per_wi = 8; size_t reductions_per_wi(preferred_reductions_per_wi); size_t reduction_groups = @@ -5346,82 +4355,17 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { throw std::runtime_error("Unable to allocate device memory"); } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); - } - }); + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, tmp, single_batch_nelems, n, k, m, + delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); sycl::event red_ev = single_reduction_for_gemm_contig( @@ -5455,80 +4399,15 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, partially_reduced_tmp + reduction_nelems * iter_nelems; } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); - } - }); + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + single_batch_nelems, n, k, m, delta_n, n_wi, delta_k, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + // tree_reduction_for_gemm_contig returns sycl::event + // for reduction sycl::event red_ev = tree_reduction_for_gemm_contig( exec_q, partially_reduced_tmp, partially_reduced_tmp2, @@ -5552,16 +4431,14 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, } template -sycl::event -gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, - const lhsTy *lhs_tp, - const rhsTy *rhs_tp, - resTy *res_tp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - std::vector const &depends) +sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + std::vector const &depends) { constexpr int wi_delta_n = 2; size_t wg_delta_n(16); // rows of A processed in WG @@ -5580,111 +4457,43 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, wg_delta_m // modified by reference ); - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + constexpr OuterInnerDimsIndexerT lhs_indexer{}; + constexpr OuterInnerDimsIndexerT rhs_indexer{}; + constexpr OuterInnerDimsIndexerT res_indexer{}; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + constexpr BatchIndexerT batch_indexer{}; - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); + constexpr size_t single_batch_nelems = 1; - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - } - }); - return gemm_ev; + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + + return gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, + k, m, wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); } else { using ReductionOpT = sycl::plus; constexpr resTy identity_val = sycl::known_identity::value; - size_t iter_nelems = batch_nelems * n * m; + + size_t iter_nelems = n * m; size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to // process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - constexpr size_t preferred_reductions_per_wi = 4; + constexpr size_t preferred_reductions_per_wi = 8; size_t reductions_per_wi(preferred_reductions_per_wi); size_t reduction_groups = @@ -5700,92 +4509,18 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { throw std::runtime_error("Unable to allocate device memory"); } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, - n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, - n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); - } - }); + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, tmp, single_batch_nelems, n, + k, m, wg_delta_n, wg_delta_m, wi_delta_k, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); sycl::event red_ev = single_reduction_for_gemm_contig( @@ -5819,91 +4554,13 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, partially_reduced_tmp + reduction_nelems * iter_nelems; } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); - } - }); + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + single_batch_nelems, n, k, m, wg_delta_n, + wg_delta_m, wi_delta_k, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer, depends); sycl::event red_ev = tree_reduction_for_gemm_contig( @@ -5928,63 +4585,64 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, } template -sycl::event -gemm_batch_contig_tree_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - ssize_t lhs_batch_offset, - ssize_t rhs_batch_offset, - ssize_t res_batch_offset, - std::vector const &depends = {}) +sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + std::vector const &depends = {}) { - const lhsTy *lhs_tp = - reinterpret_cast(lhs_cp) + lhs_batch_offset; - const rhsTy *rhs_tp = - reinterpret_cast(rhs_cp) + rhs_batch_offset; - resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const size_t min_nm = std::min(n, m); + const size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + constexpr size_t single_batch_nelems = 1; + return gemm_batch_new_nm_contig_impl( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + depends); + } if (k == 0) { - sycl::event gemm_batch_no_reduction_ev = + sycl::event gemm_no_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + cgh.fill(res_tp, resTy(0), n * m); }); - return gemm_batch_no_reduction_ev; + return gemm_no_reduction_ev; } - if ((k > n && k > m) || m < 4) { + if (max_nm < 64) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m < 4) { - return gemm_batch_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - depends); + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } else { - return gemm_batch_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - depends); + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } else { - return gemm_batch_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } else { // m > 1, n > k or m > k using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { - return gemm_batch_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } - else { // m > 1, n > k or m > k, resTy complex - return gemm_batch_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + else { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } } diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp index 91e07e3793..2d26c42909 100644 --- a/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp @@ -350,10 +350,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1, int inner_nd = inner_dims; const py::ssize_t *inner_shape_ptr = x1_shape_ptr + batch_dims; using shT = std::vector; - shT inner_x1_strides(std::begin(x1_strides_vec) + batch_dims, - std::end(x1_strides_vec)); - shT inner_x2_strides(std::begin(x2_strides_vec) + batch_dims, - std::end(x2_strides_vec)); + const shT inner_x1_strides(std::begin(x1_strides_vec) + batch_dims, + std::end(x1_strides_vec)); + const shT inner_x2_strides(std::begin(x2_strides_vec) + batch_dims, + std::end(x2_strides_vec)); shT simplified_inner_shape; shT simplified_inner_x1_strides; @@ -369,10 +369,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1, const py::ssize_t *batch_shape_ptr = x1_shape_ptr; - shT batch_x1_strides(std::begin(x1_strides_vec), - std::begin(x1_strides_vec) + batch_dims); - shT batch_x2_strides(std::begin(x2_strides_vec), - std::begin(x2_strides_vec) + batch_dims); + const shT batch_x1_strides(std::begin(x1_strides_vec), + std::begin(x1_strides_vec) + batch_dims); + const shT batch_x2_strides(std::begin(x2_strides_vec), + std::begin(x2_strides_vec) + batch_dims); shT const &batch_dst_strides = dst_strides_vec; shT simplified_batch_shape; @@ -551,9 +551,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1, } sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); - py::ssize_t *x1_shape_strides = packed_shapes_strides; - py::ssize_t *x2_shape_strides = packed_shapes_strides + 2 * (x1_nd); - py::ssize_t *dst_shape_strides = + const py::ssize_t *x1_shape_strides = packed_shapes_strides; + const py::ssize_t *x2_shape_strides = + packed_shapes_strides + 2 * (x1_nd); + const py::ssize_t *dst_shape_strides = packed_shapes_strides + 2 * (x1_nd + x2_nd); std::vector all_deps; @@ -619,9 +620,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1, shT outer_inner_x1_strides; dpctl::tensor::py_internal::split_iteration_space( x1_shape_vec, x1_strides_vec, batch_dims, - batch_dims + x1_outer_inner_dims, batch_x1_shape, - outer_inner_x1_shape, // 4 vectors modified - batch_x1_strides, outer_inner_x1_strides); + batch_dims + x1_outer_inner_dims, + // 4 vectors modified + batch_x1_shape, outer_inner_x1_shape, batch_x1_strides, + outer_inner_x1_strides); shT batch_x2_shape; shT outer_inner_x2_shape; @@ -629,9 +631,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1, shT outer_inner_x2_strides; dpctl::tensor::py_internal::split_iteration_space( x2_shape_vec, x2_strides_vec, batch_dims, - batch_dims + x2_outer_inner_dims, batch_x2_shape, - outer_inner_x2_shape, // 4 vectors modified - batch_x2_strides, outer_inner_x2_strides); + batch_dims + x2_outer_inner_dims, + // 4 vectors modified + batch_x2_shape, outer_inner_x2_shape, batch_x2_strides, + outer_inner_x2_strides); shT batch_dst_shape; shT outer_inner_dst_shape; @@ -639,9 +642,10 @@ py_dot(const dpctl::tensor::usm_ndarray &x1, shT outer_inner_dst_strides; dpctl::tensor::py_internal::split_iteration_space( dst_shape_vec, dst_strides_vec, batch_dims, - batch_dims + dst_outer_inner_dims, batch_dst_shape, - outer_inner_dst_shape, // 4 vectors modified - batch_dst_strides, outer_inner_dst_strides); + batch_dims + dst_outer_inner_dims, + // 4 vectors modified + batch_dst_shape, outer_inner_dst_shape, batch_dst_strides, + outer_inner_dst_strides); using shT = std::vector; shT simplified_batch_shape; @@ -746,16 +750,16 @@ py_dot(const dpctl::tensor::usm_ndarray &x1, sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); - auto batch_shape_strides = packed_shapes_strides; - auto x1_outer_inner_shapes_strides = + const auto batch_shape_strides = packed_shapes_strides; + const auto x1_outer_inner_shapes_strides = packed_shapes_strides + 4 * batch_dims; - auto x2_outer_inner_shapes_strides = packed_shapes_strides + - 4 * batch_dims + - 2 * (x1_outer_inner_dims); - auto dst_outer_shapes_strides = + const auto x2_outer_inner_shapes_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims); + const auto dst_outer_shapes_strides = packed_shapes_strides + 4 * batch_dims + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims); - auto dst_full_shape_strides = + const auto dst_full_shape_strides = packed_shapes_strides + 4 * batch_dims + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims) + 2 * (dst_outer_inner_dims); diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp index de59450174..19ba9ad0c3 100644 --- a/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp @@ -23,11 +23,21 @@ template struct DotAtomicOutputType T2, std::uint32_t, std::uint32_t>, + td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct DotAtomicOutputType std::int64_t, std::int64_t>, td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; }; @@ -75,11 +86,21 @@ template struct DotNoAtomicOutputType T2, std::uint32_t, std::uint32_t>, + td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct DotNoAtomicOutputType T2, std::complex, std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, td_ns::BinaryTypeMapResultEntry, T2, diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index e51c0a2ac7..f939bb39fa 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -571,6 +571,80 @@ def test_matmul_inplace_same_tensors(): assert dpt.all(ar2 == dpt.full(sh, n, dtype=ar2.dtype)) +@pytest.fixture +def random_matrix(): + rs = np.random.RandomState(seed=123456) + m_np = rs.randint(low=0, high=6, size=(400, 400)) + return m_np + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_largish_square(dtype, random_matrix): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m_np = random_matrix.astype(dtype) + x_np = np.matmul(m_np.T, m_np) + + m = dpt.asarray(m_np) + mT = dpt.asarray(m.mT, copy=True, order="C") + x1 = dpt.matmul(m.mT, m) + x2 = dpt.matmul(mT, m) + + tol = 0 + if dpt.isdtype(x2.dtype, ("real floating", "complex floating")): + tol = 32 * dpt.finfo(x2.dtype).eps + + assert dpt.allclose(x1, x2, atol=tol, rtol=tol) + assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol) + + # check stided input + m_np = m_np[:-1, :-1] + x_np = np.matmul(m_np.T, m_np) + + m = m[:-1, :-1] + mT = dpt.asarray(m.mT, copy=True, order="C") + x1 = dpt.matmul(m.mT, m) + x2 = dpt.matmul(mT, m) + + assert dpt.allclose(x1, x2, atol=tol, rtol=tol) + assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_largish_rect(dtype, random_matrix): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m_np = random_matrix.astype(dtype)[:, :-1] + x_np = np.matmul(m_np.T[:-2, :], m_np) + + m = dpt.asarray(m_np) + mmT = m.mT[:-2, :] + mT = dpt.asarray(mmT, copy=True, order="C") + x1 = dpt.matmul(mmT, m) + x2 = dpt.matmul(mT, m) + + tol = 0 + if dpt.isdtype(x2.dtype, ("real floating", "complex floating")): + tol = 32 * dpt.finfo(x2.dtype).eps + + assert dpt.allclose(x1, x2, atol=tol, rtol=tol) + assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol) + + m_np = m_np[:-1, :-1] + x_np = np.matmul(m_np.T[:-2, :], m_np) + + m = m[:-1, :-1] + mmT = m.mT[:-2, :] + mT = dpt.asarray(mmT, copy=True, order="C") + x1 = dpt.matmul(mmT, m) + x2 = dpt.matmul(mT, m) + + assert dpt.allclose(x1, x2, atol=tol, rtol=tol) + assert dpt.allclose(x1, dpt.asarray(x_np), atol=tol, rtol=tol) + + @pytest.mark.parametrize("dtype", _numeric_types) def test_tensordot_outer(dtype): q = get_queue_or_skip()