diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 6a1247c4c3e6..0d8ad1a67432 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -166,6 +166,45 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, return gemm_batch_event; } +void standardize_strides_to_nonzero(std::vector &strides, + const py::ssize_t *shape) +{ + // When shape of an array along any particular dimension is 1, the stride + // along that dimension is undefined. This function standardize the strides + // by calculating the non-zero value of the strides. + std::size_t ndim = strides.size(); + bool has_zero_stride = std::accumulate(strides.begin(), strides.end(), 1, + std::multiplies{}) == 0; + + if (has_zero_stride) { + for (std::size_t i = 0; i < ndim - 1; ++i) { + strides[i] = strides[i] == 0 + ? std::accumulate(shape + i + 1, shape + ndim, 1, + std::multiplies{}) + : strides[i]; + } + strides[ndim - 1] = strides[ndim - 1] == 0 ? 1 : strides[ndim - 1]; + } +} + +void standardize_strides_to_zero(std::vector &strides, + const py::ssize_t *shape) +{ + // When shape of an array along any particular dimension is 1, the stride + // along that dimension is undefined. This function standardize the strides + // by defining such a stride as zero. This is because for these cases, + // instead of copying the array into the additional dimension for batch + // multiplication, we choose to use zero as the stride between different + // matrices. Therefore, the same array is used repeatedly. + std::size_t ndim = strides.size(); + + for (size_t i = 0; i < ndim; ++i) { + if (shape[i] <= 1) { + strides[i] = 0; + } + } +} + std::tuple gemm_batch(sycl::queue &exec_q, dpctl::tensor::usm_ndarray matrixA, @@ -240,10 +279,15 @@ std::tuple std::vector a_stride = matrixA.get_strides_vector(); std::vector b_stride = matrixB.get_strides_vector(); std::vector c_stride = resultC.get_strides_vector(); + standardize_strides_to_zero(a_stride, a_shape); + standardize_strides_to_zero(b_stride, b_shape); + standardize_strides_to_zero(c_stride, c_shape); const std::int64_t stridea = a_stride[0]; const std::int64_t strideb = b_stride[0]; const std::int64_t stridec = c_stride[0]; + standardize_strides_to_nonzero(a_stride, a_shape); + standardize_strides_to_nonzero(b_stride, b_shape); bool A_base_is_f_contig = a_stride[1] == 1 && a_stride[2] == a_shape[1]; bool B_base_is_f_contig = b_stride[1] == 1 && b_stride[2] == b_shape[1]; diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 8d4509fd0deb..5269114d609e 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -297,6 +297,7 @@ def _define_contig_flag(x): if x.ndim < 2: return True + x_strides = _standardize_strides_to_nonzero(x_strides, x_shape) x_is_c_contiguous = x_strides[-1] == 1 and x_strides[-2] == x_shape[-1] x_is_f_contiguous = x_strides[-2] == 1 and x_strides[-1] == x_shape[-2] if x_is_c_contiguous or x_is_f_contiguous: @@ -1371,6 +1372,28 @@ def _shape_error(a, b, core_dim, err_msg): ) +def _standardize_strides_to_nonzero(strides, shape): + """ + Standardizing the strides. + When shape of an array along any particular dimension is 1, the stride + along that dimension is undefined. This function standardize the strides + by calculating the non-zero value of the strides. + """ + + ndim = len(strides) + if numpy.prod(strides) == 0: + stndrd_strides = tuple( + numpy.prod(shape[i + 1 :]) if strides[i] == 0 else strides[i] + for i in range(ndim - 1) + ) + last_stride = 1 if strides[ndim - 1] == 0 else strides[ndim - 1] + stndrd_strides += (last_stride,) + else: + stndrd_strides = strides + + return stndrd_strides + + def _transpose_ex(a, axeses): """ Copied from _transpose_ex in cupy/core/_einsum.py