diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index c9f388fdd6..d9bd6b5b2b 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -123,6 +123,10 @@ def sum(arr, axis=None, dtype=None, keepdims=False): res_usm_type = arr.usm_type if arr.size == 0: + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res_shape = tuple(res_shape[i] for i in inv_perm) return dpt.zeros( res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index 3e25cfb27a..76230fc655 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -133,3 +133,26 @@ def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype): assert isinstance(r, dpt.usm_ndarray) assert r.dtype == dpt.dtype(out_dtype) assert dpt.asnumpy(r) == 1 + + +def test_sum_keepdims_zero_size(): + """See gh-1293""" + get_queue_or_skip() + n = 10 + a = dpt.ones((n, 0, n)) + + s1 = dpt.sum(a, keepdims=True) + assert s1.shape == (1, 1, 1) + + s2 = dpt.sum(a, axis=(0, 1), keepdims=True) + assert s2.shape == (1, 1, n) + + s3 = dpt.sum(a, axis=(1, 2), keepdims=True) + assert s3.shape == (n, 1, 1) + + s4 = dpt.sum(a, axis=(0, 2), keepdims=True) + assert s4.shape == (1, 0, 1) + + a0 = a[0] + s5 = dpt.sum(a0, keepdims=True) + assert s5.shape == (1, 1)