From f7ac1f19717d40725e807b09092fe43626cfa20c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 24 Jul 2023 14:30:16 -0500 Subject: [PATCH 1/2] Advanced boolean indexing support using int32/int64 cumsum temporaries This PR changes mask_positions to work with cumsum temporary of dtype int32 as well as int64. Similarly, other functions such as _nonzero, _extract, _place also support cumsum array of int32 as well as int64. Support for int32 allows to improve performance for indexing into smaller arrays. --- .../kernels/boolean_advanced_indexing.hpp | 131 +++++++--- .../source/boolean_advanced_indexing.cpp | 234 +++++++++++++----- .../source/boolean_advanced_indexing.hpp | 4 +- 3 files changed, 279 insertions(+), 90 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp index d0db095d15..bb2ddc5ad6 100644 --- a/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp @@ -393,7 +393,7 @@ size_t mask_positions_contig_impl(sycl::queue q, throw std::bad_alloc(); } sycl::event copy_e = - q.copy(last_elem, last_elem_host_usm, 1, {comp_ev}); + q.copy(last_elem, last_elem_host_usm, 1, {comp_ev}); copy_e.wait(); size_t return_val = static_cast(*last_elem_host_usm); sycl::free(last_elem_host_usm, q); @@ -401,7 +401,16 @@ size_t mask_positions_contig_impl(sycl::queue q, return return_val; } -template struct MaskPositionsContigFactory +template struct MaskPositionsContigFactoryForInt32 +{ + fnT get() + { + fnT fn = mask_positions_contig_impl; + return fn; + } +}; + +template struct MaskPositionsContigFactoryForInt64 { fnT get() { @@ -452,7 +461,7 @@ size_t mask_positions_strided_impl(sycl::queue q, throw std::bad_alloc(); } sycl::event copy_e = - q.copy(last_elem, last_elem_host_usm, 1, {comp_ev}); + q.copy(last_elem, last_elem_host_usm, 1, {comp_ev}); copy_e.wait(); size_t return_val = static_cast(*last_elem_host_usm); sycl::free(last_elem_host_usm, q); @@ -460,7 +469,16 @@ size_t mask_positions_strided_impl(sycl::queue q, return return_val; } -template struct MaskPositionsStridedFactory +template struct MaskPositionsStridedFactoryForInt32 +{ + fnT get() + { + fnT fn = mask_positions_strided_impl; + return fn; + } +}; + +template struct MaskPositionsStridedFactoryForInt64 { fnT get() { @@ -611,7 +629,18 @@ sycl::event masked_extract_some_slices_strided_impl( return comp_ev; } -template struct MaskExtractAllSlicesStridedFactory +template +struct MaskExtractAllSlicesStridedFactoryForInt32 +{ + fnT get() + { + fnT fn = masked_extract_all_slices_strided_impl; + return fn; + } +}; + +template +struct MaskExtractAllSlicesStridedFactoryForInt64 { fnT get() { @@ -620,7 +649,18 @@ template struct MaskExtractAllSlicesStridedFactory } }; -template struct MaskExtractSomeSlicesStridedFactory +template +struct MaskExtractSomeSlicesStridedFactoryForInt32 +{ + fnT get() + { + fnT fn = masked_extract_some_slices_strided_impl; + return fn; + } +}; + +template +struct MaskExtractSomeSlicesStridedFactoryForInt64 { fnT get() { @@ -763,7 +803,18 @@ sycl::event masked_place_some_slices_strided_impl( return comp_ev; } -template struct MaskPlaceAllSlicesStridedFactory +template +struct MaskPlaceAllSlicesStridedFactoryForInt32 +{ + fnT get() + { + fnT fn = masked_place_all_slices_strided_impl; + return fn; + } +}; + +template +struct MaskPlaceAllSlicesStridedFactoryForInt64 { fnT get() { @@ -772,7 +823,18 @@ template struct MaskPlaceAllSlicesStridedFactory } }; -template struct MaskPlaceSomeSlicesStridedFactory +template +struct MaskPlaceSomeSlicesStridedFactoryForInt32 +{ + fnT get() + { + fnT fn = masked_place_some_slices_strided_impl; + return fn; + } +}; + +template +struct MaskPlaceSomeSlicesStridedFactoryForInt64 { fnT get() { @@ -783,7 +845,17 @@ template struct MaskPlaceSomeSlicesStridedFactory // Non-zero -class non_zero_indexes_krn; +template class non_zero_indexes_krn; + +typedef sycl::event (*non_zero_indexes_fn_ptr_t)( + sycl::queue, + py::ssize_t, + py::ssize_t, + int, + const char *, + char *, + const py::ssize_t *, + std::vector const &); template sycl::event non_zero_indexes_impl(sycl::queue exec_q, @@ -800,28 +872,29 @@ sycl::event non_zero_indexes_impl(sycl::queue exec_q, sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - cgh.parallel_for( - sycl::range<1>(iter_size), [=](sycl::id<1> idx) { - auto i = idx[0]; - - auto cs_curr_val = cumsum_data[i] - 1; - auto cs_prev_val = (i > 0) ? cumsum_data[i - 1] : indT1(0); - bool cond = (cs_curr_val == cs_prev_val); - - py::ssize_t i_ = static_cast(i); - for (int dim = nd; --dim > 0;) { - auto sd = mask_shape[dim]; - py::ssize_t q = i_ / sd; - py::ssize_t r = (i_ - q * sd); - if (cond) { - indexes_data[cs_curr_val + dim * nz_elems] = - static_cast(r); - } - i_ = q; - } + cgh.parallel_for>( + sycl::range<1>(iter_size), [=](sycl::id<1> idx) + { + auto i = idx[0]; + + auto cs_curr_val = cumsum_data[i] - 1; + auto cs_prev_val = (i > 0) ? cumsum_data[i - 1] : indT1(0); + bool cond = (cs_curr_val == cs_prev_val); + + py::ssize_t i_ = static_cast(i); + for (int dim = nd; --dim > 0;) { + auto sd = mask_shape[dim]; + py::ssize_t q = i_ / sd; + py::ssize_t r = (i_ - q * sd); if (cond) { - indexes_data[cs_curr_val] = static_cast(i_); + indexes_data[cs_curr_val + dim * nz_elems] = + static_cast(r); } + i_ = q; + } + if (cond) { + indexes_data[cs_curr_val] = static_cast(i_); + } }); }); diff --git a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp index 59f62af5f1..0c9f8656d0 100644 --- a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp @@ -97,25 +97,45 @@ namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::kernels::indexing::mask_positions_contig_impl_fn_ptr_t; static mask_positions_contig_impl_fn_ptr_t - mask_positions_contig_dispatch_vector[td_ns::num_types]; + mask_positions_contig_i64_dispatch_vector[td_ns::num_types]; +static mask_positions_contig_impl_fn_ptr_t + mask_positions_contig_i32_dispatch_vector[td_ns::num_types]; using dpctl::tensor::kernels::indexing::mask_positions_strided_impl_fn_ptr_t; static mask_positions_strided_impl_fn_ptr_t - mask_positions_strided_dispatch_vector[td_ns::num_types]; + mask_positions_strided_i64_dispatch_vector[td_ns::num_types]; +static mask_positions_strided_impl_fn_ptr_t + mask_positions_strided_i32_dispatch_vector[td_ns::num_types]; void populate_mask_positions_dispatch_vectors(void) { - using dpctl::tensor::kernels::indexing::MaskPositionsContigFactory; + using dpctl::tensor::kernels::indexing::MaskPositionsContigFactoryForInt64; td_ns::DispatchVectorBuilder + MaskPositionsContigFactoryForInt64, + td_ns::num_types> dvb1; - dvb1.populate_dispatch_vector(mask_positions_contig_dispatch_vector); + dvb1.populate_dispatch_vector(mask_positions_contig_i64_dispatch_vector); - using dpctl::tensor::kernels::indexing::MaskPositionsStridedFactory; - td_ns::DispatchVectorBuilder + using dpctl::tensor::kernels::indexing::MaskPositionsContigFactoryForInt32; + td_ns::DispatchVectorBuilder dvb2; - dvb2.populate_dispatch_vector(mask_positions_strided_dispatch_vector); + dvb2.populate_dispatch_vector(mask_positions_contig_i32_dispatch_vector); + + using dpctl::tensor::kernels::indexing::MaskPositionsStridedFactoryForInt64; + td_ns::DispatchVectorBuilder + dvb3; + dvb3.populate_dispatch_vector(mask_positions_strided_i64_dispatch_vector); + + using dpctl::tensor::kernels::indexing::MaskPositionsStridedFactoryForInt32; + td_ns::DispatchVectorBuilder + dvb4; + dvb4.populate_dispatch_vector(mask_positions_strided_i32_dispatch_vector); return; } @@ -163,15 +183,20 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask, int mask_typeid = array_types.typenum_to_lookup_id(mask_typenum); int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); - // cumsum must be int64_t only + // cumsum must be int32_t/int64_t only + constexpr int int32_typeid = static_cast(td_ns::typenum_t::INT32); constexpr int int64_typeid = static_cast(td_ns::typenum_t::INT64); - if (cumsum_typeid != int64_typeid) { + if (cumsum_typeid != int32_typeid && cumsum_typeid != int64_typeid) { throw py::value_error( - "Cumulative sum array must have int64 data-type."); + "Cumulative sum array must have int32 or int64 data-type."); } + const bool use_i32 = (cumsum_typeid == int32_typeid); + if (mask.is_c_contiguous()) { - auto fn = mask_positions_contig_dispatch_vector[mask_typeid]; + auto fn = (use_i32) + ? mask_positions_contig_i32_dispatch_vector[mask_typeid] + : mask_positions_contig_i64_dispatch_vector[mask_typeid]; return fn(exec_q, mask_size, mask_data, cumsum_data, depends); } @@ -192,13 +217,17 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask, offset); if (nd == 1 && simplified_strides[0] == 1) { - auto fn = mask_positions_contig_dispatch_vector[mask_typeid]; + auto fn = (use_i32) + ? mask_positions_contig_i32_dispatch_vector[mask_typeid] + : mask_positions_contig_i64_dispatch_vector[mask_typeid]; return fn(exec_q, mask_size, mask_data, cumsum_data, depends); } // Strided implementation - auto strided_fn = mask_positions_strided_dispatch_vector[mask_typeid]; + auto strided_fn = + (use_i32) ? mask_positions_strided_i32_dispatch_vector[mask_typeid] + : mask_positions_strided_i64_dispatch_vector[mask_typeid]; std::vector host_task_events; using dpctl::tensor::offset_utils::device_allocate_and_pack; @@ -239,31 +268,59 @@ using dpctl::tensor::kernels::indexing:: masked_extract_all_slices_strided_impl_fn_ptr_t; static masked_extract_all_slices_strided_impl_fn_ptr_t - masked_extract_all_slices_strided_impl_dispatch_vector[td_ns::num_types]; + masked_extract_all_slices_strided_i32_impl_dispatch_vector + [td_ns::num_types]; +static masked_extract_all_slices_strided_impl_fn_ptr_t + masked_extract_all_slices_strided_i64_impl_dispatch_vector + [td_ns::num_types]; using dpctl::tensor::kernels::indexing:: masked_extract_some_slices_strided_impl_fn_ptr_t; static masked_extract_some_slices_strided_impl_fn_ptr_t - masked_extract_some_slices_strided_impl_dispatch_vector[td_ns::num_types]; + masked_extract_some_slices_strided_i32_impl_dispatch_vector + [td_ns::num_types]; +static masked_extract_some_slices_strided_impl_fn_ptr_t + masked_extract_some_slices_strided_i64_impl_dispatch_vector + [td_ns::num_types]; void populate_masked_extract_dispatch_vectors(void) { - using dpctl::tensor::kernels::indexing::MaskExtractAllSlicesStridedFactory; + using dpctl::tensor::kernels::indexing:: + MaskExtractAllSlicesStridedFactoryForInt32; td_ns::DispatchVectorBuilder< masked_extract_all_slices_strided_impl_fn_ptr_t, - MaskExtractAllSlicesStridedFactory, td_ns::num_types> + MaskExtractAllSlicesStridedFactoryForInt32, td_ns::num_types> dvb1; dvb1.populate_dispatch_vector( - masked_extract_all_slices_strided_impl_dispatch_vector); + masked_extract_all_slices_strided_i32_impl_dispatch_vector); - using dpctl::tensor::kernels::indexing::MaskExtractSomeSlicesStridedFactory; + using dpctl::tensor::kernels::indexing:: + MaskExtractAllSlicesStridedFactoryForInt64; td_ns::DispatchVectorBuilder< - masked_extract_some_slices_strided_impl_fn_ptr_t, - MaskExtractSomeSlicesStridedFactory, td_ns::num_types> + masked_extract_all_slices_strided_impl_fn_ptr_t, + MaskExtractAllSlicesStridedFactoryForInt64, td_ns::num_types> dvb2; dvb2.populate_dispatch_vector( - masked_extract_some_slices_strided_impl_dispatch_vector); + masked_extract_all_slices_strided_i64_impl_dispatch_vector); + + using dpctl::tensor::kernels::indexing:: + MaskExtractSomeSlicesStridedFactoryForInt32; + td_ns::DispatchVectorBuilder< + masked_extract_some_slices_strided_impl_fn_ptr_t, + MaskExtractSomeSlicesStridedFactoryForInt32, td_ns::num_types> + dvb3; + dvb3.populate_dispatch_vector( + masked_extract_some_slices_strided_i32_impl_dispatch_vector); + + using dpctl::tensor::kernels::indexing:: + MaskExtractSomeSlicesStridedFactoryForInt64; + td_ns::DispatchVectorBuilder< + masked_extract_some_slices_strided_impl_fn_ptr_t, + MaskExtractSomeSlicesStridedFactoryForInt64, td_ns::num_types> + dvb4; + dvb4.populate_dispatch_vector( + masked_extract_some_slices_strided_i64_impl_dispatch_vector); } std::pair @@ -357,12 +414,15 @@ py_extract(dpctl::tensor::usm_ndarray src, int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); + constexpr int int32_typeid = static_cast(td_ns::typenum_t::INT32); constexpr int int64_typeid = static_cast(td_ns::typenum_t::INT64); - if (cumsum_typeid != int64_typeid) { - throw py::value_error( - "Unexact data type of cumsum array, expecting 'int64'"); + if (cumsum_typeid != int32_typeid && cumsum_typeid != int64_typeid) { + throw py::value_error("Unexpected data type of cumsum array, expecting " + "'int32' or 'int64'"); } + const bool use_i32 = (cumsum_typeid == int32_typeid); + if (src_typeid != dst_typeid) { throw py::value_error( "Destination array must have the same elemental data types"); @@ -383,7 +443,11 @@ py_extract(dpctl::tensor::usm_ndarray src, if (axis_start == 0 && axis_end == src_nd) { // empty orthogonal directions auto fn = - masked_extract_all_slices_strided_impl_dispatch_vector[src_typeid]; + (use_i32) + ? masked_extract_all_slices_strided_i32_impl_dispatch_vector + [src_typeid] + : masked_extract_all_slices_strided_i64_impl_dispatch_vector + [src_typeid]; assert(dst_shape_vec.size() == 1); assert(dst_strides_vec.size() == 1); @@ -424,7 +488,11 @@ py_extract(dpctl::tensor::usm_ndarray src, else { // non-empty othogonal directions auto fn = - masked_extract_some_slices_strided_impl_dispatch_vector[src_typeid]; + (use_i32) + ? masked_extract_some_slices_strided_i32_impl_dispatch_vector + [src_typeid] + : masked_extract_some_slices_strided_i64_impl_dispatch_vector + [src_typeid]; int masked_src_nd = mask_span_sz; int ortho_nd = src_nd - masked_src_nd; @@ -532,31 +600,55 @@ using dpctl::tensor::kernels::indexing:: masked_place_all_slices_strided_impl_fn_ptr_t; static masked_place_all_slices_strided_impl_fn_ptr_t - masked_place_all_slices_strided_impl_dispatch_vector[td_ns::num_types]; + masked_place_all_slices_strided_i32_impl_dispatch_vector[td_ns::num_types]; +static masked_place_all_slices_strided_impl_fn_ptr_t + masked_place_all_slices_strided_i64_impl_dispatch_vector[td_ns::num_types]; using dpctl::tensor::kernels::indexing:: masked_place_some_slices_strided_impl_fn_ptr_t; static masked_place_some_slices_strided_impl_fn_ptr_t - masked_place_some_slices_strided_impl_dispatch_vector[td_ns::num_types]; + masked_place_some_slices_strided_i32_impl_dispatch_vector[td_ns::num_types]; +static masked_place_some_slices_strided_impl_fn_ptr_t + masked_place_some_slices_strided_i64_impl_dispatch_vector[td_ns::num_types]; void populate_masked_place_dispatch_vectors(void) { - using dpctl::tensor::kernels::indexing::MaskPlaceAllSlicesStridedFactory; + using dpctl::tensor::kernels::indexing:: + MaskPlaceAllSlicesStridedFactoryForInt32; td_ns::DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector( - masked_place_all_slices_strided_impl_dispatch_vector); + masked_place_all_slices_strided_i32_impl_dispatch_vector); - using dpctl::tensor::kernels::indexing::MaskPlaceSomeSlicesStridedFactory; - td_ns::DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector( - masked_place_some_slices_strided_impl_dispatch_vector); + masked_place_all_slices_strided_i64_impl_dispatch_vector); + + using dpctl::tensor::kernels::indexing:: + MaskPlaceSomeSlicesStridedFactoryForInt32; + td_ns::DispatchVectorBuilder + dvb3; + dvb3.populate_dispatch_vector( + masked_place_some_slices_strided_i32_impl_dispatch_vector); + + using dpctl::tensor::kernels::indexing:: + MaskPlaceSomeSlicesStridedFactoryForInt64; + td_ns::DispatchVectorBuilder + dvb4; + dvb4.populate_dispatch_vector( + masked_place_some_slices_strided_i64_impl_dispatch_vector); } /* @@ -651,13 +743,15 @@ py_place(dpctl::tensor::usm_ndarray dst, int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum); int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); + constexpr int int32_typeid = static_cast(td_ns::typenum_t::INT32); constexpr int int64_typeid = static_cast(td_ns::typenum_t::INT64); - if (cumsum_typeid != int64_typeid) { - throw py::value_error( - "Unexact data type of cumsum array, expecting 'int64'"); + if (cumsum_typeid != int32_typeid && cumsum_typeid != int64_typeid) { + throw py::value_error("Unexpected data type of cumsum array, expecting " + "'int32' or 'int64'"); } - // FIXME: should types be the same? + const bool use_i32 = (cumsum_typeid == int32_typeid); + if (dst_typeid != rhs_typeid) { throw py::value_error( "Destination array must have the same elemental data types"); @@ -677,8 +771,11 @@ py_place(dpctl::tensor::usm_ndarray dst, std::vector host_task_events{}; if (axis_start == 0 && axis_end == dst_nd) { // empty orthogonal directions - auto fn = - masked_place_all_slices_strided_impl_dispatch_vector[dst_typeid]; + auto fn = (use_i32) + ? masked_place_all_slices_strided_i32_impl_dispatch_vector + [dst_typeid] + : masked_place_all_slices_strided_i64_impl_dispatch_vector + [dst_typeid]; assert(rhs_shape_vec.size() == 1); assert(rhs_strides_vec.size() == 1); @@ -719,7 +816,11 @@ py_place(dpctl::tensor::usm_ndarray dst, else { // non-empty othogonal directions auto fn = - masked_place_some_slices_strided_impl_dispatch_vector[dst_typeid]; + (use_i32) + ? masked_place_some_slices_strided_i32_impl_dispatch_vector + [dst_typeid] + : masked_place_some_slices_strided_i64_impl_dispatch_vector + [dst_typeid]; int masked_dst_nd = mask_span_sz; int ortho_nd = dst_nd - masked_dst_nd; @@ -820,13 +921,15 @@ py_place(dpctl::tensor::usm_ndarray dst, // Non-zero -std::pair py_nonzero( - dpctl::tensor::usm_ndarray cumsum, // int64 input array, 1D, C-contiguous - dpctl::tensor::usm_ndarray indexes, // int64 2D output array, C-contiguous - std::vector - mask_shape, // shape of array from which cumsum was computed - sycl::queue exec_q, - std::vector const &depends) +std::pair +py_nonzero(dpctl::tensor::usm_ndarray + cumsum, // int32/int64 input array, 1D, C-contiguous + dpctl::tensor::usm_ndarray + indexes, // int32/int64 2D output array, C-contiguous + std::vector + mask_shape, // shape of array from which cumsum was computed + sycl::queue exec_q, + std::vector const &depends) { if (!dpctl::utils::queues_are_compatible(exec_q, {cumsum, indexes})) { throw py::value_error( @@ -874,11 +977,15 @@ std::pair py_nonzero( int cumsum_typenum = cumsum.get_typenum(); int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); - // cumsum must be int64_t only + constexpr int int32_typeid = static_cast(td_ns::typenum_t::INT32); constexpr int int64_typeid = static_cast(td_ns::typenum_t::INT64); - if (cumsum_typeid != int64_typeid || indexes_typeid != int64_typeid) { - throw py::value_error( - "Cumulative sum array and index array must have int64 data-type"); + + // cumsum must be int32_t or int64_t only + if ((cumsum_typeid != int32_typeid && cumsum_typeid != int64_typeid) || + (indexes_typeid != int32_typeid && indexes_typeid != int64_typeid)) + { + throw py::value_error("Cumulative sum array and index array must have " + "int32 or int64 data-type"); } if (cumsum_sz == 0) { @@ -923,12 +1030,21 @@ std::pair py_nonzero( all_deps.insert(all_deps.end(), depends.begin(), depends.end()); all_deps.push_back(copy_ev); + using dpctl::tensor::kernels::indexing::non_zero_indexes_fn_ptr_t; using dpctl::tensor::kernels::indexing::non_zero_indexes_impl; + int fn_index = ((cumsum_typeid == int64_typeid) ? 1 : 0) + + ((indexes_typeid == int64_typeid) ? 2 : 0); + std::array fn_impls = { + non_zero_indexes_impl, + non_zero_indexes_impl, + non_zero_indexes_impl, + non_zero_indexes_impl}; + auto fn = fn_impls[fn_index]; + sycl::event non_zero_indexes_ev = - non_zero_indexes_impl( - exec_q, cumsum_sz, nz_elems, ndim, cumsum.get_data(), - indexes.get_data(), src_shape_device_ptr, all_deps); + fn(exec_q, cumsum_sz, nz_elems, ndim, cumsum.get_data(), + indexes.get_data(), src_shape_device_ptr, all_deps); sycl::event temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(non_zero_indexes_ev); diff --git a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp index e6e8a54ed6..647d78bbbb 100644 --- a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp @@ -68,8 +68,8 @@ py_place(dpctl::tensor::usm_ndarray dst, extern void populate_masked_place_dispatch_vectors(void); extern std::pair py_nonzero( - dpctl::tensor::usm_ndarray cumsum, // int64 input array, 1D, C-contiguous - dpctl::tensor::usm_ndarray indexes, // int64 2D output array, C-contiguous + dpctl::tensor::usm_ndarray cumsum, // int32 input array, 1D, C-contiguous + dpctl::tensor::usm_ndarray indexes, // int32 2D output array, C-contiguous std::vector mask_shape, // shape of array from which cumsum was computed sycl::queue exec_q, From 7ade8294e8eec4c20e44548d1f9f3973c019d976 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 24 Jul 2023 14:32:40 -0500 Subject: [PATCH 2/2] If size of mask allows, using int32 type for cumsum to improve performance --- dpctl/tensor/_copy_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 63aca6ad06..541d1d4fae 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -31,6 +31,8 @@ ":class:`dpctl.tensor.usm_ndarray`." ) +int32_t_max = 2147483648 + def _copy_to_numpy(ary): if not isinstance(ary, dpt.usm_ndarray): @@ -482,7 +484,8 @@ def _extract_impl(ary, ary_mask, axis=0): "Parameter p is inconsistent with input array dimensions" ) mask_nelems = ary_mask.size - cumsum = dpt.empty(mask_nelems, dtype=dpt.int64, device=ary_mask.device) + cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64 + cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device) exec_q = cumsum.sycl_queue mask_count = ti.mask_positions(ary_mask, cumsum, sycl_queue=exec_q) dst_shape = ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :] @@ -509,8 +512,9 @@ def _nonzero_impl(ary): exec_q = ary.sycl_queue usm_type = ary.usm_type mask_nelems = ary.size + cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64 cumsum = dpt.empty( - mask_nelems, dtype=dpt.int64, sycl_queue=exec_q, order="C" + mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C" ) mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q) indexes = dpt.empty( @@ -604,7 +608,8 @@ def _place_impl(ary, ary_mask, vals, axis=0): "Parameter p is inconsistent with input array dimensions" ) mask_nelems = ary_mask.size - cumsum = dpt.empty(mask_nelems, dtype=dpt.int64, device=ary_mask.device) + cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64 + cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device) exec_q = cumsum.sycl_queue mask_count = ti.mask_positions(ary_mask, cumsum, sycl_queue=exec_q) expected_vals_shape = (