Skip to content

Commit 41b6da1

Browse files
authored
[SYCL][CUDA] Matrix MMA for double type using nvptx. (intel/llvm-test-suite#553)
Signed-off-by: JackAKirk <[email protected]>
1 parent 77ac9af commit 41b6da1

File tree

1 file changed

+156
-0
lines changed

1 file changed

+156
-0
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// REQUIRES: gpu, cuda
2+
3+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 %s -o %t.out
4+
//
5+
// Specifying the sm version via the --cuda-gpu-arch flag is necessary
6+
// for the Nvidia case. DPC++ JIT compilation is not
7+
// supported for the Nvidia case, although some JIT optimizations are performed
8+
// at the level of the PTX assembly code.
9+
10+
#include <CL/sycl.hpp>
11+
12+
using namespace sycl;
13+
using namespace sycl::ext::oneapi::experimental::matrix;
14+
15+
// Example usage of Nvidia matrix multiply.
16+
// Optimizations such as memory paddings for avoiding bank conflicts are not
17+
// included in this test which aids clarity for what is going on. This example
18+
// forms a "Big matrix" corresponding to a single "TILE" using cuda example
19+
// terminology. Multiple TILES can be used to construct yet larger matrices.
20+
// This example uses row_major a, b, and accumulator matrices.
21+
22+
// M, N, K define the unit sizes of dimensions of the three types (a, b,
23+
// accumulator) of matrices per subgroup operation.
24+
constexpr int M = 8; // number of rows of "C"/"D" (Accumulator) sub-matrices,
25+
// number of cols of "B" sub-matrix.
26+
constexpr int N = 8; // number of cols of "C"/"D" (Accumulator) sub-matrices,
27+
// number of rows of "A" sub-matrix.
28+
constexpr int K =
29+
4; // number of cols of "A"/number of rows of "B" sub-matrices.
30+
31+
constexpr int N_THREADS_PER_MATRIX_OP =
32+
32; // the number of threads per MMA subgroup is always 32 for Nvidia.
33+
34+
constexpr int SUB_TILES_M =
35+
77; // number of submatrices per row of accumulator ("C", "D") matrices.
36+
constexpr int SUB_TILES_N =
37+
55; // number of submatrices per col of accumulator matrices.
38+
constexpr int SUB_TILES_K =
39+
257; // number of submatrices per col of "A"/per row of "B", matrices.
40+
41+
constexpr int BIG_M =
42+
SUB_TILES_M *
43+
M; // total number of M dimension matrix elements for the "Big matrix".
44+
constexpr int BIG_N =
45+
SUB_TILES_N *
46+
N; // total number of N dimension matrix elements for the "Big matrix".
47+
constexpr int BIG_K =
48+
SUB_TILES_K *
49+
K; // total number of K dimension matrix elements for the "Big matrix".
50+
51+
// The stride should equal the number of elements between consecutive leading
52+
// dimensions of the "Big matrix". e.g. number of elements per row if matrix is
53+
// indexed row major. The stride tells the implementation how many elements to
54+
// skip in memory matrix row/column multiplications.
55+
constexpr int STRIDE_A = BIG_K; // row major. If col major should equal BIG_M.
56+
constexpr int STRIDE_B = BIG_N; // row_major. If col major should equal BIG_K.
57+
constexpr int STRIDE_C = BIG_N; // row major. If col major should equal BIG_M.
58+
59+
double A[BIG_M * BIG_K];
60+
double B[BIG_K * BIG_N];
61+
double C[BIG_M * BIG_N];
62+
double D[BIG_M * BIG_N];
63+
64+
// returns correct (m,n) element of matrix D = A*B + C (assuming all matrices
65+
// are indexed row_major).
66+
double matrix_ref_mn(const int &m, const int &n) {
67+
double res = C[m * BIG_N + n];
68+
69+
for (int k = 0; k < BIG_K; k++)
70+
res += A[m * BIG_K + k] * B[k * BIG_N + n];
71+
return res;
72+
}
73+
74+
int main() {
75+
for (int i = 0; i < BIG_M * BIG_N; i++) {
76+
C[i] = i;
77+
D[i] = 0;
78+
}
79+
80+
for (int i = 0; i < BIG_M * BIG_K; i++) {
81+
A[i] = i;
82+
}
83+
84+
for (int i = 0; i < BIG_K * BIG_N; i++) {
85+
B[i] = i;
86+
}
87+
88+
buffer<double, 1> bufA(A, range<1>(BIG_M * BIG_K));
89+
buffer<double, 1> bufB(B, range<1>(BIG_K * BIG_N));
90+
buffer<double, 1> bufC(C, range<1>(BIG_M * BIG_N));
91+
buffer<double, 1> bufD(D, range<1>(BIG_M * BIG_N));
92+
93+
queue q;
94+
q.submit([&](handler &cgh) {
95+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
96+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
97+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
98+
auto accD = bufD.get_access<access::mode::read_write>(cgh);
99+
100+
range<2> LocalRange = {1, N_THREADS_PER_MATRIX_OP};
101+
range<2> GlobalRange = {SUB_TILES_M, SUB_TILES_N * N_THREADS_PER_MATRIX_OP};
102+
103+
cgh.parallel_for<class imatrix>(
104+
nd_range<2>(GlobalRange, LocalRange), [=
105+
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
106+
sycl::sub_group sg = item.get_sub_group();
107+
108+
const auto m =
109+
item.get_group()
110+
.get_id()[0]; // row id of current submatrix of BIG C matrix
111+
const auto n =
112+
item.get_group().get_id()[1]; // column id of current submatrix of
113+
// BIG C matrix
114+
115+
joint_matrix<double, matrix_use::accumulator, M, N,
116+
matrix_layout::row_major>
117+
sub_c;
118+
119+
joint_matrix<double, matrix_use::a, M, K, matrix_layout::row_major>
120+
sub_a;
121+
122+
joint_matrix<double, matrix_use::b, K, N, matrix_layout::row_major>
123+
sub_b;
124+
125+
joint_matrix_load(sg, sub_c,
126+
accC.get_pointer() + (m * M) * BIG_N + n * N,
127+
STRIDE_C);
128+
129+
for (int k = 0; k < SUB_TILES_K;
130+
k += 1) // row/col id of current submatrix of BIG A/B matrices
131+
{
132+
joint_matrix_load(sg, sub_a,
133+
accA.get_pointer() + (k * K) + (m * M * BIG_K),
134+
STRIDE_A);
135+
136+
joint_matrix_load(sg, sub_b,
137+
accB.get_pointer() + (k * K * BIG_N) + (n * N),
138+
STRIDE_B);
139+
140+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
141+
}
142+
joint_matrix_store(sg, sub_c,
143+
accD.get_pointer() + (m * M) * BIG_N + n * N,
144+
STRIDE_C);
145+
});
146+
});
147+
148+
const auto host_accessor = bufD.get_access<cl::sycl::access::mode::read>();
149+
150+
for (int m = 0; m < BIG_M; m++)
151+
for (int n = 0; n < BIG_N; n++) {
152+
assert(host_accessor[m * BIG_N + n] == matrix_ref_mn(m, n));
153+
}
154+
155+
return 0;
156+
};

0 commit comments

Comments
 (0)