File tree Expand file tree Collapse file tree 2 files changed +39
-1
lines changed Expand file tree Collapse file tree 2 files changed +39
-1
lines changed Original file line number Diff line number Diff line change 25
25
from dpctl .tensor ._ctors import asarray , empty
26
26
from dpctl .tensor ._device import Device
27
27
from dpctl .tensor ._dlpack import from_dlpack
28
- from dpctl .tensor ._manipulation_functions import expand_dims , permute_dims
28
+ from dpctl .tensor ._manipulation_functions import (
29
+ expand_dims ,
30
+ permute_dims ,
31
+ squeeze ,
32
+ )
29
33
from dpctl .tensor ._reshape import reshape
30
34
from dpctl .tensor ._usmarray import usm_ndarray
31
35
39
43
"reshape" ,
40
44
"permute_dims" ,
41
45
"expand_dims" ,
46
+ "squeeze" ,
42
47
"from_numpy" ,
43
48
"to_numpy" ,
44
49
"asnumpy" ,
Original file line number Diff line number Diff line change @@ -84,3 +84,36 @@ def expand_dims(X, axes):
84
84
shape = tuple (1 if ax in axes else next (shape_it ) for ax in range (out_ndim ))
85
85
86
86
return dpt .reshape (X , shape )
87
+
88
+
89
+ def squeeze (X , axes = None ):
90
+ """
91
+ squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
92
+
93
+ Removes singleton dimensions (axes) from X; returns a view, if possible,
94
+ a copy otherwise, but with all or a subset of the dimensions
95
+ of length 1 removed.
96
+ """
97
+ if not isinstance (X , dpt .usm_ndarray ):
98
+ raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
99
+ X_shape = X .shape
100
+ if axes is not None :
101
+ if not isinstance (axes , (tuple , list )):
102
+ axes = (axes ,)
103
+ axes = normalize_axis_tuple (axes , X .ndim if X .ndim != 0 else X .ndim + 1 )
104
+ new_shape = []
105
+ for i , x in enumerate (X_shape ):
106
+ if i not in axes :
107
+ new_shape .append (x )
108
+ else :
109
+ if x != 1 :
110
+ raise ValueError (
111
+ "Cannot select an axis to squeeze out "
112
+ "which has size not equal to one."
113
+ )
114
+ else :
115
+ new_shape = [axis for axis in X_shape if axis != 1 ]
116
+ if new_shape == X .shape :
117
+ return X
118
+ else :
119
+ return dpt .reshape (X , tuple (new_shape ))
You can’t perform that action at this time.
0 commit comments