diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index bebb1889f4..890af46339 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -140,11 +140,9 @@ def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev): def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev): - # if the kind of result is different from - # the kind of input, use the default data - # we use default dtype for the resulting kind. - # This guarantees alignment of reciprocal and - # divide output types. + # if the kind of result is different from the kind of input, we use the + # default floating-point dtype for the resulting kind. This guarantees + # alignment of reciprocal and divide output types. if buf_dt.kind != arg_dtype.kind: default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev) if res_dt == default_dt: diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index 5de9024b6f..411040bada 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -102,8 +102,7 @@ using AbsContigFunctor = template struct AbsOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -119,6 +118,8 @@ template struct AbsOutputType td_ns::TypeMapResultEntry, float>, td_ns::TypeMapResultEntry, double>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -140,9 +141,7 @@ template struct AbsContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AbsOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -191,9 +190,7 @@ template struct AbsStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AbsOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp index a9bf000a20..a90f4e699f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp @@ -145,14 +145,15 @@ using AcosStridedFunctor = elementwise_common:: template struct AcosOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -174,9 +175,7 @@ template struct AcosContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AcosOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -222,9 +221,7 @@ template struct AcosStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AcosOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp index 94c9c5e56e..8af3708427 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp @@ -172,14 +172,15 @@ using AcoshStridedFunctor = elementwise_common:: template struct AcoshOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -201,9 +202,7 @@ template struct AcoshContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AcoshOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -249,9 +248,7 @@ template struct AcoshStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AcoshOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index e77068b5e1..c06e98f3e5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -132,8 +132,7 @@ using AddStridedFunctor = template struct AddOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct AddOutputType std::complex, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct AddContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AddOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -273,9 +272,7 @@ template struct AddStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AddOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -324,12 +321,12 @@ struct AddContigMatrixContigRowBroadcastFactory { fnT get() { - using resT = typename AddOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!AddOutputType::is_defined) { fnT fn = nullptr; return fn; } else { + using resT = typename AddOutputType::value_type; if constexpr (dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value) @@ -371,12 +368,12 @@ struct AddContigRowContigMatrixBroadcastFactory { fnT get() { - using resT = typename AddOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!AddOutputType::is_defined) { fnT fn = nullptr; return fn; } else { + using resT = typename AddOutputType::value_type; if constexpr (dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value) @@ -438,6 +435,50 @@ template class add_inplace_contig_kernel; +/* @brief Types supported by in-place add */ +template struct AddInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct AddInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x += y */ + std::enable_if_t::value, int> get() + { + if constexpr (AddInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event add_inplace_contig_impl(sycl::queue &exec_q, @@ -457,9 +498,7 @@ template struct AddInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AddInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -497,9 +536,7 @@ struct AddInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AddInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -544,8 +581,7 @@ struct AddInplaceRowMatrixBroadcastFactory { fnT get() { - using resT = typename AddOutputType::value_type; - if constexpr (!std::is_same_v) { + if constexpr (!AddInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp index 9622a7a207..034b71438f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp @@ -91,11 +91,12 @@ using AngleStridedFunctor = elementwise_common:: template struct AngleOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, float>, td_ns::TypeMapResultEntry, double>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -117,9 +118,7 @@ template struct AngleContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AngleOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -165,9 +164,7 @@ template struct AngleStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AngleOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp index 3a0d6efecf..35c381aa84 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp @@ -165,14 +165,15 @@ using AsinStridedFunctor = elementwise_common:: template struct AsinOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -194,9 +195,7 @@ template struct AsinContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AsinOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -242,9 +241,7 @@ template struct AsinStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AsinOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp index d64f3f0233..7373dc39d5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp @@ -148,14 +148,15 @@ using AsinhStridedFunctor = elementwise_common:: template struct AsinhOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -177,9 +178,7 @@ template struct AsinhContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AsinhOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -225,9 +224,7 @@ template struct AsinhStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AsinhOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp index 64cd3c316f..fbba3fc436 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp @@ -155,14 +155,15 @@ using AtanStridedFunctor = elementwise_common:: template struct AtanOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -184,9 +185,7 @@ template struct AtanContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AtanOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -232,9 +231,7 @@ template struct AtanStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AtanOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp index f5cacb9178..1a694527dd 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp @@ -90,8 +90,7 @@ using Atan2StridedFunctor = template struct Atan2OutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct Atan2OutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct Atan2ContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename Atan2OutputType::value_type, void>) - { + if constexpr (!Atan2OutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -149,7 +148,6 @@ template struct Atan2TypeMapFactory std::enable_if_t::value, int> get() { using rT = typename Atan2OutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -183,9 +181,7 @@ template struct Atan2StridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename Atan2OutputType::value_type, void>) - { + if constexpr (!Atan2OutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp index 5002a18b19..340e72b11c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp @@ -149,14 +149,15 @@ using AtanhStridedFunctor = elementwise_common:: template struct AtanhOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -178,9 +179,7 @@ template struct AtanhContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AtanhOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -226,9 +225,7 @@ template struct AtanhStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AtanhOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp index ffe80f622e..da32b17183 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp @@ -113,9 +113,7 @@ using BitwiseAndStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct BitwiseAndOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by - // DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct BitwiseAndOutputType std::int64_t, std::int64_t>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct BitwiseAndContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseAndOutputType::value_type, - void>) - { + if constexpr (!BitwiseAndOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -245,10 +242,7 @@ struct BitwiseAndStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseAndOutputType::value_type, - void>) - { + if constexpr (!BitwiseAndOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -322,6 +316,41 @@ template class bitwise_and_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise AND */ +template +struct BitwiseAndInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseAndInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x &= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseAndInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_and_inplace_contig_impl(sycl::queue &exec_q, @@ -343,10 +372,7 @@ struct BitwiseAndInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseAndOutputType::value_type, - void>) - { + if constexpr (!BitwiseAndInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -385,10 +411,7 @@ struct BitwiseAndInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseAndOutputType::value_type, - void>) - { + if constexpr (!BitwiseAndInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp index 6def3a511c..d6c1bc72db 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp @@ -100,8 +100,7 @@ using BitwiseInvertStridedFunctor = template struct BitwiseInvertOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -112,6 +111,8 @@ template struct BitwiseInvertOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -135,10 +136,7 @@ template struct BitwiseInvertContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseInvertOutputType::value_type, - void>) - { + if constexpr (!BitwiseInvertOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -186,10 +184,7 @@ template struct BitwiseInvertStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseInvertOutputType::value_type, - void>) - { + if constexpr (!BitwiseInvertOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp index 18a87e5287..a987c8d604 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp @@ -123,9 +123,7 @@ using BitwiseLeftShiftStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct BitwiseLeftShiftOutputType { using ResT = T1; - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by - // DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct BitwiseLeftShiftOutputType std::uint64_t, std::uint64_t>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template ::value_type, - void>) - { + if constexpr (!BitwiseLeftShiftOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -256,10 +253,7 @@ struct BitwiseLeftShiftStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!BitwiseLeftShiftOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -336,6 +330,41 @@ template class bitwise_left_shift_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise left shift */ +template +struct BitwiseLeftShiftInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseLeftShiftInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x <<= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseLeftShiftInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_left_shift_inplace_contig_impl( sycl::queue &exec_q, @@ -357,9 +386,8 @@ struct BitwiseLeftShiftInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) + if constexpr (!BitwiseLeftShiftInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; @@ -399,9 +427,8 @@ struct BitwiseLeftShiftInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) + if constexpr (!BitwiseLeftShiftInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp index aad31aae95..71f3e809d9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp @@ -112,9 +112,7 @@ using BitwiseOrStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct BitwiseOrOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by - // DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct BitwiseOrOutputType std::int64_t, std::int64_t>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct BitwiseOrContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseOrOutputType::value_type, - void>) - { + if constexpr (!BitwiseOrOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -241,10 +238,7 @@ template struct BitwiseOrStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseOrOutputType::value_type, - void>) - { + if constexpr (!BitwiseOrOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -318,6 +312,39 @@ template class bitwise_or_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise OR */ +template struct BitwiseOrInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseOrInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x |= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseOrInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_or_inplace_contig_impl(sycl::queue &exec_q, @@ -339,10 +366,7 @@ struct BitwiseOrInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseOrOutputType::value_type, - void>) - { + if constexpr (!BitwiseOrInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -381,10 +405,7 @@ struct BitwiseOrInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseOrOutputType::value_type, - void>) - { + if constexpr (!BitwiseOrInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp index 2fbee2e49d..e4dfee2ed6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp @@ -125,9 +125,7 @@ using BitwiseRightShiftStridedFunctor = template struct BitwiseRightShiftOutputType { using ResT = T1; - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by - // DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct BitwiseRightShiftOutputType std::uint64_t, std::uint64_t>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template ::value_type, - void>) - { + if constexpr (!BitwiseRightShiftOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -258,10 +255,7 @@ struct BitwiseRightShiftStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!BitwiseRightShiftOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -340,6 +334,41 @@ template class bitwise_right_shift_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise right shift */ +template +struct BitwiseRightShiftInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseRightShiftInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x >>= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseRightShiftInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_right_shift_inplace_contig_impl( sycl::queue &exec_q, @@ -361,9 +390,8 @@ struct BitwiseRightShiftInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) + if constexpr (!BitwiseRightShiftInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; @@ -403,9 +431,8 @@ struct BitwiseRightShiftInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) + if constexpr (!BitwiseRightShiftInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp index fd0ca880c9..d035b31170 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp @@ -113,9 +113,7 @@ using BitwiseXorStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct BitwiseXorOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by - // DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct BitwiseXorOutputType std::int64_t, std::int64_t>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct BitwiseXorContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseXorOutputType::value_type, - void>) - { + if constexpr (!BitwiseXorOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -245,10 +242,7 @@ struct BitwiseXorStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseXorOutputType::value_type, - void>) - { + if constexpr (!BitwiseXorOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -322,6 +316,41 @@ template class bitwise_xor_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise XOR */ +template +struct BitwiseXorInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseXorInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x ^= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseXorInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_xor_inplace_contig_impl(sycl::queue &exec_q, @@ -343,10 +372,7 @@ struct BitwiseXorInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseXorOutputType::value_type, - void>) - { + if constexpr (!BitwiseXorInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -385,10 +411,7 @@ struct BitwiseXorInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseXorOutputType::value_type, - void>) - { + if constexpr (!BitwiseXorInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp index 083b7bee9d..4f2634f17a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp @@ -82,12 +82,13 @@ using CbrtStridedFunctor = elementwise_common:: template struct CbrtOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -109,9 +110,7 @@ template struct CbrtContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!CbrtOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -157,9 +156,7 @@ template struct CbrtStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!CbrtOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp index 3cb90df632..59bc630720 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp @@ -95,20 +95,21 @@ using CeilStridedFunctor = elementwise_common:: template struct CeilOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::DefaultResultEntry>::result_type; + using value_type = + typename std::disjunction, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -130,9 +131,7 @@ template struct CeilContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!CeilOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -178,9 +177,7 @@ template struct CeilStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!CeilOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp index 5348173856..4953feedb2 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp @@ -99,8 +99,7 @@ using ConjStridedFunctor = elementwise_common:: template struct ConjOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -116,6 +115,8 @@ template struct ConjOutputType td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -137,9 +138,7 @@ template struct ConjContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!ConjOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -185,9 +184,7 @@ template struct ConjStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!ConjOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp index d226422494..92997b572b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp @@ -104,8 +104,7 @@ using CopysignStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct CopysignOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct CopysignOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct CopysignContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename CopysignOutputType::value_type, - void>) - { + if constexpr (!CopysignOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -197,10 +195,7 @@ template struct CopysignStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename CopysignOutputType::value_type, - void>) - { + if constexpr (!CopysignOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index 14b2345788..8b6b0c5fbe 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -180,8 +180,7 @@ using CosStridedFunctor = elementwise_common:: template struct CosOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -189,6 +188,8 @@ template struct CosOutputType td_ns:: TypeMapResultEntry, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -210,9 +211,7 @@ template struct CosContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!CosOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -257,9 +256,7 @@ template struct CosStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!CosOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp index 866bd2731d..cff1038ed9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp @@ -170,14 +170,15 @@ using CoshStridedFunctor = elementwise_common:: template struct CoshOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -199,9 +200,7 @@ template struct CoshContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!CoshOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -247,9 +246,7 @@ template struct CoshStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!CoshOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index 5086a89cec..d368658afc 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -141,8 +141,7 @@ using EqualStridedFunctor = template struct EqualOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns:: BinaryTypeMapResultEntry, @@ -186,6 +185,8 @@ template struct EqualOutputType std::complex, bool>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct EqualContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename EqualOutputType::value_type, void>) - { + if constexpr (!EqualOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -268,9 +267,7 @@ template struct EqualStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename EqualOutputType::value_type, void>) - { + if constexpr (!EqualOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp index 38abed80cb..7e613c9731 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp @@ -139,14 +139,15 @@ using ExpStridedFunctor = elementwise_common:: template struct ExpOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -168,9 +169,7 @@ template struct ExpContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!ExpOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -215,9 +214,7 @@ template struct ExpStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!ExpOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp index 9d244a0375..b436bb3855 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp @@ -141,14 +141,15 @@ using Exp2StridedFunctor = elementwise_common:: template struct Exp2OutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -170,9 +171,7 @@ template struct Exp2ContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Exp2OutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -218,9 +217,7 @@ template struct Exp2StridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Exp2OutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index 2ab077ab76..9a9d0a1562 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -153,8 +153,7 @@ using Expm1StridedFunctor = elementwise_common:: template struct Expm1OutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -162,6 +161,8 @@ template struct Expm1OutputType td_ns:: TypeMapResultEntry, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -183,9 +184,7 @@ template struct Expm1ContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Expm1OutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -202,7 +201,6 @@ template struct Expm1TypeMapFactory std::enable_if_t::value, int> get() { using rT = typename Expm1OutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -232,9 +230,7 @@ template struct Expm1StridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Expm1OutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp index 90e6941bdd..530dd3d9aa 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp @@ -95,20 +95,21 @@ using FloorStridedFunctor = elementwise_common:: template struct FloorOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::DefaultResultEntry>::result_type; + using value_type = + typename std::disjunction, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -130,9 +131,7 @@ template struct FloorContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!FloorOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -178,9 +177,7 @@ template struct FloorStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!FloorOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp index ce89b0778f..72ee3a789a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp @@ -148,8 +148,7 @@ using FloorDivideStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct FloorDivideOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct FloorDivideOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template ::value_type, - void>) - { + if constexpr (!FloorDivideOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -285,10 +283,7 @@ struct FloorDivideStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename FloorDivideOutputType::value_type, - void>) - { + if constexpr (!FloorDivideOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -398,6 +393,43 @@ template class floor_divide_inplace_contig_kernel; +/* @brief Types supported by in-place floor division */ +template +struct FloorDivideInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct FloorDivideInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x //= y */ + std::enable_if_t::value, int> get() + { + if constexpr (FloorDivideInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event floor_divide_inplace_contig_impl(sycl::queue &exec_q, @@ -419,10 +451,7 @@ struct FloorDivideInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename FloorDivideOutputType::value_type, - void>) - { + if constexpr (!FloorDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -461,10 +490,7 @@ struct FloorDivideInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename FloorDivideOutputType::value_type, - void>) - { + if constexpr (!FloorDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp index ea53d575ba..05c2a36b0c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp @@ -142,8 +142,7 @@ using GreaterStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct GreaterOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns:: BinaryTypeMapResultEntry, @@ -187,6 +186,8 @@ template struct GreaterOutputType std::complex, bool>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct GreaterContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename GreaterOutputType::value_type, void>) - { + if constexpr (!GreaterOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -269,9 +268,7 @@ template struct GreaterStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename GreaterOutputType::value_type, void>) - { + if constexpr (!GreaterOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp index b429a3d00c..43e4e98db1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp @@ -143,8 +143,7 @@ using GreaterEqualStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct GreaterEqualOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns:: BinaryTypeMapResultEntry, @@ -188,6 +187,8 @@ template struct GreaterEqualOutputType std::complex, bool>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template ::value_type, - void>) - { + if constexpr (!GreaterEqualOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -276,10 +274,7 @@ struct GreaterEqualStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename GreaterEqualOutputType::value_type, - void>) - { + if constexpr (!GreaterEqualOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp index 52498f76d5..c5b68644a9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp @@ -106,8 +106,7 @@ using HypotStridedFunctor = template struct HypotOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct HypotOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct HypotContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename HypotOutputType::value_type, void>) - { + if constexpr (!HypotOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -165,7 +164,6 @@ template struct HypotTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename HypotOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -199,9 +197,7 @@ template struct HypotStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename HypotOutputType::value_type, void>) - { + if constexpr (!HypotOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index 6b937b3071..e918bc0ac7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -95,8 +95,7 @@ using ImagStridedFunctor = elementwise_common:: template struct ImagOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -112,6 +111,8 @@ template struct ImagOutputType td_ns::TypeMapResultEntry, float>, td_ns::TypeMapResultEntry, double>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -133,9 +134,7 @@ template struct ImagContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!ImagOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -181,9 +180,7 @@ template struct ImagStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!ImagOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp index 88d0e6e19f..0b26342563 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp @@ -140,8 +140,7 @@ using LessStridedFunctor = template struct LessOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns:: BinaryTypeMapResultEntry, @@ -185,6 +184,8 @@ template struct LessOutputType std::complex, bool>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct LessContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename LessOutputType::value_type, void>) - { + if constexpr (!LessOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -267,9 +266,7 @@ template struct LessStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename LessOutputType::value_type, void>) - { + if constexpr (!LessOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp index 97400aa475..01289ae98f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp @@ -141,8 +141,7 @@ using LessEqualStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct LessEqualOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns:: BinaryTypeMapResultEntry, @@ -186,6 +185,8 @@ template struct LessEqualOutputType std::complex, bool>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct LessEqualContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename LessEqualOutputType::value_type, - void>) - { + if constexpr (!LessEqualOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -269,10 +267,7 @@ template struct LessEqualStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename LessEqualOutputType::value_type, - void>) - { + if constexpr (!LessEqualOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp index 6a4fd4e34e..a3e28ef5d7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp @@ -95,8 +95,7 @@ using LogStridedFunctor = elementwise_common:: template struct LogOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -104,6 +103,8 @@ template struct LogOutputType td_ns:: TypeMapResultEntry, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -125,9 +126,7 @@ template struct LogContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!LogOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -144,7 +143,6 @@ template struct LogTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename LogOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -173,9 +171,7 @@ template struct LogStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!LogOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp index af2ad072c5..793b910f69 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp @@ -114,8 +114,7 @@ using Log10StridedFunctor = elementwise_common:: template struct Log10OutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -123,6 +122,8 @@ template struct Log10OutputType td_ns:: TypeMapResultEntry, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -144,9 +145,7 @@ template struct Log10ContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Log10OutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -163,7 +162,6 @@ template struct Log10TypeMapFactory std::enable_if_t::value, int> get() { using rT = typename Log10OutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -193,9 +191,7 @@ template struct Log10StridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Log10OutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp index fe08dc805f..19238e7e37 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp @@ -119,8 +119,7 @@ using Log1pStridedFunctor = elementwise_common:: template struct Log1pOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -128,6 +127,8 @@ template struct Log1pOutputType td_ns:: TypeMapResultEntry, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -149,9 +150,7 @@ template struct Log1pContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Log1pOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -168,7 +167,6 @@ template struct Log1pTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename Log1pOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -198,9 +196,7 @@ template struct Log1pStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Log1pOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp index d4ea0aca47..69d0022c72 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp @@ -115,8 +115,7 @@ using Log2StridedFunctor = elementwise_common:: template struct Log2OutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -124,6 +123,8 @@ template struct Log2OutputType td_ns:: TypeMapResultEntry, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -145,9 +146,7 @@ template struct Log2ContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Log2OutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -164,7 +163,6 @@ template struct Log2TypeMapFactory std::enable_if_t::value, int> get() { using rT = typename Log2OutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -194,9 +192,7 @@ template struct Log2StridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!Log2OutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp index d94724edc6..b0be45ea54 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp @@ -121,8 +121,7 @@ using LogAddExpStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct LogAddExpOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct LogAddExpOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct LogAddExpContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename LogAddExpOutputType::value_type, - void>) - { + if constexpr (!LogAddExpOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -181,7 +179,6 @@ template struct LogAddExpTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename LogAddExpOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -215,10 +212,7 @@ template struct LogAddExpStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename LogAddExpOutputType::value_type, - void>) - { + if constexpr (!LogAddExpOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp index 403a8a2799..f15caa02e6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp @@ -115,9 +115,7 @@ using LogicalAndStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct LogicalAndOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by - // DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns:: BinaryTypeMapResultEntry, @@ -157,6 +155,8 @@ template struct LogicalAndOutputType std::complex, bool>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct LogicalAndContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename LogicalAndOutputType::value_type, - void>) - { + if constexpr (!LogicalAndOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -244,10 +241,7 @@ struct LogicalAndStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename LogicalAndOutputType::value_type, - void>) - { + if constexpr (!LogicalAndOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp index 4706c7936c..43e02f2102 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp @@ -114,9 +114,7 @@ using LogicalOrStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct LogicalOrOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by - // DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns:: BinaryTypeMapResultEntry, @@ -156,6 +154,8 @@ template struct LogicalOrOutputType std::complex, bool>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct LogicalOrContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename LogicalOrOutputType::value_type, - void>) - { + if constexpr (!LogicalOrOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -240,10 +237,7 @@ template struct LogicalOrStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename LogicalOrOutputType::value_type, - void>) - { + if constexpr (!LogicalOrOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp index a444bc5395..dc41760985 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp @@ -116,9 +116,7 @@ using LogicalXorStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct LogicalXorOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by - // DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns:: BinaryTypeMapResultEntry, @@ -158,6 +156,8 @@ template struct LogicalXorOutputType std::complex, bool>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct LogicalXorContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename LogicalXorOutputType::value_type, - void>) - { + if constexpr (!LogicalXorOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -245,10 +242,7 @@ struct LogicalXorStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename LogicalXorOutputType::value_type, - void>) - { + if constexpr (!LogicalXorOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp index 9081cd1c6a..e73704bad8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp @@ -118,8 +118,7 @@ using MaximumStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct MaximumOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct MaximumOutputType std::complex, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct MaximumContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename MaximumOutputType::value_type, void>) - { + if constexpr (!MaximumOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -228,7 +227,6 @@ template struct MaximumTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename MaximumOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -262,9 +260,7 @@ template struct MaximumStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename MaximumOutputType::value_type, void>) - { + if constexpr (!MaximumOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp index f736697997..590c0b6486 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp @@ -118,8 +118,7 @@ using MinimumStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct MinimumOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct MinimumOutputType std::complex, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct MinimumContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename MinimumOutputType::value_type, void>) - { + if constexpr (!MinimumOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -228,7 +227,6 @@ template struct MinimumTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename MinimumOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -262,9 +260,7 @@ template struct MinimumStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename MinimumOutputType::value_type, void>) - { + if constexpr (!MinimumOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index 147f62f53e..1af284f55b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -120,8 +120,7 @@ using MultiplyStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct MultiplyOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct MultiplyOutputType std::complex, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct MultiplyContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename MultiplyOutputType::value_type, - void>) - { + if constexpr (!MultiplyOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -231,7 +229,6 @@ template struct MultiplyTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename MultiplyOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -265,10 +262,7 @@ template struct MultiplyStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename MultiplyOutputType::value_type, - void>) - { + if constexpr (!MultiplyOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -317,12 +311,12 @@ struct MultiplyContigMatrixContigRowBroadcastFactory { fnT get() { - using resT = typename MultiplyOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!MultiplyOutputType::is_defined) { fnT fn = nullptr; return fn; } else { + using resT = typename MultiplyOutputType::value_type; if constexpr (dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value) @@ -365,12 +359,12 @@ struct MultiplyContigRowContigMatrixBroadcastFactory { fnT get() { - using resT = typename MultiplyOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!MultiplyOutputType::is_defined) { fnT fn = nullptr; return fn; } else { + using resT = typename MultiplyOutputType::value_type; if constexpr (dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value) @@ -434,6 +428,50 @@ template class multiply_inplace_contig_kernel; +/* @brief Types supported by in-place multiplication */ +template struct MultiplyInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MultiplyInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x *= y */ + std::enable_if_t::value, int> get() + { + if constexpr (MultiplyInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event multiply_inplace_contig_impl(sycl::queue &exec_q, @@ -455,10 +493,7 @@ struct MultiplyInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename MultiplyOutputType::value_type, - void>) - { + if constexpr (!MultiplyInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -497,10 +532,7 @@ struct MultiplyInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename MultiplyOutputType::value_type, - void>) - { + if constexpr (!MultiplyInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -545,8 +577,7 @@ struct MultiplyInplaceRowMatrixBroadcastFactory { fnT get() { - using resT = typename MultiplyOutputType::value_type; - if constexpr (!std::is_same_v) { + if constexpr (!MultiplyInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp index d52c2f33ee..83f17dd47b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp @@ -78,8 +78,7 @@ using NegativeContigFunctor = template struct NegativeOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -94,6 +93,8 @@ template struct NegativeOutputType td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -116,9 +117,7 @@ template struct NegativeContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!NegativeOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -135,7 +134,6 @@ template struct NegativeTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename NegativeOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -170,9 +168,7 @@ template struct NegativeStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!NegativeOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/nextafter.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/nextafter.hpp index d5137f0c6d..5dc9ea40b3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/nextafter.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/nextafter.hpp @@ -104,8 +104,7 @@ using NextafterStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct NextafterOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct NextafterOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct NextafterContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename NextafterOutputType::value_type, - void>) - { + if constexpr (!NextafterOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -164,7 +162,6 @@ template struct NextafterTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename NextafterOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -198,10 +195,7 @@ template struct NextafterStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename NextafterOutputType::value_type, - void>) - { + if constexpr (!NextafterOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp index 437ceb2da8..c1b920193b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp @@ -125,8 +125,7 @@ using NotEqualStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct NotEqualOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns:: BinaryTypeMapResultEntry, @@ -170,6 +169,8 @@ template struct NotEqualOutputType std::complex, bool>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct NotEqualContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename NotEqualOutputType::value_type, - void>) - { + if constexpr (!NotEqualOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -253,10 +251,7 @@ template struct NotEqualStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename NotEqualOutputType::value_type, - void>) - { + if constexpr (!NotEqualOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp index 92eaf3c0d2..ae2711ed0e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp @@ -93,8 +93,7 @@ using PositiveContigFunctor = template struct PositiveOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -109,6 +108,8 @@ template struct PositiveOutputType td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -131,9 +132,7 @@ template struct PositiveContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!PositiveOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -150,7 +149,6 @@ template struct PositiveTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename PositiveOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -185,9 +183,7 @@ template struct PositiveStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!PositiveOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp index a21b2d4318..bb462dceae 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp @@ -173,8 +173,7 @@ using PowStridedFunctor = template struct PowOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct PowOutputType std::complex, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct PowContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!PowOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -282,7 +281,6 @@ template struct PowTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename PowOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -314,9 +312,7 @@ template struct PowStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!PowOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -446,6 +442,49 @@ template class pow_inplace_contig_kernel; +/* @brief Types supported by in-place pow */ +template struct PowInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct PowInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x **= y */ + std::enable_if_t::value, int> get() + { + if constexpr (PowInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event pow_inplace_contig_impl(sycl::queue &exec_q, @@ -465,9 +504,7 @@ template struct PowInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!PowInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -505,9 +542,7 @@ struct PowInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!PowInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp index 92a8535309..2c3dce0c9c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp @@ -108,11 +108,12 @@ using ProjStridedFunctor = elementwise_common:: template struct ProjOutputType { - // disjunction is C++17 feature, supported by DPC++ using value_type = typename std::disjunction< td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -134,9 +135,7 @@ template struct ProjContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!ProjOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -188,9 +187,7 @@ template struct ProjStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!ProjOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index 8949a79955..c66e4003cb 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -95,8 +95,7 @@ using RealStridedFunctor = elementwise_common:: template struct RealOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -112,6 +111,8 @@ template struct RealOutputType td_ns::TypeMapResultEntry, float>, td_ns::TypeMapResultEntry, double>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -133,9 +134,7 @@ template struct RealContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!RealOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -181,9 +180,7 @@ template struct RealStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!RealOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp index 76c48e173c..4d4b70fd4f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp @@ -101,14 +101,15 @@ using ReciprocalStridedFunctor = template struct ReciprocalOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -131,9 +132,7 @@ template struct ReciprocalContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename ReciprocalOutputType::value_type, void>) - { + if constexpr (!ReciprocalOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -181,9 +180,7 @@ template struct ReciprocalStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename ReciprocalOutputType::value_type, void>) - { + if constexpr (!ReciprocalOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp index 585d1c6d7f..7bb070cc00 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp @@ -166,8 +166,7 @@ using RemainderStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct RemainderOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct RemainderOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct RemainderContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename RemainderOutputType::value_type, - void>) - { + if constexpr (!RemainderOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -299,10 +297,7 @@ template struct RemainderStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename RemainderOutputType::value_type, - void>) - { + if constexpr (!RemainderOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -424,6 +419,41 @@ template class remainder_inplace_contig_kernel; +/* @brief Types supported by in-place remainder */ +template struct RemainderInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct RemainderInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x %= y */ + std::enable_if_t::value, int> get() + { + if constexpr (RemainderInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event remainder_inplace_contig_impl(sycl::queue &exec_q, @@ -445,10 +475,7 @@ struct RemainderInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename RemainderOutputType::value_type, - void>) - { + if constexpr (!RemainderInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -487,10 +514,7 @@ struct RemainderInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename RemainderOutputType::value_type, - void>) - { + if constexpr (!RemainderInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index 49cbbf682c..241f75c1bb 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -104,8 +104,7 @@ using RoundStridedFunctor = elementwise_common:: template struct RoundOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -120,6 +119,8 @@ template struct RoundOutputType td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -141,9 +142,7 @@ template struct RoundContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!RoundOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -189,9 +188,7 @@ template struct RoundStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!RoundOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp index 541b036931..61aafb13d9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp @@ -85,12 +85,13 @@ using RsqrtStridedFunctor = elementwise_common:: template struct RsqrtOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -112,9 +113,7 @@ template struct RsqrtContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!RsqrtOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -160,9 +159,7 @@ template struct RsqrtStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!RsqrtOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp index 554398ae56..651f7d5d9a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp @@ -116,8 +116,7 @@ using SignContigFunctor = template struct SignOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -132,6 +131,8 @@ template struct SignOutputType td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -153,9 +154,7 @@ template struct SignContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SignOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -172,7 +171,6 @@ template struct SignTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename SignOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; @@ -206,9 +204,7 @@ template struct SignStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SignOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp index ee4a97d9b6..e8ac7709ad 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp @@ -91,12 +91,13 @@ using SignbitStridedFunctor = elementwise_common:: template struct SignbitOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -118,9 +119,7 @@ template struct SignbitContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SignbitOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -167,9 +166,7 @@ template struct SignbitStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SignbitOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index ba46affae6..8bc12097a8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -203,14 +203,15 @@ using SinStridedFunctor = elementwise_common:: template struct SinOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -232,9 +233,7 @@ template struct SinContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SinOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -279,9 +278,7 @@ template struct SinStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SinOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp index 01a87d923f..e83626e56d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp @@ -172,14 +172,15 @@ using SinhStridedFunctor = elementwise_common:: template struct SinhOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -201,9 +202,7 @@ template struct SinhContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SinhOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -249,9 +248,7 @@ template struct SinhStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SinhOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index 6b63f74fe7..5adb41b20d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -97,8 +97,7 @@ using SqrtStridedFunctor = elementwise_common:: template struct SqrtOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -106,6 +105,8 @@ template struct SqrtOutputType td_ns:: TypeMapResultEntry, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -127,9 +128,7 @@ template struct SqrtContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SqrtOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -175,9 +174,7 @@ template struct SqrtStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SqrtOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp index 72f3bda389..4b096cc291 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp @@ -114,8 +114,7 @@ using SquareStridedFunctor = elementwise_common:: template struct SquareOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, @@ -131,6 +130,8 @@ template struct SquareOutputType td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -152,9 +153,7 @@ template struct SquareContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SquareOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -200,9 +199,7 @@ template struct SquareStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!SquareOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp index 4a8cfb50a7..4ee3ae089b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp @@ -107,8 +107,7 @@ using SubtractStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct SubtractOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct SubtractOutputType std::complex, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct SubtractContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename SubtractOutputType::value_type, - void>) - { + if constexpr (!SubtractOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -250,10 +248,7 @@ template struct SubtractStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename SubtractOutputType::value_type, - void>) - { + if constexpr (!SubtractOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -313,10 +308,7 @@ struct SubtractContigMatrixContigRowBroadcastFactory { fnT get() { - if constexpr (std::is_same_v< - typename SubtractOutputType::value_type, - void>) - { + if constexpr (!SubtractOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -366,12 +358,12 @@ struct SubtractContigRowContigMatrixBroadcastFactory { fnT get() { - using resT = typename SubtractOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!SubtractOutputType::is_defined) { fnT fn = nullptr; return fn; } else { + using resT = typename SubtractOutputType::value_type; if constexpr (dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value) @@ -435,6 +427,49 @@ template class subtract_inplace_contig_kernel; +/* @brief Types supported by in-place subtraction */ +template struct SubtractInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct SubtractInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x -= y */ + std::enable_if_t::value, int> get() + { + if constexpr (SubtractInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event subtract_inplace_contig_impl(sycl::queue &exec_q, @@ -456,10 +491,7 @@ struct SubtractInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename SubtractOutputType::value_type, - void>) - { + if constexpr (!SubtractInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -498,10 +530,7 @@ struct SubtractInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename SubtractOutputType::value_type, - void>) - { + if constexpr (!SubtractInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -546,8 +575,7 @@ struct SubtractInplaceRowMatrixBroadcastFactory { fnT get() { - using resT = typename SubtractOutputType::value_type; - if constexpr (!std::is_same_v) { + if constexpr (!SubtractInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp index c745bc011b..4364d81fb7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp @@ -147,14 +147,15 @@ using TanStridedFunctor = elementwise_common:: template struct TanOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -176,9 +177,7 @@ template struct TanContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TanOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -223,9 +222,7 @@ template struct TanStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TanOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp index 2e23e46f4a..0af4e4e628 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp @@ -141,14 +141,15 @@ using TanhStridedFunctor = elementwise_common:: template struct TanhOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry>, td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -170,9 +171,7 @@ template struct TanhContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TanhOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -218,9 +217,7 @@ template struct TanhStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TanhOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index 396e5c995e..53db1e163c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -134,8 +134,7 @@ using TrueDivideStridedFunctor = elementwise_common::BinaryStridedFunctor< template struct TrueDivideOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct TrueDivideOutputType double, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct TrueDivideContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename TrueDivideOutputType::value_type, - void>) - { + if constexpr (!TrueDivideOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -260,10 +258,7 @@ struct TrueDivideStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename TrueDivideOutputType::value_type, - void>) - { + if constexpr (!TrueDivideOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -323,10 +318,7 @@ struct TrueDivideContigMatrixContigRowBroadcastFactory { fnT get() { - if constexpr (std::is_same_v< - typename TrueDivideOutputType::value_type, - void>) - { + if constexpr (!TrueDivideOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -376,12 +368,12 @@ struct TrueDivideContigRowContigMatrixBroadcastFactory { fnT get() { - using resT = typename TrueDivideOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!TrueDivideOutputType::is_defined) { fnT fn = nullptr; return fn; } else { + using resT = typename TrueDivideOutputType::value_type; if constexpr (dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value) @@ -439,52 +431,43 @@ template struct TrueDivideInplaceFunctor } }; -// cannot use the out of place table, as it permits real lhs and complex rhs -// T1 corresponds to the type of the rhs, while T2 corresponds to the lhs -// the type of the result must be the same as T2 -template struct TrueDivideInplaceOutputType +/* @brief Types supported by in-place divide */ +template +struct TrueDivideInplaceTypePairSupport { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ - td_ns::BinaryTypeMapResultEntry, - td_ns::BinaryTypeMapResultEntry, - td_ns::BinaryTypeMapResultEntry, - td_ns::BinaryTypeMapResultEntry, - T2, - std::complex, - std::complex>, - td_ns::BinaryTypeMapResultEntry, - std::complex>, - td_ns::BinaryTypeMapResultEntry, - T2, - std::complex, - std::complex>, - td_ns::BinaryTypeMapResultEntry, - std::complex>, - td_ns::DefaultResultEntry>::result_type; + + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; }; -template +template struct TrueDivideInplaceTypeMapFactory { /*! @brief get typeid for output type of divide(T1 x, T2 y) */ std::enable_if_t::value, int> get() { - using rT = typename TrueDivideInplaceOutputType::value_type; - static_assert(std::is_same_v || std::is_same_v); - return td_ns::GetTypeid{}.get(); + if constexpr (TrueDivideInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } } }; @@ -537,10 +520,7 @@ struct TrueDivideInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TrueDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -579,10 +559,7 @@ struct TrueDivideInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TrueDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -627,8 +604,7 @@ struct TrueDivideInplaceRowMatrixBroadcastFactory { fnT get() { - using resT = typename TrueDivideInplaceOutputType::value_type; - if constexpr (!std::is_same_v) { + if constexpr (!TrueDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp index 5740bc0ef2..55c8493880 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp @@ -92,20 +92,21 @@ using TruncStridedFunctor = elementwise_common:: template struct TruncOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, - td_ns::DefaultResultEntry>::result_type; + using value_type = + typename std::disjunction, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template @@ -127,9 +128,7 @@ template struct TruncContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TruncOutputType::is_defined) { fnT fn = nullptr; return fn; } @@ -175,9 +174,7 @@ template struct TruncStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TruncOutputType::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_logsumexp.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_logsumexp.cpp index 4cf2ec6fe9..f9d62f3e4b 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_logsumexp.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_logsumexp.cpp @@ -70,7 +70,6 @@ template struct TypePairSupportDataForLogSumExpAccumulation { static constexpr bool is_defined = std::disjunction< - // disjunction is C++17 feature, supported by DPC++ input bool td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp index 7982e0ff41..6880069fe5 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp @@ -69,8 +69,6 @@ static accumulate_strided_impl_fn_ptr_t template struct TypePairSupportDataForProdAccumulation { - - // disjunction is C++17 feature, supported by DPC++ input bool static constexpr bool is_defined = std::disjunction< td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp index 59e5830fe6..f023833982 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp @@ -69,8 +69,6 @@ static accumulate_strided_impl_fn_ptr_t template struct TypePairSupportDataForSumAccumulation { - - // disjunction is C++17 feature, supported by DPC++ input bool static constexpr bool is_defined = std::disjunction< td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp index 0cce42cdb5..823647aefd 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp @@ -67,7 +67,9 @@ namespace add_fn_ns = dpctl::tensor::kernels::add; static binary_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; + static int add_output_id_table[td_ns::num_types][td_ns::num_types]; +static int add_inplace_output_id_table[td_ns::num_types][td_ns::num_types]; static binary_strided_impl_fn_ptr_t add_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -154,6 +156,11 @@ void populate_add_dispatch_tables(void) AddInplaceRowMatrixBroadcastFactory, num_types> dtb8; dtb8.populate_dispatch_table(add_inplace_row_matrix_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::AddInplaceTypeMapFactory; + DispatchTableBuilder dtb9; + dtb9.populate_dispatch_table(add_inplace_output_id_table); }; } // namespace impl @@ -199,6 +206,7 @@ void init_add(py::module_ m) m.def("_add_result_type", add_result_type_pyapi, ""); using impl::add_inplace_contig_dispatch_table; + using impl::add_inplace_output_id_table; using impl::add_inplace_row_matrix_dispatch_table; using impl::add_inplace_strided_dispatch_table; @@ -206,7 +214,7 @@ void init_add(py::module_ m) sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, add_output_id_table, + src, dst, exec_q, depends, add_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) add_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp index afd7ffe469..15ccf22ea2 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp @@ -66,7 +66,10 @@ namespace bitwise_and_fn_ns = dpctl::tensor::kernels::bitwise_and; static binary_contig_impl_fn_ptr_t bitwise_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int bitwise_and_output_id_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_and_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_bitwise_and_dispatch_tables(void) BitwiseAndInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(bitwise_and_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseAndInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(bitwise_and_inplace_output_id_table); }; } // namespace impl @@ -160,25 +168,27 @@ void init_bitwise_and(py::module_ m) m.def("_bitwise_and_result_type", bitwise_and_result_type_pyapi, ""); using impl::bitwise_and_inplace_contig_dispatch_table; + using impl::bitwise_and_inplace_output_id_table; using impl::bitwise_and_inplace_strided_dispatch_table; - auto bitwise_and_inplace_pyapi = - [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, bitwise_and_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - bitwise_and_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - bitwise_and_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; + auto bitwise_and_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, bitwise_and_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_and_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_and_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; m.def("_bitwise_and_inplace", bitwise_and_inplace_pyapi, "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), py::arg("depends") = py::list()); diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp index 777573a738..833b78e697 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp @@ -67,8 +67,11 @@ namespace bitwise_left_shift_fn_ns = dpctl::tensor::kernels::bitwise_left_shift; static binary_contig_impl_fn_ptr_t bitwise_left_shift_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; + static int bitwise_left_shift_output_id_table[td_ns::num_types] [td_ns::num_types]; +static int bitwise_left_shift_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_left_shift_strided_dispatch_table[td_ns::num_types] @@ -120,6 +123,12 @@ void populate_bitwise_left_shift_dispatch_tables(void) dtb5; dtb5.populate_dispatch_table( bitwise_left_shift_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseLeftShiftInplaceTypeMapFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(bitwise_left_shift_inplace_output_id_table); }; } // namespace impl @@ -169,6 +178,7 @@ void init_bitwise_left_shift(py::module_ m) bitwise_left_shift_result_type_pyapi, ""); using impl::bitwise_left_shift_inplace_contig_dispatch_table; + using impl::bitwise_left_shift_inplace_output_id_table; using impl::bitwise_left_shift_inplace_strided_dispatch_table; auto bitwise_left_shift_inplace_pyapi = @@ -176,7 +186,7 @@ void init_bitwise_left_shift(py::module_ m) const event_vecT &depends = {}) { return py_binary_inplace_ufunc( src, dst, exec_q, depends, - bitwise_left_shift_output_id_table, + bitwise_left_shift_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) bitwise_left_shift_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp index 468d887392..ecdaeb6577 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp @@ -66,7 +66,10 @@ namespace bitwise_or_fn_ns = dpctl::tensor::kernels::bitwise_or; static binary_contig_impl_fn_ptr_t bitwise_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int bitwise_or_output_id_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_or_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_bitwise_or_dispatch_tables(void) BitwiseOrInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(bitwise_or_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseOrInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(bitwise_or_inplace_output_id_table); }; } // namespace impl @@ -160,25 +168,27 @@ void init_bitwise_or(py::module_ m) m.def("_bitwise_or_result_type", bitwise_or_result_type_pyapi, ""); using impl::bitwise_or_inplace_contig_dispatch_table; + using impl::bitwise_or_inplace_output_id_table; using impl::bitwise_or_inplace_strided_dispatch_table; - auto bitwise_or_inplace_pyapi = - [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, bitwise_or_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - bitwise_or_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - bitwise_or_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; + auto bitwise_or_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, bitwise_or_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_or_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_or_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; m.def("_bitwise_or_inplace", bitwise_or_inplace_pyapi, "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), py::arg("depends") = py::list()); diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp index 85fc7b99a9..9f74037f41 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp @@ -68,8 +68,11 @@ namespace bitwise_right_shift_fn_ns = static binary_contig_impl_fn_ptr_t bitwise_right_shift_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; + static int bitwise_right_shift_output_id_table[td_ns::num_types] [td_ns::num_types]; +static int bitwise_right_shift_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_right_shift_strided_dispatch_table[td_ns::num_types] @@ -121,6 +124,12 @@ void populate_bitwise_right_shift_dispatch_tables(void) dtb5; dtb5.populate_dispatch_table( bitwise_right_shift_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseRightShiftInplaceTypeMapFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(bitwise_right_shift_inplace_output_id_table); }; } // namespace impl @@ -170,6 +179,7 @@ void init_bitwise_right_shift(py::module_ m) bitwise_right_shift_result_type_pyapi, ""); using impl::bitwise_right_shift_inplace_contig_dispatch_table; + using impl::bitwise_right_shift_inplace_output_id_table; using impl::bitwise_right_shift_inplace_strided_dispatch_table; auto bitwise_right_shift_inplace_pyapi = @@ -177,7 +187,7 @@ void init_bitwise_right_shift(py::module_ m) const event_vecT &depends = {}) { return py_binary_inplace_ufunc( src, dst, exec_q, depends, - bitwise_right_shift_output_id_table, + bitwise_right_shift_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) bitwise_right_shift_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp index 18e9fb695e..6124f305e0 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp @@ -66,7 +66,10 @@ namespace bitwise_xor_fn_ns = dpctl::tensor::kernels::bitwise_xor; static binary_contig_impl_fn_ptr_t bitwise_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int bitwise_xor_output_id_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_xor_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_bitwise_xor_dispatch_tables(void) BitwiseXorInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(bitwise_xor_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseXorInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(bitwise_xor_inplace_output_id_table); }; } // namespace impl @@ -160,25 +168,27 @@ void init_bitwise_xor(py::module_ m) m.def("_bitwise_xor_result_type", bitwise_xor_result_type_pyapi, ""); using impl::bitwise_xor_inplace_contig_dispatch_table; + using impl::bitwise_xor_inplace_output_id_table; using impl::bitwise_xor_inplace_strided_dispatch_table; - auto bitwise_xor_inplace_pyapi = - [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, bitwise_xor_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - bitwise_xor_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - bitwise_xor_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; + auto bitwise_xor_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, bitwise_xor_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_xor_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_xor_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; m.def("_bitwise_xor_inplace", bitwise_xor_inplace_pyapi, "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), py::arg("depends") = py::list()); diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp index fdbffbb3be..921b494fd1 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp @@ -66,7 +66,10 @@ namespace floor_divide_fn_ns = dpctl::tensor::kernels::floor_divide; static binary_contig_impl_fn_ptr_t floor_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int floor_divide_output_id_table[td_ns::num_types][td_ns::num_types]; +static int floor_divide_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t floor_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_floor_divide_dispatch_tables(void) FloorDivideInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(floor_divide_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::FloorDivideInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(floor_divide_inplace_output_id_table); }; } // namespace impl @@ -160,25 +168,27 @@ void init_floor_divide(py::module_ m) m.def("_floor_divide_result_type", floor_divide_result_type_pyapi, ""); using impl::floor_divide_inplace_contig_dispatch_table; + using impl::floor_divide_inplace_output_id_table; using impl::floor_divide_inplace_strided_dispatch_table; - auto floor_divide_inplace_pyapi = - [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, floor_divide_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - floor_divide_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - floor_divide_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; + auto floor_divide_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, floor_divide_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + floor_divide_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + floor_divide_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; m.def("_floor_divide_inplace", floor_divide_inplace_pyapi, "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), py::arg("depends") = py::list()); diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp index 75438e1031..22fc293c98 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp @@ -67,7 +67,9 @@ namespace multiply_fn_ns = dpctl::tensor::kernels::multiply; static binary_contig_impl_fn_ptr_t multiply_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int multiply_output_id_table[td_ns::num_types][td_ns::num_types]; +static int multiply_inplace_output_id_table[td_ns::num_types][td_ns::num_types]; static binary_strided_impl_fn_ptr_t multiply_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -155,6 +157,11 @@ void populate_multiply_dispatch_tables(void) MultiplyInplaceRowMatrixBroadcastFactory, num_types> dtb8; dtb8.populate_dispatch_table(multiply_inplace_row_matrix_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::MultiplyInplaceTypeMapFactory; + DispatchTableBuilder dtb9; + dtb9.populate_dispatch_table(multiply_inplace_output_id_table); }; } // namespace impl @@ -200,6 +207,7 @@ void init_multiply(py::module_ m) m.def("_multiply_result_type", multiply_result_type_pyapi, ""); using impl::multiply_inplace_contig_dispatch_table; + using impl::multiply_inplace_output_id_table; using impl::multiply_inplace_row_matrix_dispatch_table; using impl::multiply_inplace_strided_dispatch_table; @@ -207,7 +215,7 @@ void init_multiply(py::module_ m) sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, multiply_output_id_table, + src, dst, exec_q, depends, multiply_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) multiply_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp index 347b3e298a..25e96eee7f 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp @@ -67,7 +67,9 @@ namespace pow_fn_ns = dpctl::tensor::kernels::pow; static binary_contig_impl_fn_ptr_t pow_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; + static int pow_output_id_table[td_ns::num_types][td_ns::num_types]; +static int pow_inplace_output_id_table[td_ns::num_types][td_ns::num_types]; static binary_strided_impl_fn_ptr_t pow_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -114,6 +116,11 @@ void populate_pow_dispatch_tables(void) PowInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(pow_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::PowInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(pow_inplace_output_id_table); }; } // namespace impl @@ -159,13 +166,14 @@ void init_pow(py::module_ m) m.def("_pow_result_type", pow_result_type_pyapi, ""); using impl::pow_inplace_contig_dispatch_table; + using impl::pow_inplace_output_id_table; using impl::pow_inplace_strided_dispatch_table; auto pow_inplace_pyapi = [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, pow_output_id_table, + src, dst, exec_q, depends, pow_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) pow_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp index d08c7ebda8..a2d3f48099 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp @@ -67,7 +67,10 @@ namespace remainder_fn_ns = dpctl::tensor::kernels::remainder; static binary_contig_impl_fn_ptr_t remainder_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int remainder_output_id_table[td_ns::num_types][td_ns::num_types]; +static int remainder_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t remainder_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_remainder_dispatch_tables(void) RemainderInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(remainder_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::RemainderInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(remainder_inplace_output_id_table); } } // namespace impl @@ -160,13 +168,14 @@ void init_remainder(py::module_ m) m.def("_remainder_result_type", remainder_result_type_pyapi, ""); using impl::remainder_inplace_contig_dispatch_table; + using impl::remainder_inplace_output_id_table; using impl::remainder_inplace_strided_dispatch_table; auto remainder_inplace_pyapi = [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, remainder_output_id_table, + src, dst, exec_q, depends, remainder_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) remainder_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp index 0c81ace53e..f1c1324a58 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp @@ -66,7 +66,9 @@ namespace subtract_fn_ns = dpctl::tensor::kernels::subtract; static binary_contig_impl_fn_ptr_t subtract_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int subtract_output_id_table[td_ns::num_types][td_ns::num_types]; +static int subtract_inplace_output_id_table[td_ns::num_types][td_ns::num_types]; static binary_strided_impl_fn_ptr_t subtract_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -154,6 +156,11 @@ void populate_subtract_dispatch_tables(void) SubtractInplaceRowMatrixBroadcastFactory, num_types> dtb8; dtb8.populate_dispatch_table(subtract_inplace_row_matrix_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::SubtractInplaceTypeMapFactory; + DispatchTableBuilder dtb9; + dtb9.populate_dispatch_table(subtract_inplace_output_id_table); }; } // namespace impl @@ -199,6 +206,7 @@ void init_subtract(py::module_ m) m.def("_subtract_result_type", subtract_result_type_pyapi, ""); using impl::subtract_inplace_contig_dispatch_table; + using impl::subtract_inplace_output_id_table; using impl::subtract_inplace_row_matrix_dispatch_table; using impl::subtract_inplace_strided_dispatch_table; @@ -206,7 +214,7 @@ void init_subtract(py::module_ m) sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, subtract_output_id_table, + src, dst, exec_q, depends, subtract_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) subtract_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp index f78c428152..ffb2afc3ea 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp @@ -137,7 +137,7 @@ void populate_true_divide_dispatch_tables(void) dtb5.populate_dispatch_table( true_divide_contig_row_contig_matrix_broadcast_dispatch_table); - // which input types are supported, and what is the type of the result + // which types are supported by the in-place kernels using fn_ns::TrueDivideInplaceTypeMapFactory; DispatchTableBuilder dtb6; dtb6.populate_dispatch_table(true_divide_inplace_output_id_table); diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp index 3ce0ae1264..35d2ada54a 100644 --- a/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp @@ -43,8 +43,7 @@ namespace td_ns = dpctl::tensor::type_dispatch; template struct DotAtomicOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry struct DotAtomicOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; // add separate type support lists for atomic vs. temps // gemm, gevm, and dot product share output type struct template struct DotNoAtomicOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct DotNoAtomicOutputType std::complex, std::complex>, td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; }; template struct DotTypeMapFactory @@ -179,13 +181,13 @@ template struct GemmBatchAtomicFactory { fnT get() { - using T3 = typename DotAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::gemm_batch_impl; + using T3 = typename DotAtomicOutputType::value_type; fnT fn = gemm_batch_impl; return fn; } @@ -197,13 +199,13 @@ struct GemmBatchContigAtomicFactory { fnT get() { - using T3 = typename DotAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::gemm_batch_contig_impl; + using T3 = typename DotAtomicOutputType::value_type; fnT fn = gemm_batch_contig_impl; return fn; } @@ -214,13 +216,13 @@ template struct GemmAtomicFactory { fnT get() { - using T3 = typename DotAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::gemm_impl; + using T3 = typename DotAtomicOutputType::value_type; fnT fn = gemm_impl; return fn; } @@ -231,13 +233,13 @@ template struct GemmContigAtomicFactory { fnT get() { - using T3 = typename DotAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::gemm_contig_impl; + using T3 = typename DotAtomicOutputType::value_type; fnT fn = gemm_contig_impl; return fn; } @@ -248,13 +250,13 @@ template struct GemmTempsFactory { fnT get() { - using T3 = typename DotNoAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotNoAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::gemm_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; fnT fn = gemm_tree_impl; return fn; } @@ -265,13 +267,13 @@ template struct GemmContigTempsFactory { fnT get() { - using T3 = typename DotNoAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotNoAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::gemm_contig_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; fnT fn = gemm_contig_tree_impl; return fn; } @@ -282,13 +284,13 @@ template struct GemmBatchTempsFactory { fnT get() { - using T3 = typename DotNoAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotNoAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::gemm_batch_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; fnT fn = gemm_batch_tree_impl; return fn; } @@ -300,13 +302,13 @@ struct GemmBatchContigTempsFactory { fnT get() { - using T3 = typename DotNoAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotNoAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::gemm_batch_contig_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; fnT fn = gemm_batch_contig_tree_impl; return fn; } @@ -317,13 +319,13 @@ template struct DotProductAtomicFactory { fnT get() { - using T3 = typename DotAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::dot_product_impl; + using T3 = typename DotAtomicOutputType::value_type; fnT fn = dot_product_impl; return fn; } @@ -335,13 +337,13 @@ struct DotProductNoAtomicFactory { fnT get() { - using T3 = typename DotNoAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotNoAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::dot_product_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; fnT fn = dot_product_tree_impl; return fn; } @@ -353,13 +355,13 @@ struct DotProductContigAtomicFactory { fnT get() { - using T3 = typename DotAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::dot_product_contig_impl; + using T3 = typename DotAtomicOutputType::value_type; fnT fn = dot_product_contig_impl; return fn; } @@ -371,13 +373,13 @@ struct DotProductContigNoAtomicFactory { fnT get() { - using T3 = typename DotNoAtomicOutputType::value_type; - if constexpr (std::is_same_v) { + if constexpr (!DotNoAtomicOutputType::is_defined) { fnT fn = nullptr; return fn; } else { using dpctl::tensor::kernels::dot_product_contig_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; fnT fn = dot_product_contig_tree_impl; return fn; } diff --git a/dpctl/tensor/libtensor/source/reductions/argmax.cpp b/dpctl/tensor/libtensor/source/reductions/argmax.cpp index acfeb95087..90977307e8 100644 --- a/dpctl/tensor/libtensor/source/reductions/argmax.cpp +++ b/dpctl/tensor/libtensor/source/reductions/argmax.cpp @@ -68,9 +68,7 @@ template struct TypePairSupportForArgmaxReductionTemps { - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool + static constexpr bool is_defined = std::disjunction< td_ns::TypePairDefinedEntry, // input int8_t td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/reductions/argmin.cpp b/dpctl/tensor/libtensor/source/reductions/argmin.cpp index 8e9c0106ac..342298936f 100644 --- a/dpctl/tensor/libtensor/source/reductions/argmin.cpp +++ b/dpctl/tensor/libtensor/source/reductions/argmin.cpp @@ -69,9 +69,7 @@ template struct TypePairSupportForArgminReductionTemps { - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool + static constexpr bool is_defined = std::disjunction< td_ns::TypePairDefinedEntry, // input int8_t td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp b/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp index d36b715f32..5037f8f00f 100644 --- a/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp +++ b/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp @@ -67,9 +67,7 @@ template struct TypePairSupportDataForLogSumExpReductionTemps { - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool + static constexpr bool is_defined = std::disjunction< #if 1 td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/reductions/max.cpp b/dpctl/tensor/libtensor/source/reductions/max.cpp index 8036d873aa..896191c1dd 100644 --- a/dpctl/tensor/libtensor/source/reductions/max.cpp +++ b/dpctl/tensor/libtensor/source/reductions/max.cpp @@ -78,10 +78,8 @@ static reduction_contig_impl_fn_ptr template struct TypePairSupportDataForMaxReductionAtomic { - /* value is true if a kernel for must be instantiated, false * otherwise */ - // disjunction is C++17 feature, supported by DPC++ static constexpr bool is_defined = std::disjunction< // input int32 td_ns::TypePairDefinedEntry, @@ -102,8 +100,6 @@ struct TypePairSupportDataForMaxReductionAtomic template struct TypePairSupportDataForMaxReductionTemps { - - // disjunction is C++17 feature, supported by DPC++ static constexpr bool is_defined = std::disjunction< // input bool td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/reductions/min.cpp b/dpctl/tensor/libtensor/source/reductions/min.cpp index e612e59b8f..24e71794b4 100644 --- a/dpctl/tensor/libtensor/source/reductions/min.cpp +++ b/dpctl/tensor/libtensor/source/reductions/min.cpp @@ -78,10 +78,8 @@ static reduction_contig_impl_fn_ptr template struct TypePairSupportDataForMinReductionAtomic { - /* value is true if a kernel for must be instantiated, false * otherwise */ - // disjunction is C++17 feature, supported by DPC++ static constexpr bool is_defined = std::disjunction< // input int32 td_ns::TypePairDefinedEntry, @@ -102,8 +100,6 @@ struct TypePairSupportDataForMinReductionAtomic template struct TypePairSupportDataForMinReductionTemps { - - // disjunction is C++17 feature, supported by DPC++ static constexpr bool is_defined = std::disjunction< // input bool td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp index 9b8df53a01..5299ae83b8 100644 --- a/dpctl/tensor/libtensor/source/reductions/prod.cpp +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -79,9 +79,7 @@ struct TypePairSupportDataForProductReductionAtomic /* value if true a kernel for must be instantiated, false * otherwise */ - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool + static constexpr bool is_defined = std::disjunction< td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, @@ -121,9 +119,7 @@ template struct TypePairSupportDataForProductReductionTemps { - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool + static constexpr bool is_defined = std::disjunction< td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp index 159b992307..1b9aeeff50 100644 --- a/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp +++ b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp @@ -67,9 +67,7 @@ template struct TypePairSupportDataForHypotReductionTemps { - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool + static constexpr bool is_defined = std::disjunction< td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp index bb24da9287..0221abce67 100644 --- a/dpctl/tensor/libtensor/source/reductions/sum.cpp +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -79,9 +79,7 @@ struct TypePairSupportDataForSumReductionAtomic /* value if true a kernel for must be instantiated, false * otherwise */ - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool + static constexpr bool is_defined = std::disjunction< td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, @@ -121,9 +119,7 @@ template struct TypePairSupportDataForSumReductionTemps { - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool + static constexpr bool is_defined = std::disjunction< td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry,