diff --git a/sycl/include/CL/sycl/detail/type_traits.hpp b/sycl/include/CL/sycl/detail/type_traits.hpp index bde91d54f49ef..3a50940c25a03 100644 --- a/sycl/include/CL/sycl/detail/type_traits.hpp +++ b/sycl/include/CL/sycl/detail/type_traits.hpp @@ -157,6 +157,10 @@ template class S> using is_gen_based_on_type_sizeof = bool_constant::value && (sizeof(vector_element_t) == N)>; +template struct is_vec : std::false_type {}; +template +struct is_vec> : std::true_type {}; + // is_integral template struct is_integral : std::is_integral> {}; diff --git a/sycl/include/CL/sycl/intel/sub_group.hpp b/sycl/include/CL/sycl/intel/sub_group.hpp index 8de208238b71a..e7eeb9d8844d9 100644 --- a/sycl/include/CL/sycl/intel/sub_group.hpp +++ b/sycl/include/CL/sycl/intel/sub_group.hpp @@ -8,15 +8,21 @@ #pragma once +#include #include #include +#include #include #include #include #include #include #include + +#include // std::memcpy +#include // std::bit_cast #include + #ifdef __SYCL_DEVICE_ONLY__ __SYCL_INLINE namespace cl { @@ -25,69 +31,157 @@ template class multi_ptr; namespace detail { -template struct is_vec : std::false_type {}; -template -struct is_vec> : std::true_type {}; +namespace sub_group { -template -static typename std::enable_if< - !detail::is_floating_point::value && std::is_signed::value, T>::type -calc(T x, intel::minimum op) { - return __spirv_GroupSMin(__spv::Scope::Subgroup, O, x); +template T broadcast(T x, id<1> local_id) { + using OCLT = detail::ConvertToOpenCLType_t; + return __spirv_GroupBroadcast(__spv::Scope::Subgroup, OCLT(x), + local_id.get(0)); } -template -static typename std::enable_if< - !detail::is_floating_point::value && std::is_unsigned::value, T>::type -calc(T x, intel::minimum op) { - return __spirv_GroupUMin(__spv::Scope::Subgroup, O, x); -} +#define __SYCL_SG_GENERATE_BODY_1ARG(name, SPIRVOperation) \ + template T name(T x, id<1> local_id) { \ + using OCLT = detail::ConvertToOpenCLType_t; \ + return __spirv_##SPIRVOperation(OCLT(x), local_id.get(0)); \ + } -template -static typename std::enable_if::value, T>::type -calc(T x, intel::minimum op) { - return __spirv_GroupFMin(__spv::Scope::Subgroup, O, x); -} +__SYCL_SG_GENERATE_BODY_1ARG(shuffle, SubgroupShuffleINTEL) +__SYCL_SG_GENERATE_BODY_1ARG(shuffle_xor, SubgroupShuffleXorINTEL) + +#undef __SYCL_SG_GENERATE_BODY_1ARG + +#define __SYCL_SG_GENERATE_BODY_2ARG(name, SPIRVOperation) \ + template T name(T A, T B, uint32_t Delta) { \ + using OCLT = detail::ConvertToOpenCLType_t; \ + return __spirv_##SPIRVOperation(OCLT(A), OCLT(B), Delta); \ + } + +__SYCL_SG_GENERATE_BODY_2ARG(shuffle_down, SubgroupShuffleDownINTEL) +__SYCL_SG_GENERATE_BODY_2ARG(shuffle_up, SubgroupShuffleUpINTEL) -template -static typename std::enable_if< - !detail::is_floating_point::value && std::is_signed::value, T>::type -calc(T x, intel::maximum op) { - return __spirv_GroupSMax(__spv::Scope::Subgroup, O, x); +#undef __SYCL_SG_GENERATE_BODY_2ARG + +// Selects 8-bit, 16-bit or 32-bit type depending on size of T. If T doesn't +// maps to mentioned types, then void is returned +template +using SelectBlockT = + select_apply_cl_scalar_t; + +template +using AcceptableForLoadStore = + bool_constant>::value && + Space == access::address_space::global_space>; + +// TODO: move this to public cl::sycl::bit_cast as extension? +template To bit_cast(const From &from) { +#if __cpp_lib_bit_cast + return std::bit_cast(from); +#else + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif // __has_builtin + +#if __has_builtin(__builtin_bit_cast) + return __builtin_bit_cast(To, from); +#else + To to; + std::memcpy(&to, &from, sizeof(To)); + return to; +#endif // __has_builtin(__builtin_bit_cast) +#endif // __cpp_lib_bit_cast } -template -static typename std::enable_if< - !detail::is_floating_point::value && std::is_unsigned::value, T>::type -calc(T x, intel::maximum op) { - return __spirv_GroupUMax(__spv::Scope::Subgroup, O, x); +template +T load(const multi_ptr src) { + using BlockT = SelectBlockT; + using PtrT = detail::ConvertToOpenCLType_t>; + + BlockT Ret = + __spirv_SubgroupBlockReadINTEL(reinterpret_cast(src.get())); + + return bit_cast(Ret); } -template -static typename std::enable_if::value, T>::type -calc(T x, intel::maximum op) { - return __spirv_GroupFMax(__spv::Scope::Subgroup, O, x); +template +vec load(const multi_ptr src) { + using BlockT = SelectBlockT; + using VecT = detail::ConvertToOpenCLType_t>; + using PtrT = detail::ConvertToOpenCLType_t>; + + VecT Ret = + __spirv_SubgroupBlockReadINTEL(reinterpret_cast(src.get())); + + return bit_cast::vector_t>(Ret); } -template -static typename std::enable_if< - !detail::is_floating_point::value && std::is_integral::value, T>::type -calc(T x, intel::plus op) { - return __spirv_GroupIAdd(__spv::Scope::Subgroup, O, x); +template +void store(multi_ptr dst, const T &x) { + using BlockT = SelectBlockT; + using PtrT = detail::ConvertToOpenCLType_t>; + + __spirv_SubgroupBlockWriteINTEL(reinterpret_cast(dst.get()), + bit_cast(x)); } -template -static typename std::enable_if::value, T>::type -calc(T x, intel::plus op) { - return __spirv_GroupFAdd(__spv::Scope::Subgroup, O, x); +template +void store(multi_ptr dst, const vec &x) { + using BlockT = SelectBlockT; + using VecT = detail::ConvertToOpenCLType_t>; + using PtrT = detail::ConvertToOpenCLType_t>; + + __spirv_SubgroupBlockWriteINTEL(reinterpret_cast(dst.get()), + bit_cast(x)); } +struct GroupOpISigned {}; struct GroupOpIUnsigned {}; struct GroupOpFP {}; + +template struct GroupOpTag; + +template +struct GroupOpTag::value>> { + using type = GroupOpISigned; +}; + +template +struct GroupOpTag::value>> { + using type = GroupOpIUnsigned; +}; + +template +struct GroupOpTag::value>> { + using type = GroupOpFP; +}; + +#define __SYCL_SG_CALC_OVERLOAD(GroupTag, SPIRVOperation, BinaryOperation) \ + template \ + static T calc(GroupTag, T x, BinaryOperation op) { \ + using OCLT = detail::ConvertToOpenCLType_t; \ + OCLT Arg = x; \ + OCLT Ret = __spirv_Group##SPIRVOperation(__spv::Scope::Subgroup, O, Arg); \ + return Ret; \ + } + +__SYCL_SG_CALC_OVERLOAD(GroupOpISigned, SMin, intel::minimum) +__SYCL_SG_CALC_OVERLOAD(GroupOpIUnsigned, UMin, intel::minimum) +__SYCL_SG_CALC_OVERLOAD(GroupOpFP, FMin, intel::minimum) +__SYCL_SG_CALC_OVERLOAD(GroupOpISigned, SMax, intel::maximum) +__SYCL_SG_CALC_OVERLOAD(GroupOpIUnsigned, UMax, intel::maximum) +__SYCL_SG_CALC_OVERLOAD(GroupOpFP, FMax, intel::maximum) +__SYCL_SG_CALC_OVERLOAD(GroupOpISigned, IAdd, intel::plus) +__SYCL_SG_CALC_OVERLOAD(GroupOpIUnsigned, IAdd, intel::plus) +__SYCL_SG_CALC_OVERLOAD(GroupOpFP, FAdd, intel::plus) + +#undef __SYCL_SG_CALC_OVERLOAD + template class BinaryOperation> -static T calc(T x, BinaryOperation) { - return calc(x, BinaryOperation()); +static T calc(typename GroupOpTag::type, T x, BinaryOperation) { + return calc(typename GroupOpTag::type(), x, BinaryOperation()); } +} // namespace sub_group + } // namespace detail namespace intel { @@ -106,9 +200,7 @@ struct sub_group { id<1> get_group_id() const { return __spirv_BuiltInSubgroupId; } - unsigned int get_group_range() const { - return __spirv_BuiltInNumSubgroups; - } + unsigned int get_group_range() const { return __spirv_BuiltInNumSubgroups; } unsigned int get_uniform_group_range() const { return __spirv_BuiltInNumEnqueuedSubgroups; @@ -124,7 +216,6 @@ struct sub_group { return __spirv_GroupAll(__spv::Scope::Subgroup, predicate); } - template using EnableIfIsScalarArithmetic = detail::enable_if_t< !detail::is_vec::value && detail::is_arithmetic::value, T>; @@ -132,14 +223,14 @@ struct sub_group { /* --- collectives --- */ template - T broadcast(EnableIfIsScalarArithmetic x, id<1> local_id) const { - return __spirv_GroupBroadcast(__spv::Scope::Subgroup, x, - local_id.get(0)); + EnableIfIsScalarArithmetic broadcast(T x, id<1> local_id) const { + return detail::sub_group::broadcast(x, local_id); } template EnableIfIsScalarArithmetic reduce(T x, BinaryOperation op) const { - return detail::calc(x, op); + return detail::sub_group::calc( + typename detail::sub_group::GroupOpTag::type(), x, op); } template @@ -149,12 +240,13 @@ struct sub_group { template EnableIfIsScalarArithmetic exclusive_scan(T x, BinaryOperation op) const { - return detail::calc(x, op); + return detail::sub_group::calc( + typename detail::sub_group::GroupOpTag::type(), x, op); } template EnableIfIsScalarArithmetic exclusive_scan(T x, T init, - BinaryOperation op) const { + BinaryOperation op) const { if (get_local_id().get(0) == 0) { x = op(init, x); } @@ -167,7 +259,8 @@ struct sub_group { template EnableIfIsScalarArithmetic inclusive_scan(T x, BinaryOperation op) const { - return detail::calc(x, op); + return detail::sub_group::calc( + typename detail::sub_group::GroupOpTag::type(), x, op); } template @@ -179,197 +272,92 @@ struct sub_group { return inclusive_scan(x, op); } - /* --- one - input shuffles --- */ - /* indices in [0 , sub - group size ) */ + /* --- one-input shuffles --- */ + /* indices in [0 , sub_group size) */ template - EnableIfIsScalarArithmetic - shuffle(T x, id<1> local_id) const { - return __spirv_SubgroupShuffleINTEL(x, local_id.get(0)); + T shuffle(T x, id<1> local_id) const { + return detail::sub_group::shuffle(x, local_id); } - template - typename std::enable_if::value, T>::type - shuffle(T x, id<1> local_id) const { - return __spirv_SubgroupShuffleINTEL((typename T::vector_t)x, - local_id.get(0)); + template T shuffle_down(T x, uint32_t delta) const { + return detail::sub_group::shuffle_down(x, x, delta); } template - EnableIfIsScalarArithmetic - shuffle_down(T x, uint32_t delta) const { - return shuffle_down(x, x, delta); + T shuffle_up(T x, uint32_t delta) const { + return detail::sub_group::shuffle_up(x, x, delta); } template - typename std::enable_if::value, T>::type - shuffle_down(T x, uint32_t delta) const { - return shuffle_down(x, x, delta); + T shuffle_xor(T x, id<1> value) const { + return detail::sub_group::shuffle_xor(x, value); } - template - EnableIfIsScalarArithmetic - shuffle_up(T x, uint32_t delta) const { - return shuffle_up(x, x, delta); - } - - template - typename std::enable_if::value, T>::type - shuffle_up(T x, uint32_t delta) const { - return shuffle_up(x, x, delta); - } + /* --- two-input shuffles --- */ + /* indices in [0 , 2 * sub_group size) */ template - EnableIfIsScalarArithmetic - shuffle_xor(T x, id<1> value) const { - return __spirv_SubgroupShuffleXorINTEL(x, (uint32_t)value.get(0)); + T shuffle(T x, T y, id<1> local_id) const { + return detail::sub_group::shuffle_down(x, y, + (local_id - get_local_id()).get(0)); } template - typename std::enable_if::value, T>::type - shuffle_xor(T x, id<1> value) const { - return __spirv_SubgroupShuffleXorINTEL((typename T::vector_t)x, - (uint32_t)value.get(0)); + T shuffle_down(T current, T next, uint32_t delta) const { + return detail::sub_group::shuffle_down(current, next, delta); } - /* --- two - input shuffles --- */ - /* indices in [0 , 2* sub - group size ) */ template - EnableIfIsScalarArithmetic - shuffle(T x, T y, id<1> local_id) const { - return __spirv_SubgroupShuffleDownINTEL( - x, y, local_id.get(0) - get_local_id().get(0)); + T shuffle_up(T previous, T current, uint32_t delta) const { + return detail::sub_group::shuffle_up(previous, current, delta); } - template - typename std::enable_if::value, T>::type - shuffle(T x, T y, id<1> local_id) const { - return __spirv_SubgroupShuffleDownINTEL( - (typename T::vector_t)x, (typename T::vector_t)y, - local_id.get(0) - get_local_id().get(0)); - } - - template - EnableIfIsScalarArithmetic - shuffle_down(T current, T next, uint32_t delta) const { - return __spirv_SubgroupShuffleDownINTEL(current, next, delta); - } - - template - typename std::enable_if::value, T>::type - shuffle_down(T current, T next, uint32_t delta) const { - return __spirv_SubgroupShuffleDownINTEL( - (typename T::vector_t)current, (typename T::vector_t)next, delta); - } - - template - EnableIfIsScalarArithmetic - shuffle_up(T previous, T current, uint32_t delta) const { - return __spirv_SubgroupShuffleUpINTEL(previous, current, delta); - } - - template - typename std::enable_if::value, T>::type - shuffle_up(T previous, T current, uint32_t delta) const { - return __spirv_SubgroupShuffleUpINTEL( - (typename T::vector_t)previous, (typename T::vector_t)current, delta); - } - - /* --- sub - group load / stores --- */ - /* these can map to SIMD or block read / write hardware where available */ + /* --- sub_group load/stores --- */ + /* these can map to SIMD or block read/write hardware where available */ template - typename std::enable_if<(sizeof(T) == sizeof(uint32_t) || - sizeof(T) == sizeof(uint16_t) || - sizeof(T) == sizeof(uint8_t)) && - Space == access::address_space::global_space, - T>::type + detail::enable_if_t< + detail::sub_group::AcceptableForLoadStore::value, T> load(const multi_ptr src) const { - T data; - if (sizeof(T) == sizeof(uint32_t)) { - uint32_t t = __spirv_SubgroupBlockReadINTEL( - (const __attribute__((opencl_global)) uint32_t *)src.get()); - data = *((T *)(&t)); - } else if (sizeof(T) == sizeof(uint16_t)) { - uint16_t t = __spirv_SubgroupBlockReadINTEL( - (const __attribute__((opencl_global)) uint16_t *)src.get()); - data = *((T *)(&t)); - } else { - uint8_t t = __spirv_SubgroupBlockReadINTEL( - (const __attribute__((opencl_global)) uint8_t *)src.get()); - data = *((T *)(&t)); - } - return data; + return detail::sub_group::load(src); } template - vec::type, - N> + detail::enable_if_t< + detail::sub_group::AcceptableForLoadStore::value && N != 1, + vec> load(const multi_ptr src) const { - if (N == 1) { - return load(src); - } - if (sizeof(T) == sizeof(uint32_t)) { - typedef uint32_t ocl_t __attribute__((ext_vector_type(N))); - - ocl_t t = __spirv_SubgroupBlockReadINTEL( - (const __attribute__((opencl_global)) uint32_t *)src.get()); - return *((typename vec::vector_t *)(&t)); - } - typedef uint16_t ocl_t __attribute__((ext_vector_type(N))); + return detail::sub_group::load(src); + } - ocl_t t = __spirv_SubgroupBlockReadINTEL( - (const __attribute__((opencl_global)) uint16_t *)src.get()); - return *((typename vec::vector_t *)(&t)); + template + detail::enable_if_t< + detail::sub_group::AcceptableForLoadStore::value && N == 1, + vec> + load(const multi_ptr src) const { + return detail::sub_group::load(src); } template - void - store(multi_ptr dst, - const typename std::enable_if< - (sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint16_t) || - sizeof(T) == sizeof(uint8_t)) && - Space == access::address_space::global_space, - T>::type &x) const { - if (sizeof(T) == sizeof(uint32_t)) { - __spirv_SubgroupBlockWriteINTEL( - (__attribute__((opencl_global)) uint32_t *)dst.get(), *((uint32_t *)&x)); - } else if (sizeof(T) == sizeof(uint16_t)) { - __spirv_SubgroupBlockWriteINTEL( - (__attribute__((opencl_global)) uint16_t *)dst.get(), *((uint16_t *)&x)); - } else { - __spirv_SubgroupBlockWriteINTEL( - (__attribute__((opencl_global)) uint8_t *)dst.get(), *((uint8_t *)&x)); - } + detail::enable_if_t< + detail::sub_group::AcceptableForLoadStore::value> + store(multi_ptr dst, const T &x) const { + detail::sub_group::store(dst, x); } template - void store(multi_ptr dst, - const vec::type, N> &x) const { + detail::enable_if_t< + detail::sub_group::AcceptableForLoadStore::value && N == 1> + store(multi_ptr dst, const vec &x) const { store(dst, x); } template - void store( - multi_ptr dst, - const vec::type, - N> &x) const { - if (sizeof(T) == sizeof(uint32_t)) { - typedef uint32_t ocl_t __attribute__((ext_vector_type(N))); - __spirv_SubgroupBlockWriteINTEL((__attribute__((opencl_global)) uint32_t *)dst.get(), - *((ocl_t *)&x)); - } else { - typedef uint16_t ocl_t __attribute__((ext_vector_type(N))); - __spirv_SubgroupBlockWriteINTEL((__attribute__((opencl_global)) uint16_t *)dst.get(), - *((ocl_t *)&x)); - } + detail::enable_if_t< + detail::sub_group::AcceptableForLoadStore::value && N != 1> + store(multi_ptr dst, const vec &x) const { + detail::sub_group::store(dst, x); } /* --- synchronization functions --- */