Skip to content

Commit f0f23ac

Browse files
Add squeeze func
1 parent e30e7a4 commit f0f23ac

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from dpctl.tensor._ctors import asarray, empty
2626
from dpctl.tensor._device import Device
2727
from dpctl.tensor._dlpack import from_dlpack
28-
from dpctl.tensor._manipulation_functions import expand_dims, permute_dims
28+
from dpctl.tensor._manipulation_functions import (
29+
expand_dims,
30+
permute_dims,
31+
squeeze,
32+
)
2933
from dpctl.tensor._reshape import reshape
3034
from dpctl.tensor._usmarray import usm_ndarray
3135

@@ -39,6 +43,7 @@
3943
"reshape",
4044
"permute_dims",
4145
"expand_dims",
46+
"squeeze",
4247
"from_numpy",
4348
"to_numpy",
4449
"asnumpy",

dpctl/tensor/_manipulation_functions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,37 @@ def expand_dims(X, axes):
6868
shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim))
6969

7070
return dpt.reshape(X, shape)
71+
72+
73+
def squeeze(X, axes=None):
74+
"""
75+
squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
76+
77+
Removes singleton dimensions (axes) from X; returns a view, if possible,
78+
a copy otherwise, but with all or a subset of the dimensions
79+
of length 1 removed.
80+
"""
81+
if not isinstance(X, dpt.usm_ndarray):
82+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
83+
X_shape = X.shape
84+
if axes is not None:
85+
if not isinstance(axes, (tuple, list)):
86+
axes = (axes,)
87+
axes = normalize_axis_tuple(axes, X.ndim if X.ndim != 0 else X.ndim + 1)
88+
new_shape = []
89+
for i, x in enumerate(X_shape):
90+
if i not in axes:
91+
new_shape.append(x)
92+
else:
93+
if x != 1:
94+
raise ValueError(
95+
"Cannot select an axis to squeeze out "
96+
"which has size not equal to one."
97+
)
98+
new_shape = tuple(new_shape)
99+
else:
100+
new_shape = tuple(axis for axis in X_shape if axis != 1)
101+
if new_shape == X.shape:
102+
return X
103+
else:
104+
return dpt.reshape(X, new_shape)

0 commit comments

Comments
 (0)