Skip to content

Commit ec7bfb6

Browse files
committed
use syrk for int dtypes when possible
1 parent 83b7ada commit ec7bfb6

File tree

2 files changed

+59
-39
lines changed

2 files changed

+59
-39
lines changed

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,6 @@
5050
]
5151

5252

53-
def _call_syrk(x1, x2):
54-
"""
55-
Check to see if `syrk` can be called instead of `gemm`.
56-
57-
It is assumed that x1 and x2 are usm_ndarray objects. These arrays have
58-
already been validated to be 2-dimensional and contiguous. Therefore, this
59-
function only verifies the following: Both arrays reference the same
60-
memory. The number of rows in x1 equals the number of columns in x2. If one
61-
array is C-contiguous, the other must be F-contiguous.
62-
63-
"""
64-
call_syrk = False
65-
if (
66-
x1._pointer == x2._pointer
67-
and x1.shape[0] == x2.shape[1]
68-
and x1.flags.c_contiguous != x2.flags.c_contiguous
69-
and x1.flags.f_contiguous != x2.flags.f_contiguous
70-
):
71-
call_syrk = True
72-
73-
return call_syrk
74-
75-
7653
def _compute_res_dtype(*arrays, dtype=None, out=None, casting="no"):
7754
"""
7855
Determines the output array data type.
@@ -541,6 +518,31 @@ def _get_signature(func):
541518
return signature, distinct_core
542519

543520

521+
def _is_syrk_compatible(x1, x2):
522+
"""
523+
Check to see if `syrk` can be called instead of `gemm`.
524+
Input arrays have already been validated to be 2-dimensional.
525+
526+
"""
527+
# Must share data (same base buffer)
528+
if dpnp.get_usm_ndarray(x1)._pointer != dpnp.get_usm_ndarray(x2)._pointer:
529+
return False
530+
531+
# Result must be square
532+
if x1.shape[0] != x2.shape[1]:
533+
return False
534+
535+
# Strides must match transpose pattern
536+
if x1.strides[0] != x2.strides[1] or x1.strides[1] != x2.strides[0]:
537+
return False
538+
539+
# one must be transpose of the other
540+
if x1.shape == x2.shape and x1.strides == x2.strides:
541+
return False # same view, not transposed
542+
543+
return True
544+
545+
544546
def _shape_error(shape1, shape2, func, err_msg):
545547
"""Validate the shapes of input and output arrays."""
546548

@@ -983,6 +985,11 @@ def dpnp_multiplication(
983985
x1 = dpnp.reshape(x1, x1_shape[-2:])
984986
x2 = dpnp.reshape(x2, x2_shape[-2:])
985987
res_shape = (x1_shape[-2], x2_shape[-1])
988+
if _is_syrk_compatible(x1, x2):
989+
call_flag = "syrk"
990+
res_dtype_orig = res_dtype
991+
if dpnp.issubdtype(res_dtype, dpnp.integer):
992+
res_dtype = dpnp.default_float_type(x1.device)
986993
elif x1_base_is_1D:
987994
# TODO: implement gemv_batch to use it here with transpose
988995
call_flag = "gemm_batch"
@@ -1088,21 +1095,17 @@ def dpnp_multiplication(
10881095
depends=_manager.submitted_events,
10891096
)
10901097
_manager.add_event_pair(ht_ev, gemv_ev)
1098+
elif call_flag == "syrk":
1099+
_manager = dpu.SequentialOrderManager[exec_q]
1100+
ht_ev, gemv_ev = bi._syrk(
1101+
exec_q,
1102+
dpnp.get_usm_ndarray(x1),
1103+
dpnp.get_usm_ndarray(result),
1104+
depends=_manager.submitted_events,
1105+
)
1106+
_manager.add_event_pair(ht_ev, gemv_ev)
10911107
elif call_flag == "gemm":
1092-
x1_usm = dpnp.get_usm_ndarray(x1)
1093-
x2_usm = dpnp.get_usm_ndarray(x2)
1094-
call_syrk = _call_syrk(x1_usm, x2_usm)
1095-
if call_syrk:
1096-
_manager = dpu.SequentialOrderManager[exec_q]
1097-
ht_ev, gemv_ev = bi._syrk(
1098-
exec_q,
1099-
x1_usm,
1100-
dpnp.get_usm_ndarray(result),
1101-
depends=_manager.submitted_events,
1102-
)
1103-
_manager.add_event_pair(ht_ev, gemv_ev)
1104-
else:
1105-
result = _gemm_matmul(exec_q, x1_usm, x2_usm, result)
1108+
result = _gemm_matmul(exec_q, x1, x2, result)
11061109
else:
11071110
assert call_flag == "gemm_batch"
11081111
result = _gemm_batch_matmul(exec_q, x1, x2, result)
@@ -1130,6 +1133,9 @@ def dpnp_multiplication(
11301133
elif res_shape != result_shape:
11311134
result = dpnp.reshape(result, result_shape)
11321135

1136+
if call_flag == "syrk" and res_dtype_orig != res_dtype:
1137+
result = result.astype(res_dtype_orig)
1138+
11331139
if out is None:
11341140
if axes is not None:
11351141
# Move the data back to the appropriate axes of the result array

dpnp/tests/test_product.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
assert_dtype_allclose,
1313
generate_random_numpy_array,
1414
get_all_dtypes,
15-
get_float_complex_dtypes,
1615
numpy_version,
1716
)
1817
from .third_party.cupy import testing
@@ -1184,7 +1183,7 @@ def test_special_case(self, dt_out, shape1, shape2):
11841183
result = dpnp.matmul(ia, ib, out=iout)
11851184
assert_dtype_allclose(result, expected)
11861185

1187-
@pytest.mark.parametrize("dt", get_float_complex_dtypes())
1186+
@pytest.mark.parametrize("dt", get_all_dtypes())
11881187
def test_syrk(self, dt):
11891188
a = generate_random_numpy_array((6, 9), dtype=dt)
11901189
ia = dpnp.array(a)
@@ -1202,6 +1201,21 @@ def test_syrk(self, dt):
12021201
expected = a.T @ a
12031202
assert_dtype_allclose(result, expected)
12041203

1204+
@pytest.mark.parametrize("dt", [dpnp.int32, dpnp.float32])
1205+
def test_syrk_strided(self, dt):
1206+
a = generate_random_numpy_array((20, 30), dtype=dt)
1207+
ia = dpnp.array(a)
1208+
a = a[::2, ::2]
1209+
ia = ia[::2, ::2]
1210+
1211+
result = dpnp.matmul(ia, ia.mT)
1212+
expected = numpy.matmul(a, a.T)
1213+
assert_dtype_allclose(result, expected)
1214+
1215+
result = ia.mT @ ia
1216+
expected = a.T @ a
1217+
assert_dtype_allclose(result, expected)
1218+
12051219
@pytest.mark.parametrize(
12061220
"order, out_order",
12071221
[("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")],

0 commit comments

Comments
 (0)