-
Notifications
You must be signed in to change notification settings - Fork 795
[SYCL][CUDA][Matrix] Initial Tensorcore matrix ext impl #4696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fde5488
1f200d7
5a70623
a223fd2
04dc06a
9817ae6
62f28c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
#pragma once | ||
|
||
#include <CL/sycl/detail/defines_elementary.hpp> | ||
#include <immintrin.h> | ||
|
||
__SYCL_INLINE_NAMESPACE(cl) { | ||
namespace sycl { | ||
namespace ext { | ||
namespace oneapi { | ||
namespace experimental::matrix { | ||
|
||
enum class matrix_use { a, b, accumulator }; | ||
|
||
enum class matrix_layout { row_major, col_major, packed_a, packed_b }; | ||
|
||
template <typename T, matrix_use MT, size_t Rows = sycl::dynamic_extent, | ||
size_t Cols = sycl::dynamic_extent, | ||
matrix_layout Layout = matrix_layout::row_major, | ||
typename Group = sycl::sub_group, typename Cond = void> | ||
struct joint_matrix { | ||
joint_matrix(Group g) {} | ||
}; | ||
|
||
// The enable_if_t usage in this file is used to disable the | ||
// matrix_layout::packed case which is not compatible with the Nvidia cuda | ||
// backend. | ||
template <matrix_layout Layout> | ||
struct joint_matrix< | ||
double, matrix_use::a, 8, 4, Layout, sycl::sub_group, | ||
typename std::enable_if_t<Layout == matrix_layout::row_major || | ||
Layout == matrix_layout::col_major>> { | ||
double data[1]; | ||
}; | ||
|
||
template <matrix_layout Layout> | ||
struct joint_matrix< | ||
double, matrix_use::b, 4, 8, Layout, sycl::sub_group, | ||
typename std::enable_if_t<(Layout == matrix_layout::row_major || | ||
Layout == matrix_layout::col_major)>> { | ||
double data[1]; | ||
}; | ||
|
||
template <matrix_layout Layout> | ||
struct joint_matrix< | ||
double, matrix_use::accumulator, 8, 8, Layout, sycl::sub_group, | ||
typename std::enable_if_t<Layout == matrix_layout::row_major || | ||
Layout == matrix_layout::col_major>> { | ||
double data[2]; | ||
}; | ||
|
||
} // namespace experimental::matrix | ||
|
||
namespace detail { | ||
using namespace experimental; | ||
|
||
template <typename T, matrix::matrix_use MT, size_t NumRows, size_t NumCols, | ||
matrix::matrix_layout Layout, access::address_space Space, | ||
typename Cond = void> | ||
struct joint_matrix_load_impl { | ||
void load(matrix::joint_matrix<T, MT, NumRows, NumCols, Layout> &res, | ||
multi_ptr<T, Space> src, size_t stride); | ||
}; | ||
|
||
template <matrix::matrix_layout Layout> constexpr int get_layout_id(); | ||
|
||
template <> constexpr int get_layout_id<matrix::matrix_layout::row_major>() { | ||
return 0; | ||
} | ||
|
||
template <> constexpr int get_layout_id<matrix::matrix_layout::col_major>() { | ||
return 1; | ||
} | ||
|
||
template <matrix::matrix_layout Layout, access::address_space Space> | ||
struct joint_matrix_load_impl< | ||
double, matrix::matrix_use::a, 8, 4, Layout, Space, | ||
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major || | ||
Layout == matrix::matrix_layout::col_major>> { | ||
void | ||
load(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, Layout> &res, | ||
multi_ptr<double, Space> src, size_t stride) { | ||
|
||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_ld_a(res.data, src.get(), stride, get_layout_id<Layout>()); | ||
#endif | ||
#endif | ||
} | ||
}; | ||
|
||
template <matrix::matrix_layout Layout, access::address_space Space> | ||
struct joint_matrix_load_impl< | ||
double, matrix::matrix_use::b, 4, 8, Layout, Space, | ||
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major || | ||
Layout == matrix::matrix_layout::col_major>> { | ||
void | ||
load(matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, Layout> &res, | ||
multi_ptr<double, Space> src, size_t stride) { | ||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_ld_b(res.data, src.get(), stride, get_layout_id<Layout>()); | ||
#endif | ||
#endif | ||
} | ||
}; | ||
|
||
template <matrix::matrix_layout Layout, access::address_space Space> | ||
struct joint_matrix_load_impl< | ||
double, matrix::matrix_use::accumulator, 8, 8, Layout, Space, | ||
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major || | ||
Layout == matrix::matrix_layout::col_major>> { | ||
void load(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, | ||
Layout> &res, | ||
multi_ptr<double, Space> src, size_t stride) { | ||
|
||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_ld_c(res.data, src.get(), stride, get_layout_id<Layout>()); | ||
#endif | ||
#endif | ||
} | ||
}; | ||
|
||
template <typename T, size_t NumRows, size_t NumCols, | ||
matrix::matrix_layout Layout, access::address_space Space, | ||
typename Cond = void> | ||
struct joint_matrix_store_impl { | ||
void store(matrix::joint_matrix<T, matrix::matrix_use::accumulator, NumRows, | ||
NumCols, Layout> &src, | ||
multi_ptr<T, Space> dst, size_t stride); | ||
}; | ||
|
||
template <matrix::matrix_layout Layout, access::address_space Space> | ||
struct joint_matrix_store_impl< | ||
double, 8, 8, Layout, Space, | ||
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major || | ||
Layout == matrix::matrix_layout::col_major>> { | ||
void store(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, | ||
Layout> &src, | ||
multi_ptr<double, Space> dst, size_t stride) { | ||
|
||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride, | ||
get_layout_id<Layout>()); | ||
#endif | ||
#endif | ||
} | ||
}; | ||
|
||
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N, | ||
matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB, | ||
matrix::matrix_layout LayoutC, typename Cond = void> | ||
struct joint_matrix_mad_impl { | ||
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC> | ||
mad(matrix::joint_matrix<T1, matrix::matrix_use::a, M, K, LayoutA> A, | ||
matrix::joint_matrix<T1, matrix::matrix_use::b, K, N, LayoutB> B, | ||
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC> | ||
C); | ||
}; | ||
|
||
template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB> | ||
constexpr int get_layout_pair_id(); | ||
|
||
template <> | ||
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major, | ||
matrix::matrix_layout::row_major>() { | ||
return 0; | ||
} | ||
|
||
template <> | ||
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major, | ||
matrix::matrix_layout::col_major>() { | ||
return 1; | ||
} | ||
|
||
template <> | ||
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major, | ||
matrix::matrix_layout::row_major>() { | ||
return 2; | ||
} | ||
|
||
template <> | ||
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major, | ||
matrix::matrix_layout::col_major>() { | ||
return 3; | ||
} | ||
|
||
template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB, | ||
matrix::matrix_layout LayoutC> | ||
struct joint_matrix_mad_impl< | ||
double, double, 8, 4, 8, LayoutA, LayoutB, LayoutC, | ||
typename std::enable_if_t<(LayoutA == matrix::matrix_layout::row_major || | ||
LayoutA == matrix::matrix_layout::col_major) && | ||
(LayoutB == matrix::matrix_layout::row_major || | ||
LayoutB == matrix::matrix_layout::col_major) && | ||
(LayoutC == matrix::matrix_layout::row_major || | ||
LayoutC == matrix::matrix_layout::col_major)>> { | ||
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC> | ||
mad(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, LayoutA> A, | ||
matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, LayoutB> B, | ||
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, | ||
LayoutC> | ||
C) { | ||
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC> | ||
D; | ||
|
||
#ifdef __NVPTX__ | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, | ||
get_layout_pair_id<LayoutA, LayoutB>(), 0); | ||
#endif | ||
#endif | ||
|
||
return D; | ||
} | ||
}; | ||
|
||
} // namespace detail | ||
|
||
namespace experimental::matrix { | ||
|
||
template <typename Group, typename T, matrix_use MT, size_t NumRows, | ||
size_t NumCols, matrix_layout Layout, access::address_space Space> | ||
void joint_matrix_load( | ||
Group sg, joint_matrix<T, MT, NumRows, NumCols, Layout, Group> &res, | ||
multi_ptr<T, Space> src, size_t stride) { | ||
detail::joint_matrix_load_impl<T, MT, NumRows, NumCols, Layout, Space>{}.load( | ||
res, src, stride); | ||
} | ||
|
||
template <typename Group, typename T, size_t NumRows, size_t NumCols, | ||
matrix_layout Layout, access::address_space Space> | ||
void joint_matrix_store(Group sg, | ||
joint_matrix<T, matrix_use::accumulator, NumRows, | ||
NumCols, Layout, Group> &src, | ||
multi_ptr<T, Space> dst, size_t stride) { | ||
detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout, Space>{}.store( | ||
src, dst, stride); | ||
} | ||
|
||
template <typename Group, typename T1, typename T2, std::size_t M, | ||
std::size_t K, std::size_t N, matrix_layout LayoutA, | ||
matrix_layout LayoutB, matrix_layout LayoutC> | ||
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> | ||
joint_matrix_mad( | ||
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A, | ||
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B, | ||
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) { | ||
return detail::joint_matrix_mad_impl<T1, T2, M, K, N, LayoutA, LayoutB, | ||
LayoutC>{} | ||
.mad(A, B, C); | ||
} | ||
|
||
} // namespace experimental::matrix | ||
} // namespace oneapi | ||
} // namespace ext | ||
} // namespace sycl | ||
} // __SYCL_INLINE_NAMESPACE(cl) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,6 @@ | |
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp> | ||
#include <sycl/ext/oneapi/matrix/static-query.hpp> | ||
#endif | ||
#if (SYCL_EXT_ONEAPI_MATRIX == 3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation can also benefit from the static query we have as well. Besides that the query can give the user information about what the implementation support, it can also construct the matrices and make the sizes optional for the user. We should probably add this to matrix-jit.hpp and fork to using the AOT tensorcore implementation based on some option (AOT for tensorcore). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think that there should be a single header for all of the definitions of joint_matrix, joint_matrix_load, joint_matrix_store, joint_matrix_mad, and then backend dependent specializations of these functions can be in separate files? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, if you can use the same things as in matrix/matrix-jit.hpp like matrix_layout and not redefine them, that would be better. As you know, we are planning on adding the new "use" argument for AMX and DPAS as well. Once we do that, there will be one definition of joint_matrix type/joint_matrix_load/store/mad. If you make this change now, later, there will be one place for us to change (remove the old joint_matrix,load,store,mad that do not have "use" argument). And we won't need to touch the tensorcores specific specifications that will be in a different file. Also, when this convergence happens, there will be no need for the feature test macro. Since this is an experimental interface, we don't need to keep track of "old" versions of the interface. We will remove AOT AMX (SYCL_EXT_ONEAPI_MATRIX=1), we only keep matrix-jit.hpp that enables DPAS, AMX and tensorecores. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. matrix_layout has an identical definition as in matrix-jit.hpp.
I'm not sure what you are asking me to do here? : if I add the definitions of joint_matrix_* used in matrix-tensorcore.hpp into matrix-jit.hpp they will be a redeclaration of the intel specific functions already defined in matrix-jit.hpp that do not use the matrix_use template parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @dkhaldi , We would like to get this merged. Could you clarify what you would like me to change? Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for late reply, I was thinking you can have these defined under the new test macro = 3 in the same file so they don't get redefined. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK sure, I think that keeping them separate is a good idea for now. |
||
#include <sycl/ext/oneapi/matrix/matrix-tensorcore.hpp> | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// REQUIRES: gpu, cuda | ||
|
||
// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s | ||
|
||
#include <CL/sycl.hpp> | ||
|
||
using namespace sycl; | ||
using namespace sycl::ext::oneapi::experimental::matrix; | ||
|
||
// M, N, K define the sizes of dimensions of the three matrix types (a, b, | ||
// accumulator) used per subgroup operation. | ||
constexpr int M = 8; // number of rows of accumulator, | ||
// number of cols of b. | ||
constexpr int N = 8; // number of cols of accumulator, | ||
// number of rows of a. | ||
constexpr int K = 4; // number of cols of a/number of rows of b. | ||
|
||
double A[M * K]; | ||
double B[K * N]; | ||
double C[M * N]; | ||
double D[M * N]; | ||
|
||
int main() { | ||
|
||
buffer<double, 1> bufA(A, range<1>(M * K)); | ||
buffer<double, 1> bufB(B, range<1>(K * N)); | ||
buffer<double, 1> bufC(C, range<1>(M * N)); | ||
buffer<double, 1> bufD(D, range<1>(M * N)); | ||
|
||
queue q; | ||
|
||
q.submit([&](handler &cgh) { | ||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||
auto accD = bufD.get_access<access::mode::read_write>(cgh); | ||
|
||
cgh.parallel_for<class row_row>( | ||
nd_range<2>({1, 32}, {1, 32}), [= | ||
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { | ||
sycl::sub_group sg = item.get_sub_group(); | ||
|
||
joint_matrix<double, matrix_use::accumulator, M, N, | ||
matrix_layout::row_major> | ||
sub_c; | ||
|
||
joint_matrix<double, matrix_use::a, M, K, matrix_layout::row_major> | ||
sub_a; | ||
|
||
joint_matrix<double, matrix_use::b, K, N, matrix_layout::row_major> | ||
sub_b; | ||
|
||
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_c, accC.get_pointer(), N); | ||
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 4) #{{.*}} | ||
joint_matrix_load(sg, sub_a, accA.get_pointer(), K); | ||
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_b, accB.get_pointer(), N); | ||
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %11, double %12, double %9, double %10) #{{.*}} | ||
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); | ||
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}} | ||
joint_matrix_store(sg, sub_c, accD.get_pointer(), N); | ||
}); | ||
}); | ||
|
||
q.submit([&](handler &cgh) { | ||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||
auto accD = bufD.get_access<access::mode::read_write>(cgh); | ||
|
||
cgh.parallel_for<class col_col>( | ||
nd_range<2>({1, 32}, {1, 32}), [= | ||
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { | ||
sycl::sub_group sg = item.get_sub_group(); | ||
|
||
joint_matrix<double, matrix_use::accumulator, M, N, | ||
matrix_layout::col_major> | ||
sub_c; | ||
|
||
joint_matrix<double, matrix_use::a, M, K, matrix_layout::col_major> | ||
sub_a; | ||
|
||
joint_matrix<double, matrix_use::b, K, N, matrix_layout::col_major> | ||
sub_b; | ||
|
||
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_c, accC.get_pointer(), M); | ||
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_a, accA.get_pointer(), M); | ||
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 4) #{{.*}} | ||
joint_matrix_load(sg, sub_b, accB.get_pointer(), K); | ||
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %11, double %12, double %9, double %10) #{{.*}} | ||
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); | ||
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}} | ||
joint_matrix_store(sg, sub_c, accD.get_pointer(), M); | ||
}); | ||
}); | ||
|
||
return 0; | ||
}; |
Uh oh!
There was an error while loading. Please reload this page.