diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 5c6df9114b161..4aa9ff0effc4a 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -18,6 +18,10 @@ enum class matrix_use { a, b, accumulator }; enum class matrix_layout { row_major, col_major, packed_a, packed_b }; +namespace precision { +class tf32 {}; +} // namespace precision + template struct joint_matrix_load_impl { void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride); }; @@ -111,18 +120,19 @@ constexpr int get_layout_id< return 1; } -template struct joint_matrix_load_impl< - T, Use, NumRows, NumCols, Layout, Space, + S, T, Use, NumRows, NumCols, Layout, Space, typename std::enable_if_t> { void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride) { if constexpr (std::is_same::value) { int32_t *tileptr = reinterpret_cast(src.get()); @@ -247,15 +257,27 @@ struct joint_matrix_load_impl< get_layout_id()); } } else if constexpr (std::is_same::value) { - if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride, - get_layout_id()); + if (std::is_same::value) { + if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride, + get_layout_id()); + } + } else if (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(src.get()); + if constexpr (NumRows == 16 && NumCols == 8) { + __mma_tf32_m16n16k8_ld_a(reinterpret_cast(res.data), + tileptr, stride, get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_tf32_m16n16k8_ld_b(reinterpret_cast(res.data), + tileptr, stride, get_layout_id()); + } } } else if constexpr (std::is_same::value) { if constexpr (Use == @@ -495,6 +517,10 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } + } else if constexpr (M == 16 && N == 16 && K == 8) { + __mma_tf32_m16n16k8_mma_f32(D.data, reinterpret_cast(A.data), + reinterpret_cast(B.data), C.data, + get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { __dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, get_layout_pair_id(), 0); @@ -507,13 +533,18 @@ struct joint_matrix_mad_impl< namespace experimental::matrix { -template +template ::value || + (std::is_same::value && + std::is_same::value), + bool> = true> void joint_matrix_load( - Group sg, joint_matrix &res, + Group sg, joint_matrix &res, multi_ptr src, size_t stride) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_load_impl{} .load(res, src, stride); #else @@ -573,6 +604,21 @@ joint_matrix_mad( #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } +// This function rounds the bottom 13 bits up or down, and then zeros out the +// bottom bits +float round_to_tf32(float a) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + int32_t tmp_int = __nvvm_f2tf32_rna(a); + return __nvvm_bitcast_i2f(tmp_int); +#else + uint32_t tmp_uint = reinterpret_cast(a); + tmp_uint += 0x1000u; + tmp_uint &= 0xFFFFE000u; + float ret = reinterpret_cast(tmp_uint); + return ret; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + } // namespace experimental::matrix } // namespace oneapi } // namespace ext diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp new file mode 100644 index 0000000000000..9cdd5e739b00a --- /dev/null +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp @@ -0,0 +1,141 @@ +// REQUIRES: cuda + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +// IMPORTANT: before updating sm version support beyond sm_86 read the following +// NOTE! + +// NOTE: Technically the 'wrong' ptx instruction is called by +// joint_matrix_load/joint_matrix_store in this case: notice that the load and +// store instructions use shape m16n16k16, rather than the correct shape +// m16n16k8. The 'wrong' ptx instruction is used because it returns the correct +// SASS instructions for all existing supported sm versions: sm_80 and sm_86. +// The reason for this ptx instruction redundancy is due to the ptx naming +// convention for the mnk shape triple; however we cannot in principle a priori +// know that future sm versions will behave in the same way and that this +// redundancy will continue as future architecture is released. This should be +// validated before supporting any sm versions beyond sm_86. The reason that we +// choose to use the m16n16k16 instruction is that it allows the significant +// advantage of being able to use a portable interface across Intel and Nvidia +// backends. + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +// M, N, K define the sizes of dimensions of the three matrix types (a, b, +// accumulator) used per subgroup operation. +constexpr int M = 16; // number of rows of accumulator, + // number of cols of b. +constexpr int N = 16; // number of cols of accumulator, + // number of rows of a. +constexpr int K = 8; // number of cols of a/number of rows of b. + +// float is used in this test as the storage type for tf32 +float A[M * K]; +float B[K * N]; +float C[M * N]; +float D[M * N]; + +int main() { + + buffer bufA(A, range<1>(M * K)); // will be used as tf32 + buffer bufB(B, range<1>(K * N)); // will be used as tf32 + buffer bufC(C, range<1>(M * N)); + buffer bufD(D, range<1>(M * N)); + + queue q; + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accC = bufC.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + joint_matrix + sub_c; + + //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), K); + //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), N); + //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), N); + + // CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} + // Round a, b to tf32 + for (auto i = 0; i < 4; ++i) + sub_a.data[i] = round_to_tf32(sub_a.data[i]); + + for (auto i = 0; i < 4; ++i) + sub_b.data[i] = round_to_tf32(sub_b.data[i]); + + //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + //CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), N); + }); + }); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accC = bufC.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + joint_matrix + sub_c; + + //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), K); + //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), N); + //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, i32 {{.*}}) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), N); + + // CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} + // Round a, b to tf32 + for (auto i = 0; i < 4; ++i) + sub_a.data[i] = round_to_tf32(sub_a.data[i]); + + for (auto i = 0; i < 4; ++i) + sub_b.data[i] = round_to_tf32(sub_b.data[i]); + + //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + //CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), N); + }); + }); + + return 0; +};