From 16c0bbed1cfe28010f4b17ed4a4b32fe9d1cc0f8 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 18 Nov 2021 00:36:54 +0800 Subject: [PATCH 1/8] [Matrix] Enable wi_slice for joint_matrix --- sycl/include/CL/__spirv/spirv_ops.hpp | 13 ++ .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 58 ++++++ sycl/test/matrix/matrix-int8-test-slice.cpp | 172 ++++++++++++++++++ 3 files changed, 243 insertions(+) create mode 100644 sycl/test/matrix/matrix-int8-test-slice.cpp diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index c8579d3f49a2e..951c1d66901c7 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -86,6 +86,19 @@ __spirv_JointMatrixSUMadINTEL( __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); +template +using __spirv_wi_slice_t = T __attribute__((ext_vector_type(0xffffff))); + +template +extern SYCL_EXTERNAL __spirv_wi_slice_t &__spirv_JointMatrixGetSliceData( + __spv::__spirv_JointMatrixINTEL *); + +template +extern SYCL_EXTERNAL size_t __spirv_JointMatrixGetSliceLength( + __spv::__spirv_JointMatrixINTEL *); + #ifndef __SPIRV_BUILTIN_DECLARATIONS__ #error \ "SPIR-V built-ins are not available. Please set -fdeclare-spirv-builtins flag." diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index e81881e52f6a7..1f5ba441bd37f 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -191,6 +191,64 @@ joint_matrix_mad(Group sg, joint_matrix &mA, PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + +#ifdef __clang__ +template +using wi_slice_t = T __attribute__((ext_vector_type(0xffffff))); +#else +template +using wi_slice_t __attribute__((vector_size(0xffffff))) = T; +#endif // __clang__ + +// dummy value for initializing wi_slice::data in host code. +wi_slice_t dummy_i32; +wi_slice_t dummy_i8; +wi_slice_t dummy_u8; +wi_slice_t dummy_u16; +wi_slice_t dummy_f32; + +template wi_slice_t &getDummy() {} +template <> wi_slice_t &getDummy() { return dummy_i32; } +template <> wi_slice_t &getDummy() { return dummy_i8; } +template <> wi_slice_t &getDummy() { return dummy_u8; } +template <> wi_slice_t &getDummy() { return dummy_f32; } +template <> wi_slice_t &getDummy() { return dummy_u16; } + +template +class wi_slice { + joint_matrix &M; + +public: + wi_slice(joint_matrix &Mat) + : M(Mat), +#ifdef __SYCL_DEVICE_ONLY__ + data(__spirv_JointMatrixGetSliceData(Mat.spvm)) { + } +#else + data(getDummy()) { + } +#endif // __SYCL_DEVICE_ONLY__ + wi_slice_t &data; + size_t length() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_JointMatrixGetSliceLength(M.spvm); +#else + throw runtime_error("wi_slice is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } +}; + +// TODO: must be a member function of joint_matrix class. +template +inline __SYCL_ALWAYS_INLINE wi_slice +joint_matrix_get_slice(joint_matrix &M) { + return wi_slice(M); +} + } // namespace experimental::matrix } // namespace oneapi } // namespace ext diff --git a/sycl/test/matrix/matrix-int8-test-slice.cpp b/sycl/test/matrix/matrix-int8-test-slice.cpp new file mode 100644 index 0000000000000..5e5f519f56c8c --- /dev/null +++ b/sycl/test/matrix/matrix-int8-test-slice.cpp @@ -0,0 +1,172 @@ +// RUN: %clangxx -fsycl -O2 %s -o %t.out +#include +#if (SYCL_EXT_ONEAPI_MATRIX == 2) +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define TILE_SZ 16 +#define TM (TILE_SZ-4) +#define TN (TILE_SZ-4) +#define TK (4 * TILE_SZ-16) + +#define SG_SZ 16 + +template struct big_matrix{ +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) { + } +}; + +template +void matrix_multiply(big_matrix &C, big_matrix &A, big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed layout, + // users need to specify the updated VNNI sizes along with the packed_b layout. + // By default, the layout is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + auto wi_slice_c = joint_matrix_get_slice(sub_c); // M.get_wi_slice() + for (int i = 0; i < wi_slice.length(); i++) { + wi_slice_c.data[i] *= 1; + } + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 4) * (N * 4) + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +int8_t A[MATRIX_M][MATRIX_K]; +int8_t B[MATRIX_K / 4][MATRIX_N * 4]; +int32_t C[MATRIX_M][MATRIX_N]; +int32_t D[MATRIX_M][MATRIX_N]; + +void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, + int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + char *va = (char *)(A_mem + m * K + k); + char *vb = (char *)(B_mem + k * N + n); + int acc = *(C_mem + m * N + n); + for (int i = 0; i < 4; i++) { + acc += (va[i] * vb[i]); + } + *(C_mem + m * N + n) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i+2*j; + } + } + for (int i = 0; i < MATRIX_K / 4; i++) { + for (int j = 0; j < MATRIX_N * 4; j++) { + B[i][j] = i+j; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1; + D[i][j] = 1; + } + } + + big_matrix MC((int32_t *)&C); + big_matrix MD((int32_t *)&D); + big_matrix MA((int8_t *)&A); + big_matrix MB((int8_t *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 4); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << C[i][j] << ", "; + std::cout << "\n"; + } + std::cout << std::endl; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << D[i][j] << ", "; + std::cout << "\n"; + } +} +#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) From 4803165a4b8336e04f146845d3962556de6ab22e Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Wed, 24 Nov 2021 21:15:54 +0800 Subject: [PATCH 2/8] Change implementation by using wi_elem --- sycl/include/CL/__spirv/spirv_ops.hpp | 15 +++- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 83 ++++++++++++------- sycl/test/matrix/matrix-int8-test-slice.cpp | 6 +- 3 files changed, 69 insertions(+), 35 deletions(-) diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 951c1d66901c7..b87139d52d240 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -96,8 +96,19 @@ extern SYCL_EXTERNAL __spirv_wi_slice_t &__spirv_JointMatrixGetSliceData( template -extern SYCL_EXTERNAL size_t __spirv_JointMatrixGetSliceLength( - __spv::__spirv_JointMatrixINTEL *); +extern SYCL_EXTERNAL +size_t __spirv_JointMatrixGetSliceLength(__spv::__spirv_JointMatrixINTEL*); + +template +extern SYCL_EXTERNAL T __spirv_JointMatrixGetSliceElem( + __spv::__spirv_JointMatrixINTEL *, size_t i); + +template +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +__spirv_JointMatrixSetSliceElem( + __spv::__spirv_JointMatrixINTEL *, size_t i, T val); #ifndef __SPIRV_BUILTIN_DECLARATIONS__ #error \ diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 1f5ba441bd37f..f7a19b50ddeff 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -44,6 +44,11 @@ template struct spv_scope_traits> { constexpr static auto value = __spv::Scope::Workgroup; }; +template +class wi_slice; + template @@ -58,6 +63,11 @@ struct joint_matrix { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + + inline __SYCL_ALWAYS_INLINE wi_slice + get_wi_slice() { + return wi_slice(*this); + } }; template using wi_slice_t = T __attribute__((ext_vector_type(0xffffff))); #else template -using wi_slice_t __attribute__((vector_size(0xffffff))) = T; +using wi_slice_t __attribute__((vector_size(0x800000))) = T; #endif // __clang__ -// dummy value for initializing wi_slice::data in host code. -wi_slice_t dummy_i32; -wi_slice_t dummy_i8; -wi_slice_t dummy_u8; -wi_slice_t dummy_u16; -wi_slice_t dummy_f32; - -template wi_slice_t &getDummy() {} -template <> wi_slice_t &getDummy() { return dummy_i32; } -template <> wi_slice_t &getDummy() { return dummy_i8; } -template <> wi_slice_t &getDummy() { return dummy_u8; } -template <> wi_slice_t &getDummy() { return dummy_f32; } -template <> wi_slice_t &getDummy() { return dummy_u16; } - template -class wi_slice { +class wi_elem { joint_matrix &M; + std::size_t idx; public: - wi_slice(joint_matrix &Mat) - : M(Mat), + wi_elem(joint_matrix &Mat, std::size_t i) + : M(Mat), idx(i) {} + operator T() { #ifdef __SYCL_DEVICE_ONLY__ - data(__spirv_JointMatrixGetSliceData(Mat.spvm)) { + return __spirv_JointMatrixGetSliceElem(M.spvm, idx); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ } + wi_elem &operator=(const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_JointMatrixSetSliceElem(M.spvm, idx, rhs); + return *this; #else - data(getDummy()) { + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ } + wi_elem &operator*=(const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_JointMatrixSetSliceElem( + M.spvm, idx, __spirv_JointMatrixGetSliceElem(M.spvm, idx) * rhs); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ - wi_slice_t &data; + } + // TODO: add other arithmetic operators +}; + +template +class wi_slice { + joint_matrix &M; + +public: + wi_slice(joint_matrix &Mat) : M(Mat) {} size_t length() { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_JointMatrixGetSliceLength(M.spvm); #else - throw runtime_error("wi_slice is not supported on host device.", + throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + wi_elem operator[](size_t i) { + return wi_elem(M, i); + } }; -// TODO: must be a member function of joint_matrix class. -template -inline __SYCL_ALWAYS_INLINE wi_slice -joint_matrix_get_slice(joint_matrix &M) { - return wi_slice(M); -} - } // namespace experimental::matrix } // namespace oneapi } // namespace ext diff --git a/sycl/test/matrix/matrix-int8-test-slice.cpp b/sycl/test/matrix/matrix-int8-test-slice.cpp index 5e5f519f56c8c..c4cf3a0fa649f 100644 --- a/sycl/test/matrix/matrix-int8-test-slice.cpp +++ b/sycl/test/matrix/matrix-int8-test-slice.cpp @@ -72,9 +72,9 @@ void matrix_multiply(big_matrix &C, big_matrix Date: Mon, 29 Nov 2021 18:07:30 +0800 Subject: [PATCH 3/8] Fix preCI's fail and address dounia's comments --- sycl/include/CL/__spirv/spirv_ops.hpp | 4 +- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 4 +- ...test-slice.cpp => matrix-elemwise-ops.cpp} | 45 ++++++++++--------- 3 files changed, 28 insertions(+), 25 deletions(-) rename sycl/test/matrix/{matrix-int8-test-slice.cpp => matrix-elemwise-ops.cpp} (83%) diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index b87139d52d240..9595c0fc2dd02 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -96,8 +96,8 @@ extern SYCL_EXTERNAL __spirv_wi_slice_t &__spirv_JointMatrixGetSliceData( template -extern SYCL_EXTERNAL -size_t __spirv_JointMatrixGetSliceLength(__spv::__spirv_JointMatrixINTEL*); +extern SYCL_EXTERNAL size_t __spirv_JointMatrixGetSliceLength( + __spv::__spirv_JointMatrixINTEL *); template diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index f7a19b50ddeff..39d2fcd7c4dfb 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -65,8 +65,8 @@ struct joint_matrix { } inline __SYCL_ALWAYS_INLINE wi_slice - get_wi_slice() { - return wi_slice(*this); + get_wi_data() { + return wi_slice(*this); } }; diff --git a/sycl/test/matrix/matrix-int8-test-slice.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp similarity index 83% rename from sycl/test/matrix/matrix-int8-test-slice.cpp rename to sycl/test/matrix/matrix-elemwise-ops.cpp index c4cf3a0fa649f..83aa24d36c5a6 100644 --- a/sycl/test/matrix/matrix-int8-test-slice.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -7,26 +7,28 @@ using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; #define TILE_SZ 16 -#define TM (TILE_SZ-4) -#define TN (TILE_SZ-4) -#define TK (4 * TILE_SZ-16) +#define TM (TILE_SZ - 4) +#define TN (TILE_SZ - 4) +#define TK (4 * TILE_SZ - 16) #define SG_SZ 16 -template struct big_matrix{ +template struct big_matrix { public: T *mat; public: T *get_data() { return mat; } void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) { - } + big_matrix(T *data) : mat(data) {} }; -template -void matrix_multiply(big_matrix &C, big_matrix &A, big_matrix &B) { +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { size_t M = NUM_ROWS_C; size_t N = NUM_COLS_C; size_t K = NUM_COLS_A; @@ -60,9 +62,10 @@ void matrix_multiply(big_matrix &C, big_matrix sub_a(sg); - // For B, since current implementation does not support non-packed layout, - // users need to specify the updated VNNI sizes along with the packed_b layout. - // By default, the layout is row_major and size is (TK, TN). + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). joint_matrix sub_b(sg); joint_matrix sub_c(sg); @@ -72,10 +75,6 @@ void matrix_multiply(big_matrix &C, big_matrix &C, big_matrix MC((int32_t *)&C); big_matrix MD((int32_t *)&D); big_matrix MA((int8_t *)&A); - big_matrix MB((int8_t *)&B); + big_matrix MB((int8_t *)&B); matrix_multiply(MC, MA, MB); matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, - MATRIX_N, MATRIX_K / 4); + MATRIX_N, MATRIX_K / 4); bool res = true; for (int i = 0; i < MATRIX_M; i++) { From d236a563c716b95dfe1407da4bf6d18413dd7441 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Mon, 29 Nov 2021 23:31:05 +0800 Subject: [PATCH 4/8] Remove useless wi_slice_t --- sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 39d2fcd7c4dfb..c7ea8613e1e07 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -202,14 +202,6 @@ joint_matrix_mad(Group sg, joint_matrix &mA, #endif // __SYCL_DEVICE_ONLY__ } -#ifdef __clang__ -template -using wi_slice_t = T __attribute__((ext_vector_type(0xffffff))); -#else -template -using wi_slice_t __attribute__((vector_size(0x800000))) = T; -#endif // __clang__ - template From a894d046f4fa2d7895e9b4c4832b3902b453d524 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 30 Nov 2021 10:05:59 +0800 Subject: [PATCH 5/8] Remove useless comments --- sycl/test/matrix/matrix-elemwise-ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 83aa24d36c5a6..4e08ef2eb0251 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -85,7 +85,7 @@ void matrix_multiply(big_matrix &C, sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - auto wi_slice_c = sub_c.get_wi_data(); // M.get_wi_data() + auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] *= 1; } From 93c3e22dd6ebe5a603c496f2a0df879637676403 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 21 Dec 2021 00:20:17 +0800 Subject: [PATCH 6/8] Address Dounia&Alexey's comments --- sycl/include/CL/__spirv/spirv_ops.hpp | 16 ++++---------- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 22 +++++++++---------- sycl/test/matrix/matrix-elemwise-ops.cpp | 1 + 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 9595c0fc2dd02..b80c3ebeb43ab 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -86,29 +86,21 @@ __spirv_JointMatrixSUMadINTEL( __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); -template -using __spirv_wi_slice_t = T __attribute__((ext_vector_type(0xffffff))); - -template -extern SYCL_EXTERNAL __spirv_wi_slice_t &__spirv_JointMatrixGetSliceData( - __spv::__spirv_JointMatrixINTEL *); - template -extern SYCL_EXTERNAL size_t __spirv_JointMatrixGetSliceLength( +extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL( __spv::__spirv_JointMatrixINTEL *); template -extern SYCL_EXTERNAL T __spirv_JointMatrixGetSliceElem( +extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic( __spv::__spirv_JointMatrixINTEL *, size_t i); template extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * -__spirv_JointMatrixSetSliceElem( - __spv::__spirv_JointMatrixINTEL *, size_t i, T val); +__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, + T val, size_t i); #ifndef __SPIRV_BUILTIN_DECLARATIONS__ #error \ diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index c7ea8613e1e07..ad4095a38c880 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -205,24 +205,24 @@ joint_matrix_mad(Group sg, joint_matrix &mA, template -class wi_elem { +class wi_element { joint_matrix &M; std::size_t idx; public: - wi_elem(joint_matrix &Mat, std::size_t i) + wi_element(joint_matrix &Mat, std::size_t i) : M(Mat), idx(i) {} operator T() { #ifdef __SYCL_DEVICE_ONLY__ - return __spirv_JointMatrixGetSliceElem(M.spvm, idx); + return __spirv_VectorExtractDynamic(M.spvm, idx); #else throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } - wi_elem &operator=(const T &rhs) { + wi_element &operator=(const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ - M.spvm = __spirv_JointMatrixSetSliceElem(M.spvm, idx, rhs); + M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); return *this; #else (void)rhs; @@ -230,10 +230,10 @@ class wi_elem { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } - wi_elem &operator*=(const T &rhs) { + wi_element &operator*=(const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ - M.spvm = __spirv_JointMatrixSetSliceElem( - M.spvm, idx, __spirv_JointMatrixGetSliceElem(M.spvm, idx) * rhs); + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) * rhs, idx); return *this; #else (void)rhs; @@ -253,14 +253,14 @@ class wi_slice { wi_slice(joint_matrix &Mat) : M(Mat) {} size_t length() { #ifdef __SYCL_DEVICE_ONLY__ - return __spirv_JointMatrixGetSliceLength(M.spvm); + return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm); #else throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } - wi_elem operator[](size_t i) { - return wi_elem(M, i); + wi_element operator[](size_t i) { + return wi_element(M, i); } }; diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 4e08ef2eb0251..319b95c5a2735 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -1,4 +1,5 @@ // RUN: %clangxx -fsycl -O2 %s -o %t.out +// XFAIL: * #include #if (SYCL_EXT_ONEAPI_MATRIX == 2) #include From 3ffe9590595e7cdfcbce44f450f8970234f43aa3 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 21 Dec 2021 22:55:10 +0800 Subject: [PATCH 7/8] Fix clang-format issue --- sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index ad4095a38c880..d6cd2e41ed308 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -210,7 +210,8 @@ class wi_element { std::size_t idx; public: - wi_element(joint_matrix &Mat, std::size_t i) + wi_element(joint_matrix &Mat, + std::size_t i) : M(Mat), idx(i) {} operator T() { #ifdef __SYCL_DEVICE_ONLY__ From 1d2dab3b22a2cc3bbbcf9cae20f9a6e40ac392d6 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 21 Dec 2021 23:31:24 +0800 Subject: [PATCH 8/8] choose a different number for elemwice multiplication --- sycl/test/matrix/matrix-elemwise-ops.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 319b95c5a2735..081b0f6dfb63a 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -86,10 +86,10 @@ void matrix_multiply(big_matrix &C, sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - auto wi_slice_c = sub_c.get_wi_data(); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] *= 1; - } + } + auto wi_slice_c = sub_c.get_wi_data(); + for (int i = 0; i < wi_slice_c.length(); i++) { + wi_slice_c[i] *= 2; } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + @@ -121,6 +121,7 @@ void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, } *(C_mem + m * N + n) = acc; } + *(C_mem + m * N + n) *= 2; } }