diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index 2885867797..159cd60394 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -599,11 +599,16 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): matrices on which to perform matrix multiplication. out (Optional[usm_ndarray]): the array into which the result of the matrix product is written. - If `None` then a new array is returned. + The data type of `out` must match the expected data type of the + result or (if provided) `dtype`. + If `None` then a new array is returned. Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the data type of the + returned array is determined by the Type Promotion Rules. + Default: `None`. order (["K", "C", "F", "A"]): memory layout of the output array, if `out` is `None`, otherwise - the `order` parameter value is not used. - + the `order` parameter value is not used. Default: `K`. Returns: usm_ndarray: * if both `x1` and `x2` are one-dimensional arrays with shape @@ -613,8 +618,8 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): a two-dimensional array with shape `(K, N)`, returned array is a two-dimensional array with shape `(M, N)` and contains the conventional matrix product. - * if `x1` is a one-dimensinal array with shape `(K,)` and `x2` is an - array with shape `(..., K, N)`, returned array contains the + * if `x1` is a one-dimensional array with shape `(K,)` and `x2` is + an array with shape `(..., K, N)`, returned array contains the conventional matrix product and has shape `(..., N)`. * if `x1` is an array with shape `(..., M, K)` and `x2` is a one-dimensional array with shape `(K,)`, returned array has shape @@ -741,12 +746,21 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): if not out.flags.writable: raise ValueError("provided `out` array is read-only") - if out.shape != res_shape: + final_res_shape = tuple( + res_shape[i] + for i in range(-len(res_shape), 0) + if i not in appended_axes + ) + if out.shape != final_res_shape: raise ValueError( "The shape of input and output arrays are inconsistent. " - f"Expected output shape is {res_shape}, got {out.shape}" + f"Expected output shape is {final_res_shape}, got {out.shape}" ) + if appended_axes: + out = dpt.expand_dims(out, appended_axes) + orig_out = out + if res_dt != out.dtype: raise ValueError( f"Output array of type {res_dt} is needed," f"got {out.dtype}" diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 7d7ba15a50..c36c195769 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -980,3 +980,28 @@ def test_vecdot_contig_small(): res = dpt.vecdot(v1, v2) assert dpt.all(res[:-1] == 0) assert res[-1] == n + + +def test_matmul_out_appended_axes(): + get_queue_or_skip() + + n0, n1, n2 = 4, 10, 5 + # vm + x1 = dpt.ones(n1, dtype="i4") + x2 = dpt.ones((n0, n1, n2), dtype="i4") + out = dpt.empty((n0, n2), dtype="i4") + + dpt.matmul(x1, x2, out=out) + assert dpt.all(out == n1) + + # mv + x2 = x2.mT + x1, x2 = x2, x1 + dpt.matmul(x1, x2, out=out) + assert dpt.all(out == n1) + + # vv + x1 = dpt.ones(n1, dtype="i4") + out = dpt.empty((), dtype="i4") + dpt.matmul(x1, x2, out=out) + assert out == n1