Skip to content

Commit 33b795a

Browse files
committed
address comments
1 parent 94c13eb commit 33b795a

File tree

5 files changed

+90
-91
lines changed

5 files changed

+90
-91
lines changed

dpnp/backend/extensions/blas/syrk.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace py = pybind11;
4444
namespace type_utils = dpctl::tensor::type_utils;
4545

4646
typedef sycl::event (*syrk_impl_fn_ptr_t)(sycl::queue &,
47-
oneapi::mkl::transpose,
47+
const oneapi::mkl::transpose,
4848
const std::int64_t,
4949
const std::int64_t,
5050
const char *,
@@ -60,7 +60,7 @@ static syrk_impl_fn_ptr_t syrk_dispatch_vector[dpctl_td_ns::num_types];
6060

6161
template <typename T>
6262
static sycl::event syrk_impl(sycl::queue &exec_q,
63-
oneapi::mkl::transpose transA,
63+
const oneapi::mkl::transpose transA,
6464
const std::int64_t n,
6565
const std::int64_t k,
6666
const char *matrixA,
@@ -107,7 +107,7 @@ static sycl::event syrk_impl(sycl::queue &exec_q,
107107
};
108108

109109
// we pass beta = 0, so passing upper or lower does not matter
110-
oneapi::mkl::uplo uplo = oneapi::mkl::uplo::upper;
110+
static constexpr auto uplo = oneapi::mkl::uplo::upper;
111111
syrk_event = syrk_func(
112112
exec_q,
113113
uplo, // Specifies whether C’s data is stored in its upper
@@ -198,7 +198,7 @@ std::pair<sycl::event, sycl::event>
198198

199199
const bool is_matrixA_f_contig = matrixA.is_f_contiguous();
200200
const bool is_matrixA_c_contig = matrixA.is_c_contiguous();
201-
if (!is_matrixA_f_contig and !is_matrixA_c_contig) {
201+
if (!is_matrixA_f_contig && !is_matrixA_c_contig) {
202202
throw py::value_error(
203203
"Input matrix is not c-contiguous nor f-contiguous.");
204204
}

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,8 +947,6 @@ def dpnp_multiplication(
947947
x1_is_2D, x1_is_1D, x1_base_is_1D = _define_dim_flags(x1, axis=-1)
948948
x2_is_2D, x2_is_1D, x2_base_is_1D = _define_dim_flags(x2, axis=-2)
949949

950-
# TODO: investigate usage of syrk function from BLAS in
951-
# case of a.T @ a and a @ a.T to gain performance.
952950
if numpy.prod(result_shape) == 0:
953951
res_shape = result_shape
954952
elif x1_shape[-1] == 1:

dpnp/tests/test_product.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,6 +1195,7 @@ def test_syrk(self, dt):
11951195

11961196
iout = dpnp.empty(result.shape, dtype=dt)
11971197
result = dpnp.matmul(ia, ia.mT, out=iout)
1198+
assert result is iout
11981199
assert_dtype_allclose(result, expected)
11991200

12001201
@pytest.mark.parametrize(

dpnp/tests/test_sycl_queue.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -629,50 +629,50 @@ def test_bitwise_op_2in(op, device):
629629
assert_sycl_queue_equal(zy.sycl_queue, y.sycl_queue)
630630

631631

632-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
633-
@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32])
634-
@pytest.mark.parametrize(
635-
"shape1, shape2",
636-
[
637-
((2, 4), (4,)),
638-
((4,), (4, 3)),
639-
((2, 4), (4, 3)),
640-
((2, 0), (0, 3)),
641-
((2, 4), (4, 0)),
642-
((4, 2, 3), (4, 3, 5)),
643-
((4, 2, 3), (4, 3, 1)),
644-
((4, 1, 3), (4, 3, 5)),
645-
((6, 7, 4, 3), (6, 7, 3, 5)),
646-
],
647-
ids=[
648-
"((2, 4), (4,))",
649-
"((4,), (4, 3))",
650-
"((2, 4), (4, 3))",
651-
"((2, 0), (0, 3))",
652-
"((2, 4), (4, 0))",
653-
"((4, 2, 3), (4, 3, 5))",
654-
"((4, 2, 3), (4, 3, 1))",
655-
"((4, 1, 3), (4, 3, 5))",
656-
"((6, 7, 4, 3), (6, 7, 3, 5))",
657-
],
658-
)
659-
def test_matmul(device, dtype, shape1, shape2):
660-
# int32 checks dpctl implementation and float32 checks oneMKL
661-
a = dpnp.arange(numpy.prod(shape1), dtype=dtype, device=device)
662-
b = dpnp.arange(numpy.prod(shape2), dtype=dtype, device=device)
663-
a, b = a.reshape(shape1), b.reshape(shape2)
664-
result = dpnp.matmul(a, b)
665-
666-
result_queue = result.sycl_queue
667-
assert_sycl_queue_equal(result_queue, a.sycl_queue)
668-
assert_sycl_queue_equal(result_queue, b.sycl_queue)
632+
class TestMatmul:
633+
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
634+
@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32])
635+
@pytest.mark.parametrize(
636+
"shape1, shape2",
637+
[
638+
((2, 4), (4,)),
639+
((4,), (4, 3)),
640+
((2, 4), (4, 3)),
641+
((2, 0), (0, 3)),
642+
((2, 4), (4, 0)),
643+
((4, 2, 3), (4, 3, 5)),
644+
((4, 2, 3), (4, 3, 1)),
645+
((4, 1, 3), (4, 3, 5)),
646+
((6, 7, 4, 3), (6, 7, 3, 5)),
647+
],
648+
ids=[
649+
"((2, 4), (4,))",
650+
"((4,), (4, 3))",
651+
"((2, 4), (4, 3))",
652+
"((2, 0), (0, 3))",
653+
"((2, 4), (4, 0))",
654+
"((4, 2, 3), (4, 3, 5))",
655+
"((4, 2, 3), (4, 3, 1))",
656+
"((4, 1, 3), (4, 3, 5))",
657+
"((6, 7, 4, 3), (6, 7, 3, 5))",
658+
],
659+
)
660+
def test_matmul(self, device, dtype, shape1, shape2):
661+
# int32 checks dpctl implementation and float32 checks oneMKL
662+
a = dpnp.arange(numpy.prod(shape1), dtype=dtype, device=device)
663+
b = dpnp.arange(numpy.prod(shape2), dtype=dtype, device=device)
664+
a, b = a.reshape(shape1), b.reshape(shape2)
665+
result = dpnp.matmul(a, b)
669666

667+
result_queue = result.sycl_queue
668+
assert_sycl_queue_equal(result_queue, a.sycl_queue)
669+
assert_sycl_queue_equal(result_queue, b.sycl_queue)
670670

671-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
672-
def test_matmul_syrk(device):
673-
a = dpnp.arange(20, dtype=dpnp.float32, device=device).reshape(4, 5)
674-
result = dpnp.matmul(a, a.mT)
675-
assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue)
671+
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
672+
def test_matmul_syrk(self, device):
673+
a = dpnp.arange(20, dtype=dpnp.float32, device=device).reshape(4, 5)
674+
result = dpnp.matmul(a, a.mT)
675+
assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue)
676676

677677

678678
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)

dpnp/tests/test_usm_type.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -403,52 +403,52 @@ def test_bitwise_op_2in(op, usm_type_x, usm_type_y):
403403
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
404404

405405

406-
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
407-
@pytest.mark.parametrize("usm_type_y", list_of_usm_types)
408-
@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32])
409-
@pytest.mark.parametrize(
410-
"shape1, shape2",
411-
[
412-
((2, 4), (4,)),
413-
((4,), (4, 3)),
414-
((2, 4), (4, 3)),
415-
((2, 0), (0, 3)),
416-
((2, 4), (4, 0)),
417-
((4, 2, 3), (4, 3, 5)),
418-
((4, 2, 3), (4, 3, 1)),
419-
((4, 1, 3), (4, 3, 5)),
420-
((6, 7, 4, 3), (6, 7, 3, 5)),
421-
],
422-
ids=[
423-
"((2, 4), (4,))",
424-
"((4,), (4, 3))",
425-
"((2, 4), (4, 3))",
426-
"((2, 0), (0, 3))",
427-
"((2, 4), (4, 0))",
428-
"((4, 2, 3), (4, 3, 5))",
429-
"((4, 2, 3), (4, 3, 1))",
430-
"((4, 1, 3), (4, 3, 5))",
431-
"((6, 7, 4, 3), (6, 7, 3, 5))",
432-
],
433-
)
434-
def test_matmul(usm_type_x, usm_type_y, dtype, shape1, shape2):
435-
# int32 checks dpctl implementation and float32 checks oneMKL
436-
x = dpnp.arange(numpy.prod(shape1), dtype=dtype, usm_type=usm_type_x)
437-
y = dpnp.arange(numpy.prod(shape2), dtype=dtype, usm_type=usm_type_y)
438-
x, y = x.reshape(shape1), y.reshape(shape2)
439-
z = dpnp.matmul(x, y)
440-
441-
assert x.usm_type == usm_type_x
442-
assert y.usm_type == usm_type_y
443-
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
406+
class TestMatmul:
407+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
408+
@pytest.mark.parametrize("usm_type_y", list_of_usm_types)
409+
@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32])
410+
@pytest.mark.parametrize(
411+
"shape1, shape2",
412+
[
413+
((2, 4), (4,)),
414+
((4,), (4, 3)),
415+
((2, 4), (4, 3)),
416+
((2, 0), (0, 3)),
417+
((2, 4), (4, 0)),
418+
((4, 2, 3), (4, 3, 5)),
419+
((4, 2, 3), (4, 3, 1)),
420+
((4, 1, 3), (4, 3, 5)),
421+
((6, 7, 4, 3), (6, 7, 3, 5)),
422+
],
423+
ids=[
424+
"((2, 4), (4,))",
425+
"((4,), (4, 3))",
426+
"((2, 4), (4, 3))",
427+
"((2, 0), (0, 3))",
428+
"((2, 4), (4, 0))",
429+
"((4, 2, 3), (4, 3, 5))",
430+
"((4, 2, 3), (4, 3, 1))",
431+
"((4, 1, 3), (4, 3, 5))",
432+
"((6, 7, 4, 3), (6, 7, 3, 5))",
433+
],
434+
)
435+
def test_basic(self, usm_type_x, usm_type_y, dtype, shape1, shape2):
436+
# int32 checks dpctl implementation and float32 checks oneMKL
437+
x = dpnp.arange(numpy.prod(shape1), dtype=dtype, usm_type=usm_type_x)
438+
y = dpnp.arange(numpy.prod(shape2), dtype=dtype, usm_type=usm_type_y)
439+
x, y = x.reshape(shape1), y.reshape(shape2)
440+
z = dpnp.matmul(x, y)
444441

442+
assert x.usm_type == usm_type_x
443+
assert y.usm_type == usm_type_y
444+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
445445

446-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
447-
def test_matmul_syrk(usm_type):
448-
x = dpnp.arange(20, dtype=dpnp.float32, usm_type=usm_type).reshape(4, 5)
449-
y = dpnp.matmul(x, x.mT)
446+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
447+
def test_syrk(self, usm_type):
448+
x = dpnp.arange(20, dtype=dpnp.float32, usm_type=usm_type).reshape(4, 5)
449+
y = dpnp.matmul(x, x.mT)
450450

451-
assert y.usm_type == usm_type
451+
assert y.usm_type == usm_type
452452

453453

454454
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)

0 commit comments

Comments
 (0)