From f35d63e2665b402c00d1a0c19ddc4ebaa7a99c74 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 27 Oct 2023 11:46:06 -0700 Subject: [PATCH 1/3] max and min now use MinMaxAtomicSupportFactory These functions were using ArithmeticAtomicSupportFactory, which disables atomics for floating point types --- .../libtensor/source/reductions/reduction_atomic_support.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp b/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp index 695f4b73d0..2478545efe 100644 --- a/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp +++ b/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp @@ -117,12 +117,12 @@ template struct MinMaxAtomicSupportFactory }; template -struct MaxAtomicSupportFactory : public ArithmeticAtomicSupportFactory +struct MaxAtomicSupportFactory : public MinMaxAtomicSupportFactory { }; template -struct MinAtomicSupportFactory : public ArithmeticAtomicSupportFactory +struct MinAtomicSupportFactory : public MinMaxAtomicSupportFactory { }; From f293713b822c38f047223d726e5b8c22f104c30f Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 27 Oct 2023 12:14:17 -0700 Subject: [PATCH 2/3] Resolves #1455 This issue was caused by a typo where when the `axis0` kernels for tree and atomic reductions would be called, the `axis1` kernel would be called instead --- .../libtensor/source/reductions/reduction_over_axis.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp index da8da0938d..aa46f1c02a 100644 --- a/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp +++ b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp @@ -417,10 +417,10 @@ std::pair py_reduction_over_axis( typename std::remove_all_extents::type; contig_fn_ptr_T fn; if (supports_atomics) { - fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid]; + fn = axis0_atomic_dispatch_table[src_typeid][dst_typeid]; } else { - fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; } if (fn != nullptr) { sycl::event reduction_over_axis0_contig_ev = @@ -727,7 +727,7 @@ std::pair py_tree_reduction_over_axis( } } else if (mat_reduce_over_axis0) { - auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + auto fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; if (fn != nullptr) { sycl::event reduction_over_axis0_contig_ev = fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), @@ -929,7 +929,6 @@ std::pair py_search_over_axis( } using dpctl::tensor::py_internal::simplify_iteration_space; - using dpctl::tensor::py_internal::simplify_iteration_space_1; auto const &src_shape_vecs = src.get_shape_vector(); auto const &src_strides_vecs = src.get_strides_vector(); From 891161f582772431756d747b1dc1b28370dc71d3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 27 Oct 2023 13:19:43 -0700 Subject: [PATCH 3/3] Adds tests for #1455 resolution --- dpctl/tests/test_tensor_sum.py | 30 +++++++++++++++++ dpctl/tests/test_usm_ndarray_reductions.py | 39 ++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index fbfd9547e1..749ca055b9 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -212,6 +212,36 @@ def test_axis0_bug(): assert dpt.all(s == expected) +def test_sum_axis1_axis0(): + """See gh-1455""" + get_queue_or_skip() + + # The atomic case is checked in `test_usm_ndarray_reductions` + # This test checks the tree reduction path for correctness + x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5)) + + m = dpt.sum(x, axis=0) + expected = dpt.asarray( + [ + [60, 63, 66, 69, 72], + [75, 78, 81, 84, 87], + [90, 93, 96, 99, 102], + [105, 108, 111, 114, 117], + ], + dtype="f4", + ) + tol = dpt.finfo(m.dtype).resolution + assert dpt.allclose(m, expected, atol=tol, rtol=tol) + + x = dpt.flip(x, axis=2) + m = dpt.sum(x, axis=2) + expected = dpt.asarray( + [[10, 35, 60, 85], [110, 135, 160, 185], [210, 235, 260, 285]], + dtype="f4", + ) + assert dpt.allclose(m, expected, atol=tol, rtol=tol) + + def _any_complex(dtypes): return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes) diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index 56059e54b8..45afb26aac 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -61,6 +61,20 @@ def test_max_min_axis(): assert dpt.all(m == x[:, 0, 0, :, 0]) +def test_max_axis1_axis0(): + """See gh-1455""" + get_queue_or_skip() + + x = dpt.reshape(dpt.arange(3 * 4 * 5), (3, 4, 5)) + + m = dpt.max(x, axis=0) + assert dpt.all(m == x[-1, :, :]) + + x = dpt.flip(x, axis=2) + m = dpt.max(x, axis=2) + assert dpt.all(m == x[:, :, 0]) + + def test_reduction_keepdims(): get_queue_or_skip() @@ -440,3 +454,28 @@ def test_hypot_complex(): x = dpt.zeros(1, dtype="c8") with pytest.raises(TypeError): dpt.reduce_hypot(x) + + +def test_tree_reduction_axis1_axis0(): + """See gh-1455""" + get_queue_or_skip() + + x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5)) + + m = dpt.logsumexp(x, axis=0) + tol = dpt.finfo(m.dtype).resolution + assert_allclose( + dpt.asnumpy(m), + np.logaddexp.reduce(dpt.asnumpy(x), axis=0, dtype=m.dtype), + rtol=tol, + atol=tol, + ) + + x = dpt.flip(x, axis=2) + m = dpt.logsumexp(x, axis=2) + assert_allclose( + dpt.asnumpy(m), + np.logaddexp.reduce(dpt.asnumpy(x), axis=2, dtype=m.dtype), + rtol=tol, + atol=tol, + )