@@ -970,6 +970,7 @@ class GemmBatchFunctorThreadNM_vecm
970
970
size_t i = block_i * wg_delta_n * wi_delta_n;
971
971
size_t j = block_j * wg_delta_m * wi_total_delta_m;
972
972
973
+ using slmA_t = typename LocAccT1::value_type;
973
974
using slmB_t = typename LocAccT2::value_type;
974
975
975
976
const size_t a_st0 = k;
@@ -1037,8 +1038,7 @@ class GemmBatchFunctorThreadNM_vecm
1037
1038
slmB_t vec{};
1038
1039
#pragma unroll
1039
1040
for (std::uint32_t lane_id = 0 ; lane_id < m_vec_size;
1040
- ++lane_id)
1041
- {
1041
+ ++lane_id) {
1042
1042
const size_t g_j1 = g_j + lane_id;
1043
1043
vec[lane_id] = (g_j1 < m && g_s < k)
1044
1044
? static_cast <resT>(
@@ -1057,16 +1057,29 @@ class GemmBatchFunctorThreadNM_vecm
1057
1057
const std::uint32_t lo_lhs_st_k = (wg_delta_n * wi_delta_n);
1058
1058
const std::uint32_t lo_rhs_rk_k = (wg_delta_m * wi_delta_m_vecs);
1059
1059
for (std::uint32_t pr_k = 0 ; pr_k < wi_delta_k; ++pr_k) {
1060
+ std::array<slmA_t, wi_delta_n> pr_lhs{};
1061
+ #pragma unroll
1062
+ for (std::uint32_t pr_i = 0 ; pr_i < wi_delta_n; ++pr_i) {
1063
+ pr_lhs[pr_i] =
1064
+ local_lhs_block[pr_k * lo_lhs_st_k +
1065
+ (local_i + pr_i * wg_delta_n)];
1066
+ }
1067
+
1068
+ std::array<slmB_t, wi_delta_m_vecs> pr_rhs{};
1069
+ #pragma unroll
1070
+ for (std::uint32_t pr_j = 0 ; pr_j < wi_delta_m_vecs; ++pr_j) {
1071
+ pr_rhs[pr_j] =
1072
+ local_rhs_block[pr_k * lo_rhs_rk_k +
1073
+ (local_j + pr_j * wg_delta_m)];
1074
+ }
1075
+
1060
1076
#pragma unroll
1061
1077
for (std::uint32_t pr_i = 0 ; pr_i < wi_delta_n; ++pr_i) {
1062
1078
#pragma unroll
1063
1079
for (std::uint32_t pr_j = 0 ; pr_j < wi_delta_m_vecs; ++pr_j)
1064
1080
{
1065
1081
private_C[pr_i * wi_delta_m_vecs + pr_j] +=
1066
- local_lhs_block[pr_k * lo_lhs_st_k +
1067
- (local_i + pr_i * wg_delta_n)] *
1068
- local_rhs_block[pr_k * lo_rhs_rk_k +
1069
- (local_j + pr_j * wg_delta_m)];
1082
+ pr_lhs[pr_i] * pr_rhs[pr_j];
1070
1083
}
1071
1084
}
1072
1085
}
@@ -1106,8 +1119,7 @@ class GemmBatchFunctorThreadNM_vecm
1106
1119
j + (local_j + pr_j * wg_delta_m) * m_vec_size;
1107
1120
#pragma unroll
1108
1121
for (std::uint32_t lane_id = 0 ; lane_id < m_vec_size;
1109
- ++lane_id)
1110
- {
1122
+ ++lane_id) {
1111
1123
const size_t out_flat_id =
1112
1124
out_i * c_st0 + (out_j + lane_id) * c_st1;
1113
1125
if (out_j + lane_id < m) {
0 commit comments