Skip to content

Commit 202f460

Browse files
author
Diptorup Deb
authored
Merge pull request #1391 from IntelPython/device_func_unit_tests
Verify global_barrier, indexing, private array inside device_func
2 parents 97f9069 + d89104d commit 202f460

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,29 @@ def _kernel(nd_item: NdItem, a):
2929
dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (N,)), a)
3030

3131
assert a[0] == N * 2
32+
33+
34+
def test_group_barrier_device_func():
35+
"""A test for group_barrier function."""
36+
37+
@dpex_exp.device_func
38+
def _increment_value(nd_item: NdItem, a):
39+
i = nd_item.get_global_id(0)
40+
41+
a[i] += 1
42+
group_barrier(nd_item.get_group(), MemoryScope.DEVICE)
43+
44+
if i == 0:
45+
for idx in range(1, a.size):
46+
a[0] += a[idx]
47+
48+
@dpex_exp.kernel
49+
def _kernel(nd_item: NdItem, a):
50+
_increment_value(nd_item, a)
51+
52+
N = 16
53+
a = dpnp.ones(N, dtype=dpnp.int32)
54+
55+
dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (N,)), a)
56+
57+
assert a[0] == N * 2

numba_dpex/tests/experimental/test_private_array.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,29 @@ def test_private_array(call_kernel, decorator, kernel):
8282
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32)
8383

8484
assert np.array_equal(want, a.asnumpy())
85+
86+
87+
@pytest.mark.parametrize(
88+
"func",
89+
[
90+
private_array_kernel,
91+
private_array_kernel_fill_true,
92+
private_array_kernel_fill_false,
93+
private_2d_array_kernel,
94+
],
95+
)
96+
def test_private_array_in_device_func(func):
97+
98+
_df = dpex_exp.device_func(func)
99+
100+
@dpex_exp.kernel
101+
def _kernel(item: Item, a):
102+
_df(item, a)
103+
104+
a = dpnp.empty(10, dtype=dpnp.float32)
105+
dpex_exp.call_kernel(_kernel, Range(a.size), a)
106+
107+
# sum of squares from 1 to n: n*(n+1)*(2*n+1)/6
108+
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32)
109+
110+
assert np.array_equal(want, a.asnumpy())

0 commit comments

Comments
 (0)