From e67ed63e8f1769f2a73cfc63a412b14e8d28d18b Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 26 Aug 2024 13:32:18 -0700 Subject: [PATCH 1/5] Implement overloads for `dpctl.tensor.pow` which use `sycl::pown` Improves performance for specific edge cases where the base array is of a floating-point data type and the exponent is 32-bit integer --- .../kernels/elementwise_functions/pow.hpp | 157 ++++++++++++++++-- .../source/elementwise_functions/pow.cpp | 18 +- 2 files changed, 154 insertions(+), 21 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp index a21b2d4318..959fa47170 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp @@ -53,15 +53,16 @@ namespace tu_ns = dpctl::tensor::type_utils; template struct PowFunctor { - using supports_sg_loadstore = std::negation< std::disjunction, tu_ns::is_complex>>; using supports_vec = std::negation< std::disjunction, tu_ns::is_complex>>; + using pown_exp_t = std::int32_t; + resT operator()(const argT1 &in1, const argT2 &in2) const { - if constexpr (std::is_integral_v || std::is_integral_v) { + if constexpr (std::is_integral_v && std::is_integral_v) { auto tmp1 = in1; auto tmp2 = in2; if constexpr (std::is_signed_v) { @@ -92,6 +93,17 @@ template struct PowFunctor return exprm_ns::pow(exprm_ns::complex(in1), exprm_ns::complex(in2)); } + else if constexpr ((std::is_floating_point_v || + std::is_same_v) && + std::is_integral_v) + { + if constexpr (std::is_same_v) { + return sycl::pown(in1, in2); + } + else { + return sycl::pown(in1, static_cast(in2)); + } + } else { return sycl::pow(in1, in2); } @@ -102,7 +114,7 @@ template struct PowFunctor operator()(const sycl::vec &in1, const sycl::vec &in2) const { - if constexpr (std::is_integral_v || std::is_integral_v) { + if constexpr (std::is_integral_v && std::is_integral_v) { sycl::vec res; #pragma unroll for (int i = 0; i < vec_sz; ++i) { @@ -131,6 +143,40 @@ template struct PowFunctor } return res; } + else if constexpr ((std::is_floating_point_v || + std::is_same_v) && + std::is_integral_v) + { + if constexpr (std::is_same_v) { + auto res = sycl::pown(in1, in2); + if constexpr (std::is_same_v< + resT, typename decltype(res)::element_type>) + { + return res; + } + else { + using dpctl::tensor::type_utils::vec_cast; + + return vec_cast(res); + } + } + else { + using dpctl::tensor::type_utils::vec_cast; + + auto tmp = vec_cast(in2); + auto res = sycl::pown(in1, tmp); + if constexpr (std::is_same_v< + resT, typename decltype(res)::element_type>) + { + return res; + } + else { + return vec_cast(res); + } + } + } else { auto res = sycl::pow(in1, in2); if constexpr (std::is_same_v struct PowOutputType { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ + using value_type = typename std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ td_ns::BinaryTypeMapResultEntry struct PowOutputType T2, std::int64_t, std::int64_t>, + td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, @@ -329,15 +384,16 @@ template struct PowStridedFactory template struct PowInplaceFunctor { - using supports_sg_loadstore = std::negation< std::disjunction, tu_ns::is_complex>>; using supports_vec = std::negation< std::disjunction, tu_ns::is_complex>>; + using pown_exp_t = std::int32_t; + void operator()(resT &res, const argT &in) { - if constexpr (std::is_integral_v || std::is_integral_v) { + if constexpr (std::is_integral_v && std::is_integral_v) { auto tmp1 = res; auto tmp2 = in; if constexpr (std::is_signed_v) { @@ -373,17 +429,27 @@ template struct PowInplaceFunctor res = exprm_ns::pow(exprm_ns::complex(res), exprm_ns::complex(in)); } + else if constexpr ((std::is_floating_point_v || + std::is_same_v) && + std::is_integral_v) + { + if constexpr (std::is_same_v) { + res = sycl::pown(res, in); + } + else { + res = sycl::pown(res, static_cast(in)); + } + } else { res = sycl::pow(res, in); } - return; } template void operator()(sycl::vec &res, const sycl::vec &in) { - if constexpr (std::is_integral_v || std::is_integral_v) { + if constexpr (std::is_integral_v && std::is_integral_v) { #pragma unroll for (int i = 0; i < vec_sz; ++i) { auto tmp1 = res[i]; @@ -413,12 +479,75 @@ template struct PowInplaceFunctor res[i] = res_tmp; } } + else if constexpr ((std::is_floating_point_v || + std::is_same_v) && + std::is_integral_v) + { + if constexpr (std::is_same_v) { + res = sycl::pown(res, in); + } + else { + using dpctl::tensor::type_utils::vec_cast; + + auto tmp = vec_cast(in); + res = sycl::pown(res, tmp); + } + } else { res = sycl::pow(res, in); } } }; +/* @brief Types supported by in-place pow */ +template struct PowInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct PowInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of divide(T1 x, T2 y) */ + std::enable_if_t::value, int> get() + { + if constexpr (PowInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template struct PowInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!PowInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -478,7 +605,7 @@ template struct PowInplaceContigFactory } }; -template +template class pow_inplace_strided_kernel; template @@ -505,9 +632,7 @@ struct PowInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!PowInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp index 347b3e298a..e1fbf54772 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp @@ -68,6 +68,8 @@ namespace pow_fn_ns = dpctl::tensor::kernels::pow; static binary_contig_impl_fn_ptr_t pow_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; static int pow_output_id_table[td_ns::num_types][td_ns::num_types]; +static int pow_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t pow_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -101,19 +103,24 @@ void populate_pow_dispatch_tables(void) dtb3; dtb3.populate_dispatch_table(pow_contig_dispatch_table); + // which input types are supported, and what is the type of the result + using fn_ns::PowInplaceTypeMapFactory; + DispatchTableBuilder dtb4; + dtb4.populate_dispatch_table(pow_inplace_output_id_table); + // function pointers for inplace operation on general strided arrays using fn_ns::PowInplaceStridedFactory; DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(pow_inplace_strided_dispatch_table); + dtb5; + dtb5.populate_dispatch_table(pow_inplace_strided_dispatch_table); // function pointers for inplace operation on contiguous inputs and output using fn_ns::PowInplaceContigFactory; DispatchTableBuilder - dtb5; - dtb5.populate_dispatch_table(pow_inplace_contig_dispatch_table); + dtb6; + dtb6.populate_dispatch_table(pow_inplace_contig_dispatch_table); }; } // namespace impl @@ -160,12 +167,13 @@ void init_pow(py::module_ m) using impl::pow_inplace_contig_dispatch_table; using impl::pow_inplace_strided_dispatch_table; + using impl::pow_inplace_output_id_table; auto pow_inplace_pyapi = [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, pow_output_id_table, + src, dst, exec_q, depends, pow_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) pow_inplace_contig_dispatch_table, From 6f9abe0e62dd7a8399d4d9e9b4ecf72deecc7025 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 26 Aug 2024 13:32:46 -0700 Subject: [PATCH 2/5] Refactor in-place division to use `TypePairDefinedEntry` This makes the code easier to understand --- .../elementwise_functions/true_divide.hpp | 82 ++++++++----------- 1 file changed, 34 insertions(+), 48 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index 396e5c995e..d8f4bf672f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -439,52 +439,45 @@ template struct TrueDivideInplaceFunctor } }; -// cannot use the out of place table, as it permits real lhs and complex rhs -// T1 corresponds to the type of the rhs, while T2 corresponds to the lhs -// the type of the result must be the same as T2 -template struct TrueDivideInplaceOutputType +/* @brief Types supported by in-place divide */ +template +struct TrueDivideInplaceTypePairSupport { - using value_type = typename std::disjunction< // disjunction is C++17 - // feature, supported by DPC++ - td_ns::BinaryTypeMapResultEntry, - td_ns::BinaryTypeMapResultEntry, - td_ns::BinaryTypeMapResultEntry, - td_ns::BinaryTypeMapResultEntry, - T2, - std::complex, - std::complex>, - td_ns::BinaryTypeMapResultEntry, - std::complex>, - td_ns::BinaryTypeMapResultEntry, - T2, - std::complex, - std::complex>, - td_ns::BinaryTypeMapResultEntry, - std::complex>, - td_ns::DefaultResultEntry>::result_type; + + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; }; -template +template struct TrueDivideInplaceTypeMapFactory { /*! @brief get typeid for output type of divide(T1 x, T2 y) */ std::enable_if_t::value, int> get() { - using rT = typename TrueDivideInplaceOutputType::value_type; - static_assert(std::is_same_v || std::is_same_v); - return td_ns::GetTypeid{}.get(); + if constexpr (TrueDivideInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } } }; @@ -537,10 +530,7 @@ struct TrueDivideInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TrueDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -579,10 +569,7 @@ struct TrueDivideInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!TrueDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -627,8 +614,7 @@ struct TrueDivideInplaceRowMatrixBroadcastFactory { fnT get() { - using resT = typename TrueDivideInplaceOutputType::value_type; - if constexpr (!std::is_same_v) { + if constexpr (!TrueDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } From 80184c00ff163e1311df6794fcc3c7b0bfb509bc Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 26 Aug 2024 13:33:43 -0700 Subject: [PATCH 3/5] Decouples in-place and out-of-place type support tables Improves readability of in-place code --- .../kernels/elementwise_functions/add.hpp | 58 +++++++++++++++--- .../elementwise_functions/bitwise_and.hpp | 48 ++++++++++++--- .../bitwise_left_shift.hpp | 48 +++++++++++++-- .../elementwise_functions/bitwise_or.hpp | 46 +++++++++++--- .../bitwise_right_shift.hpp | 48 +++++++++++++-- .../elementwise_functions/bitwise_xor.hpp | 48 ++++++++++++--- .../elementwise_functions/floor_divide.hpp | 50 +++++++++++++--- .../elementwise_functions/multiply.hpp | 60 +++++++++++++++---- .../elementwise_functions/remainder.hpp | 48 ++++++++++++--- .../elementwise_functions/subtract.hpp | 59 ++++++++++++++---- .../source/elementwise_functions/add.cpp | 10 +++- .../elementwise_functions/bitwise_and.cpp | 44 ++++++++------ .../bitwise_left_shift.cpp | 12 +++- .../elementwise_functions/bitwise_or.cpp | 44 ++++++++------ .../bitwise_right_shift.cpp | 12 +++- .../elementwise_functions/bitwise_xor.cpp | 44 ++++++++------ .../elementwise_functions/floor_divide.cpp | 44 ++++++++------ .../source/elementwise_functions/multiply.cpp | 10 +++- .../source/elementwise_functions/pow.cpp | 7 +-- .../elementwise_functions/remainder.cpp | 11 +++- .../source/elementwise_functions/subtract.cpp | 10 +++- .../elementwise_functions/true_divide.cpp | 2 +- 22 files changed, 604 insertions(+), 159 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index e77068b5e1..0ccbe4dfa7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -438,6 +438,53 @@ template class add_inplace_contig_kernel; +/* @brief Types supported by in-place add */ +template struct AddInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct AddInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x += y */ + std::enable_if_t::value, int> get() + { + if constexpr (AddInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event add_inplace_contig_impl(sycl::queue &exec_q, @@ -457,9 +504,7 @@ template struct AddInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AddInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -497,9 +542,7 @@ struct AddInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) - { + if constexpr (!AddInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -544,8 +587,7 @@ struct AddInplaceRowMatrixBroadcastFactory { fnT get() { - using resT = typename AddOutputType::value_type; - if constexpr (!std::is_same_v) { + if constexpr (!AddInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp index ffe80f622e..971bfce7fc 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp @@ -322,6 +322,44 @@ template class bitwise_and_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise AND */ +template +struct BitwiseAndInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseAndInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x &= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseAndInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_and_inplace_contig_impl(sycl::queue &exec_q, @@ -343,10 +381,7 @@ struct BitwiseAndInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseAndOutputType::value_type, - void>) - { + if constexpr (!BitwiseAndInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -385,10 +420,7 @@ struct BitwiseAndInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseAndOutputType::value_type, - void>) - { + if constexpr (!BitwiseAndInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp index 18a87e5287..f9b74a506c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp @@ -336,6 +336,44 @@ template class bitwise_left_shift_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise left shift */ +template +struct BitwiseLeftShiftInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseLeftShiftInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x <<= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseLeftShiftInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_left_shift_inplace_contig_impl( sycl::queue &exec_q, @@ -357,9 +395,8 @@ struct BitwiseLeftShiftInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) + if constexpr (!BitwiseLeftShiftInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; @@ -399,9 +436,8 @@ struct BitwiseLeftShiftInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) + if constexpr (!BitwiseLeftShiftInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp index aad31aae95..70ed86bd0f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp @@ -318,6 +318,42 @@ template class bitwise_or_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise OR */ +template struct BitwiseOrInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseOrInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x |= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseOrInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_or_inplace_contig_impl(sycl::queue &exec_q, @@ -339,10 +375,7 @@ struct BitwiseOrInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseOrOutputType::value_type, - void>) - { + if constexpr (!BitwiseOrInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -381,10 +414,7 @@ struct BitwiseOrInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseOrOutputType::value_type, - void>) - { + if constexpr (!BitwiseOrInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp index 2fbee2e49d..5119034544 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp @@ -340,6 +340,44 @@ template class bitwise_right_shift_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise right shift */ +template +struct BitwiseRightShiftInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseRightShiftInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x >>= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseRightShiftInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_right_shift_inplace_contig_impl( sycl::queue &exec_q, @@ -361,9 +399,8 @@ struct BitwiseRightShiftInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) + if constexpr (!BitwiseRightShiftInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; @@ -403,9 +440,8 @@ struct BitwiseRightShiftInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) + if constexpr (!BitwiseRightShiftInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp index fd0ca880c9..36d383aa48 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp @@ -322,6 +322,44 @@ template class bitwise_xor_inplace_contig_kernel; +/* @brief Types supported by in-place bitwise XOR */ +template +struct BitwiseXorInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseXorInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x ^= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseXorInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event bitwise_xor_inplace_contig_impl(sycl::queue &exec_q, @@ -343,10 +381,7 @@ struct BitwiseXorInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseXorOutputType::value_type, - void>) - { + if constexpr (!BitwiseXorInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -385,10 +420,7 @@ struct BitwiseXorInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename BitwiseXorOutputType::value_type, - void>) - { + if constexpr (!BitwiseXorInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp index ce89b0778f..b103f98b7a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp @@ -398,6 +398,46 @@ template class floor_divide_inplace_contig_kernel; +/* @brief Types supported by in-place floor division */ +template +struct FloorDivideInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct FloorDivideInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x //= y */ + std::enable_if_t::value, int> get() + { + if constexpr (FloorDivideInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event floor_divide_inplace_contig_impl(sycl::queue &exec_q, @@ -419,10 +459,7 @@ struct FloorDivideInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename FloorDivideOutputType::value_type, - void>) - { + if constexpr (!FloorDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -461,10 +498,7 @@ struct FloorDivideInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename FloorDivideOutputType::value_type, - void>) - { + if constexpr (!FloorDivideInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index 147f62f53e..c0e2e0ae58 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -434,6 +434,53 @@ template class multiply_inplace_contig_kernel; +/* @brief Types supported by in-place multiplication */ +template struct MultiplyInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MultiplyInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x *= y */ + std::enable_if_t::value, int> get() + { + if constexpr (MultiplyInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event multiply_inplace_contig_impl(sycl::queue &exec_q, @@ -455,10 +502,7 @@ struct MultiplyInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename MultiplyOutputType::value_type, - void>) - { + if constexpr (!MultiplyInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -497,10 +541,7 @@ struct MultiplyInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename MultiplyOutputType::value_type, - void>) - { + if constexpr (!MultiplyInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -545,8 +586,7 @@ struct MultiplyInplaceRowMatrixBroadcastFactory { fnT get() { - using resT = typename MultiplyOutputType::value_type; - if constexpr (!std::is_same_v) { + if constexpr (!MultiplyInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp index 585d1c6d7f..32c64ca9cf 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp @@ -424,6 +424,44 @@ template class remainder_inplace_contig_kernel; +/* @brief Types supported by in-place remainder */ +template struct RemainderInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct RemainderInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x %= y */ + std::enable_if_t::value, int> get() + { + if constexpr (RemainderInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event remainder_inplace_contig_impl(sycl::queue &exec_q, @@ -445,10 +483,7 @@ struct RemainderInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename RemainderOutputType::value_type, - void>) - { + if constexpr (!RemainderInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -487,10 +522,7 @@ struct RemainderInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename RemainderOutputType::value_type, - void>) - { + if constexpr (!RemainderInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp index 4a8cfb50a7..8845f0657a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp @@ -435,6 +435,52 @@ template class subtract_inplace_contig_kernel; +/* @brief Types supported by in-place subtraction */ +template struct SubtractInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< // disjunction is + // C++17 feature, + // supported by + // DPC++ input + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct SubtractInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x -= y */ + std::enable_if_t::value, int> get() + { + if constexpr (SubtractInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + template sycl::event subtract_inplace_contig_impl(sycl::queue &exec_q, @@ -456,10 +502,7 @@ struct SubtractInplaceContigFactory { fnT get() { - if constexpr (std::is_same_v< - typename SubtractOutputType::value_type, - void>) - { + if constexpr (!SubtractInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -498,10 +541,7 @@ struct SubtractInplaceStridedFactory { fnT get() { - if constexpr (std::is_same_v< - typename SubtractOutputType::value_type, - void>) - { + if constexpr (!SubtractInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } @@ -546,8 +586,7 @@ struct SubtractInplaceRowMatrixBroadcastFactory { fnT get() { - using resT = typename SubtractOutputType::value_type; - if constexpr (!std::is_same_v) { + if constexpr (!SubtractInplaceTypePairSupport::is_defined) { fnT fn = nullptr; return fn; } diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp index 0cce42cdb5..823647aefd 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp @@ -67,7 +67,9 @@ namespace add_fn_ns = dpctl::tensor::kernels::add; static binary_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; + static int add_output_id_table[td_ns::num_types][td_ns::num_types]; +static int add_inplace_output_id_table[td_ns::num_types][td_ns::num_types]; static binary_strided_impl_fn_ptr_t add_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -154,6 +156,11 @@ void populate_add_dispatch_tables(void) AddInplaceRowMatrixBroadcastFactory, num_types> dtb8; dtb8.populate_dispatch_table(add_inplace_row_matrix_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::AddInplaceTypeMapFactory; + DispatchTableBuilder dtb9; + dtb9.populate_dispatch_table(add_inplace_output_id_table); }; } // namespace impl @@ -199,6 +206,7 @@ void init_add(py::module_ m) m.def("_add_result_type", add_result_type_pyapi, ""); using impl::add_inplace_contig_dispatch_table; + using impl::add_inplace_output_id_table; using impl::add_inplace_row_matrix_dispatch_table; using impl::add_inplace_strided_dispatch_table; @@ -206,7 +214,7 @@ void init_add(py::module_ m) sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, add_output_id_table, + src, dst, exec_q, depends, add_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) add_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp index afd7ffe469..15ccf22ea2 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp @@ -66,7 +66,10 @@ namespace bitwise_and_fn_ns = dpctl::tensor::kernels::bitwise_and; static binary_contig_impl_fn_ptr_t bitwise_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int bitwise_and_output_id_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_and_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_bitwise_and_dispatch_tables(void) BitwiseAndInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(bitwise_and_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseAndInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(bitwise_and_inplace_output_id_table); }; } // namespace impl @@ -160,25 +168,27 @@ void init_bitwise_and(py::module_ m) m.def("_bitwise_and_result_type", bitwise_and_result_type_pyapi, ""); using impl::bitwise_and_inplace_contig_dispatch_table; + using impl::bitwise_and_inplace_output_id_table; using impl::bitwise_and_inplace_strided_dispatch_table; - auto bitwise_and_inplace_pyapi = - [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, bitwise_and_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - bitwise_and_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - bitwise_and_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; + auto bitwise_and_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, bitwise_and_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_and_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_and_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; m.def("_bitwise_and_inplace", bitwise_and_inplace_pyapi, "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), py::arg("depends") = py::list()); diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp index 777573a738..833b78e697 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp @@ -67,8 +67,11 @@ namespace bitwise_left_shift_fn_ns = dpctl::tensor::kernels::bitwise_left_shift; static binary_contig_impl_fn_ptr_t bitwise_left_shift_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; + static int bitwise_left_shift_output_id_table[td_ns::num_types] [td_ns::num_types]; +static int bitwise_left_shift_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_left_shift_strided_dispatch_table[td_ns::num_types] @@ -120,6 +123,12 @@ void populate_bitwise_left_shift_dispatch_tables(void) dtb5; dtb5.populate_dispatch_table( bitwise_left_shift_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseLeftShiftInplaceTypeMapFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(bitwise_left_shift_inplace_output_id_table); }; } // namespace impl @@ -169,6 +178,7 @@ void init_bitwise_left_shift(py::module_ m) bitwise_left_shift_result_type_pyapi, ""); using impl::bitwise_left_shift_inplace_contig_dispatch_table; + using impl::bitwise_left_shift_inplace_output_id_table; using impl::bitwise_left_shift_inplace_strided_dispatch_table; auto bitwise_left_shift_inplace_pyapi = @@ -176,7 +186,7 @@ void init_bitwise_left_shift(py::module_ m) const event_vecT &depends = {}) { return py_binary_inplace_ufunc( src, dst, exec_q, depends, - bitwise_left_shift_output_id_table, + bitwise_left_shift_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) bitwise_left_shift_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp index 468d887392..ecdaeb6577 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp @@ -66,7 +66,10 @@ namespace bitwise_or_fn_ns = dpctl::tensor::kernels::bitwise_or; static binary_contig_impl_fn_ptr_t bitwise_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int bitwise_or_output_id_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_or_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_bitwise_or_dispatch_tables(void) BitwiseOrInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(bitwise_or_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseOrInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(bitwise_or_inplace_output_id_table); }; } // namespace impl @@ -160,25 +168,27 @@ void init_bitwise_or(py::module_ m) m.def("_bitwise_or_result_type", bitwise_or_result_type_pyapi, ""); using impl::bitwise_or_inplace_contig_dispatch_table; + using impl::bitwise_or_inplace_output_id_table; using impl::bitwise_or_inplace_strided_dispatch_table; - auto bitwise_or_inplace_pyapi = - [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, bitwise_or_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - bitwise_or_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - bitwise_or_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; + auto bitwise_or_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, bitwise_or_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_or_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_or_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; m.def("_bitwise_or_inplace", bitwise_or_inplace_pyapi, "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), py::arg("depends") = py::list()); diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp index 85fc7b99a9..9f74037f41 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp @@ -68,8 +68,11 @@ namespace bitwise_right_shift_fn_ns = static binary_contig_impl_fn_ptr_t bitwise_right_shift_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; + static int bitwise_right_shift_output_id_table[td_ns::num_types] [td_ns::num_types]; +static int bitwise_right_shift_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_right_shift_strided_dispatch_table[td_ns::num_types] @@ -121,6 +124,12 @@ void populate_bitwise_right_shift_dispatch_tables(void) dtb5; dtb5.populate_dispatch_table( bitwise_right_shift_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseRightShiftInplaceTypeMapFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(bitwise_right_shift_inplace_output_id_table); }; } // namespace impl @@ -170,6 +179,7 @@ void init_bitwise_right_shift(py::module_ m) bitwise_right_shift_result_type_pyapi, ""); using impl::bitwise_right_shift_inplace_contig_dispatch_table; + using impl::bitwise_right_shift_inplace_output_id_table; using impl::bitwise_right_shift_inplace_strided_dispatch_table; auto bitwise_right_shift_inplace_pyapi = @@ -177,7 +187,7 @@ void init_bitwise_right_shift(py::module_ m) const event_vecT &depends = {}) { return py_binary_inplace_ufunc( src, dst, exec_q, depends, - bitwise_right_shift_output_id_table, + bitwise_right_shift_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) bitwise_right_shift_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp index 18e9fb695e..6124f305e0 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp @@ -66,7 +66,10 @@ namespace bitwise_xor_fn_ns = dpctl::tensor::kernels::bitwise_xor; static binary_contig_impl_fn_ptr_t bitwise_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int bitwise_xor_output_id_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_xor_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t bitwise_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_bitwise_xor_dispatch_tables(void) BitwiseXorInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(bitwise_xor_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseXorInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(bitwise_xor_inplace_output_id_table); }; } // namespace impl @@ -160,25 +168,27 @@ void init_bitwise_xor(py::module_ m) m.def("_bitwise_xor_result_type", bitwise_xor_result_type_pyapi, ""); using impl::bitwise_xor_inplace_contig_dispatch_table; + using impl::bitwise_xor_inplace_output_id_table; using impl::bitwise_xor_inplace_strided_dispatch_table; - auto bitwise_xor_inplace_pyapi = - [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, bitwise_xor_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - bitwise_xor_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - bitwise_xor_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; + auto bitwise_xor_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, bitwise_xor_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_xor_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_xor_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; m.def("_bitwise_xor_inplace", bitwise_xor_inplace_pyapi, "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), py::arg("depends") = py::list()); diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp index fdbffbb3be..921b494fd1 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp @@ -66,7 +66,10 @@ namespace floor_divide_fn_ns = dpctl::tensor::kernels::floor_divide; static binary_contig_impl_fn_ptr_t floor_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int floor_divide_output_id_table[td_ns::num_types][td_ns::num_types]; +static int floor_divide_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t floor_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_floor_divide_dispatch_tables(void) FloorDivideInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(floor_divide_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::FloorDivideInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(floor_divide_inplace_output_id_table); }; } // namespace impl @@ -160,25 +168,27 @@ void init_floor_divide(py::module_ m) m.def("_floor_divide_result_type", floor_divide_result_type_pyapi, ""); using impl::floor_divide_inplace_contig_dispatch_table; + using impl::floor_divide_inplace_output_id_table; using impl::floor_divide_inplace_strided_dispatch_table; - auto floor_divide_inplace_pyapi = - [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, floor_divide_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - floor_divide_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - floor_divide_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; + auto floor_divide_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, floor_divide_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + floor_divide_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + floor_divide_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; m.def("_floor_divide_inplace", floor_divide_inplace_pyapi, "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), py::arg("depends") = py::list()); diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp index 75438e1031..22fc293c98 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp @@ -67,7 +67,9 @@ namespace multiply_fn_ns = dpctl::tensor::kernels::multiply; static binary_contig_impl_fn_ptr_t multiply_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int multiply_output_id_table[td_ns::num_types][td_ns::num_types]; +static int multiply_inplace_output_id_table[td_ns::num_types][td_ns::num_types]; static binary_strided_impl_fn_ptr_t multiply_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -155,6 +157,11 @@ void populate_multiply_dispatch_tables(void) MultiplyInplaceRowMatrixBroadcastFactory, num_types> dtb8; dtb8.populate_dispatch_table(multiply_inplace_row_matrix_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::MultiplyInplaceTypeMapFactory; + DispatchTableBuilder dtb9; + dtb9.populate_dispatch_table(multiply_inplace_output_id_table); }; } // namespace impl @@ -200,6 +207,7 @@ void init_multiply(py::module_ m) m.def("_multiply_result_type", multiply_result_type_pyapi, ""); using impl::multiply_inplace_contig_dispatch_table; + using impl::multiply_inplace_output_id_table; using impl::multiply_inplace_row_matrix_dispatch_table; using impl::multiply_inplace_strided_dispatch_table; @@ -207,7 +215,7 @@ void init_multiply(py::module_ m) sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, multiply_output_id_table, + src, dst, exec_q, depends, multiply_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) multiply_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp index e1fbf54772..5f3ceae813 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp @@ -68,8 +68,7 @@ namespace pow_fn_ns = dpctl::tensor::kernels::pow; static binary_contig_impl_fn_ptr_t pow_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; static int pow_output_id_table[td_ns::num_types][td_ns::num_types]; -static int pow_inplace_output_id_table[td_ns::num_types] - [td_ns::num_types]; +static int pow_inplace_output_id_table[td_ns::num_types][td_ns::num_types]; static binary_strided_impl_fn_ptr_t pow_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -103,7 +102,7 @@ void populate_pow_dispatch_tables(void) dtb3; dtb3.populate_dispatch_table(pow_contig_dispatch_table); - // which input types are supported, and what is the type of the result + // which types are supported by the in-place kernels using fn_ns::PowInplaceTypeMapFactory; DispatchTableBuilder dtb4; dtb4.populate_dispatch_table(pow_inplace_output_id_table); @@ -166,8 +165,8 @@ void init_pow(py::module_ m) m.def("_pow_result_type", pow_result_type_pyapi, ""); using impl::pow_inplace_contig_dispatch_table; - using impl::pow_inplace_strided_dispatch_table; using impl::pow_inplace_output_id_table; + using impl::pow_inplace_strided_dispatch_table; auto pow_inplace_pyapi = [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp index d08c7ebda8..a2d3f48099 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp @@ -67,7 +67,10 @@ namespace remainder_fn_ns = dpctl::tensor::kernels::remainder; static binary_contig_impl_fn_ptr_t remainder_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int remainder_output_id_table[td_ns::num_types][td_ns::num_types]; +static int remainder_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t remainder_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -115,6 +118,11 @@ void populate_remainder_dispatch_tables(void) RemainderInplaceContigFactory, num_types> dtb5; dtb5.populate_dispatch_table(remainder_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::RemainderInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(remainder_inplace_output_id_table); } } // namespace impl @@ -160,13 +168,14 @@ void init_remainder(py::module_ m) m.def("_remainder_result_type", remainder_result_type_pyapi, ""); using impl::remainder_inplace_contig_dispatch_table; + using impl::remainder_inplace_output_id_table; using impl::remainder_inplace_strided_dispatch_table; auto remainder_inplace_pyapi = [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, remainder_output_id_table, + src, dst, exec_q, depends, remainder_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) remainder_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp index 0c81ace53e..f1c1324a58 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp @@ -66,7 +66,9 @@ namespace subtract_fn_ns = dpctl::tensor::kernels::subtract; static binary_contig_impl_fn_ptr_t subtract_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + static int subtract_output_id_table[td_ns::num_types][td_ns::num_types]; +static int subtract_inplace_output_id_table[td_ns::num_types][td_ns::num_types]; static binary_strided_impl_fn_ptr_t subtract_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -154,6 +156,11 @@ void populate_subtract_dispatch_tables(void) SubtractInplaceRowMatrixBroadcastFactory, num_types> dtb8; dtb8.populate_dispatch_table(subtract_inplace_row_matrix_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::SubtractInplaceTypeMapFactory; + DispatchTableBuilder dtb9; + dtb9.populate_dispatch_table(subtract_inplace_output_id_table); }; } // namespace impl @@ -199,6 +206,7 @@ void init_subtract(py::module_ m) m.def("_subtract_result_type", subtract_result_type_pyapi, ""); using impl::subtract_inplace_contig_dispatch_table; + using impl::subtract_inplace_output_id_table; using impl::subtract_inplace_row_matrix_dispatch_table; using impl::subtract_inplace_strided_dispatch_table; @@ -206,7 +214,7 @@ void init_subtract(py::module_ m) sycl::queue &exec_q, const event_vecT &depends = {}) { return py_binary_inplace_ufunc( - src, dst, exec_q, depends, subtract_output_id_table, + src, dst, exec_q, depends, subtract_inplace_output_id_table, // function pointers to handle inplace operation on // contiguous arrays (pointers may be nullptr) subtract_inplace_contig_dispatch_table, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp index f78c428152..ffb2afc3ea 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp @@ -137,7 +137,7 @@ void populate_true_divide_dispatch_tables(void) dtb5.populate_dispatch_table( true_divide_contig_row_contig_matrix_broadcast_dispatch_table); - // which input types are supported, and what is the type of the result + // which types are supported by the in-place kernels using fn_ns::TrueDivideInplaceTypeMapFactory; DispatchTableBuilder dtb6; dtb6.populate_dispatch_table(true_divide_inplace_output_id_table); From 5afa60589f3855d5e1bbd5951e8875aefc1a158b Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 26 Aug 2024 13:34:28 -0700 Subject: [PATCH 4/5] Fixes a comment in `_acceptance_fn_reciprocal` --- dpctl/tensor/_type_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index bebb1889f4..890af46339 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -140,11 +140,9 @@ def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev): def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev): - # if the kind of result is different from - # the kind of input, use the default data - # we use default dtype for the resulting kind. - # This guarantees alignment of reciprocal and - # divide output types. + # if the kind of result is different from the kind of input, we use the + # default floating-point dtype for the resulting kind. This guarantees + # alignment of reciprocal and divide output types. if buf_dt.kind != arg_dtype.kind: default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev) if res_dt == default_dt: From c746b3e2a53e39e0387c431cf2031649e630dc8f Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 27 Aug 2024 11:09:49 -0700 Subject: [PATCH 5/5] Update dpt.pow tests to account for `sycl::pown` overload in dtype matrix --- dpctl/tests/elementwise/test_pow.py | 80 +++++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/dpctl/tests/elementwise/test_pow.py b/dpctl/tests/elementwise/test_pow.py index e298ed2347..eda4216d23 100644 --- a/dpctl/tests/elementwise/test_pow.py +++ b/dpctl/tests/elementwise/test_pow.py @@ -21,12 +21,31 @@ import dpctl import dpctl.tensor as dpt -from dpctl.tensor._type_utils import _can_cast +from dpctl.tensor._type_utils import _can_cast, _find_buf_dtype2 from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported from .utils import _all_dtypes, _compare_dtypes, _usm_types +def _using_pown(lhs_dt, rhs_dt, dev): + # find if type combination is legal with use of `sycl::pown` + dt1, dt2, res_dt = _find_buf_dtype2( + lhs_dt, + rhs_dt, + dpt.pow.get_type_result_resolver_function(), + dev, + dpt.pow.get_type_promotion_path_acceptance_function(), + ) + if res_dt: + if dt1 is None: + dt1 = lhs_dt + if dt2 is None: + dt2 = rhs_dt + if dt1.kind == "f" and dt2.kind == "i": + return True + return False + + @pytest.mark.parametrize("op1_dtype", _all_dtypes[1:]) @pytest.mark.parametrize("op2_dtype", _all_dtypes[1:]) def test_power_dtype_matrix(op1_dtype, op2_dtype): @@ -40,10 +59,18 @@ def test_power_dtype_matrix(op1_dtype, op2_dtype): r = dpt.pow(ar1, ar2) assert isinstance(r, dpt.usm_ndarray) - expected = np.power( - np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype) - ) - assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + dev = q.sycl_device + if _using_pown(ar1.dtype, ar2.dtype, dev): + expected = np.power( + np.ones(1, dtype=op1_dtype), + np.ones(1, dtype=op2_dtype), + dtype=r.dtype, + ) + else: + expected = np.power( + np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype) + ) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) assert r.shape == ar1.shape assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() assert r.sycl_queue == ar1.sycl_queue @@ -53,10 +80,17 @@ def test_power_dtype_matrix(op1_dtype, op2_dtype): r = dpt.pow(ar3[::-1], ar4[::2]) assert isinstance(r, dpt.usm_ndarray) - expected = np.power( - np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype) - ) - assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + if _using_pown(ar3.dtype, ar4.dtype, dev): + expected = np.power( + np.ones(1, dtype=op1_dtype), + np.ones(1, dtype=op2_dtype), + dtype=r.dtype, + ) + else: + expected = np.power( + np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype) + ) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) assert r.shape == ar3.shape assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() @@ -169,6 +203,25 @@ def test_pow_inplace_python_scalar(dtype): X **= complex(1) +def _using_pown_inplace(lhs_dt, rhs_dt, dev): + # find if type combination is legal with use of `sycl::pown` + dt1, dt2, res_dt = _find_buf_dtype2( + lhs_dt, + rhs_dt, + dpt.pow.get_type_result_resolver_function(), + dev, + dpt.pow.get_type_promotion_path_acceptance_function(), + ) + if res_dt: + if dt1 is None: + dt1 = lhs_dt + if dt2 is None: + dt2 = rhs_dt + if dt1.kind == "f" and dt2.kind == "i" and dt1 == lhs_dt: + return True + return False + + @pytest.mark.parametrize("op1_dtype", _all_dtypes[1:]) @pytest.mark.parametrize("op2_dtype", _all_dtypes[1:]) def test_pow_inplace_dtype_matrix(op1_dtype, op2_dtype): @@ -183,7 +236,12 @@ def test_pow_inplace_dtype_matrix(op1_dtype, op2_dtype): dev = q.sycl_device _fp16 = dev.has_aspect_fp16 _fp64 = dev.has_aspect_fp64 - if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64): + ar1_dt = ar1.dtype + ar2_dt = ar2.dtype + # need to check for `pown` overload + if _can_cast(ar2_dt, ar1_dt, _fp16, _fp64) or _using_pown_inplace( + ar1_dt, ar2_dt, dev + ): ar1 **= ar2 assert ( dpt.asnumpy(ar1) == np.full(ar1.shape, 1, dtype=ar1.dtype) @@ -192,7 +250,7 @@ def test_pow_inplace_dtype_matrix(op1_dtype, op2_dtype): ar3 = dpt.ones(sz, dtype=op1_dtype) ar4 = dpt.ones(2 * sz, dtype=op2_dtype) - ar3[::-1] *= ar4[::2] + ar3[::-1] **= ar4[::2] assert ( dpt.asnumpy(ar3) == np.full(ar3.shape, 1, dtype=ar3.dtype) ).all()