From 6bf0ce6f284148d01d773a865b9d4c0c0037b834 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 18 Mar 2022 16:45:21 +0800 Subject: [PATCH 01/14] [Matrix][SYCL] Add use argument for joint_matrix and add another feature macro for it --- sycl/include/CL/__spirv/spirv_ops.hpp | 76 ++- sycl/include/CL/__spirv/spirv_types.hpp | 13 +- sycl/include/CL/sycl/feature_test.hpp.in | 2 +- .../sycl/ext/oneapi/matrix/matrix-jit-use.hpp | 511 ++++++++++++++++++ .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 85 +-- .../include/sycl/ext/oneapi/matrix/matrix.hpp | 4 + .../ext/oneapi/matrix/static-query-use.hpp | 435 +++++++++++++++ 7 files changed, 1061 insertions(+), 65 deletions(-) create mode 100644 sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp create mode 100644 sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 3c6ce1c639960..fde9cdbca50e7 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -24,88 +24,106 @@ #ifdef __SYCL_DEVICE_ONLY__ template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); template extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL( - T *Ptr, __spv::__spirv_JointMatrixINTEL *Object, + T *Ptr, __spv::__spirv_JointMatrixINTEL *Object, std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixUUMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixUSMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixSUMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_CompositeConstruct(const T v); -template extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL( - __spv::__spirv_JointMatrixINTEL *); + __spv::__spirv_JointMatrixINTEL *); -template extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic( - __spv::__spirv_JointMatrixINTEL *, size_t i); + __spv::__spirv_JointMatrixINTEL *, size_t i); -template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * -__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, T val, size_t i); #ifndef __SPIRV_BUILTIN_DECLARATIONS__ diff --git a/sycl/include/CL/__spirv/spirv_types.hpp b/sycl/include/CL/__spirv/spirv_types.hpp index 4fb0fa8886796..0ba4ac0036c23 100644 --- a/sycl/include/CL/__spirv/spirv_types.hpp +++ b/sycl/include/CL/__spirv/spirv_types.hpp @@ -115,6 +115,13 @@ enum class MatrixLayout : uint32_t { PackedB = 3 }; +enum class MatrixUse : uint32_t { + MatrixA = 0, + MatrixB = 1, + Accumulator = 2, + Unnecessary = 3 +}; + // TODO: replace the following W/A with a better solution when we have it. // The following structure is used to represent the joint matrix type in the // LLVM IR. The structure has a pointer to a multidimensional array member which @@ -129,10 +136,12 @@ enum class MatrixLayout : uint32_t { // information to SPIRV translator. // The long term solution would be to introduce a matrix type in Clang and use // it instead of this member. -template struct __spirv_JointMatrixINTEL { - T (*Value)[R][C][static_cast(U) + 1][static_cast(S) + 1]; + T(*Value) + [R][C][static_cast(L) + 1][static_cast(U) + 1] + [static_cast(S) + 1]; }; } // namespace __spv diff --git a/sycl/include/CL/sycl/feature_test.hpp.in b/sycl/include/CL/sycl/feature_test.hpp.in index 6c58b98d29f1f..7af8e1b207452 100644 --- a/sycl/include/CL/sycl/feature_test.hpp.in +++ b/sycl/include/CL/sycl/feature_test.hpp.in @@ -38,7 +38,7 @@ namespace sycl { // 2- provides JIT implementation (target agnostic) for the // experimental matrix extension #ifndef SYCL_EXT_ONEAPI_MATRIX -#define SYCL_EXT_ONEAPI_MATRIX 2 +#define SYCL_EXT_ONEAPI_MATRIX 3 #endif #define SYCL_EXT_ONEAPI_ASSERT 1 #define SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS 1 diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp new file mode 100644 index 0000000000000..3d2ffd47d80eb --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -0,0 +1,511 @@ +//==------------------ matrix-jit-use.hpp - SYCL matrix ----------------*- C++ +//-*---==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +#pragma once + +#include +#include +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace ext { +namespace oneapi { +namespace experimental::matrix { + +enum class matrix_layout { row_major, col_major, packed_a, packed_b }; + +template struct spv_matrix_layout_traits { + static constexpr __spv::MatrixLayout value = __spv::MatrixLayout::RowMajor; +}; + +#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \ + template <> struct spv_matrix_layout_traits { \ + static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \ + }; + +SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::row_major, + __spv::MatrixLayout::RowMajor) +SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major, + __spv::MatrixLayout::ColumnMajor) +SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_a, __spv::MatrixLayout::PackedA) +SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_b, __spv::MatrixLayout::PackedB) + +enum class matrix_use { matrix_a, matrix_b, accumulator, unnecessary }; + +template struct spv_matrix_use_traits { + static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA; +}; + +#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \ + template <> struct spv_matrix_use_traits { \ + static constexpr __spv::MatrixUse value = SPV_USE; \ + }; + +SPV_MATRIX_USE_TRAITS(matrix_use::matrix_a, __spv::MatrixUse::MatrixA) +SPV_MATRIX_USE_TRAITS(matrix_use::matrix_b, __spv::MatrixUse::MatrixB) +SPV_MATRIX_USE_TRAITS(matrix_use::accumulator, __spv::MatrixUse::Accumulator) +SPV_MATRIX_USE_TRAITS(matrix_use::unnecessary, __spv::MatrixUse::Unnecessary) + +template struct spv_scope_traits {}; +template <> struct spv_scope_traits { + constexpr static auto value = __spv::Scope::Subgroup; +}; +template struct spv_scope_traits> { + constexpr static auto value = __spv::Scope::Workgroup; +}; + +template +class wi_slice; +template +struct joint_matrix { +public: + __spv::__spirv_JointMatrixINTEL::value, + spv_matrix_use_traits::value> *spvm; + joint_matrix(Group sg) { +#ifndef __SYCL_DEVICE_ONLY__ + (void)sg; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + inline __SYCL_ALWAYS_INLINE wi_slice + get_wi_data() { + return wi_slice(*this); + } +}; + +template +inline __SYCL_ALWAYS_INLINE void +joint_matrix_load(Group sg, + joint_matrix &res, + multi_ptr src, size_t stride, matrix_layout MemL) { +#ifdef __SYCL_DEVICE_ONLY__ + T *Ptr = src.get(); + switch (MemL) { + default: + assert(false && "Invalid Memory Layout!"); + case matrix_layout::row_major: + res.spvm = + __spirv_JointMatrixLoadINTEL::value, + spv_matrix_use_traits::value>( + Ptr, stride, __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); + break; + case matrix_layout::col_major: + res.spvm = + __spirv_JointMatrixLoadINTEL::value, + spv_matrix_use_traits::value>( + Ptr, stride, __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); + break; + case matrix_layout::packed_a: + res.spvm = + __spirv_JointMatrixLoadINTEL::value, + spv_matrix_use_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedA, + spv_scope_traits::value); + break; + case matrix_layout::packed_b: + res.spvm = + __spirv_JointMatrixLoadINTEL::value, + spv_matrix_use_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedB, + spv_scope_traits::value); + break; + } +#else + (void)sg; + (void)res; + (void)src; + (void)stride; + (void)MemL; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ +} + +template +inline __SYCL_ALWAYS_INLINE void +joint_matrix_store(Group sg, + joint_matrix &src, + multi_ptr res, size_t stride, matrix_layout MemL) { +#ifdef __SYCL_DEVICE_ONLY__ + T *Ptr = res.get(); + switch (MemL) { + default: + assert(false && "Invalid Memory Layout!"); + case matrix_layout::row_major: + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( + Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); + break; + case matrix_layout::col_major: + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( + Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); + break; + case matrix_layout::packed_a: + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( + Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, + spv_scope_traits::value); + break; + case matrix_layout::packed_b: + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( + Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB, + spv_scope_traits::value); + break; + } +#else + (void)sg; + (void)src; + (void)res; + (void)stride; + (void)MemL; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ +} + +template +inline __SYCL_ALWAYS_INLINE joint_matrix +joint_matrix_mad(Group sg, joint_matrix &mA, + joint_matrix &mB, + joint_matrix &mC) { +#ifdef __SYCL_DEVICE_ONLY__ + joint_matrix res(sg); + if constexpr (std::is_same::value && + std::is_same::value && + std::is_same::value) + res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else if constexpr (std::is_unsigned::value && std::is_unsigned::value) + res.spvm = __spirv_JointMatrixUUMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else if constexpr (std::is_signed::value && std::is_unsigned::value) + res.spvm = __spirv_JointMatrixSUMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else if constexpr (std::is_unsigned::value && std::is_signed::value) + res.spvm = __spirv_JointMatrixUSMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else + res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm); + return res; +#else + (void)sg; + (void)mA; + (void)mB; + (void)mC; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ +} + +template +inline __SYCL_ALWAYS_INLINE void +joint_matrix_fill(Group sg, + joint_matrix &res, + const T2 v) { + // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ + // functions + (void)sg; +#ifdef __SYCL_DEVICE_ONLY__ + res.spvm = __spirv_CompositeConstruct::value, + spv_matrix_use_traits::value>( + static_cast(v)); + +#else + (void)res; + (void)v; +#endif // __SYCL_DEVICE_ONLY__ +} + +template +class wi_element { + joint_matrix &M; + std::size_t idx; + +public: + wi_element(joint_matrix &Mat, + std::size_t i) + : M(Mat), idx(i) {} + operator T() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast(0); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + template wi_element &operator=(const T2 &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast(rhs), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element & + operator=(const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + +#if __SYCL_DEVICE_ONLY__ +#define OP(op) \ + template wi_element &operator op##=(const T2 &rhs) { \ + M.spvm = __spirv_VectorInsertDynamic( \ + M.spvm, \ + static_cast(__spirv_VectorExtractDynamic(M.spvm, idx) \ + op static_cast(rhs)), \ + idx); \ + return *this; \ + } +#else // __SYCL_DEVICE_ONLY__ +#define OP(op) \ + template wi_element &operator op##=(const T2 &rhs) { \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_INVALID_DEVICE); \ + } +#endif // __SYCL_DEVICE_ONLY__ + OP(+) + OP(-) + OP(*) + OP(/) +#undef OP +}; + +// Note that similarly to the other matrix functions, uint16_t is used here to +// represent bf16 type. Since the AMX and DPAS implementations don't support +// uint16_t, this interpretation is possible. This design choice was made before +// the introduction of SYCL experimental bfloat16 type. Our plan is to move +// towards using the SYCL bfloat16. But since it is still experimental, we will +// probably keep both uint16 interpretation and SYCL bfloat16. +template +class wi_element { + joint_matrix &M; + std::size_t idx; + +public: + wi_element(joint_matrix &Mat, + std::size_t i) + : M(Mat), idx(i) {} + operator uint16_t() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return std::fabs(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx))) >= + std::numeric_limits::epsilon(); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator=( + const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + // We use here the following functions for conversion (bf16=>fp32 and + // fp32=>bf16). This is a workaround until we are able to use + // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are + // supported in the CPU backend + static float make_fp32(uint16_t x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; + } + + static uint16_t make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (uint16_t)*res; + } + +#if __SYCL_DEVICE_ONLY__ +#define OP(op) \ + wi_element &operator op##=(const uint16_t &rhs) { \ + M.spvm = __spirv_VectorInsertDynamic( \ + M.spvm, \ + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx) \ + op make_fp32(rhs))), \ + idx); \ + return *this; \ + } +#else // __SYCL_DEVICE_ONLY__ +#define OP(op) \ + wi_element &operator op##=(const uint16_t &rhs) { \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_INVALID_DEVICE); \ + } +#endif // __SYCL_DEVICE_ONLY__ + OP(+) + OP(-) + OP(*) + OP(/) +#undef OP + + template struct Converter { + static T2 convert(const T1 &from) { return static_cast(from); } + }; + + template struct Converter { + static uint16_t convert(const T &from) { return make_bf16(from); } + }; +#if __SYCL_DEVICE_ONLY__ +#define OP(input_type, type, op) \ + friend type operator op( \ + const wi_element &lhs, \ + const uint16_t &rhs) { \ + return Converter::convert(make_fp32( \ + __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) op make_fp32(rhs)); \ + } \ + friend type operator op( \ + const uint16_t &lhs, \ + const wi_element &rhs) { \ + return Converter::convert(make_fp32( \ + __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \ + } +#else // __SYCL_DEVICE_ONLY__ +#define OP(input_type, type, op) \ + friend type operator op( \ + const wi_element &lhs, \ + const uint16_t &rhs) { \ + (void)lhs; \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_INVALID_DEVICE); \ + } \ + friend type operator op( \ + const uint16_t &lhs, \ + const wi_element &rhs) { \ + (void)lhs; \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_INVALID_DEVICE); \ + } +#endif // __SYCL_DEVICE_ONLY__ + OP(float, uint16_t, +) + OP(float, uint16_t, -) + OP(float, uint16_t, *) + OP(float, uint16_t, /) + OP(bool, bool, ==) + OP(bool, bool, !=) + OP(bool, bool, <) + OP(bool, bool, >) + OP(bool, bool, <=) + OP(bool, bool, >=) +#undef OP +}; + +template +class wi_slice { + joint_matrix &M; + +public: + wi_slice(joint_matrix &Mat) + : M(Mat) {} + size_t length() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + wi_element operator[](size_t i) { + return wi_element(M, i); + } +}; + +} // namespace experimental::matrix +} // namespace oneapi +} // namespace ext +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index fa467934c0cfd..724662bc03e34 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -36,6 +36,21 @@ SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major, SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_a, __spv::MatrixLayout::PackedA) SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_b, __spv::MatrixLayout::PackedB) +enum class matrix_use { matrix_a, matrix_b, accumulator, unnecessary }; + +template struct spv_matrix_use_traits { + static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA; +}; + +#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \ + template <> struct spv_matrix_use_traits { \ + static constexpr __spv::MatrixUse value = SPV_USE; \ + }; + +SPV_MATRIX_USE_TRAITS(matrix_use::matrix_a, __spv::MatrixUse::MatrixA) +SPV_MATRIX_USE_TRAITS(matrix_use::matrix_b, __spv::MatrixUse::MatrixB) +SPV_MATRIX_USE_TRAITS(matrix_use::accumulator, __spv::MatrixUse::Accumulator) +SPV_MATRIX_USE_TRAITS(matrix_use::unnecessary, __spv::MatrixUse::Unnecessary) template struct spv_scope_traits {}; template <> struct spv_scope_traits { constexpr static auto value = __spv::Scope::Subgroup; @@ -55,7 +70,8 @@ template ::value> *spvm; + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value> *spvm; joint_matrix(Group sg) { #ifndef __SYCL_DEVICE_ONLY__ (void)sg; @@ -83,32 +99,32 @@ joint_matrix_load(Group sg, default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: - res.spvm = - __spirv_JointMatrixLoadINTEL::value>( - Ptr, stride, __spv::MatrixLayout::RowMajor, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( + Ptr, stride, __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); break; case matrix_layout::col_major: - res.spvm = - __spirv_JointMatrixLoadINTEL::value>( - Ptr, stride, __spv::MatrixLayout::ColumnMajor, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( + Ptr, stride, __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); break; case matrix_layout::packed_a: - res.spvm = - __spirv_JointMatrixLoadINTEL::value>( - Ptr, stride, __spv::MatrixLayout::PackedA, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedA, + spv_scope_traits::value); break; case matrix_layout::packed_b: - res.spvm = - __spirv_JointMatrixLoadINTEL::value>( - Ptr, stride, __spv::MatrixLayout::PackedB, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedB, + spv_scope_traits::value); break; } #else @@ -135,26 +151,30 @@ joint_matrix_store(Group sg, default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: - __spirv_JointMatrixStoreINTEL::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; case matrix_layout::col_major: - __spirv_JointMatrixStoreINTEL::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; case matrix_layout::packed_a: - __spirv_JointMatrixStoreINTEL::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; case matrix_layout::packed_b: - __spirv_JointMatrixStoreINTEL::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB, spv_scope_traits::value); break; @@ -212,10 +232,9 @@ joint_matrix_fill(Group sg, // functions (void)sg; #ifdef __SYCL_DEVICE_ONLY__ - res.spvm = - __spirv_CompositeConstruct::value>( - static_cast(v)); + res.spvm = __spirv_CompositeConstruct< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_matrix_use_traits::value>(static_cast(v)); #else (void)res; diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp index 52ca60eab19b9..c6c9500682595 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp @@ -26,5 +26,9 @@ #include #endif #if (SYCL_EXT_ONEAPI_MATRIX == 3) +#include +#include +#endif +#if (SYCL_EXT_ONEAPI_MATRIX == 4) #include #endif diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp new file mode 100644 index 0000000000000..490f34c20cf9a --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -0,0 +1,435 @@ +//===-------------- static-query.hpp - SYCL matrix ------------*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // +// This file implements the static query interface for the joint_matrix +// experimental extension. AMX, DPAS and different other TPUs support different +// logical sizes and types. The query interface is used to validate user code +// and inform them about supported types, sizes, scope, and layouts by the +// current implementation. Note that this query interface is a compile-time +// query, so there will be no runtime errors. The query interface provides +// three functionalities: +// 1- At compile time, inform the user whether a specific +// combination is valid or not. +// 2- Construct the matrices using a default shape +// if user does not provide a combination +// 3- General query interface for sizes, types, +// static/dynamic, scope. This is needed to void padding by the user, +// for tuning, and efficient code generation if used by a library. + +#pragma once + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace ext { +namespace oneapi { +namespace experimental::matrix { + +enum class tpu { + dpas, + amx, +}; +enum class matrix_type { + bf8, + bf16, + fp16, + fp19, // tfloat32 + fp32, + fp64, + sint2, + sint4, + sint8, + sint16, + sint32, + sint64, + uint2, + uint4, + uint8, + uint16, + uint32, + uint64 +}; + +enum class scope_t { sub_group, work_group }; + +template +struct tpu_params; + +#if __cplusplus >= 201703L +template +constexpr bool is_combination_valid_amx(int M, int N, int K) { + // is_same_v is a C++17 feature + if ((std::is_same_v && std::is_same_v && + std::is_same_v && M <= 16 && N <= 16 && K <= 64) || + (std::is_same_v && std::is_same_v && + std::is_same_v && M <= 16 && N <= 16 && K <= 64) || + (std::is_same_v && std::is_same_v && + std::is_same_v && M <= 16 && N <= 16 && K <= 64) || + (std::is_same_v && std::is_same_v && + std::is_same_v && M <= 16 && N <= 16 && K <= 64) || + // bf16 + (std::is_same_v && + std::is_same_v && std::is_same_v && + M <= 16 && N <= 16 && K <= 32)) + return true; + else + return false; +} + +template +constexpr bool are_types_valid_amx() { + if ((std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v && std::is_same_v)) + return true; + else + return false; +} +#endif + +// General query: +// types are not given, no default sizes and no implicit matrix construction +template +struct tpu_params { + static constexpr std::size_t defaultM = -1; // depends on the type + static constexpr std::size_t defaultN = -1; + static constexpr std::size_t defaultK = -1; + + bool dynamic_p = false; // should be true in future implementations because + // AMX hardware supports dynamic sizes + uint32_t numtiles = 8; + scope_t scope = scope_t::sub_group; + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type ctype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + using mt = matrix_type; + static constexpr combination combinations[] = { + {16, 16, 64, mt::sint8, mt::sint8, mt::sint32}, + {16, 16, 64, mt::sint8, mt::uint8, mt::sint32}, + {16, 16, 64, mt::uint8, mt::sint8, mt::sint32}, + {16, 16, 64, mt::uint8, mt::uint8, mt::sint32}, + {16, 16, 32, mt::bf16, mt::bf16, mt::fp32}}; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +#if __cplusplus >= 201703L +// Sizes-only query +// Specialization for when only types are given, need to query only sizes +template +struct tpu_params && + !std::is_same_v && + !std::is_same_v)>::type> { + static_assert((are_types_valid_amx()), + "Invalid types for AMX, supported types are int8_t, uint8_t, " + "and bf16 (Note that unsigned short should be used in the" + "DPC++ code to implement bf16) "); + + // construct the matrices using the default sizes + static constexpr std::size_t defaultM = 16; + static constexpr std::size_t defaultN = 16; + static constexpr std::size_t defaultK = ((sizeof(Ta) == 1) ? 64 : 32); + + template + using joint_matrix_a = + joint_matrix; + template + using joint_matrix_b = + joint_matrix; + template + using joint_matrix_c = + joint_matrix; + + bool dynamic_p = false; // should be true in future implementations because + // AMX hardware supports dynamic sizes + uint32_t numtiles = 8; + scope_t scope = scope_t::sub_group; + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type ctype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + static constexpr combination combinations[] = { + {16, 16, (sizeof(Ta) == 1) ? 64 : 32}}; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +// Valid or not: +// Specialization when both types and sizes are given +template +struct tpu_params< + tpu::amx, Ta, Tb, Tc, M, N, K, + typename std::enable_if<( + !std::is_same_v && !std::is_same_v && + !std::is_same_v && M != 0 && N != 0 && K != 0)>::type> { + // Validate that parameters are supported + static_assert( + (M == 0 && N == 0 && K == 0) || + (is_combination_valid_amx(M, N, K)), + "Invalid parameters for AMX, query valid types and maximum sizes " + "using: tpu_params myparams; and then check out " + "myparams.combinations array"); + + // if combination is valid, construct the matrices + + static constexpr std::size_t defaultM = (M != 0) ? M : 16; + static constexpr std::size_t defaultN = (N != 0) ? N : 16; + static constexpr std::size_t defaultK = + (K != 0) ? K : ((sizeof(Ta) == 1) ? 64 : 32); + + template + using joint_matrix_a = + joint_matrix; + template + using joint_matrix_b = + joint_matrix; + template + using joint_matrix_c = + joint_matrix; + + bool dynamic_p = false; // should be true in future implementations + // because AMX hardware supports dynamic sizes + uint32_t numtiles = 8; + scope_t scope = scope_t::sub_group; +}; + +// DPAS case +// The DPAS implementation supports the logical capability support of the HW +// So in this case, M, N, K sizes returned by the query represent the logical +// capabilities of the DPAS hardware. + +template +constexpr bool is_combination_valid_dpas(int M, int N, int K) { + if ((std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 16) || + (std::is_same_v && + std::is_same_v && std::is_same_v && + (M == 1 || M == 2 || M == 4 || M == 8) && N == 8 && K == 16)) + return true; + else + return false; +} + +template +constexpr bool are_types_valid_dpas() { + if ((std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v && std::is_same_v)) + return true; + else + return false; +} +#endif + +// General Query +// specialization for when types are not given --> no default values +template +struct tpu_params { + static constexpr std::size_t defaultM = -1; // depends on the type + static constexpr std::size_t defaultN = -1; + static constexpr std::size_t defaultK = -1; + + bool dynamic_p = false; // no dynamic allocation on the GPU + uint32_t numtiles = -1; // does not apply for DPAS + scope_t scope = scope_t::sub_group; + + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type ctype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + using mt = matrix_type; + static constexpr combination combinations[] = { + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 1, 8, 32}, + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 2, 8, 32}, + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 4, 8, 32}, + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 8, 8, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 1, 8, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 2, 8, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 4, 8, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 8, 8, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 1, 8, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 2, 8, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 4, 8, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 8, 8, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 1, 8, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 2, 8, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 4, 8, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 8, 8, 32}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 1, 8, 16}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 2, 8, 16}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 4, 8, 16}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 8, 8, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 1, 8, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 2, 8, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 4, 8, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 8, 8, 16}, + }; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +// Sizes-only query: +// Specialization for when only types are given, need to query only sizes + +#if __cplusplus >= 201703L +template +struct tpu_params && + !std::is_same_v && + !std::is_same_v)>::type> { + static_assert((are_types_valid_dpas()), + "Invalid types for DPAS, supported types are int8_t, uint8_t, " + "half, and bf16 (Note that unsigned short should be used in the" + "DPC++ code to implement bf16)"); + + // construct the matrices using the default sizes + + static constexpr std::size_t defaultM = 8; + static constexpr std::size_t defaultN = 8; + static constexpr std::size_t defaultK = ((sizeof(Ta) == 1) ? 32 : 16); + + template + using joint_matrix_a = + joint_matrix; + template + using joint_matrix_b = + joint_matrix; + template + using joint_matrix_c = + joint_matrix; + + bool dynamic_p = false; // no dynamic allocation on the GPU + uint32_t numtiles = -1; // does not apply for DPAS + scope_t scope = scope_t::sub_group; + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type ctype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + using mt = matrix_type; + static constexpr combination combinations[] = { + // The types used in the initialization below are fake and not used. In + // this case, users already chose the types, they are only looking for the + // sizes + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 1, 8, (sizeof(Ta) == 1) ? 32 : 16}, + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 2, 8, (sizeof(Ta) == 1) ? 32 : 16}, + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 4, 8, (sizeof(Ta) == 1) ? 32 : 16}, + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 8, 8, (sizeof(Ta) == 1) ? 32 : 16}, + }; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +// Valid or not: +// Specialization when both types and sizes are given +template +struct tpu_params< + tpu::dpas, Ta, Tb, Tc, M, N, K, + typename std::enable_if<((!std::is_same_v && M != 0))>::type> { + // Validate that parameters are supported + static_assert((M == 0 && N == 0 && K == 0) || + (is_combination_valid_dpas(M, N, K)), + "Invalid parameters for DPAS, query valid combinations " + "using: tpu_params myparams; and then check out " + "myparams.combinations array"); + + // if combination is valid, construct the matrices + static constexpr std::size_t defaultM = (M != 0) ? M : 8; + static constexpr std::size_t defaultN = (N != 0) ? N : 8; + static constexpr std::size_t defaultK = + (K != 0) ? K : ((sizeof(Ta) == 1) ? 32 : 16); + + template + using joint_matrix_a = + joint_matrix; + template + using joint_matrix_b = + joint_matrix; + template + using joint_matrix_c = + joint_matrix; + + bool dynamic_p = false; // no dynamic allocation on the GPU + uint32_t numtiles = -1; // does not apply for DPAS + scope_t scope = scope_t::sub_group; +}; +#endif +} // namespace experimental::matrix +} // namespace oneapi +} // namespace ext +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) From 4f18f7ff6dc8a4a65f12e3ffe291fe17b8149185 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Wed, 23 Mar 2022 10:13:57 +0800 Subject: [PATCH 02/14] small fix for testcase and set them xfail --- sycl/include/CL/__spirv/spirv_types.hpp | 4 +- .../sycl/ext/oneapi/matrix/matrix-aot-amx.hpp | 448 ------------------ .../include/sycl/ext/oneapi/matrix/matrix.hpp | 9 +- sycl/test/matrix/matrix-amx-bf16-test.cpp | 192 -------- sycl/test/matrix/matrix-amx-int8-test.cpp | 176 ------- sycl/test/matrix/matrix-bf16-test-SG-16.cpp | 5 +- sycl/test/matrix/matrix-bf16-test.cpp | 5 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 5 +- sycl/test/matrix/matrix-int8-test-SG-16.cpp | 5 +- sycl/test/matrix/matrix-int8-test.cpp | 5 +- sycl/test/matrix/query.cpp | 2 +- 11 files changed, 15 insertions(+), 841 deletions(-) delete mode 100644 sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp delete mode 100644 sycl/test/matrix/matrix-amx-bf16-test.cpp delete mode 100644 sycl/test/matrix/matrix-amx-int8-test.cpp diff --git a/sycl/include/CL/__spirv/spirv_types.hpp b/sycl/include/CL/__spirv/spirv_types.hpp index 0ba4ac0036c23..427012bb064a8 100644 --- a/sycl/include/CL/__spirv/spirv_types.hpp +++ b/sycl/include/CL/__spirv/spirv_types.hpp @@ -140,8 +140,8 @@ template struct __spirv_JointMatrixINTEL { T(*Value) - [R][C][static_cast(L) + 1][static_cast(U) + 1] - [static_cast(S) + 1]; + [R][C][static_cast(L) + 1][static_cast(S) + 1] + [static_cast(U) + 1]; }; } // namespace __spv diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp deleted file mode 100644 index 21583b6a3e2f6..0000000000000 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp +++ /dev/null @@ -1,448 +0,0 @@ -//===------------ matrix-aot-amx.hpp - SYCL matrix ------------*- C++ -*---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// ===--------------------------------------------------------------------=== // -/// -/// We provide new interfaces for matrix muliply in this patch: -/// 1. A new class called joint_matrix is introduced, and the user needs to -/// specify the type of the elements, sizes, and the memory layout. -/// -/// 2. joint_matrix_load is used for loading data from main memory to tiles of -/// AMX or kernel's local memory. -/// -/// 3. joint_matrix_store is used for storing data tiles of AMX or kernel's -/// local memory to main memory. -/// -/// 4. joint_matrix_mad is used for the matrix multiply and add function. -/// It performs the multiply operation on the matrices A and B, accumulates the -/// result with C and returns the result. -/// -/// The following operation can be realized with the interfaces: -/// C = A*B+C -/// 1. All cases where A(int8, any-size, row_major), B(int8, any-size, -/// packed_b), C(int32, any-size, row_major) -/// 2. All cases where A(bf16, any-size, row_major), B(bf16, any-size, -/// packed_b), C(float, any-size, row_major) -/// -/// -// ===--------------------------------------------------------------------=== // - -#pragma once - -#include -#include - -__SYCL_INLINE_NAMESPACE(cl) { -namespace sycl { -namespace ext { -namespace intel { -namespace detail { -template class submatrix { -public: - _tile1024i tile; - short rows, cols; -}; - -// TODO: we are adding it this way until sycl::dynamic_extent gets implemented. -constexpr size_t dynamic_extent = std::numeric_limits::max(); - -template struct elems_per_dword { - static constexpr size_t value = 1; -}; - -#define ELEMS_PER_DWORD(TYPE, NUM) \ - template <> struct elems_per_dword { \ - static constexpr size_t value = NUM; \ - }; - -ELEMS_PER_DWORD(int8_t, 4) -ELEMS_PER_DWORD(unsigned short, 2) - -} // namespace detail - -namespace experimental::matrix { -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL extern "C" _tile1024i -_tileloadd64_internal(short row, short col, char *buf, size_t stride); -SYCL_EXTERNAL extern "C" _tile1024i -_tdpbssd_internal(unsigned short m, unsigned short n, unsigned short k, - _tile1024i dst, _tile1024i src1, _tile1024i src2); -SYCL_EXTERNAL extern "C" _tile1024i -_tdpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, - _tile1024i dst, _tile1024i src1, _tile1024i src2); -SYCL_EXTERNAL extern "C" void _tilestored64_internal(short row, short col, - char *buf, size_t stride, - _tile1024i tile); -static _tile1024i tileloadd64_internal(short row, short col, char *buf, - size_t stride) { - return _tileloadd64_internal(row, col, buf, stride); -} -static _tile1024i tdpbssd_internal(unsigned short m, unsigned short n, - unsigned short k, _tile1024i dst, - _tile1024i src1, _tile1024i src2) { - return _tdpbssd_internal(m, n, k, dst, src1, src2); -} -static _tile1024i tdpbf16ps_internal(unsigned short m, unsigned short n, - unsigned short k, _tile1024i dst, - _tile1024i src1, _tile1024i src2) { - return _tdpbf16ps_internal(m, n, k, dst, src1, src2); -} -static void tilestored64_internal(short row, short col, char *buf, - size_t stride, _tile1024i tile) { - return _tilestored64_internal(row, col, buf, stride, tile); -} -#else -static _tile1024i tileloadd64_internal(short row, short col, char *buf, - size_t stride) { - return __builtin_ia32_tileloadd64_internal(row, col, buf, stride); -} -static _tile1024i tdpbssd_internal(unsigned short m, unsigned short n, - unsigned short k, _tile1024i dst, - _tile1024i src1, _tile1024i src2) { - return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2); -} -static _tile1024i tdpbf16ps_internal(unsigned short m, unsigned short n, - unsigned short k, _tile1024i dst, - _tile1024i src1, _tile1024i src2) { - return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2); -} -static void tilestored64_internal(short row, short col, char *buf, - size_t stride, _tile1024i tile) { - __builtin_ia32_tilestored64_internal(row, col, buf, stride, tile); -} -#endif - -enum class matrix_layout { row_major, col_major, packed_a, packed_b }; - -inline constexpr size_t tile_size = 16; - -template -struct joint_matrix { - joint_matrix(Group sg) {} - joint_matrix(Group sg, size_t Size) { - static_assert((NumRows != detail::dynamic_extent && - NumCols != detail::dynamic_extent), - "AMX implementation does not support dynamic allocation"); - } - joint_matrix(Group sg, size_t Rows, size_t Cols) { - static_assert((NumRows != detail::dynamic_extent && - NumCols != detail::dynamic_extent), - "AMX implementation does not support dynamic allocation"); - } -}; - -// This template specialization handles cases where matrix can't be accommodated -// by a tile. In this case, we create raw_storage for the matrix and the size -// is the multiply of (TILE*TILE*4). -template -struct joint_matrix< - Group, T, NumRows, NumCols, Layout, - typename std::enable_if::type> { -public: - // trows: Num of tiles in row. - // If T=int8, NumRows==33, trows should be 3=(33+15)/16 - static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size; - // tcols: Num of tiles in column. - static constexpr size_t tcols = - (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size; - // if T=int8, NumRows==33, NumCols==33*4, tile_size==16, then size of - // raw_storage should be 48*48*4. - // FIXME: Greedy Regalloc for tile seems has some limitation and currently we - // do tileload for (16,16*4) instead of varying shapes, so raw_storage's size - // is multiple of (16*16*4) - static constexpr size_t size = trows * tcols * tile_size * tile_size * 4; - // stride is aligned to T instead of int8 - static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T); - int8_t raw_storage[size]; - static constexpr bool isSmall = false; - -public: - matrix_layout layout; - // We do zero-padding for matrix whose size is not fitted into tiles in ctor. - joint_matrix(Group sg) { memset(raw_storage, 0x00, size); } -}; - -// This template specialization handles cases where matrix can be put into a -// tile and users specify layout is packed_a or packed_b -template -struct joint_matrix< - Group, T, NumRows, NumCols, Layout, - typename std::enable_if<(NumRows <= tile_size) && - (NumCols * sizeof(T) / 4 <= tile_size)>::type> { -public: - static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size; - // tcols: Num of tiles in column. - static constexpr size_t tcols = - (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size; - static constexpr size_t size = trows * tcols * tile_size * tile_size * 4; - // stride is aligned to T instead of int8 - static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T); - _tile1024i tile; - static constexpr bool isSmall = true; - matrix_layout layout; - // We do zero-padding for matrix whose size is not fitted into tiles in ctor. - joint_matrix(Group sg) {} -}; - -} // namespace experimental::matrix - -namespace detail { - -using namespace experimental; - -template -inline __SYCL_ALWAYS_INLINE static - typename std::enable_if<(NumRows > matrix::tile_size) || - (NumCols * sizeof(T) / 4 > matrix::tile_size), - void>::type - submatrix_load(detail::submatrix &sub_m, - matrix::joint_matrix jm, - uint32_t row, uint32_t col, size_t stride, - matrix::matrix_layout layout, bool shouldreload) { - uint32_t offset = (row * stride + col); - T *ptr = reinterpret_cast(jm.raw_storage); - ptr += offset; - stride *= sizeof(T); - sub_m.rows = matrix::tile_size; - sub_m.cols = matrix::tile_size * 4; - sub_m.tile = matrix::tileloadd64_internal( - sub_m.rows, sub_m.cols, reinterpret_cast(ptr), stride); -} - -template -inline __SYCL_ALWAYS_INLINE static - typename std::enable_if<(NumRows <= matrix::tile_size) && - (NumCols * sizeof(T) / 4 <= matrix::tile_size), - void>::type - submatrix_load(detail::submatrix &sub_m, - matrix::joint_matrix &jm, - uint32_t row, uint32_t col, size_t stride, - matrix::matrix_layout layout, bool shouldreload) { - if (shouldreload) { - // Force sub_m.tile's shape to be matrix::tile_size * - // matrix::tile_size * 4 - int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4]; - matrix::tilestored64_internal(NumRows, NumCols * sizeof(T), - reinterpret_cast(NewjmC), - matrix::tile_size * 4, jm.tile); - sub_m.rows = matrix::tile_size; - sub_m.cols = matrix::tile_size * 4; - sub_m.tile = matrix::tileloadd64_internal(sub_m.rows, sub_m.cols, - reinterpret_cast(NewjmC), - matrix::tile_size * 4); - return; - } - sub_m.rows = NumRows; - sub_m.cols = NumCols * sizeof(T); - sub_m.tile = jm.tile; -} - -// This handles cases where T1 is int8, T2 is int32. -inline __SYCL_ALWAYS_INLINE static void -submatrix_mad(detail::submatrix &sub_ma, - detail::submatrix &sub_mb, - detail::submatrix &sub_mc) { - sub_mc.tile = matrix::tdpbssd_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols, - sub_mc.tile, sub_ma.tile, sub_mb.tile); -} - -// This handles cases where T1 is int16(bfloat16), T2 is float. -inline __SYCL_ALWAYS_INLINE static void -submatrix_mad(detail::submatrix &sub_ma, - detail::submatrix &sub_mb, - detail::submatrix &sub_mc) { - sub_mc.tile = - matrix::tdpbf16ps_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols, - sub_mc.tile, sub_ma.tile, sub_mb.tile); -} - -template -inline __SYCL_ALWAYS_INLINE static - typename std::enable_if<(NumRows > matrix::tile_size) || - (NumCols * sizeof(T) / 4 > matrix::tile_size), - void>::type - submatrix_store(detail::submatrix &sub_m, - matrix::joint_matrix &jm, - uint32_t row, uint32_t col, size_t stride, - matrix::matrix_layout layout, bool shouldreload) { - uint32_t offset = (row * stride + col); - T *ptr = reinterpret_cast(jm.raw_storage); - ptr += offset; - stride *= sizeof(T); - matrix::tilestored64_internal(sub_m.rows, sub_m.cols, - reinterpret_cast(ptr), stride, - sub_m.tile); -} - -template -inline __SYCL_ALWAYS_INLINE static - typename std::enable_if<(NumRows <= matrix::tile_size) && - (NumCols * sizeof(T) / 4 <= matrix::tile_size), - void>::type - submatrix_store(detail::submatrix &sub_m, - matrix::joint_matrix &jm, - uint32_t row, uint32_t col, size_t stride, - matrix::matrix_layout layout, bool shouldreload) { - if (shouldreload) { - int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4]; - matrix::tilestored64_internal(matrix::tile_size, matrix::tile_size * 4, - reinterpret_cast(NewjmC), - matrix::tile_size * 4, sub_m.tile); - jm.tile = matrix::tileloadd64_internal(NumRows, NumCols * sizeof(T), - reinterpret_cast(NewjmC), - matrix::tile_size * 4); - return; - } - jm.tile = sub_m.tile; -} - -} // namespace detail - -namespace experimental::matrix { - -// This handles cases where matrix can't be accommodated by a tile -template -inline __SYCL_ALWAYS_INLINE typename std::enable_if< - (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type -joint_matrix_load(Group sg, - joint_matrix &jm, - multi_ptr src, size_t stride, - matrix_layout layout) { - T *mem = src.get(); - // memcpy from mem to jm.raw_storage - for (int i = 0; i < NumRows; ++i) { - char *srcptr = reinterpret_cast(mem) + i * stride * sizeof(T); - char *dstptr = - reinterpret_cast(jm.raw_storage) + i * jm.stride * sizeof(T); - // TODO: we may reformat layout. - memcpy(dstptr, srcptr, NumCols * sizeof(T)); - } - jm.layout = layout; -} - -// This handles cases where matrix can be put into a tile -template -inline __SYCL_ALWAYS_INLINE - typename std::enable_if<(NumRows <= tile_size) && - (NumCols * sizeof(T) / 4 <= tile_size), - void>::type - joint_matrix_load(Group sg, - joint_matrix &jm, - multi_ptr src, size_t stride, - matrix_layout layout) { - T *mem = src.get(); - // tileload happens! - jm.tile = - tileloadd64_internal(NumRows, NumCols * sizeof(T), - reinterpret_cast(mem), stride * sizeof(T)); - jm.layout = layout; -} - -// This handles cases where matrix can't be accommodated by a tile -template -inline __SYCL_ALWAYS_INLINE typename std::enable_if< - (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type -joint_matrix_store(Group sg, - joint_matrix &jm, - multi_ptr dst, size_t stride, - matrix_layout layout) { - T *mem = dst.get(); - for (int i = 0; i < NumRows; ++i) { - char *dstptr = reinterpret_cast(mem) + i * stride * sizeof(T); - char *srcptr = - reinterpret_cast(jm.raw_storage) + i * jm.stride * sizeof(T); - // TODO: we may reformat layout. - memcpy(dstptr, srcptr, NumCols * sizeof(T)); - } - return; -} - -// This handles cases where matrix can be put into a tile -template -inline __SYCL_ALWAYS_INLINE - typename std::enable_if<(NumRows <= tile_size) && - (NumCols * sizeof(T) / 4 <= tile_size), - void>::type - joint_matrix_store(Group sg, - joint_matrix &jm, - multi_ptr dst, size_t stride, - matrix_layout layout) { - T *mem = dst.get(); - // tilestore happens! - tilestored64_internal(NumRows, NumCols * sizeof(T), - reinterpret_cast(mem), stride * sizeof(T), - jm.tile); - return; -} - -template -inline __SYCL_ALWAYS_INLINE typename std::enable_if< - ((std::is_same::value && std::is_same::value) || - (std::is_same::value && - std::is_same::value)) && - (LayoutA == matrix_layout::row_major) && - (LayoutB == matrix_layout::packed_b) && - (LayoutC == matrix_layout::row_major), - joint_matrix>::type -joint_matrix_mad(Group sg, - joint_matrix &jmA, - joint_matrix &jmB, - joint_matrix &jmC) { - joint_matrix res(jmC); - constexpr size_t epd = detail::elems_per_dword::value; - // If A is large and C is small, in joint_matrix_load, we do memcpy for A, and - // we do tileload for C whose shape is not tile_size*tile_size*4. In - // joint_matrix_mad, we do tileload for A and shape is tile_size*tile_size*4. - // So we need to reshape C before we do dpbssd. - bool Cshouldreload = res.isSmall && !jmA.isSmall && !jmB.isSmall; - bool Ashouldreload = jmA.isSmall && !jmB.isSmall; - bool Bshouldreload = jmB.isSmall && !jmA.isSmall; - - for (int m = 0; m < res.trows; ++m) { - for (int n = 0; n < res.tcols; ++n) { - detail::submatrix sub_c; - - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - submatrix_load(sub_c, res, m * tile_size, n * tile_size, res.stride, - matrix_layout::row_major, Cshouldreload); - for (int k = 0; k < jmA.tcols; ++k) { // K->int8_t - detail::submatrix sub_a; - detail::submatrix sub_b; - submatrix_load(sub_a, jmA, m * tile_size, k * tile_size * epd, - jmA.stride, matrix_layout::packed_a, Ashouldreload); - // Assume we alreay in vnni format. - submatrix_load(sub_b, jmB, k * tile_size, n * tile_size * epd, - jmB.stride, matrix_layout::packed_b, Bshouldreload); - submatrix_mad(sub_a, sub_b, sub_c); - } - submatrix_store(sub_c, res, m * tile_size, n * tile_size, res.stride, - matrix_layout::row_major, Cshouldreload); - } - } - return res; -} - -} // namespace experimental::matrix -} // namespace intel -} // namespace ext -} // namespace sycl -} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp index c6c9500682595..04f04fa15700d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp @@ -17,18 +17,13 @@ #include #if (SYCL_EXT_ONEAPI_MATRIX == 1) -#if defined(__AMXTILE__) && defined(__AMXINT8__) && defined(__AMXBF16__) -#include -#endif -#endif -#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include #include #endif -#if (SYCL_EXT_ONEAPI_MATRIX == 3) +#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include #include #endif -#if (SYCL_EXT_ONEAPI_MATRIX == 4) +#if (SYCL_EXT_ONEAPI_MATRIX == 3) #include #endif diff --git a/sycl/test/matrix/matrix-amx-bf16-test.cpp b/sycl/test/matrix/matrix-amx-bf16-test.cpp deleted file mode 100644 index 7c1f89350f9b9..0000000000000 --- a/sycl/test/matrix/matrix-amx-bf16-test.cpp +++ /dev/null @@ -1,192 +0,0 @@ -// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -march=sapphirerapids -fsycl -O2 %s -o %t.out -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 1) -#include - -using namespace sycl; -using namespace sycl::ext::intel; -using namespace sycl::ext::intel::experimental::matrix; - -#define TILE_SZ 16 -#define TM (3 * TILE_SZ - 1) -#define TN (3 * TILE_SZ - 1) -#define TK (9 * TILE_SZ + 2) - -template struct big_matrix { -public: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - -template -void matrix_multiply(big_matrix &C, - big_matrix &A, - big_matrix &B) { - size_t M = NUM_ROWS_C; - size_t N = NUM_COLS_C; - size_t K = NUM_COLS_A; - // B => K/4 x N*4, A => M x K, C => M, N - // stride should be X's cols, e.g., B's stirde = N*4 - assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2); - size_t NDRangeM = M / TM; - size_t NDRangeN = N / TN; - buffer bufA(A.get_data(), range<2>(M, K)); - buffer bufB(B.get_data(), range<2>(K, N)); - buffer bufC((float *)C.get_data(), range<2>(M, N)); - - queue q; - q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); - - cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), - [ accA, accB, accC, M, N, K ](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(1)]] - - { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx; - const auto sg_starty = global_idy; - - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a( - sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix - sub_b(sg); - joint_matrix sub_c(sg); - - // Only the leader perform AMX computation. - if (spmd_item.get_local_id(1) % TILE_SZ) - return; - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { // K->int8_t - joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assume we alreay in vnni format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty * TN * 2, - N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty * TN, - N, matrix_layout::row_major); - }); // parallel for - }).wait(); -} - -static constexpr size_t MATRIX_M = TM * 2; -static constexpr size_t MATRIX_N = TN * 2; -static constexpr size_t MATRIX_K = TK * 2; -unsigned short A[MATRIX_M][MATRIX_K]; -unsigned short B[MATRIX_K / 2][MATRIX_N * 2]; -float C[MATRIX_M][MATRIX_N]; -float D[MATRIX_M][MATRIX_N]; - -float make_fp32(short x) { - unsigned int y = x; - y = y << 16; - float *res = reinterpret_cast(&y); - return *res; -} - -unsigned short make_bf16(float x) { - int *res = reinterpret_cast(&x); - *res = *res >> 16; - return (unsigned short)*res; -} - -void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, - int K) { - // tiling - for (int m = 0; m < M; m++) - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - short *va = (short *)(A_mem + m * K + k); - short *vb = (short *)(B_mem + k * N + n); - float acc = *((float *)(C_mem + m * N + n)); - // FIXME: Should we do reduce-add in another version? - for (int i = 0; i < 2; i++) { - acc += (make_fp32(va[i]) * make_fp32(vb[i])); - } - *((float *)(C_mem + m * N + n)) = acc; - } - } -} - -int main() { - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_K; j++) { - A[i][j] = make_bf16(1.0f * (i + j)); - } - } - for (int i = 0; i < MATRIX_K / 2; i++) { - for (int j = 0; j < MATRIX_N * 2; j++) { - B[i][j] = make_bf16(2.0f * i + 3.0f * j); - } - } - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) { - C[i][j] = 1.0; - D[i][j] = 1.0; - } - } - - big_matrix MC((float *)&C); - big_matrix MD((float *)&D); - big_matrix MA((unsigned short *)&A); - big_matrix MB( - (unsigned short *)&B); - matrix_multiply(MC, MA, MB); - matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, - MATRIX_N, MATRIX_K / 2); - - bool res = true; - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) { - if (C[i][j] != D[i][j]) - res = false; - } - } - if (res) - std::cout << "passed\n"; - else - std::cout << "failed\n"; - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) - std::cout << C[i][j] << ", "; - std::cout << "\n"; - } - std::cout << std::endl; - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) - std::cout << D[i][j] << ", "; - std::cout << "\n"; - } -} -#endif diff --git a/sycl/test/matrix/matrix-amx-int8-test.cpp b/sycl/test/matrix/matrix-amx-int8-test.cpp deleted file mode 100644 index 4d3ec4e0ead3a..0000000000000 --- a/sycl/test/matrix/matrix-amx-int8-test.cpp +++ /dev/null @@ -1,176 +0,0 @@ -// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -march=sapphirerapids -fsycl -O2 %s -o %t.out -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 1) -#include - -using namespace sycl; -using namespace sycl::ext::intel; -using namespace sycl::ext::intel::experimental::matrix; - -#define TILE_SZ 16 -#define TM (4 * TILE_SZ - 4) -#define TN (4 * TILE_SZ - 4) -#define TK (4 * TILE_SZ - 16) - -template struct big_matrix { -public: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - -template -void matrix_multiply(big_matrix &C, - big_matrix &A, - big_matrix &B) { - size_t M = NUM_ROWS_C; - size_t N = NUM_COLS_C; - size_t K = NUM_COLS_A; - // B => K/4 x N*4, A => M x K, C => M, N - // stride should be X's cols, e.g., B's stirde = N*4 - assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); - size_t NDRangeM = M / TM; - size_t NDRangeN = N / TN; - buffer bufA(A.get_data(), range<2>(M, K)); - buffer bufB(B.get_data(), range<2>(K, N)); - buffer bufC(C.get_data(), range<2>(M, N)); - - queue q; - q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); - - cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), - [ accA, accB, accC, M, N, K ](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(1)]] - - { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx; - const auto sg_starty = global_idy; - - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix - sub_b(sg); - joint_matrix sub_c(sg); - - // Only the leader perform AMX computation. - if (spmd_item.get_local_id(1) % TILE_SZ) - return; - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { // K->int8_t - joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::packed_a); - // Assume we alreay in vnni format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty * TN * 4, - N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty * TN, - N, matrix_layout::row_major); - }); // parallel for - }).wait(); -} - -static constexpr size_t MATRIX_M = TM * 2; -static constexpr size_t MATRIX_N = TN * 2; -static constexpr size_t MATRIX_K = TK * 2; -int8_t A[MATRIX_M][MATRIX_K]; -int8_t B[MATRIX_K / 4][MATRIX_N * 4]; -int32_t C[MATRIX_M][MATRIX_N]; -int32_t D[MATRIX_M][MATRIX_N]; - -void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, - int N, int K) { - // tiling - for (int m = 0; m < M; m++) - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - char *va = (char *)(A_mem + m * K + k); - char *vb = (char *)(B_mem + k * N + n); - int acc = *(C_mem + m * N + n); - for (int i = 0; i < 4; i++) { - acc += (va[i] * vb[i]); - } - *(C_mem + m * N + n) = acc; - } - } -} - -int main() { - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_K; j++) { - A[i][j] = i + 2 * j; - } - } - for (int i = 0; i < MATRIX_K / 4; i++) { - for (int j = 0; j < MATRIX_N * 4; j++) { - B[i][j] = i + j; - } - } - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) { - C[i][j] = 1; - D[i][j] = 1; - } - } - - big_matrix MC((int32_t *)&C); - big_matrix MD((int32_t *)&D); - big_matrix MA((int8_t *)&A); - big_matrix MB((int8_t *)&B); - matrix_multiply(MC, MA, MB); - matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, - MATRIX_N, MATRIX_K / 4); - - bool res = true; - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) { - if (C[i][j] != D[i][j]) - res = false; - } - } - if (res) - std::cout << "passed\n"; - else - std::cout << "failed\n"; - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) - std::cout << C[i][j] << ", "; - std::cout << "\n"; - } - std::cout << std::endl; - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) - std::cout << D[i][j] << ", "; - std::cout << "\n"; - } -} -#endif diff --git a/sycl/test/matrix/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/matrix-bf16-test-SG-16.cpp index 0150c420e3bd9..4be3ed90809a2 100644 --- a/sycl/test/matrix/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/matrix-bf16-test-SG-16.cpp @@ -1,6 +1,6 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out +// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 +// XFAIL: * #include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include using namespace sycl; @@ -180,4 +180,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/matrix-bf16-test.cpp b/sycl/test/matrix/matrix-bf16-test.cpp index efa118a33bff7..fb1b89f03adb7 100644 --- a/sycl/test/matrix/matrix-bf16-test.cpp +++ b/sycl/test/matrix/matrix-bf16-test.cpp @@ -1,6 +1,6 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out +// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 +// XFAIL: * #include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include using namespace sycl; @@ -180,4 +180,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 31a26362a6f7f..f4a2273fea40f 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -1,7 +1,7 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out +// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 +// XFAIL: * #include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include using namespace sycl; @@ -174,4 +174,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/matrix-int8-test-SG-16.cpp index 2905a2ea66311..0348a0c3bd330 100644 --- a/sycl/test/matrix/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/matrix-int8-test-SG-16.cpp @@ -1,6 +1,6 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out +// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 +// XFAIL: * #include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include using namespace sycl; @@ -165,4 +165,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index ba4327a8e6cd5..9f4bf3ea1e810 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -1,11 +1,11 @@ -// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s +// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s +// XFAIL: * // CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } // CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } // CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } #include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include using namespace sycl; @@ -167,4 +167,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/query.cpp b/sycl/test/matrix/query.cpp index d75ea156632dd..883b7c3335217 100644 --- a/sycl/test/matrix/query.cpp +++ b/sycl/test/matrix/query.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -fsycl -o query %s +// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -fsycl -o query %s #include #include From 299ed922af9603feceb661103d41add650c835f4 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 1 Apr 2022 21:59:40 +0800 Subject: [PATCH 03/14] change the order of use and layout and address dounia's comments --- sycl/include/CL/__spirv/spirv_ops.hpp | 52 +++--- sycl/include/CL/__spirv/spirv_types.hpp | 2 +- .../sycl/ext/oneapi/matrix/matrix-jit-use.hpp | 132 +++++++------- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 66 ++++--- .../include/sycl/ext/oneapi/matrix/matrix.hpp | 2 + .../ext/oneapi/matrix/static-query-use.hpp | 48 ++--- sycl/test/matrix/matrix-int8-test-use.cpp | 172 ++++++++++++++++++ 7 files changed, 332 insertions(+), 142 deletions(-) create mode 100644 sycl/test/matrix/matrix-int8-test-use.cpp diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 982b8db49a4c2..fc73d9ebd024b 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -23,20 +23,20 @@ #ifdef __SYCL_DEVICE_ONLY__ template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); template extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL( - T *Ptr, __spv::__spirv_JointMatrixINTEL *Object, + T *Ptr, __spv::__spirv_JointMatrixINTEL *Object, std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); @@ -48,11 +48,11 @@ template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixUUMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixUSMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixSUMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_CompositeConstruct(const T v); template extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL( - __spv::__spirv_JointMatrixINTEL *); + __spv::__spirv_JointMatrixINTEL *); template extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic( - __spv::__spirv_JointMatrixINTEL *, size_t i); + __spv::__spirv_JointMatrixINTEL *, size_t i); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * -__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, T val, size_t i); #ifndef __SPIRV_BUILTIN_DECLARATIONS__ diff --git a/sycl/include/CL/__spirv/spirv_types.hpp b/sycl/include/CL/__spirv/spirv_types.hpp index 427012bb064a8..215f2990bf261 100644 --- a/sycl/include/CL/__spirv/spirv_types.hpp +++ b/sycl/include/CL/__spirv/spirv_types.hpp @@ -136,7 +136,7 @@ enum class MatrixUse : uint32_t { // information to SPIRV translator. // The long term solution would be to introduce a matrix type in Clang and use // it instead of this member. -template struct __spirv_JointMatrixINTEL { T(*Value) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp index 3d2ffd47d80eb..828c308e7ec15 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -19,6 +19,8 @@ namespace ext { namespace oneapi { namespace experimental::matrix { +// packed_a and packed_b will be replaced by packed once the use implementation +// is stable. enum class matrix_layout { row_major, col_major, packed_a, packed_b }; template struct spv_matrix_layout_traits { @@ -37,6 +39,8 @@ SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major, SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_a, __spv::MatrixLayout::PackedA) SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_b, __spv::MatrixLayout::PackedB) +// unnecessary was introduced for backward compatibility. +// Once the use implementation is stable, "unnecessary" value will be omitted enum class matrix_use { matrix_a, matrix_b, accumulator, unnecessary }; template struct spv_matrix_use_traits { @@ -62,19 +66,19 @@ template struct spv_scope_traits> { }; template class wi_slice; template struct joint_matrix { public: - __spv::__spirv_JointMatrixINTEL::value, - spv_matrix_use_traits::value> *spvm; + __spv::__spirv_JointMatrixINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value> *spvm; joint_matrix(Group sg) { #ifndef __SYCL_DEVICE_ONLY__ (void)sg; @@ -83,18 +87,19 @@ struct joint_matrix { #endif // __SYCL_DEVICE_ONLY__ } - inline __SYCL_ALWAYS_INLINE wi_slice + inline __SYCL_ALWAYS_INLINE wi_slice get_wi_data() { - return wi_slice(*this); + return wi_slice(*this); } }; template + access::address_space Space> inline __SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, - joint_matrix &res, + joint_matrix &res, multi_ptr src, size_t stride, matrix_layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = src.get(); @@ -104,32 +109,32 @@ joint_matrix_load(Group sg, case matrix_layout::row_major: res.spvm = __spirv_JointMatrixLoadINTEL::value, - spv_matrix_use_traits::value>( + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; case matrix_layout::col_major: res.spvm = __spirv_JointMatrixLoadINTEL::value, - spv_matrix_use_traits::value>( + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; case matrix_layout::packed_a: res.spvm = __spirv_JointMatrixLoadINTEL::value, - spv_matrix_use_traits::value>( + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; case matrix_layout::packed_b: res.spvm = __spirv_JointMatrixLoadINTEL::value, - spv_matrix_use_traits::value>( + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, stride, __spv::MatrixLayout::PackedB, spv_scope_traits::value); break; @@ -146,11 +151,12 @@ joint_matrix_load(Group sg, } template + access::address_space Space> inline __SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, - joint_matrix &src, + joint_matrix &src, multi_ptr res, size_t stride, matrix_layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = res.get(); @@ -158,30 +164,30 @@ joint_matrix_store(Group sg, default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: - __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( + __spirv_JointMatrixStoreINTEL::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; case matrix_layout::col_major: - __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( + __spirv_JointMatrixStoreINTEL::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; case matrix_layout::packed_a: - __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( + __spirv_JointMatrixStoreINTEL::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; case matrix_layout::packed_b: - __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( + __spirv_JointMatrixStoreINTEL::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB, spv_scope_traits::value); break; @@ -198,15 +204,14 @@ joint_matrix_store(Group sg, } template -inline __SYCL_ALWAYS_INLINE joint_matrix -joint_matrix_mad(Group sg, joint_matrix &mA, - joint_matrix &mB, - joint_matrix &mC) { + size_t K, size_t N, matrix_use UseA, matrix_use UseB, matrix_use UseC, + matrix_layout LayoutA, matrix_layout LayoutB, matrix_layout LayoutC> +inline __SYCL_ALWAYS_INLINE joint_matrix +joint_matrix_mad(Group sg, joint_matrix &mA, + joint_matrix &mB, + joint_matrix &mC) { #ifdef __SYCL_DEVICE_ONLY__ - joint_matrix res(sg); + joint_matrix res(sg); if constexpr (std::is_same::value && std::is_same::value && std::is_same::value) @@ -234,16 +239,17 @@ template inline __SYCL_ALWAYS_INLINE void joint_matrix_fill(Group sg, - joint_matrix &res, + joint_matrix &res, const T2 v) { // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ // functions (void)sg; #ifdef __SYCL_DEVICE_ONLY__ - res.spvm = __spirv_CompositeConstruct::value, - spv_matrix_use_traits::value>( - static_cast(v)); + res.spvm = + __spirv_CompositeConstruct::value, + spv_matrix_layout_traits::value>( + static_cast(v)); #else (void)res; @@ -252,15 +258,15 @@ joint_matrix_fill(Group sg, } template class wi_element { - joint_matrix &M; + joint_matrix &M; std::size_t idx; public: - wi_element(joint_matrix &Mat, + wi_element(joint_matrix &Mat, std::size_t i) : M(Mat), idx(i) {} operator T() { @@ -293,7 +299,7 @@ class wi_element { } wi_element & - operator=(const wi_element &rhs) { + operator=(const wi_element &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); @@ -336,14 +342,14 @@ class wi_element { // the introduction of SYCL experimental bfloat16 type. Our plan is to move // towards using the SYCL bfloat16. But since it is still experimental, we will // probably keep both uint16 interpretation and SYCL bfloat16. -template -class wi_element { - joint_matrix &M; +class wi_element { + joint_matrix &M; std::size_t idx; public: - wi_element(joint_matrix &Mat, + wi_element(joint_matrix &Mat, std::size_t i) : M(Mat), idx(i) {} operator uint16_t() { @@ -377,7 +383,7 @@ class wi_element { } wi_element &operator=( - const wi_element &rhs) { + const wi_element &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); @@ -440,21 +446,21 @@ class wi_element { #if __SYCL_DEVICE_ONLY__ #define OP(input_type, type, op) \ friend type operator op( \ - const wi_element &lhs, \ + const wi_element &lhs, \ const uint16_t &rhs) { \ return Converter::convert(make_fp32( \ __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) op make_fp32(rhs)); \ } \ friend type operator op( \ const uint16_t &lhs, \ - const wi_element &rhs) { \ + const wi_element &rhs) { \ return Converter::convert(make_fp32( \ __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \ } #else // __SYCL_DEVICE_ONLY__ #define OP(input_type, type, op) \ friend type operator op( \ - const wi_element &lhs, \ + const wi_element &lhs, \ const uint16_t &rhs) { \ (void)lhs; \ (void)rhs; \ @@ -463,7 +469,7 @@ class wi_element { } \ friend type operator op( \ const uint16_t &lhs, \ - const wi_element &rhs) { \ + const wi_element &rhs) { \ (void)lhs; \ (void)rhs; \ throw runtime_error("joint matrix is not supported on host device.", \ @@ -483,13 +489,13 @@ class wi_element { #undef OP }; -template +template class wi_slice { - joint_matrix &M; + joint_matrix &M; public: - wi_slice(joint_matrix &Mat) + wi_slice(joint_matrix &Mat) : M(Mat) {} size_t length() { #ifdef __SYCL_DEVICE_ONLY__ @@ -499,8 +505,8 @@ class wi_slice { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } - wi_element operator[](size_t i) { - return wi_element(M, i); + wi_element operator[](size_t i) { + return wi_element(M, i); } }; diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 3f8bba16aba08..ce568abf2031a 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -70,8 +70,9 @@ template ::value, - spv_matrix_use_traits::value> *spvm; + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value> *spvm; joint_matrix(Group sg) { #ifndef __SYCL_DEVICE_ONLY__ (void)sg; @@ -100,29 +101,33 @@ joint_matrix_load(Group sg, assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: res.spvm = __spirv_JointMatrixLoadINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; case matrix_layout::col_major: res.spvm = __spirv_JointMatrixLoadINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; case matrix_layout::packed_a: res.spvm = __spirv_JointMatrixLoadINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; case matrix_layout::packed_b: res.spvm = __spirv_JointMatrixLoadINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, stride, __spv::MatrixLayout::PackedB, spv_scope_traits::value); break; @@ -152,31 +157,35 @@ joint_matrix_store(Group sg, assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( - Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, - spv_scope_traits::value); + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(Ptr, src.spvm, stride, + __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); break; case matrix_layout::col_major: __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( - Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, - spv_scope_traits::value); + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(Ptr, src.spvm, stride, + __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); break; case matrix_layout::packed_a: __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( - Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, - spv_scope_traits::value); + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(Ptr, src.spvm, stride, + __spv::MatrixLayout::PackedA, + spv_scope_traits::value); break; case matrix_layout::packed_b: __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>( - Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB, - spv_scope_traits::value); + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(Ptr, src.spvm, stride, + __spv::MatrixLayout::PackedB, + spv_scope_traits::value); break; } #else @@ -233,8 +242,9 @@ joint_matrix_fill(Group sg, (void)sg; #ifdef __SYCL_DEVICE_ONLY__ res.spvm = __spirv_CompositeConstruct< - T, NumRows, NumCols, spv_matrix_layout_traits::value, - spv_matrix_use_traits::value>(static_cast(v)); + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(static_cast(v)); #else (void)res; diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp index 04f04fa15700d..78e2d5252f256 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp @@ -16,6 +16,8 @@ #include +// the default is matrix-jit-use but existing tests in llvm-test-suite won't +// fail because we have the "unnecessary" use value #if (SYCL_EXT_ONEAPI_MATRIX == 1) #include #include diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp index 490f34c20cf9a..0889f0742b9be 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -152,16 +152,16 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // should be true in future implementations because // AMX hardware supports dynamic sizes @@ -209,16 +209,16 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // should be true in future implementations // because AMX hardware supports dynamic sizes @@ -352,16 +352,16 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS @@ -412,16 +412,16 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS diff --git a/sycl/test/matrix/matrix-int8-test-use.cpp b/sycl/test/matrix/matrix-int8-test-use.cpp new file mode 100644 index 0000000000000..35c6aec48bd26 --- /dev/null +++ b/sycl/test/matrix/matrix-int8-test-use.cpp @@ -0,0 +1,172 @@ +// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s +// XFAIL: * + +// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } +// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } +// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define TILE_SZ 16 +#define TM (TILE_SZ - 4) +#define TN (TILE_SZ - 4) +#define TK (4 * TILE_SZ - 16) + +#define SG_SZ 16 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_fill(sg, sub_c, 0); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 4) * (N * 4) + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +int8_t A[MATRIX_M][MATRIX_K]; +int8_t B[MATRIX_K / 4][MATRIX_N * 4]; +int32_t C[MATRIX_M][MATRIX_N]; +int32_t D[MATRIX_M][MATRIX_N]; + +void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, + int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + char *va = (char *)(A_mem + m * K + k); + char *vb = (char *)(B_mem + k * N + n); + int acc = *(C_mem + m * N + n); + for (int i = 0; i < 4; i++) { + acc += (va[i] * vb[i]); + } + *(C_mem + m * N + n) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i + 2 * j; + } + } + for (int i = 0; i < MATRIX_K / 4; i++) { + for (int j = 0; j < MATRIX_N * 4; j++) { + B[i][j] = i + j; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 0; + D[i][j] = 0; + } + } + + big_matrix MC((int32_t *)&C); + big_matrix MD((int32_t *)&D); + big_matrix MA((int8_t *)&A); + big_matrix MB((int8_t *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 4); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << C[i][j] << ", "; + std::cout << "\n"; + } + std::cout << std::endl; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << D[i][j] << ", "; + std::cout << "\n"; + } +} From ad6adf26d1cc11f5c7b84b080e16a19da9c649f1 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 5 Apr 2022 18:19:54 +0800 Subject: [PATCH 04/14] small fix --- sycl/include/CL/sycl/feature_test.hpp.in | 2 +- sycl/test/matrix/matrix-int8-test-use.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/CL/sycl/feature_test.hpp.in b/sycl/include/CL/sycl/feature_test.hpp.in index c74ebbb914ef2..89596995fcef1 100644 --- a/sycl/include/CL/sycl/feature_test.hpp.in +++ b/sycl/include/CL/sycl/feature_test.hpp.in @@ -38,7 +38,7 @@ namespace sycl { // 2- provides JIT implementation (target agnostic) for the // experimental matrix extension #ifndef SYCL_EXT_ONEAPI_MATRIX -#define SYCL_EXT_ONEAPI_MATRIX 3 +#define SYCL_EXT_ONEAPI_MATRIX 2 #endif #define SYCL_EXT_ONEAPI_ASSERT 1 #define SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS 1 diff --git a/sycl/test/matrix/matrix-int8-test-use.cpp b/sycl/test/matrix/matrix-int8-test-use.cpp index 35c6aec48bd26..42d9a4f771202 100644 --- a/sycl/test/matrix/matrix-int8-test-use.cpp +++ b/sycl/test/matrix/matrix-int8-test-use.cpp @@ -54,7 +54,7 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { // The submatrix API has to be accessed by all the workitems in a From 253e018294cd65c7def20856a578626f7a8b86ff Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 5 Apr 2022 18:58:12 +0800 Subject: [PATCH 05/14] fix the lint's issue --- sycl/test/matrix/matrix-int8-test-use.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sycl/test/matrix/matrix-int8-test-use.cpp b/sycl/test/matrix/matrix-int8-test-use.cpp index 42d9a4f771202..53537a30bc18a 100644 --- a/sycl/test/matrix/matrix-int8-test-use.cpp +++ b/sycl/test/matrix/matrix-int8-test-use.cpp @@ -54,7 +54,8 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { // The submatrix API has to be accessed by all the workitems in a From 42feae69af65ff0cedd9c230d78b3dd9b0e911ba Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Wed, 6 Apr 2022 01:29:11 +0800 Subject: [PATCH 06/14] small fix again --- sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp index 0889f0742b9be..ef48e480f58e7 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -1,4 +1,4 @@ -//===-------------- static-query.hpp - SYCL matrix ------------*- C++ -*---===// +//===---------- static-query-use.hpp - SYCL matrix ------------*- C++ -*---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From 29ff3fbaa8def5f68ae71f19f2b944dc8396da83 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Wed, 17 Aug 2022 16:19:40 +0800 Subject: [PATCH 07/14] remove matrix layout && make "use" non-optional && remove use template arguments from mad function --- sycl/include/CL/__spirv/spirv_types.hpp | 2 + .../sycl/ext/oneapi/matrix/matrix-jit-use.hpp | 168 ++++++++---------- .../ext/oneapi/matrix/static-query-use.hpp | 36 ++-- 3 files changed, 88 insertions(+), 118 deletions(-) diff --git a/sycl/include/CL/__spirv/spirv_types.hpp b/sycl/include/CL/__spirv/spirv_types.hpp index 76e40d47c5cd0..77fab5e92636d 100644 --- a/sycl/include/CL/__spirv/spirv_types.hpp +++ b/sycl/include/CL/__spirv/spirv_types.hpp @@ -136,6 +136,8 @@ enum class MatrixUse : uint32_t { // information to SPIRV translator. // The long term solution would be to introduce a matrix type in Clang and use // it instead of this member. +// We keep the Layout here since backend hasn't removed Layout support totally. +// Once backend removed Layout, we will remove the Layout here. template struct __spirv_JointMatrixINTEL { diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp index bc0b583198a29..a49bf8c5eb029 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -65,20 +65,16 @@ template struct spv_scope_traits> { constexpr static auto value = __spv::Scope::Workgroup; }; -template class wi_slice; -template struct joint_matrix { public: __spv::__spirv_JointMatrixINTEL< T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value> *spvm; + spv_matrix_layout_traits::value> *spvm; joint_matrix(Group sg) { #ifndef __SYCL_DEVICE_ONLY__ (void)sg; @@ -87,19 +83,16 @@ struct joint_matrix { #endif // __SYCL_DEVICE_ONLY__ } - inline __SYCL_ALWAYS_INLINE wi_slice + inline __SYCL_ALWAYS_INLINE wi_slice get_wi_data() { - return wi_slice(*this); + return wi_slice(*this); } }; template + matrix_use Use, access::address_space Space> inline __SYCL_ALWAYS_INLINE void -joint_matrix_load(Group sg, - joint_matrix &res, +joint_matrix_load(Group sg, joint_matrix &res, multi_ptr src, size_t stride, matrix_layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = src.get(); @@ -107,36 +100,32 @@ joint_matrix_load(Group sg, default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: - res.spvm = - __spirv_JointMatrixLoadINTEL::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::RowMajor, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); break; case matrix_layout::col_major: - res.spvm = - __spirv_JointMatrixLoadINTEL::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::ColumnMajor, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); break; case matrix_layout::packed_a: - res.spvm = - __spirv_JointMatrixLoadINTEL::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::PackedA, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedA, + spv_scope_traits::value); break; case matrix_layout::packed_b: - res.spvm = - __spirv_JointMatrixLoadINTEL::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::PackedB, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedB, + spv_scope_traits::value); break; } #else @@ -151,12 +140,9 @@ joint_matrix_load(Group sg, } template + matrix_use Use, access::address_space Space> inline __SYCL_ALWAYS_INLINE void -joint_matrix_store(Group sg, - joint_matrix &src, +joint_matrix_store(Group sg, joint_matrix &src, multi_ptr res, size_t stride, matrix_layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = res.get(); @@ -164,30 +150,30 @@ joint_matrix_store(Group sg, default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: - __spirv_JointMatrixStoreINTEL::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; case matrix_layout::col_major: - __spirv_JointMatrixStoreINTEL::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; case matrix_layout::packed_a: - __spirv_JointMatrixStoreINTEL::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; case matrix_layout::packed_b: - __spirv_JointMatrixStoreINTEL::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB, spv_scope_traits::value); break; @@ -204,14 +190,15 @@ joint_matrix_store(Group sg, } template -inline __SYCL_ALWAYS_INLINE joint_matrix -joint_matrix_mad(Group sg, joint_matrix &mA, - joint_matrix &mB, - joint_matrix &mC) { + size_t K, size_t N> +inline __SYCL_ALWAYS_INLINE + joint_matrix + joint_matrix_mad( + Group sg, joint_matrix &mA, + joint_matrix &mB, + joint_matrix &mC) { #ifdef __SYCL_DEVICE_ONLY__ - joint_matrix res(sg); + joint_matrix res(sg); if constexpr (std::is_same::value && std::is_same::value && std::is_same::value) @@ -236,20 +223,18 @@ joint_matrix_mad(Group sg, joint_matrix &mA, } template + matrix_use Use, typename T2> inline __SYCL_ALWAYS_INLINE void -joint_matrix_fill(Group sg, - joint_matrix &res, +joint_matrix_fill(Group sg, joint_matrix &res, const T2 v) { // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ // functions (void)sg; #ifdef __SYCL_DEVICE_ONLY__ - res.spvm = - __spirv_CompositeConstruct::value, - spv_matrix_layout_traits::value>( - static_cast(v)); + res.spvm = __spirv_CompositeConstruct< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + static_cast(v)); #else (void)res; @@ -257,17 +242,14 @@ joint_matrix_fill(Group sg, #endif // __SYCL_DEVICE_ONLY__ } -template class wi_element { - joint_matrix &M; + joint_matrix &M; std::size_t idx; public: - wi_element(joint_matrix &Mat, - std::size_t i) + wi_element(joint_matrix &Mat, std::size_t i) : M(Mat), idx(i) {} operator T() { #ifdef __SYCL_DEVICE_ONLY__ @@ -299,7 +281,7 @@ class wi_element { } wi_element & - operator=(const wi_element &rhs) { + operator=(const wi_element &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); @@ -342,14 +324,13 @@ class wi_element { // the introduction of SYCL experimental bfloat16 type. Our plan is to move // towards using the SYCL bfloat16. But since it is still experimental, we will // probably keep both uint16 interpretation and SYCL bfloat16. -template -class wi_element { - joint_matrix &M; +template +class wi_element { + joint_matrix &M; std::size_t idx; public: - wi_element(joint_matrix &Mat, + wi_element(joint_matrix &Mat, std::size_t i) : M(Mat), idx(i) {} operator uint16_t() { @@ -382,8 +363,8 @@ class wi_element { #endif // __SYCL_DEVICE_ONLY__ } - wi_element &operator=( - const wi_element &rhs) { + wi_element & + operator=(const wi_element &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); @@ -446,21 +427,21 @@ class wi_element { #if __SYCL_DEVICE_ONLY__ #define OP(input_type, type, op) \ friend type operator op( \ - const wi_element &lhs, \ + const wi_element &lhs, \ const uint16_t &rhs) { \ return Converter::convert(make_fp32( \ __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) op make_fp32(rhs)); \ } \ friend type operator op( \ const uint16_t &lhs, \ - const wi_element &rhs) { \ + const wi_element &rhs) { \ return Converter::convert(make_fp32( \ __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \ } #else // __SYCL_DEVICE_ONLY__ #define OP(input_type, type, op) \ friend type operator op( \ - const wi_element &lhs, \ + const wi_element &lhs, \ const uint16_t &rhs) { \ (void)lhs; \ (void)rhs; \ @@ -469,7 +450,7 @@ class wi_element { } \ friend type operator op( \ const uint16_t &lhs, \ - const wi_element &rhs) { \ + const wi_element &rhs) { \ (void)lhs; \ (void)rhs; \ throw runtime_error("joint matrix is not supported on host device.", \ @@ -490,13 +471,12 @@ class wi_element { }; template + typename Group> class wi_slice { - joint_matrix &M; + joint_matrix &M; public: - wi_slice(joint_matrix &Mat) - : M(Mat) {} + wi_slice(joint_matrix &Mat) : M(Mat) {} size_t length() { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm); @@ -505,8 +485,8 @@ class wi_slice { PI_ERROR_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } - wi_element operator[](size_t i) { - return wi_element(M, i); + wi_element operator[](size_t i) { + return wi_element(M, i); } }; diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp index 7867797f92cac..14655c3875552 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -152,16 +152,13 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // should be true in future implementations because // AMX hardware supports dynamic sizes @@ -209,16 +206,13 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // should be true in future implementations // because AMX hardware supports dynamic sizes @@ -352,16 +346,13 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS @@ -412,16 +403,13 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS From ff083c90022e099ad6298732b7b6e8a1a5a2ee14 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 23 Aug 2022 13:55:45 +0800 Subject: [PATCH 08/14] add matrix-use support for bfloat16 --- .../sycl/ext/oneapi/matrix/matrix-jit-use.hpp | 168 +++++++++++++++- sycl/test/matrix/matrix-bfloat16-test-use.cpp | 189 ++++++++++++++++++ 2 files changed, 350 insertions(+), 7 deletions(-) create mode 100644 sycl/test/matrix/matrix-bfloat16-test-use.cpp diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp index a49bf8c5eb029..4f9731289c7d5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -11,13 +11,15 @@ #include #include +#include #include namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { namespace ext { namespace oneapi { -namespace experimental::matrix { +namespace experimental { +namespace matrix { // packed_a and packed_b will be replaced by packed once the use implementation // is stable. @@ -67,7 +69,7 @@ template struct spv_scope_traits> { template -class wi_slice; +class wi_data; template struct joint_matrix { @@ -83,9 +85,9 @@ struct joint_matrix { #endif // __SYCL_DEVICE_ONLY__ } - inline __SYCL_ALWAYS_INLINE wi_slice + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { - return wi_slice(*this); + return wi_data(*this); } }; @@ -470,13 +472,164 @@ class wi_element { #undef OP }; +template +class wi_element { + joint_matrix &M; + std::size_t idx; + +public: + wi_element(joint_matrix &Mat, + std::size_t i) + : M(Mat), idx(i) {} + operator sycl::ext::oneapi::experimental::bfloat16() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return std::fabs(static_cast(__spirv_VectorExtractDynamic( + M.spvm, idx))) >= std::numeric_limits::epsilon(); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator=(const sycl::ext::oneapi::experimental::bfloat16 &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element & + operator=(const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + +#if __SYCL_DEVICE_ONLY__ +#define OP(opassign, op) \ + wi_element &operator opassign( \ + const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ + M.spvm = __spirv_VectorInsertDynamic( \ + M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \ + return *this; \ + } +#else // __SYCL_DEVICE_ONLY__ +#define OP(opassign, op) \ + wi_element &operator opassign( \ + const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } +#endif // __SYCL_DEVICE_ONLY__ + OP(+=, +) + OP(-=, -) + OP(*=, *) + OP(/=, /) +#undef OP + +#if __SYCL_DEVICE_ONLY__ +#define OP(type, op) \ + friend type operator op( \ + const wi_element &lhs, \ + const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \ + } \ + friend type operator op( \ + const sycl::ext::oneapi::experimental::bfloat16 &lhs, \ + const wi_element &rhs) { \ + return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \ + } + OP(sycl::ext::oneapi::experimental::bfloat16, +) + OP(sycl::ext::oneapi::experimental::bfloat16, -) + OP(sycl::ext::oneapi::experimental::bfloat16, *) + OP(sycl::ext::oneapi::experimental::bfloat16, /) +#undef OP +#define OP(type, op) \ + friend type operator op( \ + const wi_element &lhs, \ + const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ + return type{static_cast(__spirv_VectorExtractDynamic( \ + lhs.M.spvm, lhs.idx)) op static_cast(rhs)}; \ + } \ + friend type operator op( \ + const sycl::ext::oneapi::experimental::bfloat16 &lhs, \ + const wi_element &rhs) { \ + return type{static_cast(__spirv_VectorExtractDynamic( \ + rhs.M.spvm, rhs.idx)) op static_cast(lhs)}; \ + } + OP(bool, ==) + OP(bool, !=) + OP(bool, <) + OP(bool, >) + OP(bool, <=) + OP(bool, >=) +#undef OP +#else // __SYCL_DEVICE_ONLY__ +#define OP(type, op) \ + friend type operator op( \ + const wi_element &, \ + const sycl::ext::oneapi::experimental::bfloat16 &) { \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } \ + friend type operator op( \ + const sycl::ext::oneapi::experimental::bfloat16 &, \ + const wi_element &) { \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } + OP(sycl::ext::oneapi::experimental::bfloat16, +) + OP(sycl::ext::oneapi::experimental::bfloat16, -) + OP(sycl::ext::oneapi::experimental::bfloat16, *) + OP(sycl::ext::oneapi::experimental::bfloat16, /) + OP(bool, ==) + OP(bool, !=) + OP(bool, <) + OP(bool, >) + OP(bool, <=) + OP(bool, >=) +#undef OP +#endif // __SYCL_DEVICE_ONLY__ +}; + template -class wi_slice { +class wi_data { joint_matrix &M; public: - wi_slice(joint_matrix &Mat) : M(Mat) {} + wi_data(joint_matrix &Mat) : M(Mat) {} size_t length() { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm); @@ -490,7 +643,8 @@ class wi_slice { } }; -} // namespace experimental::matrix +} // namespace matrix +} // namespace experimental } // namespace oneapi } // namespace ext } // __SYCL_INLINE_VER_NAMESPACE(_V1) diff --git a/sycl/test/matrix/matrix-bfloat16-test-use.cpp b/sycl/test/matrix/matrix-bfloat16-test-use.cpp new file mode 100644 index 0000000000000..65a864c29e996 --- /dev/null +++ b/sycl/test/matrix/matrix-bfloat16-test-use.cpp @@ -0,0 +1,189 @@ +// RUN: %clangxx -fsycl -O2 %s -o %t.out +#include +#include + +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; + +static constexpr auto TILE_SZ = 16; +static constexpr auto TM = TILE_SZ - 1; +static constexpr auto TN = TILE_SZ - 1; +static constexpr auto TK = 2 * TILE_SZ - 2; + +static constexpr auto SG_SZ = 16; + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + sycl::buffer bufA(A.get_data(), sycl::range<2>(M, K)); + sycl::buffer bufB(B.get_data(), sycl::range<2>(K, N)); + sycl::buffer bufC((float *)C.get_data(), sycl::range<2>(M, N)); + + sycl::queue q; + q.submit([&](sycl::handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](sycl::nd_item<2> spmd_item) + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sycl::ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { // + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 2) * (N * 2) + + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +bfloat16 A[MATRIX_M][MATRIX_K]; +bfloat16 B[MATRIX_K / 2][MATRIX_N * 2]; +unsigned short Aref[MATRIX_M][MATRIX_K]; +unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2]; +float C[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +float make_fp32(short x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; +} + +unsigned short make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (unsigned short)*res; +} + +void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, + int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + short *va = (short *)(A_mem + m * K + k); + short *vb = (short *)(B_mem + k * N + n); + float acc = *((float *)(C_mem + m * N + n)); + // FIXME: Should we do reduce-add in another version? + for (int i = 0; i < 2; i++) { + acc += (make_fp32(va[i]) * make_fp32(vb[i])); + } + *((float *)(C_mem + m * N + n)) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + // Ee create bfloat16 from unsigned short since float-to-bfloat's + // conversion is not allowed. + A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j))); + Aref[i][j] = make_bf16(1.0f * (i + j)); + } + } + for (int i = 0; i < MATRIX_K / 2; i++) { + for (int j = 0; j < MATRIX_N * 2; j++) { + B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j))); + Bref[i][j] = make_bf16(2.0f * i + 3.0f * j); + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1.0; + D[i][j] = 1.0; + } + } + + big_matrix MC((float *)&C); + big_matrix MD((float *)&D); + big_matrix MA((bfloat16 *)&A); + big_matrix MB((bfloat16 *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 2); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << C[i][j] << ", "; + std::cout << "\n"; + } + std::cout << std::endl; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << D[i][j] << ", "; + std::cout << "\n"; + } +} From 2b1fbc9cc3d390476b348ce33eefb9f52427ccb8 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 26 Aug 2022 16:04:43 +0800 Subject: [PATCH 09/14] Keep layout on the matrix type --- .../lib/SPIRV/libSPIRV/spirv_internal.hpp | 3 +- sycl/include/CL/__spirv/spirv_types.hpp | 5 +- .../sycl/ext/oneapi/matrix/matrix-jit-use.hpp | 184 ++++++++++-------- .../ext/oneapi/matrix/static-query-use.hpp | 36 ++-- sycl/test/matrix/matrix-int8-test-use.cpp | 6 +- 5 files changed, 132 insertions(+), 102 deletions(-) diff --git a/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp b/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp index 3f5a1f0850b83..9fb7de7857fbf 100644 --- a/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -97,7 +97,8 @@ enum InternalJointMatrixLayout { RowMajor = 0, ColumnMajor = 1, PackedA = 2, - PackedB = 3 + PackedB = 3, + Unused = 4 }; enum InternalJointMatrixUse { MatrixA = 0, MatrixB = 1, Accumulator = 2 }; diff --git a/sycl/include/CL/__spirv/spirv_types.hpp b/sycl/include/CL/__spirv/spirv_types.hpp index 77fab5e92636d..c210640cb9202 100644 --- a/sycl/include/CL/__spirv/spirv_types.hpp +++ b/sycl/include/CL/__spirv/spirv_types.hpp @@ -112,7 +112,8 @@ enum class MatrixLayout : uint32_t { RowMajor = 0, ColumnMajor = 1, PackedA = 2, - PackedB = 3 + PackedB = 3, + Unused = 4 }; enum class MatrixUse : uint32_t { @@ -136,8 +137,6 @@ enum class MatrixUse : uint32_t { // information to SPIRV translator. // The long term solution would be to introduce a matrix type in Clang and use // it instead of this member. -// We keep the Layout here since backend hasn't removed Layout support totally. -// Once backend removed Layout, we will remove the Layout here. template struct __spirv_JointMatrixINTEL { diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp index 4f9731289c7d5..370f027de7695 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -11,7 +11,6 @@ #include #include -#include #include namespace sycl { @@ -23,10 +22,10 @@ namespace matrix { // packed_a and packed_b will be replaced by packed once the use implementation // is stable. -enum class matrix_layout { row_major, col_major, packed_a, packed_b }; +enum class matrix_layout { row_major, col_major, packed_a, packed_b, unused }; template struct spv_matrix_layout_traits { - static constexpr __spv::MatrixLayout value = __spv::MatrixLayout::RowMajor; + static constexpr __spv::MatrixLayout value = __spv::MatrixLayout::Unused; }; #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \ @@ -40,6 +39,7 @@ SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major, __spv::MatrixLayout::ColumnMajor) SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_a, __spv::MatrixLayout::PackedA) SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_b, __spv::MatrixLayout::PackedB) +SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::unused, __spv::MatrixLayout::Unused) // unnecessary was introduced for backward compatibility. // Once the use implementation is stable, "unnecessary" value will be omitted @@ -68,15 +68,17 @@ template struct spv_scope_traits> { }; template class wi_data; template struct joint_matrix { public: __spv::__spirv_JointMatrixINTEL< T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value> *spvm; + spv_matrix_layout_traits::value> *spvm; joint_matrix(Group sg) { #ifndef __SYCL_DEVICE_ONLY__ (void)sg; @@ -85,16 +87,18 @@ struct joint_matrix { #endif // __SYCL_DEVICE_ONLY__ } - inline __SYCL_ALWAYS_INLINE wi_data + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { - return wi_data(*this); + return wi_data(*this); } }; template + matrix_use Use, matrix_layout Layout = matrix_layout::unused, + access::address_space Space> inline __SYCL_ALWAYS_INLINE void -joint_matrix_load(Group sg, joint_matrix &res, +joint_matrix_load(Group sg, + joint_matrix &res, multi_ptr src, size_t stride, matrix_layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = src.get(); @@ -102,32 +106,36 @@ joint_matrix_load(Group sg, joint_matrix &res, default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: - res.spvm = __spirv_JointMatrixLoadINTEL< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::RowMajor, - spv_scope_traits::value); + res.spvm = + __spirv_JointMatrixLoadINTEL::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); break; case matrix_layout::col_major: - res.spvm = __spirv_JointMatrixLoadINTEL< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::ColumnMajor, - spv_scope_traits::value); + res.spvm = + __spirv_JointMatrixLoadINTEL::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); break; case matrix_layout::packed_a: - res.spvm = __spirv_JointMatrixLoadINTEL< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::PackedA, - spv_scope_traits::value); + res.spvm = + __spirv_JointMatrixLoadINTEL::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedA, + spv_scope_traits::value); break; case matrix_layout::packed_b: - res.spvm = __spirv_JointMatrixLoadINTEL< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::PackedB, - spv_scope_traits::value); + res.spvm = + __spirv_JointMatrixLoadINTEL::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedB, + spv_scope_traits::value); break; } #else @@ -142,9 +150,11 @@ joint_matrix_load(Group sg, joint_matrix &res, } template + matrix_use Use, matrix_layout MatL = matrix_layout::unused, + access::address_space Space> inline __SYCL_ALWAYS_INLINE void -joint_matrix_store(Group sg, joint_matrix &src, +joint_matrix_store(Group sg, + joint_matrix &src, multi_ptr res, size_t stride, matrix_layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = res.get(); @@ -152,30 +162,30 @@ joint_matrix_store(Group sg, joint_matrix &src, default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: - __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; case matrix_layout::col_major: - __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; case matrix_layout::packed_a: - __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; case matrix_layout::packed_b: - __spirv_JointMatrixStoreINTEL< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB, spv_scope_traits::value); break; @@ -192,15 +202,17 @@ joint_matrix_store(Group sg, joint_matrix &src, } template + size_t K, size_t N, matrix_layout LayoutA, matrix_layout LayoutB, + matrix_layout LayoutC> inline __SYCL_ALWAYS_INLINE - joint_matrix + joint_matrix joint_matrix_mad( - Group sg, joint_matrix &mA, - joint_matrix &mB, - joint_matrix &mC) { + Group sg, + joint_matrix &mA, + joint_matrix &mB, + joint_matrix &mC) { #ifdef __SYCL_DEVICE_ONLY__ - joint_matrix res(sg); + joint_matrix res(sg); if constexpr (std::is_same::value && std::is_same::value && std::is_same::value) @@ -225,18 +237,20 @@ inline __SYCL_ALWAYS_INLINE } template + matrix_use Use, matrix_layout Layout, typename T2> inline __SYCL_ALWAYS_INLINE void -joint_matrix_fill(Group sg, joint_matrix &res, +joint_matrix_fill(Group sg, + joint_matrix &res, const T2 v) { // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ // functions (void)sg; #ifdef __SYCL_DEVICE_ONLY__ - res.spvm = __spirv_CompositeConstruct< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value>( - static_cast(v)); + res.spvm = + __spirv_CompositeConstruct::value, + spv_matrix_layout_traits::value>( + static_cast(v)); #else (void)res; @@ -245,13 +259,15 @@ joint_matrix_fill(Group sg, joint_matrix &res, } template class wi_element { - joint_matrix &M; + joint_matrix &M; std::size_t idx; public: - wi_element(joint_matrix &Mat, std::size_t i) + wi_element(joint_matrix &Mat, + std::size_t i) : M(Mat), idx(i) {} operator T() { #ifdef __SYCL_DEVICE_ONLY__ @@ -283,7 +299,7 @@ class wi_element { } wi_element & - operator=(const wi_element &rhs) { + operator=(const wi_element &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); @@ -326,13 +342,14 @@ class wi_element { // the introduction of SYCL experimental bfloat16 type. Our plan is to move // towards using the SYCL bfloat16. But since it is still experimental, we will // probably keep both uint16 interpretation and SYCL bfloat16. -template -class wi_element { - joint_matrix &M; +template +class wi_element { + joint_matrix &M; std::size_t idx; public: - wi_element(joint_matrix &Mat, + wi_element(joint_matrix &Mat, std::size_t i) : M(Mat), idx(i) {} operator uint16_t() { @@ -365,8 +382,8 @@ class wi_element { #endif // __SYCL_DEVICE_ONLY__ } - wi_element & - operator=(const wi_element &rhs) { + wi_element &operator=( + const wi_element &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); @@ -429,21 +446,21 @@ class wi_element { #if __SYCL_DEVICE_ONLY__ #define OP(input_type, type, op) \ friend type operator op( \ - const wi_element &lhs, \ + const wi_element &lhs, \ const uint16_t &rhs) { \ return Converter::convert(make_fp32( \ __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) op make_fp32(rhs)); \ } \ friend type operator op( \ const uint16_t &lhs, \ - const wi_element &rhs) { \ + const wi_element &rhs) { \ return Converter::convert(make_fp32( \ __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \ } #else // __SYCL_DEVICE_ONLY__ #define OP(input_type, type, op) \ friend type operator op( \ - const wi_element &lhs, \ + const wi_element &lhs, \ const uint16_t &rhs) { \ (void)lhs; \ (void)rhs; \ @@ -452,7 +469,7 @@ class wi_element { } \ friend type operator op( \ const uint16_t &lhs, \ - const wi_element &rhs) { \ + const wi_element &rhs) { \ (void)lhs; \ (void)rhs; \ throw runtime_error("joint matrix is not supported on host device.", \ @@ -472,17 +489,17 @@ class wi_element { #undef OP }; -template class wi_element { + Use, Layout, Group> { joint_matrix &M; + Layout, Group> &M; std::size_t idx; public: wi_element(joint_matrix &Mat, + NumCols, Use, Layout, Group> &Mat, std::size_t i) : M(Mat), idx(i) {} operator sycl::ext::oneapi::experimental::bfloat16() { @@ -517,7 +534,7 @@ class wi_element &rhs) { + NumCols, Use, Layout, Group> &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); @@ -556,14 +573,14 @@ class wi_element &lhs, \ + NumCols, Use, Layout, Group> &lhs, \ const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \ } \ friend type operator op( \ const sycl::ext::oneapi::experimental::bfloat16 &lhs, \ const wi_element &rhs) { \ + NumCols, Use, Layout, Group> &rhs) { \ return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \ } OP(sycl::ext::oneapi::experimental::bfloat16, +) @@ -574,7 +591,7 @@ class wi_element &lhs, \ + NumCols, Use, Layout, Group> &lhs, \ const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ return type{static_cast(__spirv_VectorExtractDynamic( \ lhs.M.spvm, lhs.idx)) op static_cast(rhs)}; \ @@ -582,7 +599,7 @@ class wi_element &rhs) { \ + NumCols, Use, Layout, Group> &rhs) { \ return type{static_cast(__spirv_VectorExtractDynamic( \ rhs.M.spvm, rhs.idx)) op static_cast(lhs)}; \ } @@ -597,7 +614,7 @@ class wi_element &, \ + NumCols, Use, Layout, Group> &, \ const sycl::ext::oneapi::experimental::bfloat16 &) { \ throw runtime_error("joint matrix is not supported on host device.", \ PI_ERROR_INVALID_DEVICE); \ @@ -605,7 +622,7 @@ class wi_element &) { \ + NumCols, Use, Layout, Group> &) { \ throw runtime_error("joint matrix is not supported on host device.", \ PI_ERROR_INVALID_DEVICE); \ } @@ -624,12 +641,13 @@ class wi_element + matrix_layout Layout, typename Group> class wi_data { - joint_matrix &M; + joint_matrix &M; public: - wi_data(joint_matrix &Mat) : M(Mat) {} + wi_data(joint_matrix &Mat) + : M(Mat) {} size_t length() { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm); @@ -638,8 +656,8 @@ class wi_data { PI_ERROR_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } - wi_element operator[](size_t i) { - return wi_element(M, i); + wi_element operator[](size_t i) { + return wi_element(M, i); } }; diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp index 14655c3875552..7867797f92cac 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -152,13 +152,16 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // should be true in future implementations because // AMX hardware supports dynamic sizes @@ -206,13 +209,16 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // should be true in future implementations // because AMX hardware supports dynamic sizes @@ -346,13 +352,16 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS @@ -403,13 +412,16 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS diff --git a/sycl/test/matrix/matrix-int8-test-use.cpp b/sycl/test/matrix/matrix-int8-test-use.cpp index 70697f48faea3..c51b402c03ad2 100644 --- a/sycl/test/matrix/matrix-int8-test-use.cpp +++ b/sycl/test/matrix/matrix-int8-test-use.cpp @@ -1,8 +1,8 @@ // RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s -// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [1 x [4 x [1 x i8]]]]] addrspace(4)* } -// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [1 x [4 x [3 x i32]]]]] addrspace(4)* } -// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [1 x [4 x [2 x i8]]]]] addrspace(4)* } +// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [5 x [4 x [1 x i8]]]]] addrspace(4)* } +// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [5 x [4 x [3 x i32]]]]] addrspace(4)* } +// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [5 x [4 x [2 x i8]]]]] addrspace(4)* } #include #include From 39a09f6cd5286253b152acbdb5b426e6a0006541 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 2 Sep 2022 15:42:43 +0800 Subject: [PATCH 10/14] place use at the end of __spirv_JointMatrixINTEL's template parameter --- .../lib/SPIRV/libSPIRV/spirv_internal.hpp | 3 +- sycl/include/CL/__spirv/spirv_ops.hpp | 46 ++++---- sycl/include/CL/__spirv/spirv_types.hpp | 4 +- .../sycl/ext/oneapi/matrix/matrix-jit-use.hpp | 103 ++++++++---------- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 6 +- .../ext/oneapi/matrix/static-query-use.hpp | 48 ++++---- sycl/test/matrix/matrix-bfloat16-test-use.cpp | 14 +-- sycl/test/matrix/matrix-int8-test-use.cpp | 12 +- 8 files changed, 113 insertions(+), 123 deletions(-) diff --git a/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp b/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp index 9fb7de7857fbf..3f5a1f0850b83 100644 --- a/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -97,8 +97,7 @@ enum InternalJointMatrixLayout { RowMajor = 0, ColumnMajor = 1, PackedA = 2, - PackedB = 3, - Unused = 4 + PackedB = 3 }; enum InternalJointMatrixUse { MatrixA = 0, MatrixB = 1, Accumulator = 2 }; diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 1554a4f558246..90b5601304a29 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -26,7 +26,7 @@ template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); @@ -36,7 +36,7 @@ template extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL( - T *Ptr, __spv::__spirv_JointMatrixINTEL *Object, + T *Ptr, __spv::__spirv_JointMatrixINTEL *Object, std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); @@ -48,11 +48,11 @@ template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixUUMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixUSMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixSUMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_CompositeConstruct(const T v); template extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL( - __spv::__spirv_JointMatrixINTEL *); + __spv::__spirv_JointMatrixINTEL *); template extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic( - __spv::__spirv_JointMatrixINTEL *, size_t i); + __spv::__spirv_JointMatrixINTEL *, size_t i); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * -__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, T val, size_t i); #ifndef __SPIRV_BUILTIN_DECLARATIONS__ diff --git a/sycl/include/CL/__spirv/spirv_types.hpp b/sycl/include/CL/__spirv/spirv_types.hpp index c210640cb9202..a83a452c567ee 100644 --- a/sycl/include/CL/__spirv/spirv_types.hpp +++ b/sycl/include/CL/__spirv/spirv_types.hpp @@ -137,8 +137,8 @@ enum class MatrixUse : uint32_t { // information to SPIRV translator. // The long term solution would be to introduce a matrix type in Clang and use // it instead of this member. -template +template struct __spirv_JointMatrixINTEL { T(*Value) [R][C][static_cast(L) + 1][static_cast(S) + 1] diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp index 370f027de7695..8ad626dec5d2d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -22,9 +22,9 @@ namespace matrix { // packed_a and packed_b will be replaced by packed once the use implementation // is stable. -enum class matrix_layout { row_major, col_major, packed_a, packed_b, unused }; +enum class layout { row_major, col_major, packed_a, packed_b, unused }; -template struct spv_matrix_layout_traits { +template struct spv_matrix_layout_traits { static constexpr __spv::MatrixLayout value = __spv::MatrixLayout::Unused; }; @@ -33,19 +33,17 @@ template struct spv_matrix_layout_traits { static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \ }; -SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::row_major, - __spv::MatrixLayout::RowMajor) -SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major, - __spv::MatrixLayout::ColumnMajor) -SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_a, __spv::MatrixLayout::PackedA) -SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_b, __spv::MatrixLayout::PackedB) -SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::unused, __spv::MatrixLayout::Unused) +SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor) +SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor) +SPV_MATRIX_LAYOUT_TRAITS(layout::packed_a, __spv::MatrixLayout::PackedA) +SPV_MATRIX_LAYOUT_TRAITS(layout::packed_b, __spv::MatrixLayout::PackedB) +SPV_MATRIX_LAYOUT_TRAITS(layout::unused, __spv::MatrixLayout::Unused) // unnecessary was introduced for backward compatibility. // Once the use implementation is stable, "unnecessary" value will be omitted -enum class matrix_use { matrix_a, matrix_b, accumulator, unnecessary }; +enum class use { a, b, accumulator, unnecessary }; -template struct spv_matrix_use_traits { +template struct spv_matrix_use_traits { static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA; }; @@ -54,10 +52,10 @@ template struct spv_matrix_use_traits { static constexpr __spv::MatrixUse value = SPV_USE; \ }; -SPV_MATRIX_USE_TRAITS(matrix_use::matrix_a, __spv::MatrixUse::MatrixA) -SPV_MATRIX_USE_TRAITS(matrix_use::matrix_b, __spv::MatrixUse::MatrixB) -SPV_MATRIX_USE_TRAITS(matrix_use::accumulator, __spv::MatrixUse::Accumulator) -SPV_MATRIX_USE_TRAITS(matrix_use::unnecessary, __spv::MatrixUse::Unnecessary) +SPV_MATRIX_USE_TRAITS(use::a, __spv::MatrixUse::MatrixA) +SPV_MATRIX_USE_TRAITS(use::b, __spv::MatrixUse::MatrixB) +SPV_MATRIX_USE_TRAITS(use::accumulator, __spv::MatrixUse::Accumulator) +SPV_MATRIX_USE_TRAITS(use::unnecessary, __spv::MatrixUse::Unnecessary) template struct spv_scope_traits {}; template <> struct spv_scope_traits { @@ -67,18 +65,16 @@ template struct spv_scope_traits> { constexpr static auto value = __spv::Scope::Workgroup; }; -template +template class wi_data; -template +template struct joint_matrix { public: __spv::__spirv_JointMatrixINTEL< - T, NumRows, NumCols, spv_matrix_use_traits::value, - spv_matrix_layout_traits::value> *spvm; + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_scope_traits::value, spv_matrix_use_traits::value> *spvm; joint_matrix(Group sg) { #ifndef __SYCL_DEVICE_ONLY__ (void)sg; @@ -93,19 +89,18 @@ struct joint_matrix { } }; -template +template inline __SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, joint_matrix &res, - multi_ptr src, size_t stride, matrix_layout MemL) { + multi_ptr src, size_t stride, layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = src.get(); switch (MemL) { default: assert(false && "Invalid Memory Layout!"); - case matrix_layout::row_major: + case layout::row_major: res.spvm = __spirv_JointMatrixLoadINTEL::value, @@ -113,7 +108,7 @@ joint_matrix_load(Group sg, Ptr, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; - case matrix_layout::col_major: + case layout::col_major: res.spvm = __spirv_JointMatrixLoadINTEL::value, @@ -121,7 +116,7 @@ joint_matrix_load(Group sg, Ptr, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case matrix_layout::packed_a: + case layout::packed_a: res.spvm = __spirv_JointMatrixLoadINTEL::value, @@ -129,7 +124,7 @@ joint_matrix_load(Group sg, Ptr, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; - case matrix_layout::packed_b: + case layout::packed_b: res.spvm = __spirv_JointMatrixLoadINTEL::value, @@ -149,40 +144,39 @@ joint_matrix_load(Group sg, #endif // __SYCL_DEVICE_ONLY__ } -template +template inline __SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, joint_matrix &src, - multi_ptr res, size_t stride, matrix_layout MemL) { + multi_ptr res, size_t stride, layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = res.get(); switch (MemL) { default: assert(false && "Invalid Memory Layout!"); - case matrix_layout::row_major: + case layout::row_major: __spirv_JointMatrixStoreINTEL::value, spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; - case matrix_layout::col_major: + case layout::col_major: __spirv_JointMatrixStoreINTEL::value, spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case matrix_layout::packed_a: + case layout::packed_a: __spirv_JointMatrixStoreINTEL::value, spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; - case matrix_layout::packed_b: + case layout::packed_b: __spirv_JointMatrixStoreINTEL::value, spv_matrix_layout_traits::value>( @@ -202,17 +196,15 @@ joint_matrix_store(Group sg, } template + size_t K, size_t N, layout LayoutA, layout LayoutB, layout LayoutC> inline __SYCL_ALWAYS_INLINE - joint_matrix + joint_matrix joint_matrix_mad( - Group sg, - joint_matrix &mA, - joint_matrix &mB, - joint_matrix &mC) { + Group sg, joint_matrix &mA, + joint_matrix &mB, + joint_matrix &mC) { #ifdef __SYCL_DEVICE_ONLY__ - joint_matrix res(sg); + joint_matrix res(sg); if constexpr (std::is_same::value && std::is_same::value && std::is_same::value) @@ -236,8 +228,8 @@ inline __SYCL_ALWAYS_INLINE #endif // __SYCL_DEVICE_ONLY__ } -template +template inline __SYCL_ALWAYS_INLINE void joint_matrix_fill(Group sg, joint_matrix &res, @@ -258,9 +250,8 @@ joint_matrix_fill(Group sg, #endif // __SYCL_DEVICE_ONLY__ } -template +template class wi_element { joint_matrix &M; std::size_t idx; @@ -342,7 +333,7 @@ class wi_element { // the introduction of SYCL experimental bfloat16 type. Our plan is to move // towards using the SYCL bfloat16. But since it is still experimental, we will // probably keep both uint16 interpretation and SYCL bfloat16. -template class wi_element { joint_matrix &M; @@ -489,7 +480,7 @@ class wi_element { #undef OP }; -template class wi_element { @@ -640,8 +631,8 @@ class wi_element +template class wi_data { joint_matrix &M; diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 2425185239ffe..edd8ab43d75fd 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -72,9 +72,9 @@ template ::value, - spv_matrix_layout_traits::value> *spvm; + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_scope_traits::value, + spv_matrix_use_traits::value> *spvm; joint_matrix(Group sg) { #ifndef __SYCL_DEVICE_ONLY__ (void)sg; diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp index 7867797f92cac..212f985aa2ad6 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -152,16 +152,16 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // should be true in future implementations because // AMX hardware supports dynamic sizes @@ -209,16 +209,16 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // should be true in future implementations // because AMX hardware supports dynamic sizes @@ -352,16 +352,16 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS @@ -412,16 +412,16 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template using joint_matrix_c = - joint_matrix; + joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS diff --git a/sycl/test/matrix/matrix-bfloat16-test-use.cpp b/sycl/test/matrix/matrix-bfloat16-test-use.cpp index 65a864c29e996..33810b5324c21 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-use.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-use.cpp @@ -60,35 +60,35 @@ void matrix_multiply(big_matrix &C, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sycl::ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + joint_matrix sub_a(sg); // For B, since current implementation does not support non-packed // layout, users need to specify the updated VNNI sizes along with // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); + K, layout::row_major); // Assuming B data is already in VNNI format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + N * 2, layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/matrix-int8-test-use.cpp b/sycl/test/matrix/matrix-int8-test-use.cpp index c51b402c03ad2..0c2b4aadbc27c 100644 --- a/sycl/test/matrix/matrix-int8-test-use.cpp +++ b/sycl/test/matrix/matrix-int8-test-use.cpp @@ -66,13 +66,13 @@ void matrix_multiply(big_matrix &C, const auto sg_starty = global_idy - spmd_item.get_local_id(1); ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + joint_matrix sub_a(sg); // For B, since current implementation does not support non-packed // layout, users need to specify the updated VNNI sizes along with // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 @@ -80,18 +80,18 @@ void matrix_multiply(big_matrix &C, for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); + K, layout::row_major); // Assuming B data is already in VNNI format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + N * 4, layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } From bd844520d4ed675b6ea95c040bc01fe2693fe00d Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 2 Sep 2022 16:11:04 +0800 Subject: [PATCH 11/14] fix Lint issue --- sycl/include/CL/__spirv/spirv_types.hpp | 3 +- .../ext/oneapi/matrix/static-query-use.hpp | 44 +++++++------------ 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/sycl/include/CL/__spirv/spirv_types.hpp b/sycl/include/CL/__spirv/spirv_types.hpp index a83a452c567ee..815d38b349342 100644 --- a/sycl/include/CL/__spirv/spirv_types.hpp +++ b/sycl/include/CL/__spirv/spirv_types.hpp @@ -138,7 +138,8 @@ enum class MatrixUse : uint32_t { // The long term solution would be to introduce a matrix type in Clang and use // it instead of this member. template + Scope::Flag S = Scope::Flag::Subgroup, + MatrixUse U = MatrixUse::Unnecessary> struct __spirv_JointMatrixINTEL { T(*Value) [R][C][static_cast(L) + 1][static_cast(S) + 1] diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp index 212f985aa2ad6..01a6d5f78e0fc 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -152,16 +152,13 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template - using joint_matrix_c = - joint_matrix; + using joint_matrix_c = joint_matrix; bool dynamic_p = false; // should be true in future implementations because // AMX hardware supports dynamic sizes @@ -209,16 +206,13 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template - using joint_matrix_c = - joint_matrix; + using joint_matrix_c = joint_matrix; bool dynamic_p = false; // should be true in future implementations // because AMX hardware supports dynamic sizes @@ -352,16 +346,13 @@ struct tpu_params using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template - using joint_matrix_c = - joint_matrix; + using joint_matrix_c = joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS @@ -412,16 +403,13 @@ struct tpu_params< template using joint_matrix_a = - joint_matrix; + joint_matrix; template using joint_matrix_b = - joint_matrix; + joint_matrix; template - using joint_matrix_c = - joint_matrix; + using joint_matrix_c = joint_matrix; bool dynamic_p = false; // no dynamic allocation on the GPU uint32_t numtiles = -1; // does not apply for DPAS From 3c1f7bc1166c4963a0470b611dd09603de5b190c Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 8 Sep 2022 15:59:37 +0800 Subject: [PATCH 12/14] address Lint issue --- sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp | 4 ++-- sycl/test/matrix/matrix-bf16-test-SG-16.cpp | 2 +- sycl/test/matrix/matrix-bf16-test.cpp | 2 +- sycl/test/matrix/matrix-bfloat16-test-use.cpp | 4 +--- sycl/test/matrix/matrix-bfloat16-test.cpp | 2 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 2 +- sycl/test/matrix/matrix-int8-test-SG-16.cpp | 2 +- sycl/test/matrix/matrix-int8-test-use.cpp | 2 +- sycl/test/matrix/matrix-int8-test.cpp | 2 +- sycl/test/matrix/query.cpp | 2 +- 10 files changed, 11 insertions(+), 13 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp index 8ad626dec5d2d..c38f9c2f20cd7 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -90,7 +90,7 @@ struct joint_matrix { }; template + layout Layout, access::address_space Space> inline __SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, joint_matrix &res, @@ -145,7 +145,7 @@ joint_matrix_load(Group sg, } template + layout MatL, access::address_space Space> inline __SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, joint_matrix &src, diff --git a/sycl/test/matrix/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/matrix-bf16-test-SG-16.cpp index 728351686fda2..89e34eaf59cb6 100644 --- a/sycl/test/matrix/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/matrix-bf16-test-SG-16.cpp @@ -1,6 +1,6 @@ // RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 -#include #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; diff --git a/sycl/test/matrix/matrix-bf16-test.cpp b/sycl/test/matrix/matrix-bf16-test.cpp index a9903c2903e78..32633c240a70f 100644 --- a/sycl/test/matrix/matrix-bf16-test.cpp +++ b/sycl/test/matrix/matrix-bf16-test.cpp @@ -1,6 +1,6 @@ // RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 -#include #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; diff --git a/sycl/test/matrix/matrix-bfloat16-test-use.cpp b/sycl/test/matrix/matrix-bfloat16-test-use.cpp index 33810b5324c21..1a8b101721019 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-use.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-use.cpp @@ -1,6 +1,6 @@ // RUN: %clangxx -fsycl -O2 %s -o %t.out -#include #include +#include using namespace sycl::ext::oneapi::experimental::matrix; using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; @@ -68,8 +68,6 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_b(sg); joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index eb36783795130..065895fd5498f 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -1,6 +1,6 @@ // RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 -#include #include +#include using namespace sycl::ext::oneapi::experimental::matrix; using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 61c84135c40b7..2253654a8d935 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -1,7 +1,7 @@ // RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 -#include #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; diff --git a/sycl/test/matrix/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/matrix-int8-test-SG-16.cpp index b70aa626390c0..0e9855cb24579 100644 --- a/sycl/test/matrix/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/matrix-int8-test-SG-16.cpp @@ -1,6 +1,6 @@ // RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 -#include #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; diff --git a/sycl/test/matrix/matrix-int8-test-use.cpp b/sycl/test/matrix/matrix-int8-test-use.cpp index 23b7d5b9d7a68..dde82891291c0 100644 --- a/sycl/test/matrix/matrix-int8-test-use.cpp +++ b/sycl/test/matrix/matrix-int8-test-use.cpp @@ -4,8 +4,8 @@ // CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_4_3_2 = type { [12 x [12 x [5 x [4 x [3 x i32]]]]] addrspace(4)* } // CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_4_3_1 = type { [48 x [12 x [5 x [4 x [2 x i8]]]]] addrspace(4)* } -#include #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index 53cf43a6436e2..b3edeea70258f 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -4,8 +4,8 @@ // CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3_3 = type { [12 x [12 x [1 x [4 x [4 x i32]]]]] addrspace(4)* } // CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3_3 = type { [48 x [12 x [4 x [4 x [4 x i8]]]]] addrspace(4)* } -#include #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; diff --git a/sycl/test/matrix/query.cpp b/sycl/test/matrix/query.cpp index af4b36a10439b..838ae4026460a 100644 --- a/sycl/test/matrix/query.cpp +++ b/sycl/test/matrix/query.cpp @@ -1,6 +1,6 @@ // RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -fsycl -o query %s -#include #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; From 07c5f28a48dd0befdfa11835e740c41aac6eb60e Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Sat, 10 Sep 2022 00:40:01 +0800 Subject: [PATCH 13/14] address JarkAKirk's comments --- .../sycl/ext/oneapi/matrix/matrix-jit-use.hpp | 88 +++++++++---------- 1 file changed, 42 insertions(+), 46 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp index c38f9c2f20cd7..0e99a7c4a7af9 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -90,47 +90,43 @@ struct joint_matrix { }; template -inline __SYCL_ALWAYS_INLINE void -joint_matrix_load(Group sg, - joint_matrix &res, - multi_ptr src, size_t stride, layout MemL) { + access::address_space Space> +inline __SYCL_ALWAYS_INLINE void joint_matrix_load( + Group sg, + joint_matrix &res, + multi_ptr src, size_t stride, layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = src.get(); switch (MemL) { default: assert(false && "Invalid Memory Layout!"); case layout::row_major: - res.spvm = - __spirv_JointMatrixLoadINTEL::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::RowMajor, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); break; case layout::col_major: - res.spvm = - __spirv_JointMatrixLoadINTEL::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::ColumnMajor, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); break; case layout::packed_a: - res.spvm = - __spirv_JointMatrixLoadINTEL::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::PackedA, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedA, + spv_scope_traits::value); break; case layout::packed_b: - res.spvm = - __spirv_JointMatrixLoadINTEL::value, - spv_matrix_layout_traits::value>( - Ptr, stride, __spv::MatrixLayout::PackedB, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedB, + spv_scope_traits::value); break; } #else @@ -145,41 +141,41 @@ joint_matrix_load(Group sg, } template -inline __SYCL_ALWAYS_INLINE void -joint_matrix_store(Group sg, - joint_matrix &src, - multi_ptr res, size_t stride, layout MemL) { + access::address_space Space> +inline __SYCL_ALWAYS_INLINE void joint_matrix_store( + Group sg, + joint_matrix &src, + multi_ptr res, size_t stride, layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ T *Ptr = res.get(); switch (MemL) { default: assert(false && "Invalid Memory Layout!"); case layout::row_major: - __spirv_JointMatrixStoreINTEL::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; case layout::col_major: - __spirv_JointMatrixStoreINTEL::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; case layout::packed_a: - __spirv_JointMatrixStoreINTEL::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; case layout::packed_b: - __spirv_JointMatrixStoreINTEL::value, - spv_matrix_layout_traits::value>( + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB, spv_scope_traits::value); break; From 3d3cb3ad6414e418e03626920b7d995e6011144a Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 13 Sep 2022 14:35:43 +0800 Subject: [PATCH 14/14] remove lA, LB, LC from the mad function and replace them with unused --- .../sycl/ext/oneapi/matrix/matrix-jit-use.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp index 0e99a7c4a7af9..f8ff4aeaa047c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -192,15 +192,15 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( } template + size_t K, size_t N> inline __SYCL_ALWAYS_INLINE - joint_matrix + joint_matrix joint_matrix_mad( - Group sg, joint_matrix &mA, - joint_matrix &mB, - joint_matrix &mC) { + Group sg, joint_matrix &mA, + joint_matrix &mB, + joint_matrix &mC) { #ifdef __SYCL_DEVICE_ONLY__ - joint_matrix res(sg); + joint_matrix res(sg); if constexpr (std::is_same::value && std::is_same::value && std::is_same::value)