Skip to content

Commit 86c8640

Browse files
vtavanaantonwolfy
andauthored
standardize the stride in matmul function (#1828)
* standardize the stride in matmul function * remove unused varibale --------- Co-authored-by: Anton <[email protected]>
1 parent 3345fdf commit 86c8640

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,45 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
166166
return gemm_batch_event;
167167
}
168168

169+
void standardize_strides_to_nonzero(std::vector<py::ssize_t> &strides,
170+
const py::ssize_t *shape)
171+
{
172+
// When shape of an array along any particular dimension is 1, the stride
173+
// along that dimension is undefined. This function standardize the strides
174+
// by calculating the non-zero value of the strides.
175+
std::size_t ndim = strides.size();
176+
bool has_zero_stride = std::accumulate(strides.begin(), strides.end(), 1,
177+
std::multiplies<py::ssize_t>{}) == 0;
178+
179+
if (has_zero_stride) {
180+
for (std::size_t i = 0; i < ndim - 1; ++i) {
181+
strides[i] = strides[i] == 0
182+
? std::accumulate(shape + i + 1, shape + ndim, 1,
183+
std::multiplies<py::ssize_t>{})
184+
: strides[i];
185+
}
186+
strides[ndim - 1] = strides[ndim - 1] == 0 ? 1 : strides[ndim - 1];
187+
}
188+
}
189+
190+
void standardize_strides_to_zero(std::vector<py::ssize_t> &strides,
191+
const py::ssize_t *shape)
192+
{
193+
// When shape of an array along any particular dimension is 1, the stride
194+
// along that dimension is undefined. This function standardize the strides
195+
// by defining such a stride as zero. This is because for these cases,
196+
// instead of copying the array into the additional dimension for batch
197+
// multiplication, we choose to use zero as the stride between different
198+
// matrices. Therefore, the same array is used repeatedly.
199+
std::size_t ndim = strides.size();
200+
201+
for (size_t i = 0; i < ndim; ++i) {
202+
if (shape[i] <= 1) {
203+
strides[i] = 0;
204+
}
205+
}
206+
}
207+
169208
std::tuple<sycl::event, sycl::event, bool>
170209
gemm_batch(sycl::queue &exec_q,
171210
dpctl::tensor::usm_ndarray matrixA,
@@ -240,10 +279,15 @@ std::tuple<sycl::event, sycl::event, bool>
240279
std::vector<py::ssize_t> a_stride = matrixA.get_strides_vector();
241280
std::vector<py::ssize_t> b_stride = matrixB.get_strides_vector();
242281
std::vector<py::ssize_t> c_stride = resultC.get_strides_vector();
282+
standardize_strides_to_zero(a_stride, a_shape);
283+
standardize_strides_to_zero(b_stride, b_shape);
284+
standardize_strides_to_zero(c_stride, c_shape);
243285
const std::int64_t stridea = a_stride[0];
244286
const std::int64_t strideb = b_stride[0];
245287
const std::int64_t stridec = c_stride[0];
246288

289+
standardize_strides_to_nonzero(a_stride, a_shape);
290+
standardize_strides_to_nonzero(b_stride, b_shape);
247291
bool A_base_is_f_contig = a_stride[1] == 1 && a_stride[2] == a_shape[1];
248292
bool B_base_is_f_contig = b_stride[1] == 1 && b_stride[2] == b_shape[1];
249293

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def _define_contig_flag(x):
297297
if x.ndim < 2:
298298
return True
299299

300+
x_strides = _standardize_strides_to_nonzero(x_strides, x_shape)
300301
x_is_c_contiguous = x_strides[-1] == 1 and x_strides[-2] == x_shape[-1]
301302
x_is_f_contiguous = x_strides[-2] == 1 and x_strides[-1] == x_shape[-2]
302303
if x_is_c_contiguous or x_is_f_contiguous:
@@ -1371,6 +1372,28 @@ def _shape_error(a, b, core_dim, err_msg):
13711372
)
13721373

13731374

1375+
def _standardize_strides_to_nonzero(strides, shape):
1376+
"""
1377+
Standardizing the strides.
1378+
When shape of an array along any particular dimension is 1, the stride
1379+
along that dimension is undefined. This function standardize the strides
1380+
by calculating the non-zero value of the strides.
1381+
"""
1382+
1383+
ndim = len(strides)
1384+
if numpy.prod(strides) == 0:
1385+
stndrd_strides = tuple(
1386+
numpy.prod(shape[i + 1 :]) if strides[i] == 0 else strides[i]
1387+
for i in range(ndim - 1)
1388+
)
1389+
last_stride = 1 if strides[ndim - 1] == 0 else strides[ndim - 1]
1390+
stndrd_strides += (last_stride,)
1391+
else:
1392+
stndrd_strides = strides
1393+
1394+
return stndrd_strides
1395+
1396+
13741397
def _transpose_ex(a, axeses):
13751398
"""
13761399
Copied from _transpose_ex in cupy/core/_einsum.py

0 commit comments

Comments
 (0)