diff --git a/sycl/include/CL/sycl/group.hpp b/sycl/include/CL/sycl/group.hpp index 7fc777a53e644..addb8b9548979 100644 --- a/sycl/include/CL/sycl/group.hpp +++ b/sycl/include/CL/sycl/group.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -264,10 +265,13 @@ template class group { device_event async_work_group_copy(local_ptr dest, global_ptr src, size_t numElements) const { - __ocl_event_t e = - OpGroupAsyncCopyGlobalToLocal( - __spv::Scope::Workgroup, - dest.get(), src.get(), numElements, 1, 0); + using T = detail::ConvertToOpenCLType_t; + using DestT = detail::ConvertToOpenCLType_t; + using SrcT = detail::ConvertToOpenCLType_t; + + __ocl_event_t e = OpGroupAsyncCopyGlobalToLocal( + __spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()), + numElements, 1, 0); return device_event(&e); } @@ -275,10 +279,13 @@ template class group { device_event async_work_group_copy(global_ptr dest, local_ptr src, size_t numElements) const { - __ocl_event_t e = - OpGroupAsyncCopyLocalToGlobal( - __spv::Scope::Workgroup, - dest.get(), src.get(), numElements, 1, 0); + using T = detail::ConvertToOpenCLType_t; + using DestT = detail::ConvertToOpenCLType_t; + using SrcT = detail::ConvertToOpenCLType_t; + + __ocl_event_t e = OpGroupAsyncCopyLocalToGlobal( + __spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()), + numElements, 1, 0); return device_event(&e); } @@ -287,10 +294,13 @@ template class group { global_ptr src, size_t numElements, size_t srcStride) const { - __ocl_event_t e = - OpGroupAsyncCopyGlobalToLocal( - __spv::Scope::Workgroup, - dest.get(), src.get(), numElements, srcStride, 0); + using T = detail::ConvertToOpenCLType_t; + using DestT = detail::ConvertToOpenCLType_t; + using SrcT = detail::ConvertToOpenCLType_t; + + __ocl_event_t e = OpGroupAsyncCopyGlobalToLocal( + __spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()), + numElements, srcStride, 0); return device_event(&e); } @@ -299,10 +309,13 @@ template class group { local_ptr src, size_t numElements, size_t destStride) const { - __ocl_event_t e = - OpGroupAsyncCopyLocalToGlobal( - __spv::Scope::Workgroup, - dest.get(), src.get(), numElements, destStride, 0); + using T = detail::ConvertToOpenCLType_t; + using DestT = detail::ConvertToOpenCLType_t; + using SrcT = detail::ConvertToOpenCLType_t; + + __ocl_event_t e = OpGroupAsyncCopyLocalToGlobal( + __spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()), + numElements, destStride, 0); return device_event(&e); } diff --git a/sycl/test/regression/group.cpp b/sycl/test/regression/group.cpp index 264283181b79c..c53dc3cc64360 100644 --- a/sycl/test/regression/group.cpp +++ b/sycl/test/regression/group.cpp @@ -162,10 +162,96 @@ bool group__get_linear_id() { return Pass; } +// Tests group::async_work_group_copy() +bool group__async_work_group_copy() { + std::cout << "+++ Running group::async_work_group_copy() test...\n"; + constexpr int DIMS = 2; + bool Pass = true; + + std::vector, range>> ranges; + ranges.push_back({{3, 1}, {2, 3}}); + ranges.push_back({{1, 3}, {3, 2}}); + + for (const auto &i : ranges) { + const auto LocalRange = i.first; + const auto GroupRange = i.second; + const range GlobalRange = LocalRange * GroupRange; + using DataType = vec; + const int DataLen = GlobalRange.size(); + std::unique_ptr Data(new DataType[DataLen]); + std::memset(Data.get(), 0, DataLen * sizeof(DataType)); + + try { + buffer Buf(Data.get(), DataLen); + queue Q(AsyncHandler{}); + + Q.submit([&](handler &cgh) { + auto AccGlobal = Buf.get_access(cgh); + accessor + AccLocal(LocalRange, cgh); + + cgh.parallel_for( + nd_range<2>{GlobalRange, LocalRange}, + [=](nd_item I) { + const auto Group = I.get_group(); + const auto NumElem = AccLocal.get_count(); + const auto Off = Group[0] * I.get_group_range(1) * NumElem + + Group[1] * I.get_local_range(1); + auto PtrGlobal = AccGlobal.get_pointer() + Off; + auto PtrLocal = AccLocal.get_pointer(); + if (I.get_local_range(0) == 1) { + Group.async_work_group_copy(PtrLocal, PtrGlobal, NumElem); + } else { + Group.async_work_group_copy(PtrLocal, PtrGlobal, NumElem, + I.get_global_range(1)); + } + AccLocal[I.get_local_id()][0] += I.get_global_id(0); + AccLocal[I.get_local_id()][1] += I.get_global_id(1); + if (I.get_local_range(0) == 1) { + Group.async_work_group_copy(PtrGlobal, PtrLocal, NumElem); + } else { + Group.async_work_group_copy(PtrGlobal, PtrLocal, NumElem, + I.get_global_range(1)); + } + }); + }); + } catch (cl::sycl::exception const &E) { + std::cout << "SYCL exception caught: " << E.what() << '\n'; + return 2; + } + const size_t SIZE_Y = GlobalRange.get(0); + const size_t SIZE_X = GlobalRange.get(1); + int ErrCnt = 0; + + for (size_t Y = 0; Y < SIZE_Y; Y++) { + for (size_t X = 0; X < SIZE_X; X++) { + const size_t Ind = Y * SIZE_X + X; + const auto Test0 = Data[Ind][0]; + const auto Test1 = Data[Ind][1]; + const auto Gold0 = Y; + const auto Gold1 = X; + const bool Ok = (Test0 == Gold0 && Test1 == Gold1); + Pass &= Ok; + + if (!Ok && ErrCnt++ < 10) { + std::cout << "*** ERROR at [" << Y << "][" << X << "]: "; + std::cout << Test0 << " " << Test1 << " != "; + std::cout << Gold0 << " " << Gold1 << "\n"; + } + } + } + } + if (Pass) + std::cout << " pass\n"; + return Pass; +} + int main() { bool Pass = 1; Pass &= group__get_group_range(); Pass &= group__get_linear_id(); + Pass &= group__async_work_group_copy(); if (!Pass) { std::cout << "FAILED\n";