From 857091d6e136fe37115036393e2d14f425aefb7e Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 26 Mar 2024 17:04:58 -0700 Subject: [PATCH 1/3] Fixes `out` keyword in `matmul` for cases where axes are appended to inputs --- dpctl/tensor/_linear_algebra_functions.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index 2885867797..38f9e8c5cf 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -741,12 +741,21 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): if not out.flags.writable: raise ValueError("provided `out` array is read-only") - if out.shape != res_shape: + final_res_shape = tuple( + res_shape[i] + for i in range(-len(res_shape), 0) + if i not in appended_axes + ) + if out.shape != final_res_shape: raise ValueError( "The shape of input and output arrays are inconsistent. " - f"Expected output shape is {res_shape}, got {out.shape}" + f"Expected output shape is {final_res_shape}, got {out.shape}" ) + if appended_axes: + out = dpt.expand_dims(out, appended_axes) + orig_out = out + if res_dt != out.dtype: raise ValueError( f"Output array of type {res_dt} is needed," f"got {out.dtype}" From 71f83597398aabd2892d92afbdf4029afc0bb15e Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 26 Mar 2024 17:34:29 -0700 Subject: [PATCH 2/3] Adds test for fixed matmul `out` kwarg --- dpctl/tests/test_usm_ndarray_linalg.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 7d7ba15a50..c36c195769 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -980,3 +980,28 @@ def test_vecdot_contig_small(): res = dpt.vecdot(v1, v2) assert dpt.all(res[:-1] == 0) assert res[-1] == n + + +def test_matmul_out_appended_axes(): + get_queue_or_skip() + + n0, n1, n2 = 4, 10, 5 + # vm + x1 = dpt.ones(n1, dtype="i4") + x2 = dpt.ones((n0, n1, n2), dtype="i4") + out = dpt.empty((n0, n2), dtype="i4") + + dpt.matmul(x1, x2, out=out) + assert dpt.all(out == n1) + + # mv + x2 = x2.mT + x1, x2 = x2, x1 + dpt.matmul(x1, x2, out=out) + assert dpt.all(out == n1) + + # vv + x1 = dpt.ones(n1, dtype="i4") + out = dpt.empty((), dtype="i4") + dpt.matmul(x1, x2, out=out) + assert out == n1 From 55fef8ed55d1b671177296aced292309568217d8 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 26 Mar 2024 17:56:36 -0700 Subject: [PATCH 3/3] Fix typo in matmul docstring and adds documentation for dtype kwarg --- dpctl/tensor/_linear_algebra_functions.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index 38f9e8c5cf..159cd60394 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -599,11 +599,16 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): matrices on which to perform matrix multiplication. out (Optional[usm_ndarray]): the array into which the result of the matrix product is written. - If `None` then a new array is returned. + The data type of `out` must match the expected data type of the + result or (if provided) `dtype`. + If `None` then a new array is returned. Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the data type of the + returned array is determined by the Type Promotion Rules. + Default: `None`. order (["K", "C", "F", "A"]): memory layout of the output array, if `out` is `None`, otherwise - the `order` parameter value is not used. - + the `order` parameter value is not used. Default: `K`. Returns: usm_ndarray: * if both `x1` and `x2` are one-dimensional arrays with shape @@ -613,8 +618,8 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): a two-dimensional array with shape `(K, N)`, returned array is a two-dimensional array with shape `(M, N)` and contains the conventional matrix product. - * if `x1` is a one-dimensinal array with shape `(K,)` and `x2` is an - array with shape `(..., K, N)`, returned array contains the + * if `x1` is a one-dimensional array with shape `(K,)` and `x2` is + an array with shape `(..., K, N)`, returned array contains the conventional matrix product and has shape `(..., N)`. * if `x1` is an array with shape `(..., M, K)` and `x2` is a one-dimensional array with shape `(K,)`, returned array has shape