@@ -90,3 +90,79 @@ def test_permute_dims_2d_3d(shapes):
90
90
Y = dpt .permute_dims (X , (2 , 0 , 1 ))
91
91
Ynp = np .transpose (Xnp , (2 , 0 , 1 ))
92
92
assert_array_equal (Ynp , dpt .asnumpy (Y ))
93
+
94
+
95
+ def test_expand_dims_incorrect_type ():
96
+ X_list = list ([1 , 2 , 3 , 4 , 5 ])
97
+ X_tuple = tuple (X_list )
98
+ Xnp = np .array (X_list )
99
+
100
+ pytest .raises (TypeError , dpt .permute_dims , X_list , 1 )
101
+ pytest .raises (TypeError , dpt .permute_dims , X_tuple , 1 )
102
+ pytest .raises (TypeError , dpt .permute_dims , Xnp , 1 )
103
+
104
+
105
+ def test_expand_dims_0d ():
106
+ try :
107
+ q = dpctl .SyclQueue ()
108
+ except dpctl .SyclQueueCreationError :
109
+ pytest .skip ("Queue could not be created" )
110
+
111
+ Xnp = np .array (1 , dtype = "int64" )
112
+ X = dpt .asarray (Xnp , sycl_queue = q )
113
+ Y = dpt .expand_dims (X , 0 )
114
+ Ynp = np .expand_dims (Xnp , 0 )
115
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
116
+
117
+ Y = dpt .expand_dims (X , - 1 )
118
+ Ynp = np .expand_dims (Xnp , - 1 )
119
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
120
+
121
+ pytest .raises (np .AxisError , dpt .expand_dims , X , 1 )
122
+ pytest .raises (np .AxisError , dpt .expand_dims , X , - 2 )
123
+
124
+
125
+ @pytest .mark .parametrize ("shapes" , [(3 ,), (3 , 3 ), (3 , 3 , 3 )])
126
+ def test_expand_dims_1d_3d (shapes ):
127
+ try :
128
+ q = dpctl .SyclQueue ()
129
+ except dpctl .SyclQueueCreationError :
130
+ pytest .skip ("Queue could not be created" )
131
+
132
+ Xnp_size = np .prod (shapes )
133
+
134
+ Xnp = np .random .randint (0 , 2 , size = Xnp_size , dtype = "int64" ).reshape (shapes )
135
+ X = dpt .asarray (Xnp , sycl_queue = q )
136
+ shape_len = len (shapes )
137
+ for axis in range (- shape_len - 1 , shape_len ):
138
+ Y = dpt .expand_dims (X , axis )
139
+ Ynp = np .expand_dims (Xnp , axis )
140
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
141
+
142
+ pytest .raises (np .AxisError , dpt .expand_dims , X , shape_len + 1 )
143
+ pytest .raises (np .AxisError , dpt .expand_dims , X , - shape_len - 2 )
144
+
145
+
146
+ @pytest .mark .parametrize (
147
+ "axes" , [(0 , 1 , 2 ), (0 , - 1 , - 2 ), (0 , 3 , 5 ), (0 , - 3 , - 5 )]
148
+ )
149
+ def test_expand_dims_tuple (axes ):
150
+ try :
151
+ q = dpctl .SyclQueue ()
152
+ except dpctl .SyclQueueCreationError :
153
+ pytest .skip ("Queue could not be created" )
154
+
155
+ Xnp = np .empty ((3 , 3 , 3 ))
156
+ X = dpt .asarray (Xnp , sycl_queue = q )
157
+ Y = dpt .expand_dims (X , axes )
158
+ Ynp = np .expand_dims (Xnp , axes )
159
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
160
+
161
+
162
+ def test_expand_dims_incorrect_tuple ():
163
+
164
+ X = dpt .empty ((3 , 3 , 3 ), dtype = "i4" )
165
+ pytest .raises (np .AxisError , dpt .expand_dims , X , (0 , - 6 ))
166
+ pytest .raises (np .AxisError , dpt .expand_dims , X , (0 , 5 ))
167
+
168
+ pytest .raises (ValueError , dpt .expand_dims , X , (1 , 1 ))
0 commit comments