diff --git a/dpctl/tensor/libtensor/source/sum_reductions.cpp b/dpctl/tensor/libtensor/source/sum_reductions.cpp index 7628813c6d..13ab268b55 100644 --- a/dpctl/tensor/libtensor/source/sum_reductions.cpp +++ b/dpctl/tensor/libtensor/source/sum_reductions.cpp @@ -218,7 +218,9 @@ std::pair py_sum_over_axis( return std::make_pair(keep_args_event, sum_over_axis_contig_ev); } } - else if (is_src_f_contig & is_dst_c_contig) { + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { auto fn = sum_over_axis0_contig_atomic_dispatch_table[src_typeid] [dst_typeid]; if (fn != nullptr) { diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index fc2a0ec8de..403a823324 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -172,3 +172,18 @@ def test_largish_reduction(arg_dtype, n): assert dpt.all(dpt.equal(y1, y2)) assert dpt.all(dpt.equal(y1, n * m)) + + +def test_axis0_bug(): + "gh-1391" + get_queue_or_skip() + + sh = (1, 2, 3) + a = dpt.arange(sh[0] * sh[1] * sh[2], dtype="i4") + a = dpt.reshape(a, sh) + aT = dpt.permute_dims(a, (2, 1, 0)) + + s = dpt.sum(aT, axis=2) + expected = dpt.asarray([[0, 3], [1, 4], [2, 5]]) + + assert dpt.all(s == expected)