diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index b82f49c1e0fa1..5b9d735d60445 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -11,6 +11,8 @@ #pragma once #include +#include +#include #include #include @@ -25,11 +27,11 @@ namespace detail { // following two functions could be useless if std::[lower|upper]_bound worked // well template -std::size_t lower_bound(Acc acc, std::size_t first, std::size_t last, - const Value &value, Compare comp) { - std::size_t n = last - first; - std::size_t cur = n; - std::size_t it; +size_t lower_bound(Acc acc, size_t first, size_t last, const Value &value, + Compare comp) { + size_t n = last - first; + size_t cur = n; + size_t it; while (n > 0) { it = first; cur = n / 2; @@ -43,9 +45,8 @@ std::size_t lower_bound(Acc acc, std::size_t first, std::size_t last, } template -std::size_t upper_bound(Acc acc, const std::size_t first, - const std::size_t last, const Value &value, - Compare comp) { +size_t upper_bound(Acc acc, const size_t first, const size_t last, + const Value &value, Compare comp) { return detail::lower_bound(acc, first, last, value, [comp](auto x, auto y) { return !comp(y, x); }); } @@ -72,7 +73,7 @@ struct GetValueType> { // since we couldn't assign data to raw memory, it's better to use placement // for first assignment template -void set_value(Acc ptr, const std::size_t idx, const T &val, bool is_first) { +void set_value(Acc ptr, const size_t idx, const T &val, bool is_first) { if (is_first) { ::new (ptr + idx) T(val); } else { @@ -81,23 +82,23 @@ void set_value(Acc ptr, const std::size_t idx, const T &val, bool is_first) { } template -void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, - const std::size_t start_1, const std::size_t end_1, - const std::size_t end_2, const std::size_t start_out, Compare comp, - const std::size_t chunk, bool is_first) { - const std::size_t start_2 = end_1; +void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1, + const size_t start_1, const size_t end_1, const size_t end_2, + const size_t start_out, Compare comp, const size_t chunk, + bool is_first) { + const size_t start_2 = end_1; // Borders of the sequences to merge within this call - const std::size_t local_start_1 = - sycl::min(static_cast(offset + start_1), end_1); - const std::size_t local_end_1 = - sycl::min(static_cast(local_start_1 + chunk), end_1); - const std::size_t local_start_2 = - sycl::min(static_cast(offset + start_2), end_2); - const std::size_t local_end_2 = - sycl::min(static_cast(local_start_2 + chunk), end_2); - - const std::size_t local_size_1 = local_end_1 - local_start_1; - const std::size_t local_size_2 = local_end_2 - local_start_2; + const size_t local_start_1 = + sycl::min(static_cast(offset + start_1), end_1); + const size_t local_end_1 = + sycl::min(static_cast(local_start_1 + chunk), end_1); + const size_t local_start_2 = + sycl::min(static_cast(offset + start_2), end_2); + const size_t local_end_2 = + sycl::min(static_cast(local_start_2 + chunk), end_2); + + const size_t local_size_1 = local_end_1 - local_start_1; + const size_t local_size_2 = local_end_2 - local_start_2; // TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st // to improve performance @@ -107,15 +108,15 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, // Reduce the range for searching within the 2nd sequence and handle bound // items find left border in 2nd sequence const auto local_l_item_1 = in_acc1[local_start_1]; - std::size_t l_search_bound_2 = + size_t l_search_bound_2 = detail::lower_bound(in_acc1, start_2, end_2, local_l_item_1, comp); - const std::size_t l_shift_1 = local_start_1 - start_1; - const std::size_t l_shift_2 = l_search_bound_2 - start_2; + const size_t l_shift_1 = local_start_1 - start_1; + const size_t l_shift_2 = l_search_bound_2 - start_2; set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1, is_first); - std::size_t r_search_bound_2{}; + size_t r_search_bound_2{}; // find right border in 2nd sequence if (local_size_1 > 1) { const auto local_r_item_1 = in_acc1[local_end_1 - 1]; @@ -129,15 +130,15 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, } // Handle intermediate items - for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) { + for (size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) { const auto intermediate_item_1 = in_acc1[idx]; // we shouldn't seek in whole 2nd sequence. Just for the part where the // 1st sequence should be l_search_bound_2 = detail::lower_bound(in_acc1, l_search_bound_2, r_search_bound_2, intermediate_item_1, comp); - const std::size_t shift_1 = idx - start_1; - const std::size_t shift_2 = l_search_bound_2 - start_2; + const size_t shift_1 = idx - start_1; + const size_t shift_2 = l_search_bound_2 - start_2; set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_1, is_first); @@ -148,22 +149,22 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, // Reduce the range for searching within the 1st sequence and handle bound // items find left border in 1st sequence const auto local_l_item_2 = in_acc1[local_start_2]; - std::size_t l_search_bound_1 = + size_t l_search_bound_1 = detail::upper_bound(in_acc1, start_1, end_1, local_l_item_2, comp); - const std::size_t l_shift_1 = l_search_bound_1 - start_1; - const std::size_t l_shift_2 = local_start_2 - start_2; + const size_t l_shift_1 = l_search_bound_1 - start_1; + const size_t l_shift_2 = local_start_2 - start_2; set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2, is_first); - std::size_t r_search_bound_1{}; + size_t r_search_bound_1{}; // find right border in 1st sequence if (local_size_2 > 1) { const auto local_r_item_2 = in_acc1[local_end_2 - 1]; r_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, end_1, local_r_item_2, comp); - const std::size_t r_shift_1 = r_search_bound_1 - start_1; - const std::size_t r_shift_2 = local_end_2 - 1 - start_2; + const size_t r_shift_1 = r_search_bound_1 - start_1; + const size_t r_shift_2 = local_end_2 - 1 - start_2; set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2, is_first); @@ -177,8 +178,8 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, l_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, r_search_bound_1, intermediate_item_2, comp); - const std::size_t shift_1 = l_search_bound_1 - start_1; - const std::size_t shift_2 = idx - start_2; + const size_t shift_1 = l_search_bound_1 - start_1; + const size_t shift_2 = idx - start_2; set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_2, is_first); @@ -187,12 +188,12 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, } template -void bubble_sort(Iter first, const std::size_t begin, const std::size_t end, +void bubble_sort(Iter first, const size_t begin, const size_t end, Compare comp) { if (begin < end) { - for (std::size_t i = begin; i < end; ++i) { + for (size_t i = begin; i < end; ++i) { // Handle intermediate items - for (std::size_t idx = i + 1; idx < end; ++idx) { + for (size_t idx = i + 1; idx < end; ++idx) { if (comp(first[idx], first[i])) { detail::swap_tuples(first[i], first[idx]); } @@ -202,12 +203,12 @@ void bubble_sort(Iter first, const std::size_t begin, const std::size_t end, } template -void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, +void merge_sort(Group group, Iter first, const size_t n, Compare comp, std::byte *scratch) { using T = typename GetValueType::type; - const std::size_t idx = group.get_local_linear_id(); - const std::size_t local = group.get_local_range().size(); - const std::size_t chunk = (n - 1) / local + 1; + const size_t idx = group.get_local_linear_id(); + const size_t local = group.get_local_range().size(); + const size_t chunk = (n - 1) / local + 1; // we need to sort within work item first bubble_sort(first, idx * chunk, sycl::min((idx + 1) * chunk, n), comp); @@ -216,13 +217,13 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, T *temp = reinterpret_cast(scratch); bool data_in_temp = false; bool is_first = true; - std::size_t sorted_size = 1; + size_t sorted_size = 1; while (sorted_size * chunk < n) { - const std::size_t start_1 = + const size_t start_1 = sycl::min(2 * sorted_size * chunk * (idx / sorted_size), n); - const std::size_t end_1 = sycl::min(start_1 + sorted_size * chunk, n); - const std::size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n); - const std::size_t offset = chunk * (idx % sorted_size); + const size_t end_1 = sycl::min(start_1 + sorted_size * chunk, n); + const size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n); + const size_t offset = chunk * (idx % sorted_size); if (!data_in_temp) { merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk, @@ -241,7 +242,7 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, // copy back if data is in a temporary storage if (data_in_temp) { - for (std::size_t i = 0; i < chunk; ++i) { + for (size_t i = 0; i < chunk; ++i) { if (idx * chunk + i < n) { first[idx * chunk + i] = temp[idx * chunk + i]; } @@ -250,6 +251,408 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, } } +// traits for ascending functors +template struct IsCompAscending { + static constexpr bool value = false; +}; +template struct IsCompAscending> { + static constexpr bool value = true; +}; + +// get number of states radix bits can represent +constexpr uint32_t getStatesInBits(uint32_t radix_bits) { + return (1 << radix_bits); +} + +//------------------------------------------------------------------------ +// Ordered traits for a given size and integral/float flag +//------------------------------------------------------------------------ + +template struct GetOrdered {}; + +template <> struct GetOrdered<1, true> { + using Type = uint8_t; + constexpr static int8_t mask = 0x80; +}; + +template <> struct GetOrdered<2, true> { + using Type = uint16_t; + constexpr static int16_t mask = 0x8000; +}; + +template <> struct GetOrdered<4, true> { + using Type = uint32_t; + constexpr static int32_t mask = 0x80000000; +}; + +template <> struct GetOrdered<8, true> { + using Type = uint64_t; + constexpr static int64_t mask = 0x8000000000000000; +}; + +template <> struct GetOrdered<2, false> { + using Type = uint16_t; + constexpr static uint32_t nmask = 0xFFFF; // for negative numbers + constexpr static uint32_t pmask = 0x8000; // for positive numbers +}; + +template <> struct GetOrdered<4, false> { + using Type = uint32_t; + constexpr static uint32_t nmask = 0xFFFFFFFF; // for negative numbers + constexpr static uint32_t pmask = 0x80000000; // for positive numbers +}; + +template <> struct GetOrdered<8, false> { + using Type = uint64_t; + constexpr static uint64_t nmask = 0xFFFFFFFFFFFFFFFF; // for negative numbers + constexpr static uint64_t pmask = 0x8000000000000000; // for positive numbers +}; + +//------------------------------------------------------------------------ +// Ordered type for a given type +//------------------------------------------------------------------------ + +// for unknown/unsupported type we do not have any trait +template struct Ordered {}; + +// for unsigned integrals we use the same type +template +struct Ordered::value && + std::is_unsigned::value>> { + using Type = ValT; +}; + +// for signed integrals or floatings we map: size -> corresponding unsigned +// integral +template +struct Ordered< + ValT, std::enable_if_t< + (std::is_integral::value && std::is_signed::value) || + std::is_floating_point::value || + std::is_same::value || + std::is_same::value>> { + using Type = + typename GetOrdered::value>::Type; +}; + +// shorthand +template using OrderedT = typename Ordered::Type; + +//------------------------------------------------------------------------ +// functions for conversion to Ordered type +//------------------------------------------------------------------------ + +// for already Ordered types (any uints) we use the same type +template +std::enable_if_t>, OrderedT> +convertToOrdered(ValT value) { + return value; +} + +// converts integral type to Ordered (in terms of bitness) type +template +std::enable_if_t>::value && + std::is_integral::value, + OrderedT> +convertToOrdered(ValT value) { + ValT result = value ^ GetOrdered::mask; + return *reinterpret_cast *>(&result); +} + +// converts floating type to Ordered (in terms of bitness) type +template +std::enable_if_t>::value && + (std::is_floating_point::value || + std::is_same::value || + std::is_same::value), + OrderedT> +convertToOrdered(ValT value) { + OrderedT uvalue = *reinterpret_cast *>(&value); + // check if value negative + OrderedT is_negative = uvalue >> (sizeof(ValT) * CHAR_BIT - 1); + // for positive: 00..00 -> 00..00 -> 10..00 + // for negative: 00..01 -> 11..11 -> 11..11 + OrderedT ordered_mask = + (is_negative * GetOrdered::nmask) | + GetOrdered::pmask; + return uvalue ^ ordered_mask; +} + +//------------------------------------------------------------------------ +// bit pattern functions +//------------------------------------------------------------------------ + +// required for descending comparator support +template struct InvertIf { + template ValT operator()(ValT value) { return value; } +}; + +// invert value if descending comparator is passed +template <> struct InvertIf { + template ValT operator()(ValT value) { return ~value; } + + // invertation for bool type have to be logical, rather than bit + bool operator()(bool value) { return !value; } +}; + +// get bit values in a certain bucket of a value +template +uint32_t getBucketValue(ValT value, uint32_t radix_iter) { + // invert value if we need to sort in descending order + value = InvertIf{}(value); + + // get bucket offset idx from the end of bit type (least significant bits) + uint32_t bucket_offset = radix_iter * radix_bits; + + // get offset mask for one bucket, e.g. + // radix_bits=2: 0000 0001 -> 0000 0100 -> 0000 0011 + OrderedT bucket_mask = (1u << radix_bits) - 1u; + + // get bits under bucket mask + return (value >> bucket_offset) & bucket_mask; +} +template ValT getDefaultValue(bool is_comp_asc) { + if (is_comp_asc) + return std::numeric_limits::max(); + else + return std::numeric_limits::lowest(); +} + +template struct ValuesAssigner { + template + void operator()(IterOutT output, size_t idx_out, IterInT input, + size_t idx_in) { + output[idx_out] = input[idx_in]; + } + + template + void operator()(IterOutT output, size_t idx_out, ValT value) { + output[idx_out] = value; + } +}; + +template <> struct ValuesAssigner { + template + void operator()(IterOutT, size_t, IterInT, size_t) {} + + template + void operator()(IterOutT, size_t, ValT) {} +}; + +// The iteration of radix sort for unknown number of elements per work item +template +void performRadixIterDynamicSize(GroupT group, + const uint32_t items_per_work_item, + const uint32_t radix_iter, const size_t n, + KeysT *keys_input, ValueT *vals_input, + KeysT *keys_output, ValueT *vals_output, + uint32_t *memory) { + const uint32_t radix_states = getStatesInBits(radix_bits); + const size_t wgsize = group.get_local_linear_range(); + const size_t idx = group.get_local_linear_id(); + + // 1.1. Zeroinitialize local memory + uint32_t *scan_memory = reinterpret_cast(memory); + for (uint32_t state = 0; state < radix_states; ++state) + scan_memory[state * wgsize + idx] = 0; + + sycl::group_barrier(group); + + // 1.2. count values and write result to private count array and count memory + for (uint32_t i = 0; i < items_per_work_item; ++i) { + const uint32_t val_idx = items_per_work_item * idx + i; + // get value, convert it to Ordered (in terms of bitness) + const auto val = + convertToOrdered((val_idx < n) ? keys_input[val_idx] + : getDefaultValue(is_comp_asc)); + // get bit values in a certain bucket of a value + const uint32_t bucket_val = + getBucketValue(val, radix_iter); + + // increment counter for this bit bucket + if (val_idx < n) + scan_memory[bucket_val * wgsize + idx]++; + } + + sycl::group_barrier(group); + + // 2.1 Scan. Upsweep: reduce over radix states + uint32_t reduced = 0; + for (uint32_t i = 0; i < radix_states; ++i) + reduced += scan_memory[idx * radix_states + i]; + + // 2.2. Exclusive scan: over work items + uint32_t scanned = + sycl::exclusive_scan_over_group(group, reduced, std::plus()); + + // 2.3. Exclusive downsweep: exclusive scan over radix states + for (uint32_t i = 0; i < radix_states; ++i) { + uint32_t value = scan_memory[idx * radix_states + i]; + scan_memory[idx * radix_states + i] = scanned; + scanned += value; + } + + sycl::group_barrier(group); + + uint32_t private_scan_memory[radix_states] = {0}; + + // 3. Reorder + for (uint32_t i = 0; i < items_per_work_item; ++i) { + const uint32_t val_idx = items_per_work_item * idx + i; + // get value, convert it to Ordered (in terms of bitness) + auto val = + convertToOrdered((val_idx < n) ? keys_input[val_idx] + : getDefaultValue(is_comp_asc)); + // get bit values in a certain bucket of a value + uint32_t bucket_val = + getBucketValue(val, radix_iter); + + uint32_t new_offset_idx = private_scan_memory[bucket_val]++ + + scan_memory[bucket_val * wgsize + idx]; + if (val_idx < n) { + keys_output[new_offset_idx] = keys_input[val_idx]; + ValuesAssigner()(vals_output, new_offset_idx, + vals_input, val_idx); + } + } +} + +// The iteration of radix sort for known number of elements per work item +template +void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter, + const uint32_t last_iter, KeysT *keys, + ValsT vals, std::byte *memory) { + const uint32_t radix_states = getStatesInBits(radix_bits); + const size_t wgsize = group.get_local_linear_range(); + const size_t idx = group.get_local_linear_id(); + + // 1.1. count per witem: create a private array for storing count values + uint32_t count_arr[items_per_work_item] = {0}; + uint32_t ranks[items_per_work_item] = {0}; + + // 1.1. Zeroinitialize local memory + uint32_t *scan_memory = reinterpret_cast(memory); + for (uint32_t i = 0; i < radix_states; ++i) + scan_memory[i * wgsize + idx] = 0; + + sycl::group_barrier(group); + + uint32_t *pointers[items_per_work_item] = {nullptr}; + // 1.2. count values and write result to private count array + for (uint32_t i = 0; i < items_per_work_item; ++i) { + // get value, convert it to Ordered (in terms of bitness) + OrderedT val = convertToOrdered(keys[i]); + // get bit values in a certain bucket of a value + uint32_t bucket_val = + getBucketValue(val, radix_iter); + pointers[i] = scan_memory + (bucket_val * wgsize + idx); + count_arr[i] = (*pointers[i])++; + } + sycl::group_barrier(group); + + // 2.1 Scan. Upsweep: reduce over radix states + uint32_t reduced = 0; + for (uint32_t i = 0; i < radix_states; ++i) + reduced += scan_memory[idx * radix_states + i]; + + // 2.2. Exclusive scan: over work items + uint32_t scanned = + sycl::exclusive_scan_over_group(group, reduced, std::plus()); + + // 2.3. Exclusive downsweep: exclusive scan over radix states + for (uint32_t i = 0; i < radix_states; ++i) { + uint32_t value = scan_memory[idx * radix_states + i]; + scan_memory[idx * radix_states + i] = scanned; + scanned += value; + } + + sycl::group_barrier(group); + + // 2.4. Fill ranks with offsets + for (uint32_t i = 0; i < items_per_work_item; ++i) + ranks[i] = count_arr[i] + *pointers[i]; + + sycl::group_barrier(group); + + // 3. Reorder + KeysT *keys_temp = reinterpret_cast(memory); + ValsT *vals_temp = reinterpret_cast( + memory + wgsize * items_per_work_item * sizeof(KeysT)); + for (uint32_t i = 0; i < items_per_work_item; ++i) { + keys_temp[ranks[i]] = keys[i]; + ValuesAssigner()(vals_temp, ranks[i], vals, i); + } + + sycl::group_barrier(group); + + // 4. Copy back to input + for (uint32_t i = 0; i < items_per_work_item; ++i) { + size_t shift = idx * items_per_work_item + i; + if constexpr (!is_blocked) { + if (radix_iter == last_iter - 1) + shift = i * wgsize + idx; + } + keys[i] = keys_temp[shift]; + ValuesAssigner()(vals, i, vals_temp, shift); + } +} + +template +void privateDynamicSort(GroupT group, KeysT *keys, ValsT *values, + const size_t n, std::byte *scratch, + const uint32_t first_bit, const uint32_t last_bit) { + const size_t wgsize = group.get_local_linear_range(); + constexpr uint32_t radix_states = getStatesInBits(radix_bits); + const uint32_t first_iter = first_bit / radix_bits; + const uint32_t last_iter = last_bit / radix_bits; + + KeysT *keys_input = keys; + ValsT *vals_input = values; + const uint32_t runtime_items_per_work_item = (n - 1) / wgsize + 1; + + // set pointers to unaligned memory + uint32_t *scan_memory = reinterpret_cast(scratch); + KeysT *keys_output = reinterpret_cast( + scratch + radix_states * wgsize * sizeof(uint32_t)); + // Adding 4 bytes extra space for keys due to specifics of some hardware + // architectures. + ValsT *vals_output = reinterpret_cast( + keys_output + is_key_value_sort * n * sizeof(KeysT) + alignof(uint32_t)); + + for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) { + performRadixIterDynamicSize( + group, runtime_items_per_work_item, radix_iter, n, keys_input, + vals_input, keys_output, vals_output, scan_memory); + + sycl::group_barrier(group); + + std::swap(keys_input, keys_output); + std::swap(vals_input, vals_output); + } +} + +template +void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch, + const uint32_t first_bit, const uint32_t last_bit) { + + const uint32_t first_iter = first_bit / radix_bits; + const uint32_t last_iter = last_bit / radix_bits; + + for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) { + performRadixIterStaticSize( + group, radix_iter, last_iter, keys, values, scratch); + sycl::group_barrier(group); + } +} + } // namespace detail } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp index d390134f380ee..7a7d5283bd8ec 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp @@ -10,13 +10,14 @@ #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) #include +#include namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { namespace ext::oneapi::experimental { // ---- group helpers -template class group_with_scratchpad { +template class group_with_scratchpad { Group g; sycl::span scratch; @@ -31,10 +32,10 @@ template class group_with_scratchpad { template > class default_sorter { Compare comp; std::byte *scratch; - std::size_t scratch_size; + size_t scratch_size; public: - template + template default_sorter(sycl::span scratch_, Compare comp_ = Compare()) : comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size()) {} @@ -60,7 +61,7 @@ template > class default_sorter { #ifdef __SYCL_DEVICE_ONLY__ auto range_size = g.get_local_range().size(); if (scratch_size >= memory_required(Group::fence_scope, range_size)) { - std::size_t local_id = g.get_local_linear_id(); + size_t local_id = g.get_local_linear_id(); T *temp = reinterpret_cast(scratch); ::new (temp + local_id) T(val); sycl::detail::merge_sort(g, temp, range_size, comp, @@ -78,18 +79,119 @@ template > class default_sorter { } template - static constexpr std::size_t memory_required(sycl::memory_scope, - std::size_t range_size) { + static constexpr size_t memory_required(sycl::memory_scope, + size_t range_size) { return range_size * sizeof(T) + alignof(T); } template - static constexpr std::size_t memory_required(sycl::memory_scope scope, - sycl::range r) { + static constexpr size_t memory_required(sycl::memory_scope scope, + sycl::range r) { return 2 * memory_required(scope, r.size()); } }; +enum class sorting_order { ascending, descending }; + +namespace detail { + +template +struct ConvertToComp { + using Type = std::less; +}; + +template struct ConvertToComp { + using Type = std::greater; +}; +} // namespace detail + +template +class radix_sorter { + + std::byte *scratch = nullptr; + uint32_t first_bit = 0; + uint32_t last_bit = 0; + size_t scratch_size = 0; + + static constexpr uint32_t bits = BitsPerPass; + +public: + template + radix_sorter(sycl::span scratch_, + const std::bitset mask = + std::bitset( + std::numeric_limits::max())) + : scratch(scratch_.data()), scratch_size(scratch_.size()) { + static_assert((std::is_arithmetic::value || + std::is_same::value || + std::is_same::value), + "radix sort is not usable"); + + first_bit = 0; + while (first_bit < mask.size() && !mask[first_bit]) + ++first_bit; + + last_bit = first_bit; + while (last_bit < mask.size() && mask[last_bit]) + ++last_bit; + } + + template + void operator()(GroupT g, PtrT first, PtrT last) { + (void)g; + (void)first; + (void)last; +#ifdef __SYCL_DEVICE_ONLY__ + sycl::detail::privateDynamicSort( + g, first, /*empty*/ first, (last - first) > 0 ? (last - first) : 0, + scratch, first_bit, last_bit); +#else + throw sycl::exception( + std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()), + "radix_sorter is not supported on host device."); +#endif + } + + template ValT operator()(GroupT g, ValT val) { + (void)g; + (void)val; +#ifdef __SYCL_DEVICE_ONLY__ + ValT result[]{val}; + sycl::detail::privateStaticSort( + g, result, /*empty*/ result, scratch, first_bit, last_bit); + return result[0]; +#else + throw sycl::exception( + std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()), + "radix_sorter is not supported on host device."); +#endif + } + + static constexpr size_t memory_required(sycl::memory_scope scope, + size_t range_size) { + // Scope is not important so far + (void)scope; + return range_size * sizeof(ValT) + + (1 << bits) * range_size * sizeof(uint32_t) + alignof(uint32_t); + } + + // memory_helpers + template + static constexpr size_t memory_required(sycl::memory_scope scope, + sycl::range local_range) { + // Scope is not important so far + (void)scope; + return std::max(local_range.size() * sizeof(ValT), + local_range.size() * (1 << bits) * sizeof(uint32_t)); + } +}; + } // namespace ext::oneapi::experimental } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp index beea117bbc9d0..d1b7a4fefd1a5 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp @@ -81,7 +81,7 @@ sort_over_group(Group group, T value, Sorter sorter) { #endif } -template +template typename std::enable_if::value, T>::type sort_over_group(experimental::group_with_scratchpad exec, T value, Compare comp) { @@ -90,7 +90,7 @@ sort_over_group(experimental::group_with_scratchpad exec, experimental::default_sorter(exec.get_memory(), comp)); } -template +template typename std::enable_if>, T>::type sort_over_group(experimental::group_with_scratchpad exec, T value) { @@ -116,7 +116,7 @@ joint_sort(Group group, Iter first, Iter last, Sorter sorter) { #endif } -template +template typename std::enable_if::value, void>::type joint_sort(experimental::group_with_scratchpad exec, Iter first, @@ -125,7 +125,7 @@ joint_sort(experimental::group_with_scratchpad exec, Iter first, experimental::default_sorter(exec.get_memory(), comp)); } -template +template typename std::enable_if>, void>::type joint_sort(experimental::group_with_scratchpad exec, Iter first, Iter last) { diff --git a/sycl/include/sycl/group_algorithm.hpp b/sycl/include/sycl/group_algorithm.hpp index e521e624a6820..66a9e155c0aa6 100644 --- a/sycl/include/sycl/group_algorithm.hpp +++ b/sycl/include/sycl/group_algorithm.hpp @@ -15,7 +15,6 @@ #include #include #include -#include #include #include #include diff --git a/sycl/include/sycl/sycl.hpp b/sycl/include/sycl/sycl.hpp index 5664f7270c7d4..14b9361714e49 100644 --- a/sycl/include/sycl/sycl.hpp +++ b/sycl/include/sycl/sycl.hpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include