@@ -166,6 +166,45 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
166
166
return gemm_batch_event;
167
167
}
168
168
169
+ void standardize_strides_to_nonzero (std::vector<py::ssize_t > &strides,
170
+ const py::ssize_t *shape)
171
+ {
172
+ // When shape of an array along any particular dimension is 1, the stride
173
+ // along that dimension is undefined. This function standardize the strides
174
+ // by calculating the non-zero value of the strides.
175
+ std::size_t ndim = strides.size ();
176
+ bool has_zero_stride = std::accumulate (strides.begin (), strides.end (), 1 ,
177
+ std::multiplies<py::ssize_t >{}) == 0 ;
178
+
179
+ if (has_zero_stride) {
180
+ for (std::size_t i = 0 ; i < ndim - 1 ; ++i) {
181
+ strides[i] = strides[i] == 0
182
+ ? std::accumulate (shape + i + 1 , shape + ndim, 1 ,
183
+ std::multiplies<py::ssize_t >{})
184
+ : strides[i];
185
+ }
186
+ strides[ndim - 1 ] = strides[ndim - 1 ] == 0 ? 1 : strides[ndim - 1 ];
187
+ }
188
+ }
189
+
190
+ void standardize_strides_to_zero (std::vector<py::ssize_t > &strides,
191
+ const py::ssize_t *shape)
192
+ {
193
+ // When shape of an array along any particular dimension is 1, the stride
194
+ // along that dimension is undefined. This function standardize the strides
195
+ // by defining such a stride as zero. This is because for these cases,
196
+ // instead of copying the array into the additional dimension for batch
197
+ // multiplication, we choose to use zero as the stride between different
198
+ // matrices. Therefore, the same array is used repeatedly.
199
+ std::size_t ndim = strides.size ();
200
+
201
+ for (size_t i = 0 ; i < ndim; ++i) {
202
+ if (shape[i] <= 1 ) {
203
+ strides[i] = 0 ;
204
+ }
205
+ }
206
+ }
207
+
169
208
std::tuple<sycl::event, sycl::event, bool >
170
209
gemm_batch (sycl::queue &exec_q,
171
210
dpctl::tensor::usm_ndarray matrixA,
@@ -240,10 +279,15 @@ std::tuple<sycl::event, sycl::event, bool>
240
279
std::vector<py::ssize_t > a_stride = matrixA.get_strides_vector ();
241
280
std::vector<py::ssize_t > b_stride = matrixB.get_strides_vector ();
242
281
std::vector<py::ssize_t > c_stride = resultC.get_strides_vector ();
282
+ standardize_strides_to_zero (a_stride, a_shape);
283
+ standardize_strides_to_zero (b_stride, b_shape);
284
+ standardize_strides_to_zero (c_stride, c_shape);
243
285
const std::int64_t stridea = a_stride[0 ];
244
286
const std::int64_t strideb = b_stride[0 ];
245
287
const std::int64_t stridec = c_stride[0 ];
246
288
289
+ standardize_strides_to_nonzero (a_stride, a_shape);
290
+ standardize_strides_to_nonzero (b_stride, b_shape);
247
291
bool A_base_is_f_contig = a_stride[1 ] == 1 && a_stride[2 ] == a_shape[1 ];
248
292
bool B_base_is_f_contig = b_stride[1 ] == 1 && b_stride[2 ] == b_shape[1 ];
249
293
0 commit comments