|
| 1 | +// REQUIRES: gpu, cuda |
| 2 | + |
| 3 | +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 %s -o %t.out |
| 4 | +// |
| 5 | +// Specifying the sm version via the --cuda-gpu-arch flag is necessary |
| 6 | +// for the Nvidia case. DPC++ JIT compilation is not |
| 7 | +// supported for the Nvidia case, although some JIT optimizations are performed |
| 8 | +// at the level of the PTX assembly code. |
| 9 | + |
| 10 | +#include <CL/sycl.hpp> |
| 11 | + |
| 12 | +using namespace sycl; |
| 13 | +using namespace sycl::ext::oneapi::experimental::matrix; |
| 14 | + |
| 15 | +// Example usage of Nvidia matrix multiply. |
| 16 | +// Optimizations such as memory paddings for avoiding bank conflicts are not |
| 17 | +// included in this test which aids clarity for what is going on. This example |
| 18 | +// forms a "Big matrix" corresponding to a single "TILE" using cuda example |
| 19 | +// terminology. Multiple TILES can be used to construct yet larger matrices. |
| 20 | +// This example uses row_major a, b, and accumulator matrices. |
| 21 | + |
| 22 | +// M, N, K define the unit sizes of dimensions of the three types (a, b, |
| 23 | +// accumulator) of matrices per subgroup operation. |
| 24 | +constexpr int M = 8; // number of rows of "C"/"D" (Accumulator) sub-matrices, |
| 25 | + // number of cols of "B" sub-matrix. |
| 26 | +constexpr int N = 8; // number of cols of "C"/"D" (Accumulator) sub-matrices, |
| 27 | + // number of rows of "A" sub-matrix. |
| 28 | +constexpr int K = |
| 29 | + 4; // number of cols of "A"/number of rows of "B" sub-matrices. |
| 30 | + |
| 31 | +constexpr int N_THREADS_PER_MATRIX_OP = |
| 32 | + 32; // the number of threads per MMA subgroup is always 32 for Nvidia. |
| 33 | + |
| 34 | +constexpr int SUB_TILES_M = |
| 35 | + 77; // number of submatrices per row of accumulator ("C", "D") matrices. |
| 36 | +constexpr int SUB_TILES_N = |
| 37 | + 55; // number of submatrices per col of accumulator matrices. |
| 38 | +constexpr int SUB_TILES_K = |
| 39 | + 257; // number of submatrices per col of "A"/per row of "B", matrices. |
| 40 | + |
| 41 | +constexpr int BIG_M = |
| 42 | + SUB_TILES_M * |
| 43 | + M; // total number of M dimension matrix elements for the "Big matrix". |
| 44 | +constexpr int BIG_N = |
| 45 | + SUB_TILES_N * |
| 46 | + N; // total number of N dimension matrix elements for the "Big matrix". |
| 47 | +constexpr int BIG_K = |
| 48 | + SUB_TILES_K * |
| 49 | + K; // total number of K dimension matrix elements for the "Big matrix". |
| 50 | + |
| 51 | +// The stride should equal the number of elements between consecutive leading |
| 52 | +// dimensions of the "Big matrix". e.g. number of elements per row if matrix is |
| 53 | +// indexed row major. The stride tells the implementation how many elements to |
| 54 | +// skip in memory matrix row/column multiplications. |
| 55 | +constexpr int STRIDE_A = BIG_K; // row major. If col major should equal BIG_M. |
| 56 | +constexpr int STRIDE_B = BIG_N; // row_major. If col major should equal BIG_K. |
| 57 | +constexpr int STRIDE_C = BIG_N; // row major. If col major should equal BIG_M. |
| 58 | + |
| 59 | +double A[BIG_M * BIG_K]; |
| 60 | +double B[BIG_K * BIG_N]; |
| 61 | +double C[BIG_M * BIG_N]; |
| 62 | +double D[BIG_M * BIG_N]; |
| 63 | + |
| 64 | +// returns correct (m,n) element of matrix D = A*B + C (assuming all matrices |
| 65 | +// are indexed row_major). |
| 66 | +double matrix_ref_mn(const int &m, const int &n) { |
| 67 | + double res = C[m * BIG_N + n]; |
| 68 | + |
| 69 | + for (int k = 0; k < BIG_K; k++) |
| 70 | + res += A[m * BIG_K + k] * B[k * BIG_N + n]; |
| 71 | + return res; |
| 72 | +} |
| 73 | + |
| 74 | +int main() { |
| 75 | + for (int i = 0; i < BIG_M * BIG_N; i++) { |
| 76 | + C[i] = i; |
| 77 | + D[i] = 0; |
| 78 | + } |
| 79 | + |
| 80 | + for (int i = 0; i < BIG_M * BIG_K; i++) { |
| 81 | + A[i] = i; |
| 82 | + } |
| 83 | + |
| 84 | + for (int i = 0; i < BIG_K * BIG_N; i++) { |
| 85 | + B[i] = i; |
| 86 | + } |
| 87 | + |
| 88 | + buffer<double, 1> bufA(A, range<1>(BIG_M * BIG_K)); |
| 89 | + buffer<double, 1> bufB(B, range<1>(BIG_K * BIG_N)); |
| 90 | + buffer<double, 1> bufC(C, range<1>(BIG_M * BIG_N)); |
| 91 | + buffer<double, 1> bufD(D, range<1>(BIG_M * BIG_N)); |
| 92 | + |
| 93 | + queue q; |
| 94 | + q.submit([&](handler &cgh) { |
| 95 | + auto accC = bufC.get_access<access::mode::read_write>(cgh); |
| 96 | + auto accA = bufA.get_access<access::mode::read_write>(cgh); |
| 97 | + auto accB = bufB.get_access<access::mode::read_write>(cgh); |
| 98 | + auto accD = bufD.get_access<access::mode::read_write>(cgh); |
| 99 | + |
| 100 | + range<2> LocalRange = {1, N_THREADS_PER_MATRIX_OP}; |
| 101 | + range<2> GlobalRange = {SUB_TILES_M, SUB_TILES_N * N_THREADS_PER_MATRIX_OP}; |
| 102 | + |
| 103 | + cgh.parallel_for<class imatrix>( |
| 104 | + nd_range<2>(GlobalRange, LocalRange), [= |
| 105 | + ](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { |
| 106 | + sycl::sub_group sg = item.get_sub_group(); |
| 107 | + |
| 108 | + const auto m = |
| 109 | + item.get_group() |
| 110 | + .get_id()[0]; // row id of current submatrix of BIG C matrix |
| 111 | + const auto n = |
| 112 | + item.get_group().get_id()[1]; // column id of current submatrix of |
| 113 | + // BIG C matrix |
| 114 | + |
| 115 | + joint_matrix<double, matrix_use::accumulator, M, N, |
| 116 | + matrix_layout::row_major> |
| 117 | + sub_c; |
| 118 | + |
| 119 | + joint_matrix<double, matrix_use::a, M, K, matrix_layout::row_major> |
| 120 | + sub_a; |
| 121 | + |
| 122 | + joint_matrix<double, matrix_use::b, K, N, matrix_layout::row_major> |
| 123 | + sub_b; |
| 124 | + |
| 125 | + joint_matrix_load(sg, sub_c, |
| 126 | + accC.get_pointer() + (m * M) * BIG_N + n * N, |
| 127 | + STRIDE_C); |
| 128 | + |
| 129 | + for (int k = 0; k < SUB_TILES_K; |
| 130 | + k += 1) // row/col id of current submatrix of BIG A/B matrices |
| 131 | + { |
| 132 | + joint_matrix_load(sg, sub_a, |
| 133 | + accA.get_pointer() + (k * K) + (m * M * BIG_K), |
| 134 | + STRIDE_A); |
| 135 | + |
| 136 | + joint_matrix_load(sg, sub_b, |
| 137 | + accB.get_pointer() + (k * K * BIG_N) + (n * N), |
| 138 | + STRIDE_B); |
| 139 | + |
| 140 | + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); |
| 141 | + } |
| 142 | + joint_matrix_store(sg, sub_c, |
| 143 | + accD.get_pointer() + (m * M) * BIG_N + n * N, |
| 144 | + STRIDE_C); |
| 145 | + }); |
| 146 | + }); |
| 147 | + |
| 148 | + const auto host_accessor = bufD.get_access<cl::sycl::access::mode::read>(); |
| 149 | + |
| 150 | + for (int m = 0; m < BIG_M; m++) |
| 151 | + for (int n = 0; n < BIG_N; n++) { |
| 152 | + assert(host_accessor[m * BIG_N + n] == matrix_ref_mn(m, n)); |
| 153 | + } |
| 154 | + |
| 155 | + return 0; |
| 156 | +}; |
0 commit comments