diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index 7be651bbfd..c9f388fdd6 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -122,10 +122,12 @@ def sum(arr, axis=None, dtype=None, keepdims=False): res_dt = _to_device_supported_dtype(res_dt, q.sycl_device) res_usm_type = arr.usm_type - if red_nd == 0: + if arr.size == 0: return dpt.zeros( res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) + if red_nd == 0: + return dpt.astype(arr, res_dt, copy=False) host_tasks_list = [] if ti._sum_over_axis_dtype_supported(inp_dt, res_dt, res_usm_type, q): diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index e53c51766d..3e25cfb27a 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -106,3 +106,30 @@ def test_sum_keepdims(): assert isinstance(s, dpt.usm_ndarray) assert s.shape == (3, 1, 1, 6, 1) assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all() + + +def test_sum_scalar(): + get_queue_or_skip() + + m = dpt.ones(()) + s = dpt.sum(m) + + assert isinstance(s, dpt.usm_ndarray) + assert m.sycl_queue == s.sycl_queue + assert s.shape == () + assert dpt.asnumpy(s) == np.full((), 1) + + +@pytest.mark.parametrize("arg_dtype", _all_dtypes) +@pytest.mark.parametrize("out_dtype", _all_dtypes[1:]) +def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + skip_if_dtype_not_supported(out_dtype, q) + + m = dpt.ones((), dtype=arg_dtype) + r = dpt.sum(m, dtype=out_dtype) + + assert isinstance(r, dpt.usm_ndarray) + assert r.dtype == dpt.dtype(out_dtype) + assert dpt.asnumpy(r) == 1