From cd4f6946c1a0455f202bef6fcb6d39c5b5229a0d Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 20 Jul 2023 13:41:20 -0500 Subject: [PATCH 1/2] Special case of array_size=True must also handle keepdims=True --- dpctl/tensor/_reduction.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index c9f388fdd6..d9bd6b5b2b 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -123,6 +123,10 @@ def sum(arr, axis=None, dtype=None, keepdims=False): res_usm_type = arr.usm_type if arr.size == 0: + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res_shape = tuple(res_shape[i] for i in inv_perm) return dpt.zeros( res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) From a4e357fbecbb831ec8ebd1ba8fa4725c6f0f77be Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 20 Jul 2023 13:47:44 -0500 Subject: [PATCH 2/2] Provide tests for gh-1293 --- dpctl/tests/test_tensor_sum.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index 3e25cfb27a..76230fc655 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -133,3 +133,26 @@ def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype): assert isinstance(r, dpt.usm_ndarray) assert r.dtype == dpt.dtype(out_dtype) assert dpt.asnumpy(r) == 1 + + +def test_sum_keepdims_zero_size(): + """See gh-1293""" + get_queue_or_skip() + n = 10 + a = dpt.ones((n, 0, n)) + + s1 = dpt.sum(a, keepdims=True) + assert s1.shape == (1, 1, 1) + + s2 = dpt.sum(a, axis=(0, 1), keepdims=True) + assert s2.shape == (1, 1, n) + + s3 = dpt.sum(a, axis=(1, 2), keepdims=True) + assert s3.shape == (n, 1, 1) + + s4 = dpt.sum(a, axis=(0, 2), keepdims=True) + assert s4.shape == (1, 0, 1) + + a0 = a[0] + s5 = dpt.sum(a0, keepdims=True) + assert s5.shape == (1, 1)