Skip to content

Commit 8a24e8f

Browse files
authored
Merge branch 'master' into fix-vecmat-win-failure
2 parents b2c7a0b + 1a7ce22 commit 8a24e8f

File tree

3 files changed

+3
-51
lines changed

3 files changed

+3
-51
lines changed

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
5555
const std::int64_t,
5656
char *,
5757
const std::int64_t,
58-
#if !defined(USE_ONEMATH_CUBLAS)
5958
const bool,
60-
#endif // !USE_ONEMATH_CUBLAS
6159
const std::vector<sycl::event> &);
6260

6361
static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
@@ -76,9 +74,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
7674
const std::int64_t ldb,
7775
char *resultC,
7876
const std::int64_t ldc,
79-
#if !defined(USE_ONEMATH_CUBLAS)
8077
const bool is_row_major,
81-
#endif // !USE_ONEMATH_CUBLAS
8278
const std::vector<sycl::event> &depends)
8379
{
8480
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -100,11 +96,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
10096
const Tab *a, const std::int64_t lda, const Tab *b,
10197
const std::int64_t ldb, Tab beta, Tc *c, const std::int64_t ldc,
10298
const std::vector<sycl::event> &deps) -> sycl::event {
103-
#if defined(USE_ONEMATH_CUBLAS)
104-
return mkl_blas::column_major::gemm(q, transA, transB, m, n, k,
105-
alpha, a, lda, b, ldb, beta, c,
106-
ldc, deps);
107-
#else
10899
if (is_row_major) {
109100
return mkl_blas::row_major::gemm(q, transA, transB, m, n, k,
110101
alpha, a, lda, b, ldb, beta, c,
@@ -115,7 +106,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
115106
alpha, a, lda, b, ldb, beta,
116107
c, ldc, deps);
117108
}
118-
#endif // USE_ONEMATH_CUBLAS
119109
};
120110
gemm_event = gemm_func(
121111
exec_q,
@@ -242,7 +232,7 @@ std::tuple<sycl::event, sycl::event, bool>
242232

243233
// cuBLAS supports only column-major storage
244234
#if defined(USE_ONEMATH_CUBLAS)
245-
const bool is_row_major = false;
235+
constexpr bool is_row_major = false;
246236

247237
transA = is_matrixA_c_contig ? oneapi::mkl::transpose::T
248238
: oneapi::mkl::transpose::N;
@@ -320,15 +310,9 @@ std::tuple<sycl::event, sycl::event, bool>
320310
const char *b_typeless_ptr = matrixB.get_data();
321311
char *r_typeless_ptr = resultC.get_data();
322312

323-
#if defined(USE_ONEMATH_CUBLAS)
324-
sycl::event gemm_ev =
325-
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
326-
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
327-
#else
328313
sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k,
329314
a_typeless_ptr, lda, b_typeless_ptr, ldb,
330315
r_typeless_ptr, ldc, is_row_major, depends);
331-
#endif // USE_ONEMATH_CUBLAS
332316

333317
sycl::event args_ev = dpctl::utils::keep_args_alive(
334318
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
6060
const char *,
6161
const char *,
6262
char *,
63-
#if !defined(USE_ONEMATH_CUBLAS)
6463
const bool,
65-
#endif // !USE_ONEMATH_CUBLAS
6664
const std::vector<sycl::event> &);
6765

6866
static gemm_batch_impl_fn_ptr_t
@@ -85,9 +83,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
8583
const char *matrixA,
8684
const char *matrixB,
8785
char *resultC,
88-
#if !defined(USE_ONEMATH_CUBLAS)
8986
const bool is_row_major,
90-
#endif // !USE_ONEMATH_CUBLAS
9187
const std::vector<sycl::event> &depends)
9288
{
9389
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -112,11 +108,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
112108
Tc *c, const std::int64_t ldc, const std::int64_t stridec,
113109
const std::int64_t batch_size,
114110
const std::vector<sycl::event> &deps) -> sycl::event {
115-
#if defined(USE_ONEMATH_CUBLAS)
116-
return mkl_blas::column_major::gemm_batch(
117-
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
118-
strideb, beta, c, ldc, stridec, batch_size, deps);
119-
#else
120111
if (is_row_major) {
121112
return mkl_blas::row_major::gemm_batch(
122113
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
@@ -127,7 +118,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
127118
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
128119
strideb, beta, c, ldc, stridec, batch_size, deps);
129120
}
130-
#endif // USE_ONEMATH_CUBLAS
131121
};
132122
gemm_batch_event = gemm_batch_func(
133123
exec_q,
@@ -317,7 +307,7 @@ std::tuple<sycl::event, sycl::event, bool>
317307

318308
// cuBLAS supports only column-major storage
319309
#if defined(USE_ONEMATH_CUBLAS)
320-
const bool is_row_major = false;
310+
constexpr bool is_row_major = false;
321311

322312
transA = A_base_is_c_contig ? oneapi::mkl::transpose::T
323313
: oneapi::mkl::transpose::N;
@@ -396,17 +386,10 @@ std::tuple<sycl::event, sycl::event, bool>
396386
const char *b_typeless_ptr = matrixB.get_data();
397387
char *r_typeless_ptr = resultC.get_data();
398388

399-
#if defined(USE_ONEMATH_CUBLAS)
400-
sycl::event gemm_batch_ev =
401-
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
402-
strideb, stridec, transA, transB, a_typeless_ptr,
403-
b_typeless_ptr, r_typeless_ptr, depends);
404-
#else
405389
sycl::event gemm_batch_ev =
406390
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
407391
strideb, stridec, transA, transB, a_typeless_ptr,
408392
b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
409-
#endif // USE_ONEMATH_CUBLAS
410393

411394
sycl::event args_ev = dpctl::utils::keep_args_alive(
412395
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});

dpnp/backend/extensions/blas/gemv.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ typedef sycl::event (*gemv_impl_fn_ptr_t)(sycl::queue &,
5353
const std::int64_t,
5454
char *,
5555
const std::int64_t,
56-
#if !defined(USE_ONEMATH_CUBLAS)
5756
const bool,
58-
#endif // !USE_ONEMATH_CUBLAS
5957
const std::vector<sycl::event> &);
6058

6159
static gemv_impl_fn_ptr_t gemv_dispatch_vector[dpctl_td_ns::num_types];
@@ -71,9 +69,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
7169
const std::int64_t incx,
7270
char *vectorY,
7371
const std::int64_t incy,
74-
#if !defined(USE_ONEMATH_CUBLAS)
7572
const bool is_row_major,
76-
#endif // !USE_ONEMATH_CUBLAS
7773
const std::vector<sycl::event> &depends)
7874
{
7975
type_utils::validate_type_for_device<T>(exec_q);
@@ -93,10 +89,6 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
9389
const std::int64_t lda, const T *x, const std::int64_t incx,
9490
T beta, T *y, const std::int64_t incy,
9591
const std::vector<sycl::event> &deps) -> sycl::event {
96-
#if defined(USE_ONEMATH_CUBLAS)
97-
return mkl_blas::column_major::gemv(q, transA, m, n, alpha, a, lda,
98-
x, incx, beta, y, incy, deps);
99-
#else
10092
if (is_row_major) {
10193
return mkl_blas::row_major::gemv(q, transA, m, n, alpha, a, lda,
10294
x, incx, beta, y, incy, deps);
@@ -106,7 +98,6 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
10698
lda, x, incx, beta, y, incy,
10799
deps);
108100
}
109-
#endif // USE_ONEMATH_CUBLAS
110101
};
111102
gemv_event = gemv_func(
112103
exec_q,
@@ -196,7 +187,7 @@ std::pair<sycl::event, sycl::event>
196187

197188
// cuBLAS supports only column-major storage
198189
#if defined(USE_ONEMATH_CUBLAS)
199-
const bool is_row_major = false;
190+
constexpr bool is_row_major = false;
200191
std::int64_t m;
201192
std::int64_t n;
202193

@@ -304,15 +295,9 @@ std::pair<sycl::event, sycl::event>
304295
y_typeless_ptr -= (y_shape[0] - 1) * std::abs(incy) * y_elemsize;
305296
}
306297

307-
#if defined(USE_ONEMATH_CUBLAS)
308-
sycl::event gemv_ev =
309-
gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx,
310-
y_typeless_ptr, incy, depends);
311-
#else
312298
sycl::event gemv_ev =
313299
gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx,
314300
y_typeless_ptr, incy, is_row_major, depends);
315-
#endif // USE_ONEMATH_CUBLAS
316301

317302
sycl::event args_ev = dpctl::utils::keep_args_alive(
318303
exec_q, {matrixA, vectorX, vectorY}, {gemv_ev});

0 commit comments

Comments
 (0)