Skip to content

Commit e8ae77a

Browse files
authored
Merge branch 'master' into syrk
2 parents 719d200 + 1a7ce22 commit e8ae77a

File tree

6 files changed

+40
-110
lines changed

6 files changed

+40
-110
lines changed

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
5555
const std::int64_t,
5656
char *,
5757
const std::int64_t,
58-
#if !defined(USE_ONEMATH_CUBLAS)
5958
const bool,
60-
#endif // !USE_ONEMATH_CUBLAS
6159
const std::vector<sycl::event> &);
6260

6361
static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
@@ -76,9 +74,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
7674
const std::int64_t ldb,
7775
char *resultC,
7876
const std::int64_t ldc,
79-
#if !defined(USE_ONEMATH_CUBLAS)
8077
const bool is_row_major,
81-
#endif // !USE_ONEMATH_CUBLAS
8278
const std::vector<sycl::event> &depends)
8379
{
8480
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -100,11 +96,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
10096
const Tab *a, const std::int64_t lda, const Tab *b,
10197
const std::int64_t ldb, Tab beta, Tc *c, const std::int64_t ldc,
10298
const std::vector<sycl::event> &deps) -> sycl::event {
103-
#if defined(USE_ONEMATH_CUBLAS)
104-
return mkl_blas::column_major::gemm(q, transA, transB, m, n, k,
105-
alpha, a, lda, b, ldb, beta, c,
106-
ldc, deps);
107-
#else
10899
if (is_row_major) {
109100
return mkl_blas::row_major::gemm(q, transA, transB, m, n, k,
110101
alpha, a, lda, b, ldb, beta, c,
@@ -115,7 +106,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
115106
alpha, a, lda, b, ldb, beta,
116107
c, ldc, deps);
117108
}
118-
#endif // USE_ONEMATH_CUBLAS
119109
};
120110
gemm_event = gemm_func(
121111
exec_q,
@@ -242,7 +232,7 @@ std::tuple<sycl::event, sycl::event, bool>
242232

243233
// cuBLAS supports only column-major storage
244234
#if defined(USE_ONEMATH_CUBLAS)
245-
const bool is_row_major = false;
235+
constexpr bool is_row_major = false;
246236

247237
transA = is_matrixA_c_contig ? oneapi::mkl::transpose::T
248238
: oneapi::mkl::transpose::N;
@@ -323,15 +313,9 @@ std::tuple<sycl::event, sycl::event, bool>
323313
const char *b_typeless_ptr = matrixB.get_data();
324314
char *r_typeless_ptr = resultC.get_data();
325315

326-
#if defined(USE_ONEMATH_CUBLAS)
327-
sycl::event gemm_ev =
328-
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
329-
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
330-
#else
331316
sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k,
332317
a_typeless_ptr, lda, b_typeless_ptr, ldb,
333318
r_typeless_ptr, ldc, is_row_major, depends);
334-
#endif // USE_ONEMATH_CUBLAS
335319

336320
sycl::event args_ev = dpctl::utils::keep_args_alive(
337321
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
6060
const char *,
6161
const char *,
6262
char *,
63-
#if !defined(USE_ONEMATH_CUBLAS)
6463
const bool,
65-
#endif // !USE_ONEMATH_CUBLAS
6664
const std::vector<sycl::event> &);
6765

6866
static gemm_batch_impl_fn_ptr_t
@@ -85,9 +83,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
8583
const char *matrixA,
8684
const char *matrixB,
8785
char *resultC,
88-
#if !defined(USE_ONEMATH_CUBLAS)
8986
const bool is_row_major,
90-
#endif // !USE_ONEMATH_CUBLAS
9187
const std::vector<sycl::event> &depends)
9288
{
9389
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -112,11 +108,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
112108
Tc *c, const std::int64_t ldc, const std::int64_t stridec,
113109
const std::int64_t batch_size,
114110
const std::vector<sycl::event> &deps) -> sycl::event {
115-
#if defined(USE_ONEMATH_CUBLAS)
116-
return mkl_blas::column_major::gemm_batch(
117-
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
118-
strideb, beta, c, ldc, stridec, batch_size, deps);
119-
#else
120111
if (is_row_major) {
121112
return mkl_blas::row_major::gemm_batch(
122113
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
@@ -127,7 +118,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
127118
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
128119
strideb, beta, c, ldc, stridec, batch_size, deps);
129120
}
130-
#endif // USE_ONEMATH_CUBLAS
131121
};
132122
gemm_batch_event = gemm_batch_func(
133123
exec_q,
@@ -317,7 +307,7 @@ std::tuple<sycl::event, sycl::event, bool>
317307

318308
// cuBLAS supports only column-major storage
319309
#if defined(USE_ONEMATH_CUBLAS)
320-
const bool is_row_major = false;
310+
constexpr bool is_row_major = false;
321311

322312
transA = A_base_is_c_contig ? oneapi::mkl::transpose::T
323313
: oneapi::mkl::transpose::N;
@@ -397,17 +387,10 @@ std::tuple<sycl::event, sycl::event, bool>
397387
const char *b_typeless_ptr = matrixB.get_data();
398388
char *r_typeless_ptr = resultC.get_data();
399389

400-
#if defined(USE_ONEMATH_CUBLAS)
401-
sycl::event gemm_batch_ev =
402-
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
403-
strideb, stridec, transA, transB, a_typeless_ptr,
404-
b_typeless_ptr, r_typeless_ptr, depends);
405-
#else
406390
sycl::event gemm_batch_ev =
407391
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
408392
strideb, stridec, transA, transB, a_typeless_ptr,
409393
b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
410-
#endif // USE_ONEMATH_CUBLAS
411394

412395
sycl::event args_ev = dpctl::utils::keep_args_alive(
413396
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});

dpnp/backend/extensions/blas/gemv.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ typedef sycl::event (*gemv_impl_fn_ptr_t)(sycl::queue &,
5353
const std::int64_t,
5454
char *,
5555
const std::int64_t,
56-
#if !defined(USE_ONEMATH_CUBLAS)
5756
const bool,
58-
#endif // !USE_ONEMATH_CUBLAS
5957
const std::vector<sycl::event> &);
6058

6159
static gemv_impl_fn_ptr_t gemv_dispatch_vector[dpctl_td_ns::num_types];
@@ -71,9 +69,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
7169
const std::int64_t incx,
7270
char *vectorY,
7371
const std::int64_t incy,
74-
#if !defined(USE_ONEMATH_CUBLAS)
7572
const bool is_row_major,
76-
#endif // !USE_ONEMATH_CUBLAS
7773
const std::vector<sycl::event> &depends)
7874
{
7975
type_utils::validate_type_for_device<T>(exec_q);
@@ -93,10 +89,6 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
9389
const std::int64_t lda, const T *x, const std::int64_t incx,
9490
T beta, T *y, const std::int64_t incy,
9591
const std::vector<sycl::event> &deps) -> sycl::event {
96-
#if defined(USE_ONEMATH_CUBLAS)
97-
return mkl_blas::column_major::gemv(q, transA, m, n, alpha, a, lda,
98-
x, incx, beta, y, incy, deps);
99-
#else
10092
if (is_row_major) {
10193
return mkl_blas::row_major::gemv(q, transA, m, n, alpha, a, lda,
10294
x, incx, beta, y, incy, deps);
@@ -106,7 +98,6 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
10698
lda, x, incx, beta, y, incy,
10799
deps);
108100
}
109-
#endif // USE_ONEMATH_CUBLAS
110101
};
111102
gemv_event = gemv_func(
112103
exec_q,
@@ -215,7 +206,7 @@ std::pair<sycl::event, sycl::event>
215206

216207
// cuBLAS supports only column-major storage
217208
#if defined(USE_ONEMATH_CUBLAS)
218-
const bool is_row_major = false;
209+
constexpr bool is_row_major = false;
219210
std::int64_t m;
220211
std::int64_t n;
221212

@@ -303,15 +294,9 @@ std::pair<sycl::event, sycl::event>
303294
y_typeless_ptr -= (y_shape[0] - 1) * std::abs(incy) * y_elemsize;
304295
}
305296

306-
#if defined(USE_ONEMATH_CUBLAS)
307-
sycl::event gemv_ev =
308-
gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx,
309-
y_typeless_ptr, incy, depends);
310-
#else
311297
sycl::event gemv_ev =
312298
gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx,
313299
y_typeless_ptr, incy, is_row_major, depends);
314-
#endif // USE_ONEMATH_CUBLAS
315300

316301
sycl::event args_ev = dpctl::utils::keep_args_alive(
317302
exec_q, {matrixA, vectorX, vectorY}, {gemv_ev});

dpnp/tests/test_fft.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,7 @@ def test_basic(self, func, dtype, axes):
551551

552552

553553
class TestHfft:
554-
# TODO: include boolean dtype when mkl_fft-gh-180 is merged
555-
@pytest.mark.parametrize(
556-
"dtype", get_all_dtypes(no_none=True, no_bool=True)
557-
)
554+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
558555
@pytest.mark.parametrize("n", [None, 5, 18])
559556
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
560557
def test_basic(self, dtype, n, norm):
@@ -563,10 +560,7 @@ def test_basic(self, dtype, n, norm):
563560

564561
result = dpnp.fft.hfft(ia, n=n, norm=norm)
565562
expected = numpy.fft.hfft(a, n=n, norm=norm)
566-
# TODO: change to the commented line when mkl_fft-2.0.0 is released
567-
# and being used with Intel NumPy >= 2.0.0
568-
flag = True
569-
# flag = True if numpy_version() < "2.0.0" else False
563+
flag = True if numpy_version() < "2.0.0" else False
570564
assert_dtype_allclose(
571565
result, expected, factor=24, check_only_type_kind=flag
572566
)
@@ -609,10 +603,7 @@ def test_basic(self, dtype, n, norm):
609603

610604
result = dpnp.fft.irfft(ia, n=n, norm=norm)
611605
expected = numpy.fft.irfft(a, n=n, norm=norm)
612-
# TODO: change to the commented line when mkl_fft-2.0.0 is released
613-
# and being used with Intel NumPy >= 2.0.0
614-
flag = True
615-
# flag = True if numpy_version() < "2.0.0" else False
606+
flag = True if numpy_version() < "2.0.0" else False
616607
assert_dtype_allclose(
617608
result, expected, factor=24, check_only_type_kind=flag
618609
)
@@ -779,8 +770,7 @@ def test_float16(self):
779770

780771
expected = numpy.fft.rfft(a)
781772
result = dpnp.fft.rfft(ia)
782-
# TODO: change to the commented line when mkl_fft-2.0.0 is released
783-
# and being used with Intel NumPy >= 2.0.0
773+
# TODO: change to the commented line when mkl_fft-gh-204 is resolved
784774
flag = True
785775
# flag = True if numpy_version() < "2.0.0" else False
786776
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
@@ -800,11 +790,10 @@ def test_validate_out(self):
800790

801791

802792
class TestRfft2:
803-
# TODO: add other axes when mkl_fft gh-119 is addressed
804793
@pytest.mark.parametrize(
805794
"dtype", get_all_dtypes(no_none=True, no_complex=True)
806795
)
807-
@pytest.mark.parametrize("axes", [(0, 1)]) # (1, 2),(0, 2),(2, 1),(2, 0)
796+
@pytest.mark.parametrize("axes", [(0, 1), (1, 2), (0, 2), (2, 1), (2, 0)])
808797
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
809798
@pytest.mark.parametrize("order", ["C", "F"])
810799
def test_basic(self, dtype, axes, norm, order):
@@ -859,12 +848,11 @@ def test_error(self, xp):
859848

860849

861850
class TestRfftn:
862-
# TODO: add additional axes when mkl_fft gh-119 is addressed
863851
@pytest.mark.parametrize(
864852
"dtype", get_all_dtypes(no_none=True, no_complex=True)
865853
)
866854
@pytest.mark.parametrize(
867-
"axes", [(0, 1, 2), (-2, -4, -1, -3)] # (-1, -4, -2)
855+
"axes", [(0, 1, 2), (-2, -4, -1, -3), (-1, -4, -2)]
868856
)
869857
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
870858
@pytest.mark.parametrize("order", ["C", "F"])
@@ -965,8 +953,5 @@ def test_1d_array(self):
965953

966954
result = dpnp.fft.irfftn(ia)
967955
expected = numpy.fft.irfftn(a)
968-
# TODO: change to the commented line when mkl_fft-2.0.0 is released
969-
# and being used with Intel NumPy >= 2.0.0
970-
flag = True
971-
# flag = True if numpy_version() < "2.0.0" else False
956+
flag = True if numpy_version() < "2.0.0" else False
972957
assert_dtype_allclose(result, expected, check_only_type_kind=flag)

0 commit comments

Comments
 (0)