File tree Expand file tree Collapse file tree 2 files changed +40
-1
lines changed Expand file tree Collapse file tree 2 files changed +40
-1
lines changed Original file line number Diff line number Diff line change 25
25
from dpctl .tensor ._ctors import asarray , empty
26
26
from dpctl .tensor ._device import Device
27
27
from dpctl .tensor ._dlpack import from_dlpack
28
- from dpctl .tensor ._manipulation_functions import expand_dims , permute_dims
28
+ from dpctl .tensor ._manipulation_functions import (
29
+ expand_dims ,
30
+ permute_dims ,
31
+ squeeze ,
32
+ )
29
33
from dpctl .tensor ._reshape import reshape
30
34
from dpctl .tensor ._usmarray import usm_ndarray
31
35
39
43
"reshape" ,
40
44
"permute_dims" ,
41
45
"expand_dims" ,
46
+ "squeeze" ,
42
47
"from_numpy" ,
43
48
"to_numpy" ,
44
49
"asnumpy" ,
Original file line number Diff line number Diff line change @@ -68,3 +68,37 @@ def expand_dims(X, axes):
68
68
shape = tuple (1 if ax in axes else next (shape_it ) for ax in range (out_ndim ))
69
69
70
70
return dpt .reshape (X , shape )
71
+
72
+
73
+ def squeeze (X , axes = None ):
74
+ """
75
+ squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
76
+
77
+ Removes singleton dimensions (axes) from X; returns a view, if possible,
78
+ a copy otherwise, but with all or a subset of the dimensions
79
+ of length 1 removed.
80
+ """
81
+ if not isinstance (X , dpt .usm_ndarray ):
82
+ raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
83
+ X_shape = X .shape
84
+ if axes is not None :
85
+ if not isinstance (axes , (tuple , list )):
86
+ axes = (axes ,)
87
+ axes = normalize_axis_tuple (axes , X .ndim if X .ndim != 0 else X .ndim + 1 )
88
+ new_shape = []
89
+ for i , x in enumerate (X_shape ):
90
+ if i not in axes :
91
+ new_shape .append (x )
92
+ else :
93
+ if x != 1 :
94
+ raise ValueError (
95
+ "Cannot select an axis to squeeze out "
96
+ "which has size not equal to one."
97
+ )
98
+ new_shape = tuple (new_shape )
99
+ else :
100
+ new_shape = tuple (axis for axis in X_shape if axis != 1 )
101
+ if new_shape == X .shape :
102
+ return X
103
+ else :
104
+ return dpt .reshape (X , new_shape )
You can’t perform that action at this time.
0 commit comments