From c87a11ab780b13b70fafd1c49a6e5876009ec918 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 9 Jan 2025 13:33:24 -0800 Subject: [PATCH 1/2] Add bool->bool loops to dpt.sum and dpt.prod This is done to fix edge cases where the input type is not bool and the output type is bool, which ends up falling back on loops in the input data type, which are cast to bool Leads to incorrect results in edge cases, i.e., ``` import dpctl.tensor as dpt a = dpt.asarray([-1, 1], dtype=dpt.int32) dpt.sum(a, dtype=dpt.bool) # usm_ndarray(False) ``` --- dpctl/tensor/libtensor/source/reductions/prod.cpp | 15 +++++++++++---- dpctl/tensor/libtensor/source/reductions/sum.cpp | 15 +++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp index 374193dca1..7c768ce179 100644 --- a/dpctl/tensor/libtensor/source/reductions/prod.cpp +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -120,6 +120,7 @@ struct TypePairSupportDataForProductReductionTemps { static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, @@ -224,7 +225,7 @@ struct TypePairSupportDataForProductReductionTemps outTy, std::complex>, - // fall-throug + // fall-through td_ns::NotDefinedEntry>::is_defined; }; @@ -255,7 +256,9 @@ struct ProductOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::multiplies; + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -312,7 +315,9 @@ struct ProductOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::multiplies; + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -331,7 +336,9 @@ struct ProductOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::multiplies; + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp index 47d07a345a..f449a6cde3 100644 --- a/dpctl/tensor/libtensor/source/reductions/sum.cpp +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -120,6 +120,7 @@ struct TypePairSupportDataForSumReductionTemps { static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, @@ -224,7 +225,7 @@ struct TypePairSupportDataForSumReductionTemps outTy, std::complex>, - // fall-throug + // fall-through td_ns::NotDefinedEntry>::is_defined; }; @@ -255,7 +256,9 @@ struct SumOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -312,7 +315,9 @@ struct SumOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -331,7 +336,9 @@ struct SumOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; From f9f4d6c26b1562230ee227ee76a5e03ecd4b8fe1 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 9 Jan 2025 14:58:10 -0800 Subject: [PATCH 2/2] Add test for example given in gh-1944 --- dpctl/tests/test_tensor_sum.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index 7be9e3c6d9..29ee3abb1b 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -316,3 +316,17 @@ def test_gh_1468(): a = dpt.full((2, 3, 4), 123456789, dtype=dpt.int32) t = dpt.sum(a, dtype="f4") assert t > 0 + + +@pytest.mark.parametrize( + "dt", ["i1", "i2", "i4", "i8", "f2", "f4", "f8", "c8", "c16"] +) +def test_gh_1944(dt): + "See https://github.com/IntelPython/dpctl/issues/1944" + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt, q) + x = dpt.asarray([-1, 1], dtype=dpt.dtype(dt), sycl_queue=q) + r = dpt.sum(x, dtype="?") + # reduction must be performed in the requested dtype + # if performed in the input type, result is False + assert r