Skip to content

Commit e76afa8

Browse files
committed
Merge branch 'master' into fix-irfftn
2 parents cbffff6 + afd5c6d commit e76afa8

File tree

21 files changed

+747
-278
lines changed

21 files changed

+747
-278
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
* Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478)
1212
* Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500)
1313
* Added implementation of `dpnp.ndarray.view` method [#2520](https://github.com/IntelPython/dpnp/pull/2520)
14+
* Added a new backend routine `syrk` from oneMKL to perform symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix [2509](https://github.com/IntelPython/dpnp/pull/2509)
1415

1516
### Changed
1617

dpnp/backend/extensions/blas/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set(_module_src
3030
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/gemv.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/syrk.cpp
3334
)
3435

3536
pybind11_add_module(${python_module_name} MODULE ${_module_src})
@@ -61,6 +62,7 @@ set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDEN
6162

6263
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
6364
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
65+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
6466

6567
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS})
6668
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "dotu.hpp"
3737
#include "gemm.hpp"
3838
#include "gemv.hpp"
39+
#include "syrk.hpp"
3940

4041
namespace blas_ns = dpnp::extensions::blas;
4142
namespace py = pybind11;
@@ -48,6 +49,7 @@ void init_dispatch_vectors_tables(void)
4849
blas_ns::init_gemm_batch_dispatch_table();
4950
blas_ns::init_gemm_dispatch_table();
5051
blas_ns::init_gemv_dispatch_vector();
52+
blas_ns::init_syrk_dispatch_vector();
5153
}
5254

5355
static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
@@ -73,7 +75,7 @@ PYBIND11_MODULE(_blas_impl, m)
7375
};
7476

7577
m.def("_dot", dot_pyapi,
76-
"Call `dot` from OneMKL BLAS library to compute "
78+
"Call `dot` from oneMKL BLAS library to compute "
7779
"the dot product of two real-valued vectors.",
7880
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
7981
py::arg("result"), py::arg("depends") = py::list());
@@ -91,7 +93,7 @@ PYBIND11_MODULE(_blas_impl, m)
9193
};
9294

9395
m.def("_dotc", dotc_pyapi,
94-
"Call `dotc` from OneMKL BLAS library to compute "
96+
"Call `dotc` from oneMKL BLAS library to compute "
9597
"the dot product of two complex vectors, "
9698
"conjugating the first vector.",
9799
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
@@ -110,37 +112,45 @@ PYBIND11_MODULE(_blas_impl, m)
110112
};
111113

112114
m.def("_dotu", dotu_pyapi,
113-
"Call `dotu` from OneMKL BLAS library to compute "
115+
"Call `dotu` from oneMKL BLAS library to compute "
114116
"the dot product of two complex vectors.",
115117
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
116118
py::arg("result"), py::arg("depends") = py::list());
117119
}
118120

119121
{
120122
m.def("_gemm", &blas_ns::gemm,
121-
"Call `gemm` from OneMKL BLAS library to compute "
123+
"Call `gemm` from oneMKL BLAS library to compute "
122124
"the matrix-matrix product with 2-D matrices.",
123125
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
124126
py::arg("resultC"), py::arg("depends") = py::list());
125127
}
126128

127129
{
128130
m.def("_gemm_batch", &blas_ns::gemm_batch,
129-
"Call `gemm_batch` from OneMKL BLAS library to compute "
131+
"Call `gemm_batch` from oneMKL BLAS library to compute "
130132
"the matrix-matrix product for a batch of 2-D matrices.",
131133
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
132134
py::arg("resultC"), py::arg("depends") = py::list());
133135
}
134136

135137
{
136138
m.def("_gemv", &blas_ns::gemv,
137-
"Call `gemv` from OneMKL BLAS library to compute "
139+
"Call `gemv` from oneMKL BLAS library to compute "
138140
"the matrix-vector product with a general matrix.",
139141
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
140142
py::arg("vectorY"), py::arg("transpose"),
141143
py::arg("depends") = py::list());
142144
}
143145

146+
{
147+
m.def("_syrk", &blas_ns::syrk,
148+
"Call `syrk` from oneMKL BLAS library to compute "
149+
"the matrix-vector product with a general matrix.",
150+
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("resultC"),
151+
py::arg("depends") = py::list());
152+
}
153+
144154
{
145155
m.def(
146156
"_using_onemath",

dpnp/backend/extensions/blas/dot_common.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ std::pair<sycl::event, sycl::event>
128128
dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id];
129129
if (dot_fn == nullptr) {
130130
throw py::value_error(
131-
"Types of input vectors and result array are mismatched.");
131+
"No dot implementation is available for the specified data type "
132+
"of the input and output arrays.");
132133
}
133134

134135
char *x_typeless_ptr = vectorX.get_data();

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
5555
const std::int64_t,
5656
char *,
5757
const std::int64_t,
58-
#if !defined(USE_ONEMATH_CUBLAS)
5958
const bool,
60-
#endif // !USE_ONEMATH_CUBLAS
6159
const std::vector<sycl::event> &);
6260

6361
static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
@@ -76,9 +74,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
7674
const std::int64_t ldb,
7775
char *resultC,
7876
const std::int64_t ldc,
79-
#if !defined(USE_ONEMATH_CUBLAS)
8077
const bool is_row_major,
81-
#endif // !USE_ONEMATH_CUBLAS
8278
const std::vector<sycl::event> &depends)
8379
{
8480
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -100,11 +96,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
10096
const Tab *a, const std::int64_t lda, const Tab *b,
10197
const std::int64_t ldb, Tab beta, Tc *c, const std::int64_t ldc,
10298
const std::vector<sycl::event> &deps) -> sycl::event {
103-
#if defined(USE_ONEMATH_CUBLAS)
104-
return mkl_blas::column_major::gemm(q, transA, transB, m, n, k,
105-
alpha, a, lda, b, ldb, beta, c,
106-
ldc, deps);
107-
#else
10899
if (is_row_major) {
109100
return mkl_blas::row_major::gemm(q, transA, transB, m, n, k,
110101
alpha, a, lda, b, ldb, beta, c,
@@ -115,7 +106,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
115106
alpha, a, lda, b, ldb, beta,
116107
c, ldc, deps);
117108
}
118-
#endif // USE_ONEMATH_CUBLAS
119109
};
120110
gemm_event = gemm_func(
121111
exec_q,
@@ -129,8 +119,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
129119
Tab(1), // Scaling factor for the product of matrices A and B.
130120
a, // Pointer to matrix A.
131121
lda, // Leading dimension of matrix A, which is the
132-
// stride between successive rows (for row major
133-
// layout).
122+
// stride between successive rows (for row major layout).
134123
b, // Pointer to matrix B.
135124
ldb, // Leading dimension of matrix B, similar to lda.
136125
Tab(0), // Scaling factor for matrix C.
@@ -168,7 +157,8 @@ std::tuple<sycl::event, sycl::event, bool>
168157
const int resultC_nd = resultC.get_ndim();
169158

170159
if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) {
171-
throw py::value_error("Input matrices must be two-dimensional.");
160+
throw py::value_error(
161+
"Input and output matrices must be two-dimensional.");
172162
}
173163

174164
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
@@ -242,7 +232,7 @@ std::tuple<sycl::event, sycl::event, bool>
242232

243233
// cuBLAS supports only column-major storage
244234
#if defined(USE_ONEMATH_CUBLAS)
245-
const bool is_row_major = false;
235+
constexpr bool is_row_major = false;
246236

247237
transA = is_matrixA_c_contig ? oneapi::mkl::transpose::T
248238
: oneapi::mkl::transpose::N;
@@ -286,6 +276,8 @@ std::tuple<sycl::event, sycl::event, bool>
286276
}
287277
}
288278
else {
279+
// both A and B are f_contig so using column-major gemm and
280+
// no transpose is needed
289281
transA = oneapi::mkl::transpose::N;
290282
transB = oneapi::mkl::transpose::N;
291283
lda = m;
@@ -313,22 +305,17 @@ std::tuple<sycl::event, sycl::event, bool>
313305
gemm_dispatch_table[matrixAB_type_id][resultC_type_id];
314306
if (gemm_fn == nullptr) {
315307
throw py::value_error(
316-
"Types of input matrices and result matrix are mismatched.");
308+
"No gemm implementation is available for the specified data type "
309+
"of the input and output arrays.");
317310
}
318311

319312
const char *a_typeless_ptr = matrixA.get_data();
320313
const char *b_typeless_ptr = matrixB.get_data();
321314
char *r_typeless_ptr = resultC.get_data();
322315

323-
#if defined(USE_ONEMATH_CUBLAS)
324-
sycl::event gemm_ev =
325-
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
326-
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
327-
#else
328316
sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k,
329317
a_typeless_ptr, lda, b_typeless_ptr, ldb,
330318
r_typeless_ptr, ldc, is_row_major, depends);
331-
#endif // USE_ONEMATH_CUBLAS
332319

333320
sycl::event args_ev = dpctl::utils::keep_args_alive(
334321
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
6060
const char *,
6161
const char *,
6262
char *,
63-
#if !defined(USE_ONEMATH_CUBLAS)
6463
const bool,
65-
#endif // !USE_ONEMATH_CUBLAS
6664
const std::vector<sycl::event> &);
6765

6866
static gemm_batch_impl_fn_ptr_t
@@ -85,9 +83,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
8583
const char *matrixA,
8684
const char *matrixB,
8785
char *resultC,
88-
#if !defined(USE_ONEMATH_CUBLAS)
8986
const bool is_row_major,
90-
#endif // !USE_ONEMATH_CUBLAS
9187
const std::vector<sycl::event> &depends)
9288
{
9389
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -112,11 +108,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
112108
Tc *c, const std::int64_t ldc, const std::int64_t stridec,
113109
const std::int64_t batch_size,
114110
const std::vector<sycl::event> &deps) -> sycl::event {
115-
#if defined(USE_ONEMATH_CUBLAS)
116-
return mkl_blas::column_major::gemm_batch(
117-
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
118-
strideb, beta, c, ldc, stridec, batch_size, deps);
119-
#else
120111
if (is_row_major) {
121112
return mkl_blas::row_major::gemm_batch(
122113
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
@@ -127,7 +118,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
127118
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
128119
strideb, beta, c, ldc, stridec, batch_size, deps);
129120
}
130-
#endif // USE_ONEMATH_CUBLAS
131121
};
132122
gemm_batch_event = gemm_batch_func(
133123
exec_q,
@@ -317,7 +307,7 @@ std::tuple<sycl::event, sycl::event, bool>
317307

318308
// cuBLAS supports only column-major storage
319309
#if defined(USE_ONEMATH_CUBLAS)
320-
const bool is_row_major = false;
310+
constexpr bool is_row_major = false;
321311

322312
transA = A_base_is_c_contig ? oneapi::mkl::transpose::T
323313
: oneapi::mkl::transpose::N;
@@ -389,24 +379,18 @@ std::tuple<sycl::event, sycl::event, bool>
389379
gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id];
390380
if (gemm_batch_fn == nullptr) {
391381
throw py::value_error(
392-
"Types of input matrices and result matrix are mismatched.");
382+
"No gemm_batch implementation is available for the specified data "
383+
"type of the input and output arrays.");
393384
}
394385

395386
const char *a_typeless_ptr = matrixA.get_data();
396387
const char *b_typeless_ptr = matrixB.get_data();
397388
char *r_typeless_ptr = resultC.get_data();
398389

399-
#if defined(USE_ONEMATH_CUBLAS)
400-
sycl::event gemm_batch_ev =
401-
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
402-
strideb, stridec, transA, transB, a_typeless_ptr,
403-
b_typeless_ptr, r_typeless_ptr, depends);
404-
#else
405390
sycl::event gemm_batch_ev =
406391
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
407392
strideb, stridec, transA, transB, a_typeless_ptr,
408393
b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
409-
#endif // USE_ONEMATH_CUBLAS
410394

411395
sycl::event args_ev = dpctl::utils::keep_args_alive(
412396
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});

0 commit comments

Comments
 (0)