Skip to content

Commit a204eb2

Browse files
Merge pull request #787 from IntelPython/permute_dims
Implementation of permute_dims function
2 parents 3007b26 + 26107bf commit a204eb2

File tree

5 files changed

+161
-1
lines changed

5 files changed

+161
-1
lines changed

.github/workflows/generate-docs.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ jobs:
112112
git config --global user.email 'github-actions[doc-deploy-bot]@users.noreply.github.com'
113113
git commit -m "Docs for pull request ${PR_NUM}"
114114
git push tokened_docs gh-pages
115-
- name: Unpublished pull-request docs
115+
- name: Unpublish pull-request docs
116116
if: ${{ github.event.pull_request && github.event.action == 'closed' }}
117117
env:
118118
PR_NUM: ${{ github.event.number }}
@@ -122,6 +122,8 @@ jobs:
122122
git fetch tokened_docs
123123
git checkout --track tokened_docs/gh-pages
124124
echo `pwd`
125+
ls
126+
[ -d pulls ] && ls pulls && echo "This is pull/${PR_NUM}"
125127
[ -d pulls/${PR_NUM} ] && git rm -rf pulls/${PR_NUM}
126128
git config --global user.name 'github-actions[doc-deploy-bot]'
127129
git config --global user.email 'github-actions[doc-deploy-bot]@users.noreply.github.com'

docs/generate_rst.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"dpctl.tensor._copy_utils": "Array Construction",
4646
"dpctl.tensor._dlpack": "Array Construction",
4747
"dpctl.tensor._reshape": "Array Manipulation",
48+
"dpctl.tensor._manipulation_functions": "Array Manipulation",
4849
"dpctl.memory._memory": "Functions",
4950
"dpctl.program._program": "Functions",
5051
"dpctl.utils._compute_follows_data": "Functions",

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from dpctl.tensor._ctors import asarray, empty
2626
from dpctl.tensor._device import Device
2727
from dpctl.tensor._dlpack import from_dlpack
28+
from dpctl.tensor._manipulation_functions import permute_dims
2829
from dpctl.tensor._reshape import reshape
2930
from dpctl.tensor._usmarray import usm_ndarray
3031

@@ -36,6 +37,7 @@
3637
"copy",
3738
"empty",
3839
"reshape",
40+
"permute_dims",
3941
"from_numpy",
4042
"to_numpy",
4143
"asnumpy",
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2021 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
import numpy as np
19+
20+
import dpctl.tensor as dpt
21+
22+
23+
def _check_value_of_axes(axes):
24+
axes_len = len(axes)
25+
check_array = np.zeros(axes_len)
26+
for i in axes:
27+
ii = i.__index__()
28+
if ii < 0 or ii > axes_len or check_array[ii] != 0:
29+
return False
30+
check_array[ii] = 1
31+
return True
32+
33+
34+
def permute_dims(X, axes):
35+
"""
36+
permute_dims(X: usm_ndarray, axes: tuple or list) -> usm_ndarray
37+
38+
Permute the axes (dimensions) of an array; returns the permuted
39+
array as a view.
40+
"""
41+
if not isinstance(X, dpt.usm_ndarray):
42+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
43+
if not isinstance(axes, (tuple, list)):
44+
axes = (axes,)
45+
if not X.ndim == len(axes):
46+
raise ValueError(
47+
"The length of the passed axes does not match "
48+
"to the number of usm_ndarray dimensions."
49+
)
50+
if not _check_value_of_axes(axes):
51+
raise ValueError(
52+
"The values of the axes must be in the range "
53+
f"from 0 to {X.ndim} and have no duplicates."
54+
)
55+
newstrides = tuple(X.strides[i] for i in axes)
56+
newshape = tuple(X.shape[i] for i in axes)
57+
return dpt.usm_ndarray(
58+
shape=newshape,
59+
dtype=X.dtype,
60+
buffer=X,
61+
strides=newstrides,
62+
offset=X.__sycl_usm_array_interface__.get("offset", 0),
63+
)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2021 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
import numpy as np
19+
import pytest
20+
from numpy.testing import assert_array_equal
21+
22+
import dpctl
23+
import dpctl.tensor as dpt
24+
25+
26+
def test_permute_dims_incorrect_type():
27+
X_list = list([[1, 2, 3], [4, 5, 6]])
28+
X_tuple = tuple(X_list)
29+
Xnp = np.array(X_list)
30+
31+
pytest.raises(TypeError, dpt.permute_dims, X_list, (1, 0))
32+
pytest.raises(TypeError, dpt.permute_dims, X_tuple, (1, 0))
33+
pytest.raises(TypeError, dpt.permute_dims, Xnp, (1, 0))
34+
35+
36+
def test_permute_dims_empty_array():
37+
try:
38+
q = dpctl.SyclQueue()
39+
except dpctl.SyclQueueCreationError:
40+
pytest.skip("Queue could not be created")
41+
42+
Xnp = np.empty((10, 0))
43+
X = dpt.asarray(Xnp, sycl_queue=q)
44+
Y = dpt.permute_dims(X, (1, 0))
45+
Ynp = np.transpose(Xnp, (1, 0))
46+
assert_array_equal(Ynp, dpt.asnumpy(Y))
47+
48+
49+
def test_permute_dims_0d_1d():
50+
try:
51+
q = dpctl.SyclQueue()
52+
except dpctl.SyclQueueCreationError:
53+
pytest.skip("Queue could not be created")
54+
55+
Xnp_0d = np.array(1, dtype="int64")
56+
X_0d = dpt.asarray(Xnp_0d, sycl_queue=q)
57+
Y_0d = dpt.permute_dims(X_0d, ())
58+
assert_array_equal(dpt.asnumpy(Y_0d), dpt.asnumpy(X_0d))
59+
60+
Xnp_1d = np.random.randint(0, 2, size=6, dtype="int64")
61+
X_1d = dpt.asarray(Xnp_1d, sycl_queue=q)
62+
Y_1d = dpt.permute_dims(X_1d, (0))
63+
assert_array_equal(dpt.asnumpy(Y_1d), dpt.asnumpy(X_1d))
64+
65+
pytest.raises(ValueError, dpt.permute_dims, X_1d, ())
66+
pytest.raises(IndexError, dpt.permute_dims, X_1d, (1))
67+
pytest.raises(ValueError, dpt.permute_dims, X_1d, (1, 0))
68+
pytest.raises(
69+
ValueError, dpt.permute_dims, dpt.reshape(X_1d, (2, 3)), (1, 1)
70+
)
71+
72+
73+
@pytest.mark.parametrize("shapes", [(2, 2), (1, 4), (3, 3, 3), (4, 1, 3)])
74+
def test_permute_dims_2d_3d(shapes):
75+
try:
76+
q = dpctl.SyclQueue()
77+
except dpctl.SyclQueueCreationError:
78+
pytest.skip("Queue could not be created")
79+
80+
Xnp_size = np.prod(shapes)
81+
82+
Xnp = np.random.randint(0, 2, size=Xnp_size, dtype="int64").reshape(shapes)
83+
X = dpt.asarray(Xnp, sycl_queue=q)
84+
X_ndim = X.ndim
85+
if X_ndim == 2:
86+
Y = dpt.permute_dims(X, (1, 0))
87+
Ynp = np.transpose(Xnp, (1, 0))
88+
elif X_ndim == 3:
89+
X = dpt.asarray(Xnp, sycl_queue=q)
90+
Y = dpt.permute_dims(X, (2, 0, 1))
91+
Ynp = np.transpose(Xnp, (2, 0, 1))
92+
assert_array_equal(Ynp, dpt.asnumpy(Y))

0 commit comments

Comments
 (0)