From af821299d7520ab60e265cdf808a4221f556e38a Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 22 Aug 2024 19:44:43 -0700 Subject: [PATCH] Save SLM chunks into registers prior to summing over delta_m, delta_n This does not make any difference in performance due to compiler optimization being effective as loops are unrolled, but makes compiler's job easier and the intent clearer. --- .../include/kernels/linalg_functions/gemm.hpp | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index d71cb3272a..1026efcfe0 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -970,6 +970,7 @@ class GemmBatchFunctorThreadNM_vecm size_t i = block_i * wg_delta_n * wi_delta_n; size_t j = block_j * wg_delta_m * wi_total_delta_m; + using slmA_t = typename LocAccT1::value_type; using slmB_t = typename LocAccT2::value_type; const size_t a_st0 = k; @@ -1057,16 +1058,29 @@ class GemmBatchFunctorThreadNM_vecm const std::uint32_t lo_lhs_st_k = (wg_delta_n * wi_delta_n); const std::uint32_t lo_rhs_rk_k = (wg_delta_m * wi_delta_m_vecs); for (std::uint32_t pr_k = 0; pr_k < wi_delta_k; ++pr_k) { + std::array pr_lhs{}; +#pragma unroll + for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) { + pr_lhs[pr_i] = + local_lhs_block[pr_k * lo_lhs_st_k + + (local_i + pr_i * wg_delta_n)]; + } + + std::array pr_rhs{}; +#pragma unroll + for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) { + pr_rhs[pr_j] = + local_rhs_block[pr_k * lo_rhs_rk_k + + (local_j + pr_j * wg_delta_m)]; + } + #pragma unroll for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) { #pragma unroll for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) { private_C[pr_i * wi_delta_m_vecs + pr_j] += - local_lhs_block[pr_k * lo_lhs_st_k + - (local_i + pr_i * wg_delta_n)] * - local_rhs_block[pr_k * lo_rhs_rk_k + - (local_j + pr_j * wg_delta_m)]; + pr_lhs[pr_i] * pr_rhs[pr_j]; } } }