Skip to content

Commit 30e98fe

Browse files
Add tests for permute_dims func
1 parent 19f4581 commit 30e98fe

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2021 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
import numpy as np
19+
import pytest
20+
from numpy.testing import assert_array_equal
21+
22+
import dpctl
23+
import dpctl.tensor as dpt
24+
25+
26+
def test_permute_dims_incorrect_type():
27+
X_list = list([[1, 2, 3], [4, 5, 6]])
28+
X_tuple = tuple(X_list)
29+
Xnp = np.array(X_list)
30+
31+
pytest.raises(TypeError, dpt.permute_dims, X_list, (1, 0))
32+
pytest.raises(TypeError, dpt.permute_dims, X_tuple, (1, 0))
33+
pytest.raises(TypeError, dpt.permute_dims, Xnp, (1, 0))
34+
35+
36+
def test_permute_dims_empty_array():
37+
try:
38+
q = dpctl.SyclQueue()
39+
except dpctl.SyclQueueCreationError:
40+
pytest.skip("Queue could not be created")
41+
42+
Xnp = np.empty((10, 0))
43+
X = dpt.asarray(Xnp, sycl_queue=q)
44+
Y = dpt.permute_dims(X, (1, 0))
45+
Ynp = np.transpose(Xnp, (1, 0))
46+
assert_array_equal(Ynp, dpt.asnumpy(Y))
47+
48+
49+
def test_permute_dims_0d_1d():
50+
try:
51+
q = dpctl.SyclQueue()
52+
except dpctl.SyclQueueCreationError:
53+
pytest.skip("Queue could not be created")
54+
55+
Xnp_0d = np.array(1, dtype="int64")
56+
X_0d = dpt.asarray(Xnp_0d, sycl_queue=q)
57+
Y_0d = dpt.permute_dims(X_0d, ())
58+
assert_array_equal(dpt.asnumpy(Y_0d), dpt.asnumpy(X_0d))
59+
60+
Xnp_1d = np.random.randint(0, 2, size=6, dtype="int64")
61+
X_1d = dpt.asarray(Xnp_1d, sycl_queue=q)
62+
Y_1d = dpt.permute_dims(X_1d, (0))
63+
assert_array_equal(dpt.asnumpy(Y_1d), dpt.asnumpy(X_1d))
64+
65+
pytest.raises(ValueError, dpt.permute_dims, X_1d, ())
66+
pytest.raises(IndexError, dpt.permute_dims, X_1d, (1))
67+
pytest.raises(ValueError, dpt.permute_dims, X_1d, (1, 0))
68+
pytest.raises(
69+
ValueError, dpt.permute_dims, dpt.reshape(X_1d, (2, 3)), (1, 1)
70+
)
71+
72+
73+
@pytest.mark.parametrize("shapes", [(2, 2), (1, 4), (3, 3, 3), (4, 1, 3)])
74+
def test_permute_dims_2d_3d(shapes):
75+
try:
76+
q = dpctl.SyclQueue()
77+
except dpctl.SyclQueueCreationError:
78+
pytest.skip("Queue could not be created")
79+
80+
Xnp_size = np.prod(shapes)
81+
82+
Xnp = np.random.randint(0, 2, size=Xnp_size, dtype="int64").reshape(shapes)
83+
X = dpt.asarray(Xnp, sycl_queue=q)
84+
X_ndim = X.ndim
85+
if X_ndim == 2:
86+
Y = dpt.permute_dims(X, (1, 0))
87+
Ynp = np.transpose(Xnp, (1, 0))
88+
elif X_ndim == 3:
89+
X = dpt.asarray(Xnp, sycl_queue=q)
90+
Y = dpt.permute_dims(X, (2, 0, 1))
91+
Ynp = np.transpose(Xnp, (2, 0, 1))
92+
assert_array_equal(Ynp, dpt.asnumpy(Y))

0 commit comments

Comments
 (0)