Skip to content

[SYCL][CUDA][Matrix] Initial Tensorcore matrix ext impl #4696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
#pragma once

#include <CL/sycl/detail/defines_elementary.hpp>
#include <immintrin.h>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace ext {
namespace oneapi {
namespace experimental::matrix {

enum class matrix_use { a, b, accumulator };

enum class matrix_layout { row_major, col_major, packed_a, packed_b };

template <typename T, matrix_use MT, size_t Rows = sycl::dynamic_extent,
size_t Cols = sycl::dynamic_extent,
matrix_layout Layout = matrix_layout::row_major,
typename Group = sycl::sub_group, typename Cond = void>
struct joint_matrix {
joint_matrix(Group g) {}
};

// The enable_if_t usage in this file is used to disable the
// matrix_layout::packed case which is not compatible with the Nvidia cuda
// backend.
template <matrix_layout Layout>
struct joint_matrix<
double, matrix_use::a, 8, 4, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
double data[1];
};

template <matrix_layout Layout>
struct joint_matrix<
double, matrix_use::b, 4, 8, Layout, sycl::sub_group,
typename std::enable_if_t<(Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major)>> {
double data[1];
};

template <matrix_layout Layout>
struct joint_matrix<
double, matrix_use::accumulator, 8, 8, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
double data[2];
};

} // namespace experimental::matrix

namespace detail {
using namespace experimental;

template <typename T, matrix::matrix_use MT, size_t NumRows, size_t NumCols,
matrix::matrix_layout Layout, access::address_space Space,
typename Cond = void>
struct joint_matrix_load_impl {
void load(matrix::joint_matrix<T, MT, NumRows, NumCols, Layout> &res,
multi_ptr<T, Space> src, size_t stride);
};

template <matrix::matrix_layout Layout> constexpr int get_layout_id();

template <> constexpr int get_layout_id<matrix::matrix_layout::row_major>() {
return 0;
}

template <> constexpr int get_layout_id<matrix::matrix_layout::col_major>() {
return 1;
}

template <matrix::matrix_layout Layout, access::address_space Space>
struct joint_matrix_load_impl<
double, matrix::matrix_use::a, 8, 4, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void
load(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, Layout> &res,
multi_ptr<double, Space> src, size_t stride) {

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_ld_a(res.data, src.get(), stride, get_layout_id<Layout>());
#endif
#endif
}
};

template <matrix::matrix_layout Layout, access::address_space Space>
struct joint_matrix_load_impl<
double, matrix::matrix_use::b, 4, 8, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void
load(matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, Layout> &res,
multi_ptr<double, Space> src, size_t stride) {
#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_ld_b(res.data, src.get(), stride, get_layout_id<Layout>());
#endif
#endif
}
};

template <matrix::matrix_layout Layout, access::address_space Space>
struct joint_matrix_load_impl<
double, matrix::matrix_use::accumulator, 8, 8, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void load(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
Layout> &res,
multi_ptr<double, Space> src, size_t stride) {

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_ld_c(res.data, src.get(), stride, get_layout_id<Layout>());
#endif
#endif
}
};

template <typename T, size_t NumRows, size_t NumCols,
matrix::matrix_layout Layout, access::address_space Space,
typename Cond = void>
struct joint_matrix_store_impl {
void store(matrix::joint_matrix<T, matrix::matrix_use::accumulator, NumRows,
NumCols, Layout> &src,
multi_ptr<T, Space> dst, size_t stride);
};

template <matrix::matrix_layout Layout, access::address_space Space>
struct joint_matrix_store_impl<
double, 8, 8, Layout, Space,
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
Layout == matrix::matrix_layout::col_major>> {
void store(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
Layout> &src,
multi_ptr<double, Space> dst, size_t stride) {

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride,
get_layout_id<Layout>());
#endif
#endif
}
};

template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
matrix::matrix_layout LayoutC, typename Cond = void>
struct joint_matrix_mad_impl {
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
mad(matrix::joint_matrix<T1, matrix::matrix_use::a, M, K, LayoutA> A,
matrix::joint_matrix<T1, matrix::matrix_use::b, K, N, LayoutB> B,
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
C);
};

template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB>
constexpr int get_layout_pair_id();

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
matrix::matrix_layout::row_major>() {
return 0;
}

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
matrix::matrix_layout::col_major>() {
return 1;
}

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
matrix::matrix_layout::row_major>() {
return 2;
}

template <>
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
matrix::matrix_layout::col_major>() {
return 3;
}

template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
matrix::matrix_layout LayoutC>
struct joint_matrix_mad_impl<
double, double, 8, 4, 8, LayoutA, LayoutB, LayoutC,
typename std::enable_if_t<(LayoutA == matrix::matrix_layout::row_major ||
LayoutA == matrix::matrix_layout::col_major) &&
(LayoutB == matrix::matrix_layout::row_major ||
LayoutB == matrix::matrix_layout::col_major) &&
(LayoutC == matrix::matrix_layout::row_major ||
LayoutC == matrix::matrix_layout::col_major)>> {
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
mad(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, LayoutA> A,
matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, LayoutB> B,
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
LayoutC>
C) {
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
D;

#ifdef __NVPTX__
#ifdef __SYCL_DEVICE_ONLY__
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
get_layout_pair_id<LayoutA, LayoutB>(), 0);
#endif
#endif

return D;
}
};

} // namespace detail

namespace experimental::matrix {

template <typename Group, typename T, matrix_use MT, size_t NumRows,
size_t NumCols, matrix_layout Layout, access::address_space Space>
void joint_matrix_load(
Group sg, joint_matrix<T, MT, NumRows, NumCols, Layout, Group> &res,
multi_ptr<T, Space> src, size_t stride) {
detail::joint_matrix_load_impl<T, MT, NumRows, NumCols, Layout, Space>{}.load(
res, src, stride);
}

template <typename Group, typename T, size_t NumRows, size_t NumCols,
matrix_layout Layout, access::address_space Space>
void joint_matrix_store(Group sg,
joint_matrix<T, matrix_use::accumulator, NumRows,
NumCols, Layout, Group> &src,
multi_ptr<T, Space> dst, size_t stride) {
detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout, Space>{}.store(
src, dst, stride);
}

template <typename Group, typename T1, typename T2, std::size_t M,
std::size_t K, std::size_t N, matrix_layout LayoutA,
matrix_layout LayoutB, matrix_layout LayoutC>
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group>
joint_matrix_mad(
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
return detail::joint_matrix_mad_impl<T1, T2, M, K, N, LayoutA, LayoutB,
LayoutC>{}
.mad(A, B, C);
}

} // namespace experimental::matrix
} // namespace oneapi
} // namespace ext
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
3 changes: 3 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp>
#include <sycl/ext/oneapi/matrix/static-query.hpp>
#endif
#if (SYCL_EXT_ONEAPI_MATRIX == 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation can also benefit from the static query we have as well. Besides that the query can give the user information about what the implementation support, it can also construct the matrices and make the sizes optional for the user.

We should probably add this to matrix-jit.hpp and fork to using the AOT tensorcore implementation based on some option (AOT for tensorcore).
I am asking this because we should have one place that has the interface to make maintaining the code easy but also, since this interface is experimental, we expect it will be changed (like the use argument you introduce). We should make the interface in one place so we only have to modify it in only one place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think that there should be a single header for all of the definitions of joint_matrix, joint_matrix_load, joint_matrix_store, joint_matrix_mad, and then backend dependent specializations of these functions can be in separate files?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if you can use the same things as in matrix/matrix-jit.hpp like matrix_layout and not redefine them, that would be better.
For the things that are different like the definition of joint_matrix type, joint_matrix_load/store/mad because of "use" argument, can you add the use-definitions in matrix-jit.hpp (under the new test macro = 3)

As you know, we are planning on adding the new "use" argument for AMX and DPAS as well. Once we do that, there will be one definition of joint_matrix type/joint_matrix_load/store/mad.

If you make this change now, later, there will be one place for us to change (remove the old joint_matrix,load,store,mad that do not have "use" argument). And we won't need to touch the tensorcores specific specifications that will be in a different file.

Also, when this convergence happens, there will be no need for the feature test macro. Since this is an experimental interface, we don't need to keep track of "old" versions of the interface. We will remove AOT AMX (SYCL_EXT_ONEAPI_MATRIX=1), we only keep matrix-jit.hpp that enables DPAS, AMX and tensorecores.

Copy link
Contributor Author

@JackAKirk JackAKirk Oct 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matrix_layout has an identical definition as in matrix-jit.hpp.

For the things that are different like the definition of joint_matrix type, joint_matrix_load/store/mad because of "use" argument, can you add the use-definitions in matrix-jit.hpp (under the new test macro = 3)

As you know, we are planning on adding the new "use" argument for AMX and DPAS as well. Once we do that, there will be one definition of joint_matrix type/joint_matrix_load/store/mad.

I'm not sure what you are asking me to do here? : if I add the definitions of joint_matrix_* used in matrix-tensorcore.hpp into matrix-jit.hpp they will be a redeclaration of the intel specific functions already defined in matrix-jit.hpp that do not use the matrix_use template parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dkhaldi , We would like to get this merged. Could you clarify what you would like me to change? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late reply, I was thinking you can have these defined under the new test macro = 3 in the same file so they don't get redefined.
However, I think it will be best if we merge these as separate files. Once we add the use argument, we can reiterate on this to merge both files. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sure, I think that keeping them separate is a good idea for now.

#include <sycl/ext/oneapi/matrix/matrix-tensorcore.hpp>
#endif
101 changes: 101 additions & 0 deletions sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// REQUIRES: gpu, cuda

// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s

#include <CL/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

// M, N, K define the sizes of dimensions of the three matrix types (a, b,
// accumulator) used per subgroup operation.
constexpr int M = 8; // number of rows of accumulator,
// number of cols of b.
constexpr int N = 8; // number of cols of accumulator,
// number of rows of a.
constexpr int K = 4; // number of cols of a/number of rows of b.

double A[M * K];
double B[K * N];
double C[M * N];
double D[M * N];

int main() {

buffer<double, 1> bufA(A, range<1>(M * K));
buffer<double, 1> bufB(B, range<1>(K * N));
buffer<double, 1> bufC(C, range<1>(M * N));
buffer<double, 1> bufD(D, range<1>(M * N));

queue q;

q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class row_row>(
nd_range<2>({1, 32}, {1, 32}), [=
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<double, matrix_use::accumulator, M, N,
matrix_layout::row_major>
sub_c;

joint_matrix<double, matrix_use::a, M, K, matrix_layout::row_major>
sub_a;

joint_matrix<double, matrix_use::b, K, N, matrix_layout::row_major>
sub_b;

//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 4) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 8) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %11, double %12, double %9, double %10) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});

q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class col_col>(
nd_range<2>({1, 32}, {1, 32}), [=
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<double, matrix_use::accumulator, M, N,
matrix_layout::col_major>
sub_c;

joint_matrix<double, matrix_use::a, M, K, matrix_layout::col_major>
sub_a;

joint_matrix<double, matrix_use::b, K, N, matrix_layout::col_major>
sub_b;

//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), M);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 8) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), M);
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 4) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), K);
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %11, double %12, double %9, double %10) #{{.*}}
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), M);
});
});

return 0;
};