diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp index 374193dca1..7c768ce179 100644 --- a/dpctl/tensor/libtensor/source/reductions/prod.cpp +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -120,6 +120,7 @@ struct TypePairSupportDataForProductReductionTemps { static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, @@ -224,7 +225,7 @@ struct TypePairSupportDataForProductReductionTemps outTy, std::complex>, - // fall-throug + // fall-through td_ns::NotDefinedEntry>::is_defined; }; @@ -255,7 +256,9 @@ struct ProductOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::multiplies; + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -312,7 +315,9 @@ struct ProductOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::multiplies; + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -331,7 +336,9 @@ struct ProductOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::multiplies; + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp index 47d07a345a..f449a6cde3 100644 --- a/dpctl/tensor/libtensor/source/reductions/sum.cpp +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -120,6 +120,7 @@ struct TypePairSupportDataForSumReductionTemps { static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, @@ -224,7 +225,7 @@ struct TypePairSupportDataForSumReductionTemps outTy, std::complex>, - // fall-throug + // fall-through td_ns::NotDefinedEntry>::is_defined; }; @@ -255,7 +256,9 @@ struct SumOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -312,7 +315,9 @@ struct SumOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -331,7 +336,9 @@ struct SumOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index 7be9e3c6d9..29ee3abb1b 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -316,3 +316,17 @@ def test_gh_1468(): a = dpt.full((2, 3, 4), 123456789, dtype=dpt.int32) t = dpt.sum(a, dtype="f4") assert t > 0 + + +@pytest.mark.parametrize( + "dt", ["i1", "i2", "i4", "i8", "f2", "f4", "f8", "c8", "c16"] +) +def test_gh_1944(dt): + "See https://github.com/IntelPython/dpctl/issues/1944" + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt, q) + x = dpt.asarray([-1, 1], dtype=dpt.dtype(dt), sycl_queue=q) + r = dpt.sum(x, dtype="?") + # reduction must be performed in the requested dtype + # if performed in the input type, result is False + assert r