@@ -744,31 +744,58 @@ sycl::event non_zero_indexes_impl(sycl::queue &exec_q,
744
744
const indT1 *cumsum_data = reinterpret_cast <const indT1 *>(cumsum_cp);
745
745
indT2 *indexes_data = reinterpret_cast <indT2 *>(indexes_cp);
746
746
747
+ constexpr std::size_t nominal_lws = 256u ;
748
+ const std::size_t masked_extent = iter_size;
749
+ const std::size_t lws = std::min (masked_extent, nominal_lws);
750
+
751
+ const std::size_t n_groups = (masked_extent + lws - 1 ) / lws;
752
+ sycl::range<1 > gRange {n_groups * lws};
753
+ sycl::range<1 > lRange{lws};
754
+
755
+ sycl::nd_range<1 > ndRange{gRange , lRange};
756
+
747
757
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
748
758
cgh.depends_on (depends);
749
- cgh.parallel_for <class non_zero_indexes_krn <indT1, indT2>>(
750
- sycl::range<1 >(iter_size), [=](sycl::id<1 > idx) {
751
- auto i = idx[0 ];
752
759
753
- auto cs_curr_val = cumsum_data[i] - 1 ;
754
- auto cs_prev_val = (i > 0 ) ? cumsum_data[i - 1 ] : indT1 (0 );
755
- bool cond = (cs_curr_val == cs_prev_val);
760
+ const std::size_t lacc_size = std::min (lws, masked_extent) + 1 ;
761
+ sycl::local_accessor<indT1, 1 > lacc (lacc_size, cgh);
762
+
763
+ using KernelName = class non_zero_indexes_krn <indT1, indT2>;
756
764
765
+ cgh.parallel_for <KernelName>(ndRange, [=](sycl::nd_item<1 > ndit) {
766
+ const std::size_t group_i = ndit.get_group (0 );
767
+ const std::uint32_t l_i = ndit.get_local_id (0 );
768
+ const std::uint32_t lws = ndit.get_local_range (0 );
769
+
770
+ const std::size_t masked_block_start = group_i * lws;
771
+
772
+ for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
773
+ const size_t offset = masked_block_start + i;
774
+ lacc[i] = (offset == 0 ) ? indT1 (0 )
775
+ : (offset - 1 < masked_extent)
776
+ ? cumsum_data[offset - 1 ]
777
+ : cumsum_data[masked_extent - 1 ] + 1 ;
778
+ }
779
+
780
+ sycl::group_barrier (ndit.get_group ());
781
+
782
+ const std::size_t i = masked_block_start + l_i;
783
+ const auto cs_val = lacc[l_i];
784
+ const bool cond = (lacc[l_i + 1 ] == cs_val + 1 );
785
+
786
+ if (cond && (i < masked_extent)) {
757
787
ssize_t i_ = static_cast <ssize_t >(i);
758
788
for (int dim = nd; --dim > 0 ;) {
759
- auto sd = mask_shape[dim];
760
- ssize_t q = i_ / sd;
761
- ssize_t r = (i_ - q * sd);
762
- if (cond) {
763
- indexes_data[cs_curr_val + dim * nz_elems] =
764
- static_cast <indT2>(r);
765
- }
789
+ const auto sd = mask_shape[dim];
790
+ const ssize_t q = i_ / sd;
791
+ const ssize_t r = (i_ - q * sd);
792
+ indexes_data[cs_val + dim * nz_elems] =
793
+ static_cast <indT2>(r);
766
794
i_ = q;
767
795
}
768
- if (cond) {
769
- indexes_data[cs_curr_val] = static_cast <indT2>(i_);
770
- }
771
- });
796
+ indexes_data[cs_val] = static_cast <indT2>(i_);
797
+ }
798
+ });
772
799
});
773
800
774
801
return comp_ev;
0 commit comments