Skip to content

Commit feda1b1

Browse files
Add squeeze func
1 parent 02bf7d2 commit feda1b1

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from dpctl.tensor._ctors import asarray, empty
2626
from dpctl.tensor._device import Device
2727
from dpctl.tensor._dlpack import from_dlpack
28-
from dpctl.tensor._manipulation_functions import expand_dims, permute_dims
28+
from dpctl.tensor._manipulation_functions import (
29+
expand_dims,
30+
permute_dims,
31+
squeeze,
32+
)
2933
from dpctl.tensor._reshape import reshape
3034
from dpctl.tensor._usmarray import usm_ndarray
3135

@@ -39,6 +43,7 @@
3943
"reshape",
4044
"permute_dims",
4145
"expand_dims",
46+
"squeeze",
4247
"from_numpy",
4348
"to_numpy",
4449
"asnumpy",

dpctl/tensor/_manipulation_functions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,36 @@ def expand_dims(X, axes):
8484
shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim))
8585

8686
return dpt.reshape(X, shape)
87+
88+
89+
def squeeze(X, axes=None):
90+
"""
91+
squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
92+
93+
Removes singleton dimensions (axes) from X; returns a view, if possible,
94+
a copy otherwise, but with all or a subset of the dimensions
95+
of length 1 removed.
96+
"""
97+
if not isinstance(X, dpt.usm_ndarray):
98+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
99+
X_shape = X.shape
100+
if axes is not None:
101+
if not isinstance(axes, (tuple, list)):
102+
axes = (axes,)
103+
axes = normalize_axis_tuple(axes, X.ndim if X.ndim != 0 else X.ndim + 1)
104+
new_shape = []
105+
for i, x in enumerate(X_shape):
106+
if i not in axes:
107+
new_shape.append(x)
108+
else:
109+
if x != 1:
110+
raise ValueError(
111+
"Cannot select an axis to squeeze out "
112+
"which has size not equal to one."
113+
)
114+
else:
115+
new_shape = [axis for axis in X_shape if axis != 1]
116+
if new_shape == X.shape:
117+
return X
118+
else:
119+
return dpt.reshape(X, tuple(new_shape))

0 commit comments

Comments
 (0)