Skip to content

Commit 92514f1

Browse files
Save SLM chunks into registers prior to summing over delta_m, delta_n
This does not make any difference in performance due to compiler optimization being effective as loops are unrolled, but makes compiler's job easier and the intent clearer.
1 parent cfba263 commit 92514f1

File tree

1 file changed

+20
-8
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+20
-8
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,7 @@ class GemmBatchFunctorThreadNM_vecm
970970
size_t i = block_i * wg_delta_n * wi_delta_n;
971971
size_t j = block_j * wg_delta_m * wi_total_delta_m;
972972

973+
using slmA_t = typename LocAccT1::value_type;
973974
using slmB_t = typename LocAccT2::value_type;
974975

975976
const size_t a_st0 = k;
@@ -1037,8 +1038,7 @@ class GemmBatchFunctorThreadNM_vecm
10371038
slmB_t vec{};
10381039
#pragma unroll
10391040
for (std::uint32_t lane_id = 0; lane_id < m_vec_size;
1040-
++lane_id)
1041-
{
1041+
++lane_id) {
10421042
const size_t g_j1 = g_j + lane_id;
10431043
vec[lane_id] = (g_j1 < m && g_s < k)
10441044
? static_cast<resT>(
@@ -1057,16 +1057,29 @@ class GemmBatchFunctorThreadNM_vecm
10571057
const std::uint32_t lo_lhs_st_k = (wg_delta_n * wi_delta_n);
10581058
const std::uint32_t lo_rhs_rk_k = (wg_delta_m * wi_delta_m_vecs);
10591059
for (std::uint32_t pr_k = 0; pr_k < wi_delta_k; ++pr_k) {
1060+
std::array<slmA_t, wi_delta_n> pr_lhs{};
1061+
#pragma unroll
1062+
for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) {
1063+
pr_lhs[pr_i] =
1064+
local_lhs_block[pr_k * lo_lhs_st_k +
1065+
(local_i + pr_i * wg_delta_n)];
1066+
}
1067+
1068+
std::array<slmB_t, wi_delta_m_vecs> pr_rhs{};
1069+
#pragma unroll
1070+
for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) {
1071+
pr_rhs[pr_j] =
1072+
local_rhs_block[pr_k * lo_rhs_rk_k +
1073+
(local_j + pr_j * wg_delta_m)];
1074+
}
1075+
10601076
#pragma unroll
10611077
for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) {
10621078
#pragma unroll
10631079
for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j)
10641080
{
10651081
private_C[pr_i * wi_delta_m_vecs + pr_j] +=
1066-
local_lhs_block[pr_k * lo_lhs_st_k +
1067-
(local_i + pr_i * wg_delta_n)] *
1068-
local_rhs_block[pr_k * lo_rhs_rk_k +
1069-
(local_j + pr_j * wg_delta_m)];
1082+
pr_lhs[pr_i] * pr_rhs[pr_j];
10701083
}
10711084
}
10721085
}
@@ -1106,8 +1119,7 @@ class GemmBatchFunctorThreadNM_vecm
11061119
j + (local_j + pr_j * wg_delta_m) * m_vec_size;
11071120
#pragma unroll
11081121
for (std::uint32_t lane_id = 0; lane_id < m_vec_size;
1109-
++lane_id)
1110-
{
1122+
++lane_id) {
11111123
const size_t out_flat_id =
11121124
out_i * c_st0 + (out_j + lane_id) * c_st1;
11131125
if (out_j + lane_id < m) {

0 commit comments

Comments
 (0)