Skip to content

Commit 035addc

Browse files
committed
[SYCL] Add tests for user-defined reductions extension
Spec: intel/llvm#7202 Implementation: intel/llvm#7436
1 parent a4b77f0 commit 035addc

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
4+
#include <numeric>
5+
6+
#include <sycl/ext/oneapi/experimental/user_defined_reductions.hpp>
7+
#include <sycl/sycl.hpp>
8+
9+
template <typename T> struct UserDefinedSum {
10+
T operator()(T a, T b) { return a + b; }
11+
};
12+
13+
using namespace sycl;
14+
15+
int main() {
16+
const int N = 1024;
17+
queue q;
18+
19+
buffer<int, 1> inputBuf(N);
20+
buffer<int> outputBuf{2};
21+
{
22+
// Initialize buffer on the host with 0, 1, 2, 3, ..., 1023
23+
host_accessor a{inputBuf};
24+
std::iota(a.begin(), a.end(), 0);
25+
}
26+
{
27+
q.submit([&](sycl::handler &h) {
28+
auto inputValues = sycl::accessor(inputBuf, h);
29+
accessor outputValues{outputBuf, h, write_only, no_init};
30+
31+
constexpr size_t group_size = 16;
32+
33+
// Create enough local memory for the algorithm
34+
constexpr size_t temp_memory_size = group_size * sizeof(int);
35+
auto scratch = sycl::local_accessor<std::byte, 1>(temp_memory_size, h);
36+
37+
h.parallel_for(
38+
sycl::nd_range<1>(range<1>(16), range<1>(16)), [=](nd_item<1> it) {
39+
// Create a handle that associates the group with an allocation it
40+
// can use
41+
auto handle =
42+
sycl::ext::oneapi::experimental::group_with_scratchpad(
43+
it.get_group(), sycl::span(&scratch[0], temp_memory_size));
44+
45+
int *first = inputValues.get_pointer();
46+
int *last = first + 1024;
47+
// Pass the handle as the first argument to the group algorithm
48+
int sum_joint_reduce =
49+
sycl::ext::oneapi::experimental::joint_reduce(
50+
handle, first, last, UserDefinedSum<int>{});
51+
outputValues[0] = sum_joint_reduce;
52+
53+
int sum_reduce_over_group =
54+
sycl::ext::oneapi::experimental::reduce_over_group(
55+
handle, inputValues[it.get_global_id(0)], 0,
56+
UserDefinedSum<int>{});
57+
outputValues[1] = sum_reduce_over_group;
58+
});
59+
});
60+
q.wait();
61+
}
62+
assert(outputBuf.get_host_access()[0] == 523776);
63+
assert(outputBuf.get_host_access()[1] == 120);
64+
65+
// TODO: add tests for reduce_over_group overloads ith native binary_op
66+
67+
return 0;
68+
}

0 commit comments

Comments
 (0)