From 564d1dfed7bfb804fe78cd4a4385933bb14ecea6 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 2 Jun 2023 14:58:53 +0200 Subject: [PATCH 1/3] Add res_shape param to _empty_like_pair_orderK --- dpctl/tensor/_elementwise_common.py | 8 ++++---- dpctl/tensor/_type_utils.py | 23 +++++++++++++---------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index a775376d95..b5303e49c1 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -433,7 +433,7 @@ def __call__(self, o1, o2, out=None, order="K"): if out is None: if order == "K": out = _empty_like_pair_orderK( - src1, src2, res_dt, res_usm_type, exec_q + src1, src2, res_dt, res_shape, res_usm_type, exec_q ) else: if order == "A": @@ -482,7 +482,7 @@ def __call__(self, o1, o2, out=None, order="K"): if out is None: if order == "K": out = _empty_like_pair_orderK( - src1, buf2, res_dt, res_usm_type, exec_q + src1, buf2, res_dt, res_shape, res_usm_type, exec_q ) else: out = dpt.empty( @@ -524,7 +524,7 @@ def __call__(self, o1, o2, out=None, order="K"): if out is None: if order == "K": out = _empty_like_pair_orderK( - buf1, src2, res_dt, res_usm_type, exec_q + buf1, src2, res_dt, res_shape, res_usm_type, exec_q ) else: out = dpt.empty( @@ -578,7 +578,7 @@ def __call__(self, o1, o2, out=None, order="K"): if out is None: if order == "K": out = _empty_like_pair_orderK( - buf1, buf2, res_dt, res_usm_type, exec_q + buf1, buf2, res_dt, res_shape, res_usm_type, exec_q ) else: out = dpt.empty( diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index d33f2eba06..1de70a4952 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -158,38 +158,41 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None): return dpt.permute_dims(R, inv_perm) -def _empty_like_pair_orderK(X1, X2, dt, usm_type, dev): +def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev): if not isinstance(X1, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray, got {type(X1)}") if not isinstance(X2, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray, got {type(X2)}") nd1 = X1.ndim nd2 = X2.ndim - if nd1 > nd2: + if nd1 > nd2 and X1.shape == res_shape: return _empty_like_orderK(X1, dt, usm_type, dev) - elif nd1 < nd2: + elif nd1 < nd2 and X2.shape == res_shape: return _empty_like_orderK(X2, dt, usm_type, dev) fl1 = X1.flags fl2 = X2.flags if fl1["C"] or fl2["C"]: - return dpt.empty_like( - X1, dtype=dt, usm_type=usm_type, device=dev, order="C" + return dpt.empty( + res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C" ) if fl1["F"] and fl2["F"]: - return dpt.empty_like( - X1, dtype=dt, usm_type=usm_type, device=dev, order="F" + return dpt.empty( + res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F" ) st1 = list(X1.strides) st2 = list(X2.strides) + max_ndim = max(nd1, nd2) + st1 += [0] * (max_ndim - len(st1)) + st2 += [0] * (max_ndim - len(st2)) perm = sorted( - range(nd1), + range(max_ndim), key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])), reverse=True, ) - inv_perm = sorted(range(nd1), key=lambda i: perm[i]) + inv_perm = sorted(range(max_ndim), key=lambda i: perm[i]) st1_sorted = [st1[i] for i in perm] st2_sorted = [st2[i] for i in perm] - sh = X1.shape + sh = res_shape sh_sorted = tuple(sh[i] for i in perm) R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") if max(min(st1_sorted), min(st2_sorted)) < 0: From 814307e828a3f2dd76550aa5cea78108f00f4481 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 2 Jun 2023 15:01:50 +0200 Subject: [PATCH 2/3] Update tests fot dpctl.tensor.add --- dpctl/tests/elementwise/test_add.py | 37 +++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py index fa97b1c1c7..119da8f45f 100644 --- a/dpctl/tests/elementwise/test_add.py +++ b/dpctl/tests/elementwise/test_add.py @@ -156,6 +156,43 @@ def test_add_broadcasting(): assert (dpt.asnumpy(r4) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() +def test_add_broadcasting_new_shape(): + get_queue_or_skip() + + ar1 = dpt.ones((6, 1), dtype="i4") + ar2 = dpt.arange(6, dtype="i4") + + r = dpt.add(ar1, ar2) + assert (dpt.asnumpy(r) == np.arange(1, 7, dtype="i4")[np.newaxis, :]).all() + + r1 = dpt.add(ar2, ar1) + assert (dpt.asnumpy(r1) == np.arange(1, 7, dtype="i4")[np.newaxis, :]).all() + + r2 = dpt.add(ar1[::2], ar2[::2]) + assert ( + dpt.asnumpy(r2) == np.arange(1, 7, dtype="i4")[::2][np.newaxis, :] + ).all() + + r3 = dpt.empty_like(ar1) + with pytest.raises(TypeError): + dpt.add(ar1, ar2, out=r3) + + ar3 = dpt.ones((6, 1), dtype="i4") + ar4 = dpt.ones((1, 6), dtype="i4") + + r4 = dpt.add(ar3, ar4) + assert (dpt.asnumpy(r4) == np.full((6, 6), 2, dtype="i4")).all() + + r5 = dpt.add(ar4, ar3) + assert (dpt.asnumpy(r5) == np.full((6, 6), 2, dtype="i4")).all() + + r6 = dpt.add(ar3[::2], ar4[:, ::2]) + assert (dpt.asnumpy(r6) == np.full((3, 3), 2, dtype="i4")).all() + + r7 = dpt.add(ar3[::2], ar4) + assert (dpt.asnumpy(r7) == np.full((3, 6), 2, dtype="i4")).all() + + def test_add_broadcasting_error(): get_queue_or_skip() m = dpt.ones((10, 10), dtype="i4") From 5087297ea323527c891e54b470ea852ea774e59a Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 5 Jun 2023 10:57:15 +0200 Subject: [PATCH 3/3] Update test for _empty_like_pair_orderK --- dpctl/tests/elementwise/test_type_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dpctl/tests/elementwise/test_type_utils.py b/dpctl/tests/elementwise/test_type_utils.py index c040713925..4c2fd3cb60 100644 --- a/dpctl/tests/elementwise/test_type_utils.py +++ b/dpctl/tests/elementwise/test_type_utils.py @@ -89,6 +89,7 @@ def test_type_utils_empty_like_orderK_invalid_args(): 3, ), dpt.int32, + (3,), "device", None, ) @@ -105,6 +106,7 @@ def test_type_utils_empty_like_orderK_invalid_args(): 3, ), dpt.int32, + (10,), "device", None, )