Skip to content

standardize the stride in matmul function #1828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,45 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
return gemm_batch_event;
}

void standardize_strides_to_nonzero(std::vector<py::ssize_t> &strides,
const py::ssize_t *shape)
{
// When shape of an array along any particular dimension is 1, the stride
// along that dimension is undefined. This function standardize the strides
// by calculating the non-zero value of the strides.
std::size_t ndim = strides.size();
bool has_zero_stride = std::accumulate(strides.begin(), strides.end(), 1,
std::multiplies<py::ssize_t>{}) == 0;

if (has_zero_stride) {
for (std::size_t i = 0; i < ndim - 1; ++i) {
strides[i] = strides[i] == 0
? std::accumulate(shape + i + 1, shape + ndim, 1,
std::multiplies<py::ssize_t>{})
: strides[i];
}
strides[ndim - 1] = strides[ndim - 1] == 0 ? 1 : strides[ndim - 1];
}
}

void standardize_strides_to_zero(std::vector<py::ssize_t> &strides,
const py::ssize_t *shape)
{
// When shape of an array along any particular dimension is 1, the stride
// along that dimension is undefined. This function standardize the strides
// by defining such a stride as zero. This is because for these cases,
// instead of copying the array into the additional dimension for batch
// multiplication, we choose to use zero as the stride between different
// matrices. Therefore, the same array is used repeatedly.
std::size_t ndim = strides.size();

for (size_t i = 0; i < ndim; ++i) {
if (shape[i] <= 1) {
strides[i] = 0;
}
}
}

std::tuple<sycl::event, sycl::event, bool>
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
Expand Down Expand Up @@ -240,10 +279,15 @@ std::tuple<sycl::event, sycl::event, bool>
std::vector<py::ssize_t> a_stride = matrixA.get_strides_vector();
std::vector<py::ssize_t> b_stride = matrixB.get_strides_vector();
std::vector<py::ssize_t> c_stride = resultC.get_strides_vector();
standardize_strides_to_zero(a_stride, a_shape);
standardize_strides_to_zero(b_stride, b_shape);
standardize_strides_to_zero(c_stride, c_shape);
const std::int64_t stridea = a_stride[0];
const std::int64_t strideb = b_stride[0];
const std::int64_t stridec = c_stride[0];

standardize_strides_to_nonzero(a_stride, a_shape);
standardize_strides_to_nonzero(b_stride, b_shape);
bool A_base_is_f_contig = a_stride[1] == 1 && a_stride[2] == a_shape[1];
bool B_base_is_f_contig = b_stride[1] == 1 && b_stride[2] == b_shape[1];

Expand Down
23 changes: 23 additions & 0 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _define_contig_flag(x):
if x.ndim < 2:
return True

x_strides = _standardize_strides_to_nonzero(x_strides, x_shape)
x_is_c_contiguous = x_strides[-1] == 1 and x_strides[-2] == x_shape[-1]
x_is_f_contiguous = x_strides[-2] == 1 and x_strides[-1] == x_shape[-2]
if x_is_c_contiguous or x_is_f_contiguous:
Expand Down Expand Up @@ -1371,6 +1372,28 @@ def _shape_error(a, b, core_dim, err_msg):
)


def _standardize_strides_to_nonzero(strides, shape):
"""
Standardizing the strides.
When shape of an array along any particular dimension is 1, the stride
along that dimension is undefined. This function standardize the strides
by calculating the non-zero value of the strides.
"""

ndim = len(strides)
if numpy.prod(strides) == 0:
stndrd_strides = tuple(
numpy.prod(shape[i + 1 :]) if strides[i] == 0 else strides[i]
for i in range(ndim - 1)
)
last_stride = 1 if strides[ndim - 1] == 0 else strides[ndim - 1]
stndrd_strides += (last_stride,)
else:
stndrd_strides = strides

return stndrd_strides


def _transpose_ex(a, axeses):
"""
Copied from _transpose_ex in cupy/core/_einsum.py
Expand Down