diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index cf53bec8f943..b16e902c3f2e 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -22,6 +22,7 @@ enum class matrix_layout { row_major, col_major, packed_a, packed_b }; namespace precision { class tf32 {}; +class b1 {}; } // namespace precision template +struct joint_matrix< + precision::b1, matrix_use::a, 8, 128, Layout, sycl::sub_group, + typename std::enable_if_t> { + joint_matrix() { + static_assert((Layout == matrix_layout::row_major), + "For the matrix_use::a case, matrix_layout::row_major must " + "be used for Bitwise MAD"); + }; + int32_t data; +}; + +template +struct joint_matrix< + precision::b1, matrix_use::b, 128, 8, Layout, sycl::sub_group, + typename std::enable_if_t> { + joint_matrix() { + static_assert((Layout == matrix_layout::col_major), + "For the matrix_use::b case, matrix_layout::col_major must " + "be used for Bitwise MAD"); + }; + int32_t data; +}; #undef __SYCL_JOINT_MATRIX_OVERLOAD template ()); + } else if constexpr (NumRows == 8 && NumCols == 8) { + __bmma_m8n8k128_ld_c(destptr, src.get(), stride, + get_layout_id()); } } else if constexpr (std::is_same::value) { if constexpr (std::is_same::value) { @@ -381,6 +414,16 @@ struct joint_matrix_load_impl< matrix_use::accumulator) { __dmma_m8n8k4_ld_c(dstptr, src.get(), stride, get_layout_id()); } + } else if constexpr (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(src.get()); + if constexpr (NumRows == 8 && NumCols == 128) { + __bmma_m8n8k128_ld_a_b1(&res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 128 && NumCols == 8) { + __bmma_m8n8k128_ld_b_b1(&res.data, tileptr, stride, + get_layout_id()); + } } } }; @@ -458,6 +501,10 @@ struct joint_matrix_store_impl< __dmma_m8n8k4_st_c_f64(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __bmma_m8n8k128_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } } }; @@ -486,6 +533,33 @@ struct joint_matrix_mad_impl { C); }; +template +struct joint_matrix_bmad_impl { + sycl::ext::oneapi::experimental::matrix::joint_matrix< + int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + M, N, LayoutC, sycl::sub_group> + bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + sycl::ext::oneapi::experimental::matrix::precision::b1, + sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, + sycl::sub_group> + A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + sycl::ext::oneapi::experimental::matrix::precision::b1, + sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, + sycl::sub_group> + B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + int32_t, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, LayoutC, sycl::sub_group> + C, + BinaryOperation Op); +}; + template constexpr int get_layout_pair_id(); @@ -686,6 +760,63 @@ struct joint_matrix_mad_impl< }; #endif // __cplusplus >= 201703L +#if __cplusplus >= 201703L // if constexpr usage +template +struct joint_matrix_bmad_impl< + M, K, N, LayoutC, BinaryOperation, + typename std::enable_if_t<( + LayoutC == + sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major || + LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout:: + col_major)>> { + sycl::ext::oneapi::experimental::matrix::joint_matrix< + int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + M, N, LayoutC, sycl::sub_group> + bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + sycl::ext::oneapi::experimental::matrix::precision::b1, + sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, + sycl::sub_group> + A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + sycl::ext::oneapi::experimental::matrix::precision::b1, + sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, + sycl::sub_group> + B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + int32_t, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, LayoutC, sycl::sub_group> + C, + BinaryOperation Op) { + sycl::ext::oneapi::experimental::matrix::joint_matrix< + int32_t, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, + LayoutC, sycl::sub_group> + D; + + if constexpr (std::is_same< + BinaryOperation, + sycl::bit_and>::value) { + __bmma_m8n8k128_mma_and_popc_b1( + reinterpret_cast(&D.wi_marray), &A.data, &B.data, + reinterpret_cast(&C.wi_marray), 1); + } else if constexpr (std::is_same< + BinaryOperation, + sycl::bit_xor>::value) { + __bmma_m8n8k128_mma_xor_popc_b1( + reinterpret_cast(&D.wi_marray), &A.data, &B.data, + reinterpret_cast(&C.wi_marray), 1); + } + return D; + } +}; +#endif // __cplusplus >= 201703L } // namespace detail namespace experimental { @@ -696,7 +827,9 @@ template ::value || (std::is_same::value && - std::is_same::value), + std::is_same::value) || + (std::is_same::value && + std::is_same::value), bool> = true> void joint_matrix_load( Group sg, joint_matrix &res, @@ -777,6 +910,35 @@ float round_to_tf32(float a) { #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } +template +joint_matrix +joint_matrix_bmad( + Group sg, + joint_matrix + A, + joint_matrix + B, + joint_matrix C, + BinaryOperation Op) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return sycl::ext::oneapi::detail::joint_matrix_bmad_impl{} + .bmad(A, B, C, Op); +#else + std::ignore = sg; + std::ignore = A; + std::ignore = B; + std::ignore = C; + std::ignore = Op; + throw runtime_error("joint_matrix_bmad is " + "only supported by CUDA devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + } // namespace matrix } // namespace experimental } // namespace oneapi diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-single-bit-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-single-bit-test.cpp new file mode 100644 index 000000000000..627971124b57 --- /dev/null +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-single-bit-test.cpp @@ -0,0 +1,78 @@ +// REQUIRES: cuda + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +// M, N, (K * 32) define the sizes of dimensions of the three matrix types (a, +// b, accumulator) used per subgroup operation. +constexpr int M = 8; // number of rows of accumulator, + // number of cols of b. +constexpr int N = 8; // number of cols of accumulator, + // number of rows of a. +constexpr int K = 128; // number of cols of a/number of rows of b divided by 32 + +// Each bit of each uint32_t A/B array element is an element of a single-bit +// matrix. joint_matrix_bmad performs Binary Dot Products on these matrices (see +// M. Rastegari et al. Computer Vision – ECCV 2016, 525-542 and A. Li et al. +// IEEE Transactions on Parallel and Distributed Systems, 32(7):1878-1891, +// 2021)) +uint32_t A[M * (K / 32)]; +uint32_t B[(K / 32) * N]; +int32_t C[M * N]; +int32_t D[M * N]; + +int main() { + + buffer bufA(A, range<1>(M * (K / 32))); + buffer bufB(B, range<1>((K / 32) * N)); + buffer bufC(C, range<1>(M * N)); + buffer bufD(D, range<1>(M * N)); + + queue q; + + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + //CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 8) + joint_matrix_load(sg, sub_c, accC.get_pointer(), N); + //CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.a.row.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128) + joint_matrix_load(sg, sub_a, accA.get_pointer(), K); + //CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.b.col.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128) + joint_matrix_load(sg, sub_b, accB.get_pointer(), K); + //CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1(i32 %3, i32 %4, i32 %1, i32 %2) + sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c, + sycl::bit_xor()); + //CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1(i32 %3, i32 %4, i32 %6, i32 %7) + sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c, + sycl::bit_and()); + //CHECK: tail call void @llvm.nvvm.wmma.m8n8k128.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %9, i32 %10, i32 8) + joint_matrix_store(sg, sub_c, accD.get_pointer(), N); + }); + }); + + return 0; +};