Skip to content

Commit 3b001e8

Browse files
Add test based on application of put_along_axis
Use put_along_axis to form 24 permutation matrices representing elements of S4 (group of permutations of 4 elements). Verify that every element raised to order 12 gives identity.
1 parent 7adcf67 commit 3b001e8

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,51 @@ def test_put_along_axis_validation():
16591659
dpt.put_along_axis(x, ind2, vals)
16601660

16611661

1662+
def test_put_along_axis_application():
1663+
get_queue_or_skip()
1664+
info_ = dpt.__array_namespace_info__()
1665+
def_dtypes = info_.default_dtypes(device=None)
1666+
ind_dt = def_dtypes["indexing"]
1667+
all_perms = dpt.asarray(
1668+
[
1669+
[0, 1, 2, 3],
1670+
[0, 2, 1, 3],
1671+
[2, 0, 1, 3],
1672+
[2, 1, 0, 3],
1673+
[1, 0, 2, 3],
1674+
[1, 2, 0, 3],
1675+
[0, 1, 3, 2],
1676+
[0, 2, 3, 1],
1677+
[2, 0, 3, 1],
1678+
[2, 1, 3, 0],
1679+
[1, 0, 3, 2],
1680+
[1, 2, 3, 0],
1681+
[0, 3, 1, 2],
1682+
[0, 3, 2, 1],
1683+
[2, 3, 0, 1],
1684+
[2, 3, 1, 0],
1685+
[1, 3, 0, 2],
1686+
[1, 3, 2, 0],
1687+
[3, 0, 1, 2],
1688+
[3, 0, 2, 1],
1689+
[3, 2, 0, 1],
1690+
[3, 2, 1, 0],
1691+
[3, 1, 0, 2],
1692+
[3, 1, 2, 0],
1693+
],
1694+
dtype=ind_dt,
1695+
)
1696+
p_mats = dpt.zeros((24, 4, 4), dtype=dpt.int64)
1697+
vals = dpt.ones((24, 4, 1), dtype=p_mats.dtype)
1698+
# form 24 permutation matrices
1699+
dpt.put_along_axis(p_mats, all_perms[..., dpt.newaxis], vals, axis=2)
1700+
p2 = p_mats @ p_mats
1701+
p4 = p2 @ p2
1702+
p8 = p4 @ p4
1703+
expected = dpt.eye(4, dtype=p_mats.dtype)[dpt.newaxis, ...]
1704+
assert dpt.all(p8 @ p4 == expected)
1705+
1706+
16621707
def check__extract_impl_validation(fn):
16631708
x = dpt.ones(10)
16641709
ind = dpt.ones(10, dtype="?")

0 commit comments

Comments
 (0)