Skip to content

Commit 857091d

Browse files
committed
Fixes out keyword in matmul for cases where axes are appended to inputs
1 parent d7c54e4 commit 857091d

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

dpctl/tensor/_linear_algebra_functions.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -741,12 +741,21 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
741741
if not out.flags.writable:
742742
raise ValueError("provided `out` array is read-only")
743743

744-
if out.shape != res_shape:
744+
final_res_shape = tuple(
745+
res_shape[i]
746+
for i in range(-len(res_shape), 0)
747+
if i not in appended_axes
748+
)
749+
if out.shape != final_res_shape:
745750
raise ValueError(
746751
"The shape of input and output arrays are inconsistent. "
747-
f"Expected output shape is {res_shape}, got {out.shape}"
752+
f"Expected output shape is {final_res_shape}, got {out.shape}"
748753
)
749754

755+
if appended_axes:
756+
out = dpt.expand_dims(out, appended_axes)
757+
orig_out = out
758+
750759
if res_dt != out.dtype:
751760
raise ValueError(
752761
f"Output array of type {res_dt} is needed," f"got {out.dtype}"

0 commit comments

Comments
 (0)