diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index b5f356ab30..3473d5cde5 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -160,7 +160,7 @@ tanh, trunc, ) -from ._reduction import argmax, argmin, max, min, sum +from ._reduction import argmax, argmin, max, min, prod, sum from ._testing import allclose __all__ = [ @@ -313,4 +313,5 @@ "min", "argmax", "argmin", + "prod", ] diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index 0bbfc262a4..aac1c84677 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -144,12 +144,12 @@ def _reduction_over_axis( def sum(x, axis=None, dtype=None, keepdims=False): """sum(x, axis=None, dtype=None, keepdims=False) - Calculates the sum of the input array `x`. + Calculates the sum of elements in the input array `x`. Args: x (usm_ndarray): input array. - axis (Optional[int, Tuple[int,...]]): + axis (Optional[int, Tuple[int, ...]]): axis or axes along which sums must be computed. If a tuple of unique integers, sums are computed over multiple axes. If `None`, the sum is computed over the entire array. @@ -202,6 +202,67 @@ def sum(x, axis=None, dtype=None, keepdims=False): ) +def prod(x, axis=None, dtype=None, keepdims=False): + """prod(x, axis=None, dtype=None, keepdims=False) + + Calculates the product of elements in the input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which products must be computed. If a tuple + of unique integers, products are computed over multiple axes. + If `None`, the product is computed over the entire array. + Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the default data + type is inferred from the "kind" of the input array data type. + * If `x` has a real-valued floating-point data type, + the returned array will have the default real-valued + floating-point data type for the device where input + array `x` is allocated. + * If x` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array `x` is allocated. + * If `x` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array `x` is allocated. + * If `x` has a complex-valued floating-point data typee, + the returned array will have the default complex-valued + floating-pointer data type for the device where input + array `x` is allocated. + * If `x` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array `x` is allocated. + If the data type (either specified or resolved) differs from the + data type of `x`, the input array elements are cast to the + specified data type before computing the product. Default: `None`. + keepdims (Optional[bool]): + if `True`, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if `False`, the reduced axes are not included in + the returned array. Default: `False`. + Returns: + usm_ndarray: + an array containing the products. If the product was computed over + the entire array, a zero-dimensional array is returned. The returned + array has the data type as described in the `dtype` parameter + description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + ti._prod_over_axis, + ti._prod_over_axis_dtype_supported, + _default_reduction_dtype, + _identity=1, + ) + + def _comparison_over_axis(x, axis, keepdims, _reduction_fn): if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") @@ -253,7 +314,7 @@ def max(x, axis=None, keepdims=False): Args: x (usm_ndarray): input array. - axis (Optional[int, Tuple[int,...]]): + axis (Optional[int, Tuple[int, ...]]): axis or axes along which maxima must be computed. If a tuple of unique integers, the maxima are computed over multiple axes. If `None`, the max is computed over the entire array. @@ -281,7 +342,7 @@ def min(x, axis=None, keepdims=False): Args: x (usm_ndarray): input array. - axis (Optional[int, Tuple[int,...]]): + axis (Optional[int, Tuple[int, ...]]): axis or axes along which minima must be computed. If a tuple of unique integers, the minima are computed over multiple axes. If `None`, the min is computed over the entire array. diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index abeef5d669..7cb97cd4f9 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -50,6 +50,14 @@ namespace tensor namespace kernels { +template struct can_use_reduce_over_group +{ + static constexpr bool value = + sycl::has_known_identity::value && + !std::is_same_v && !std::is_same_v && + !std::is_same_v>; +}; + template {iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) + { using KernelName = class reduction_over_group_with_atomics_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; @@ -618,7 +627,8 @@ sycl::event reduction_axis1_over_group_with_atomics_contig_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) + { using KernelName = class reduction_axis1_over_group_with_atomics_contig_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, @@ -717,7 +727,8 @@ sycl::event reduction_axis0_over_group_with_atomics_contig_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) + { using KernelName = class reduction_axis0_over_group_with_atomics_contig_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, @@ -1007,10 +1018,12 @@ sycl::event reduction_over_group_temps_strided_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) + { using KernelName = class reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), ReductionOverGroupNoAtomicFunctor< @@ -1026,6 +1039,7 @@ sycl::event reduction_over_group_temps_strided_impl( using KernelName = class custom_reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, SlmT>; + cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), CustomReductionOverGroupNoAtomicFunctor< @@ -1062,68 +1076,67 @@ sycl::event reduction_over_group_temps_strided_impl( partially_reduced_tmp + reduction_groups * iter_nelems; } - const sycl::event &first_reduction_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); - using InputIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - InputIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; + using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; - // Only 2*iter_nd entries describing shape and strides of - // iterated dimensions of input array from - // iter_shape_and_strides are going to be accessed by - // inp_indexer - InputIndexerT inp_indexer(iter_nd, iter_arg_offset, - iter_shape_and_strides); - ResIndexerT noop_tmp_indexer{}; + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer + InputIndexerT inp_indexer(iter_nd, iter_arg_offset, + iter_shape_and_strides); + ResIndexerT noop_tmp_indexer{}; - InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, - noop_tmp_indexer}; - ReductionIndexerT reduction_indexer{ - red_nd, reduction_arg_offset, reduction_shape_stride}; + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + noop_tmp_indexer}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; - auto globalRange = - sycl::range<1>{iter_nelems * reduction_groups * wg}; - auto localRange = sycl::range<1>{wg}; + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { - using KernelName = class reduction_over_group_temps_krn< + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; - cgh.parallel_for( - sycl::nd_range<1>(globalRange, localRange), - ReductionOverGroupNoAtomicFunctor< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>( - arg_tp, partially_reduced_tmp, ReductionOpT(), - identity_val, in_out_iter_indexer, - reduction_indexer, reduction_nelems, iter_nelems, - preferrered_reductions_per_wi)); - } - else { - using SlmT = sycl::local_accessor; - SlmT local_memory = SlmT(localRange, cgh); - using KernelName = - class custom_reduction_over_group_temps_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT, SlmT>; - cgh.parallel_for( - sycl::nd_range<1>(globalRange, localRange), - CustomReductionOverGroupNoAtomicFunctor< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT, SlmT>( - arg_tp, partially_reduced_tmp, ReductionOpT(), - identity_val, in_out_iter_indexer, - reduction_indexer, local_memory, reduction_nelems, - iter_nelems, preferrered_reductions_per_wi)); - } - }); + ReductionIndexerT>( + arg_tp, partially_reduced_tmp, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class custom_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, partially_reduced_tmp, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + local_memory, reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + }); size_t remaining_reduction_nelems = reduction_groups; @@ -1165,7 +1178,8 @@ sycl::event reduction_over_group_temps_strided_impl( auto globalRange = sycl::range<1>{iter_nelems * reduction_groups_ * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) { using KernelName = class reduction_over_group_temps_krn< resTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; @@ -1240,7 +1254,8 @@ sycl::event reduction_over_group_temps_strided_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) + { using KernelName = class reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; @@ -1831,6 +1846,250 @@ struct SumOverAxis0AtomicContigFactory } }; +// Product + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForProductReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForProductReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-throug + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ProductOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + // Argmax and Argmin /* = Search reduction using reduce_over_group*/ @@ -2320,7 +2579,8 @@ sycl::event search_reduction_over_group_temps_strided_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) + { using KernelName = class search_reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, true, true>; @@ -2419,7 +2679,8 @@ sycl::event search_reduction_over_group_temps_strided_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) + { using KernelName = class search_reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, true, false>; @@ -2499,7 +2760,8 @@ sycl::event search_reduction_over_group_temps_strided_impl( auto globalRange = sycl::range<1>{iter_nelems * reduction_groups_ * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) { using KernelName = class search_reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, IndexOpT, @@ -2582,7 +2844,8 @@ sycl::event search_reduction_over_group_temps_strided_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - if constexpr (su_ns::IsSyclOp::value) { + if constexpr (can_use_reduce_over_group::value) + { using KernelName = class search_reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, false, true>; diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index 3ecfbe67c7..0d4240c516 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include "math_utils.hpp" @@ -272,6 +273,18 @@ struct GetIdentity using IsPlus = std::bool_constant> || std::is_same_v>>; +// Multiplies + +template +using IsMultiplies = + std::bool_constant> || + std::is_same_v>>; + +template +struct GetIdentity::value>> +{ + static constexpr T value = static_cast(1); +}; // Identity @@ -280,13 +293,17 @@ template struct Identity }; template -struct Identity::value>> +using UseBuiltInIdentity = + std::conjunction, sycl::has_known_identity>; + +template +struct Identity::value>> { static constexpr T value = GetIdentity::value; }; template -struct Identity::value>> +struct Identity::value>> { static constexpr T value = sycl::known_identity::value; }; diff --git a/dpctl/tensor/libtensor/source/reduction_over_axis.cpp b/dpctl/tensor/libtensor/source/reduction_over_axis.cpp index 346efaa936..c67fcd5ba3 100644 --- a/dpctl/tensor/libtensor/source/reduction_over_axis.cpp +++ b/dpctl/tensor/libtensor/source/reduction_over_axis.cpp @@ -204,6 +204,59 @@ void populate_sum_over_axis_dispatch_tables(void) } // namespace impl +// Product +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_prod_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::ProductOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table); +} + +} // namespace impl + // Argmax namespace impl { @@ -259,6 +312,12 @@ void init_reduction_functions(py::module_ m) namespace impl = dpctl::tensor::py_internal::impl; + using dpctl::tensor::py_internal::py_reduction_dtype_supported; + using dpctl::tensor::py_internal::py_reduction_over_axis; + + using dpctl::tensor::py_internal::check_atomic_support; + using dpctl::tensor::py_internal::fixed_decision; + // MAX { using dpctl::tensor::py_internal::impl:: @@ -269,16 +328,21 @@ void init_reduction_functions(py::module_ m) using impl::max_over_axis_strided_atomic_dispatch_table; using impl::max_over_axis_strided_temps_dispatch_table; + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + auto max_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, const arrayT &dst, sycl::queue &exec_q, const event_vecT &depends = {}) { - using dpctl::tensor::py_internal::py_reduction_over_axis; return py_reduction_over_axis( src, trailing_dims_to_reduce, dst, exec_q, depends, max_over_axis_strided_atomic_dispatch_table, max_over_axis_strided_temps_dispatch_table, max_over_axis0_contig_atomic_dispatch_table, - max_over_axis1_contig_atomic_dispatch_table); + max_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); }; m.def("_max_over_axis", max_pyapi, "", py::arg("src"), py::arg("trailing_dims_to_reduce"), py::arg("dst"), @@ -295,16 +359,21 @@ void init_reduction_functions(py::module_ m) using impl::min_over_axis_strided_atomic_dispatch_table; using impl::min_over_axis_strided_temps_dispatch_table; + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + auto min_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, const arrayT &dst, sycl::queue &exec_q, const event_vecT &depends = {}) { - using dpctl::tensor::py_internal::py_reduction_over_axis; return py_reduction_over_axis( src, trailing_dims_to_reduce, dst, exec_q, depends, min_over_axis_strided_atomic_dispatch_table, min_over_axis_strided_temps_dispatch_table, min_over_axis0_contig_atomic_dispatch_table, - min_over_axis1_contig_atomic_dispatch_table); + min_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); }; m.def("_min_over_axis", min_pyapi, "", py::arg("src"), py::arg("trailing_dims_to_reduce"), py::arg("dst"), @@ -321,16 +390,21 @@ void init_reduction_functions(py::module_ m) using impl::sum_over_axis_strided_atomic_dispatch_table; using impl::sum_over_axis_strided_temps_dispatch_table; + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + auto sum_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, const arrayT &dst, sycl::queue &exec_q, const event_vecT &depends = {}) { - using dpctl::tensor::py_internal::py_reduction_over_axis; return py_reduction_over_axis( src, trailing_dims_to_reduce, dst, exec_q, depends, sum_over_axis_strided_atomic_dispatch_table, sum_over_axis_strided_temps_dispatch_table, sum_over_axis0_contig_atomic_dispatch_table, - sum_over_axis1_contig_atomic_dispatch_table); + sum_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); }; m.def("_sum_over_axis", sum_pyapi, "", py::arg("src"), py::arg("trailing_dims_to_reduce"), py::arg("dst"), @@ -339,17 +413,61 @@ void init_reduction_functions(py::module_ m) auto sum_dtype_supported = [&](const py::dtype &input_dtype, const py::dtype &output_dtype, const std::string &dst_usm_type, sycl::queue &q) { - using dpctl::tensor::py_internal::py_reduction_dtype_supported; return py_reduction_dtype_supported( input_dtype, output_dtype, dst_usm_type, q, sum_over_axis_strided_atomic_dispatch_table, - sum_over_axis_strided_temps_dispatch_table); + sum_over_axis_strided_temps_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); }; m.def("_sum_over_axis_dtype_supported", sum_dtype_supported, "", py::arg("arg_dtype"), py::arg("out_dtype"), py::arg("dst_usm_type"), py::arg("sycl_queue")); } + // PROD + { + using dpctl::tensor::py_internal::impl:: + populate_prod_over_axis_dispatch_tables; + populate_prod_over_axis_dispatch_tables(); + using impl::prod_over_axis0_contig_atomic_dispatch_table; + using impl::prod_over_axis1_contig_atomic_dispatch_table; + using impl::prod_over_axis_strided_atomic_dispatch_table; + using impl::prod_over_axis_strided_temps_dispatch_table; + + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + + auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + prod_over_axis0_contig_atomic_dispatch_table, + prod_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto prod_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } + // ARGMAX { using dpctl::tensor::py_internal::impl:: diff --git a/dpctl/tensor/libtensor/source/reduction_over_axis.hpp b/dpctl/tensor/libtensor/source/reduction_over_axis.hpp index c7bbadd455..1a9cb6f5e7 100644 --- a/dpctl/tensor/libtensor/source/reduction_over_axis.hpp +++ b/dpctl/tensor/libtensor/source/reduction_over_axis.hpp @@ -50,14 +50,15 @@ namespace tensor namespace py_internal { -inline bool check_atomic_support(const sycl::queue &exec_q, - sycl::usm::alloc usm_alloc_type, - bool require_atomic64 = false) +template +bool check_atomic_support(const sycl::queue &exec_q, + sycl::usm::alloc usm_alloc_type) { bool supports_atomics = false; const sycl::device &dev = exec_q.get_device(); - if (require_atomic64) { + + if constexpr (require_atomic64) { if (!dev.has(sycl::aspect::atomic64)) return false; } @@ -79,15 +80,24 @@ inline bool check_atomic_support(const sycl::queue &exec_q, return supports_atomics; } +template +bool fixed_decision(const sycl::queue &, sycl::usm::alloc) +{ + return return_value; +} + /* ====================== dtype supported ======================== */ -template -bool py_reduction_dtype_supported(const py::dtype &input_dtype, - const py::dtype &output_dtype, - const std::string &dst_usm_type, - sycl::queue &q, - const fnT &atomic_dispatch_table, - const fnT &temps_dispatch_table) +template +bool py_reduction_dtype_supported( + const py::dtype &input_dtype, + const py::dtype &output_dtype, + const std::string &dst_usm_type, + sycl::queue &q, + const fnT &atomic_dispatch_table, + const fnT &temps_dispatch_table, + const CheckAtomicSupportFnT &check_atomic_support_size4, + const CheckAtomicSupportFnT &check_atomic_support_size8) { int arg_tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl @@ -135,12 +145,11 @@ bool py_reduction_dtype_supported(const py::dtype &input_dtype, switch (output_dtype.itemsize()) { case sizeof(float): { - supports_atomics = check_atomic_support(q, kind); + supports_atomics = check_atomic_support_size4(q, kind); } break; case sizeof(double): { - constexpr bool require_atomic64 = true; - supports_atomics = check_atomic_support(q, kind, require_atomic64); + supports_atomics = check_atomic_support_size8(q, kind); } break; } @@ -158,7 +167,7 @@ bool py_reduction_dtype_supported(const py::dtype &input_dtype, /* ==================== Generic reductions ====================== */ -template +template std::pair py_reduction_over_axis( const dpctl::tensor::usm_ndarray &src, int trailing_dims_to_reduce, // comp over this many trailing indexes @@ -168,7 +177,9 @@ std::pair py_reduction_over_axis( const strided_fnT &atomic_dispatch_table, const strided_fnT &temps_dispatch_table, const contig_fnT &axis0_dispatch_table, - const contig_fnT &axis1_dispatch_table) + const contig_fnT &axis1_dispatch_table, + const SupportAtomicFnT &check_atomic_support_size4, + const SupportAtomicFnT &check_atomic_support_size8) { int src_nd = src.get_ndim(); int iteration_nd = src_nd - trailing_dims_to_reduce; @@ -243,7 +254,7 @@ std::pair py_reduction_over_axis( void *data_ptr = dst.get_data(); const auto &ctx = exec_q.get_context(); auto usm_type = sycl::get_pointer_type(data_ptr, ctx); - supports_atomics = check_atomic_support(exec_q, usm_type); + supports_atomics = check_atomic_support_size4(exec_q, usm_type); } break; case sizeof(double): { @@ -251,9 +262,7 @@ std::pair py_reduction_over_axis( const auto &ctx = exec_q.get_context(); auto usm_type = sycl::get_pointer_type(data_ptr, ctx); - constexpr bool require_atomic64 = true; - supports_atomics = - check_atomic_support(exec_q, usm_type, require_atomic64); + supports_atomics = check_atomic_support_size8(exec_q, usm_type); } break; } diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index 8f2bd45362..dc647febf7 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import pytest import dpctl.tensor as dpt @@ -55,11 +54,11 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype): assert r.dtype.kind == "f" elif m.dtype.kind == "c": assert r.dtype.kind == "c" - assert (dpt.asnumpy(r) == 100).all() + assert dpt.all(r == 100) m = dpt.ones(200, dtype=arg_dtype)[:1:-2] r = dpt.sum(m) - assert (dpt.asnumpy(r) == 99).all() + assert dpt.all(r == 99) @pytest.mark.parametrize("arg_dtype", _all_dtypes) @@ -74,7 +73,7 @@ def test_sum_arg_out_dtype_matrix(arg_dtype, out_dtype): assert isinstance(r, dpt.usm_ndarray) assert r.dtype == dpt.dtype(out_dtype) - assert (dpt.asnumpy(r) == 100).all() + assert dpt.all(r == 100) def test_sum_empty(): @@ -93,7 +92,7 @@ def test_sum_axis(): assert isinstance(s, dpt.usm_ndarray) assert s.shape == (3, 6) - assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all() + assert dpt.all(s == dpt.asarray(4 * 5 * 7, dtype="i4")) def test_sum_keepdims(): @@ -104,7 +103,7 @@ def test_sum_keepdims(): assert isinstance(s, dpt.usm_ndarray) assert s.shape == (3, 1, 1, 6, 1) - assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all() + assert dpt.all(s == dpt.asarray(4 * 5 * 7, dtype=s.dtype)) def test_sum_scalar(): @@ -116,7 +115,7 @@ def test_sum_scalar(): assert isinstance(s, dpt.usm_ndarray) assert m.sycl_queue == s.sycl_queue assert s.shape == () - assert dpt.asnumpy(s) == np.full((), 1) + assert s == dpt.full((), 1) @pytest.mark.parametrize("arg_dtype", _all_dtypes) @@ -131,7 +130,7 @@ def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype): assert isinstance(r, dpt.usm_ndarray) assert r.dtype == dpt.dtype(out_dtype) - assert dpt.asnumpy(r) == 1 + assert r == 1 def test_sum_keepdims_zero_size(): @@ -186,3 +185,66 @@ def test_axis0_bug(): expected = dpt.asarray([[0, 3], [1, 4], [2, 5]]) assert dpt.all(s == expected) + + +@pytest.mark.parametrize("arg_dtype", _all_dtypes[1:]) +def test_prod_arg_dtype_default_output_dtype_matrix(arg_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.prod(m) + + assert isinstance(r, dpt.usm_ndarray) + if m.dtype.kind == "i": + assert r.dtype.kind == "i" + elif m.dtype.kind == "u": + assert r.dtype.kind == "u" + elif m.dtype.kind == "f": + assert r.dtype.kind == "f" + elif m.dtype.kind == "c": + assert r.dtype.kind == "c" + assert dpt.all(r == 1) + + if dpt.isdtype(m.dtype, "unsigned integer"): + m = dpt.tile(dpt.arange(1, 3, dtype=arg_dtype), 10)[:1:-2] + r = dpt.prod(m) + assert dpt.all(r == dpt.asarray(512, dtype=r.dtype)) + else: + m = dpt.full(200, -1, dtype=arg_dtype)[:1:-2] + r = dpt.prod(m) + assert dpt.all(r == dpt.asarray(-1, dtype=r.dtype)) + + +def test_prod_empty(): + get_queue_or_skip() + x = dpt.empty((0,), dtype="u1") + y = dpt.prod(x) + assert y.shape == tuple() + assert int(y) == 1 + + +def test_prod_axis(): + get_queue_or_skip() + + m = dpt.ones((3, 4, 5, 6, 7), dtype="i4") + s = dpt.prod(m, axis=(1, 2, -1)) + + assert isinstance(s, dpt.usm_ndarray) + assert s.shape == (3, 6) + assert dpt.all(s == dpt.asarray(1, dtype="i4")) + + +@pytest.mark.parametrize("arg_dtype", _all_dtypes) +@pytest.mark.parametrize("out_dtype", _all_dtypes[1:]) +def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + skip_if_dtype_not_supported(out_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.prod(m, dtype=out_dtype) + + assert isinstance(r, dpt.usm_ndarray) + assert r.dtype == dpt.dtype(out_dtype) + assert dpt.all(r == 1)