Skip to content

Commit ff93cfc

Browse files
Optimization to kernel of tensor.nonzero
Use shared local memory to improve global memory bandwidth.
1 parent a3a070c commit ff93cfc

File tree

1 file changed

+44
-17
lines changed

1 file changed

+44
-17
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -744,31 +744,58 @@ sycl::event non_zero_indexes_impl(sycl::queue &exec_q,
744744
const indT1 *cumsum_data = reinterpret_cast<const indT1 *>(cumsum_cp);
745745
indT2 *indexes_data = reinterpret_cast<indT2 *>(indexes_cp);
746746

747+
constexpr std::size_t nominal_lws = 256u;
748+
const std::size_t masked_extent = iter_size;
749+
const std::size_t lws = std::min(masked_extent, nominal_lws);
750+
751+
const std::size_t n_groups = (masked_extent + lws - 1) / lws;
752+
sycl::range<1> gRange{n_groups * lws};
753+
sycl::range<1> lRange{lws};
754+
755+
sycl::nd_range<1> ndRange{gRange, lRange};
756+
747757
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
748758
cgh.depends_on(depends);
749-
cgh.parallel_for<class non_zero_indexes_krn<indT1, indT2>>(
750-
sycl::range<1>(iter_size), [=](sycl::id<1> idx) {
751-
auto i = idx[0];
752759

753-
auto cs_curr_val = cumsum_data[i] - 1;
754-
auto cs_prev_val = (i > 0) ? cumsum_data[i - 1] : indT1(0);
755-
bool cond = (cs_curr_val == cs_prev_val);
760+
const std::size_t lacc_size = std::min(lws, masked_extent) + 1;
761+
sycl::local_accessor<indT1, 1> lacc(lacc_size, cgh);
762+
763+
using KernelName = class non_zero_indexes_krn<indT1, indT2>;
756764

765+
cgh.parallel_for<KernelName>(ndRange, [=](sycl::nd_item<1> ndit) {
766+
const std::size_t group_i = ndit.get_group(0);
767+
const std::uint32_t l_i = ndit.get_local_id(0);
768+
const std::uint32_t lws = ndit.get_local_range(0);
769+
770+
const std::size_t masked_block_start = group_i * lws;
771+
772+
for (std::uint32_t i = l_i; i < lacc.size(); i += lws) {
773+
const size_t offset = masked_block_start + i;
774+
lacc[i] = (offset == 0) ? indT1(0)
775+
: (offset - 1 < masked_extent)
776+
? cumsum_data[offset - 1]
777+
: cumsum_data[masked_extent - 1] + 1;
778+
}
779+
780+
sycl::group_barrier(ndit.get_group());
781+
782+
const std::size_t i = masked_block_start + l_i;
783+
const auto cs_val = lacc[l_i];
784+
const bool cond = (lacc[l_i + 1] == cs_val + 1);
785+
786+
if (cond && (i < masked_extent)) {
757787
ssize_t i_ = static_cast<ssize_t>(i);
758788
for (int dim = nd; --dim > 0;) {
759-
auto sd = mask_shape[dim];
760-
ssize_t q = i_ / sd;
761-
ssize_t r = (i_ - q * sd);
762-
if (cond) {
763-
indexes_data[cs_curr_val + dim * nz_elems] =
764-
static_cast<indT2>(r);
765-
}
789+
const auto sd = mask_shape[dim];
790+
const ssize_t q = i_ / sd;
791+
const ssize_t r = (i_ - q * sd);
792+
indexes_data[cs_val + dim * nz_elems] =
793+
static_cast<indT2>(r);
766794
i_ = q;
767795
}
768-
if (cond) {
769-
indexes_data[cs_curr_val] = static_cast<indT2>(i_);
770-
}
771-
});
796+
indexes_data[cs_val] = static_cast<indT2>(i_);
797+
}
798+
});
772799
});
773800

774801
return comp_ev;

0 commit comments

Comments
 (0)