diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 2643cf35daaff..203ca0b9b9a3c 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -25,9 +25,21 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) { namespace ext { namespace oneapi { +class bfloat16; + +namespace detail { +using Bfloat16StorageT = uint16_t; +Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value); +bfloat16 bitsToBfloat16(const Bfloat16StorageT Value); +} // namespace detail + class bfloat16 { - using storage_t = uint16_t; - storage_t value; + detail::Bfloat16StorageT value; + + friend inline detail::Bfloat16StorageT + detail::bfloat16ToBits(const bfloat16 &Value); + friend inline bfloat16 + detail::bitsToBfloat16(const detail::Bfloat16StorageT Value); public: bfloat16() = default; @@ -36,7 +48,7 @@ class bfloat16 { private: // Explicit conversion functions - static storage_t from_float(const float &a) { + static detail::Bfloat16StorageT from_float(const float &a) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) #if (__CUDA_ARCH__ >= 800) @@ -72,7 +84,7 @@ class bfloat16 { #endif } - static float to_float(const storage_t &a) { + static float to_float(const detail::Bfloat16StorageT &a) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__) return __devicelib_ConvertBF16ToFINTEL(a); #else @@ -85,12 +97,6 @@ class bfloat16 { #endif } - static bfloat16 from_bits(const storage_t &a) { - bfloat16 res; - res.value = a; - return res; - } - public: // Implicit conversion from float to bfloat16 bfloat16(const float &a) { value = from_float(a); } @@ -122,7 +128,7 @@ class bfloat16 { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) #if (__CUDA_ARCH__ >= 800) - return from_bits(__nvvm_neg_bf16(lhs.value)); + return detail::bitsToBfloat16(__nvvm_neg_bf16(lhs.value)); #else return -to_float(lhs.value); #endif @@ -203,6 +209,23 @@ class bfloat16 { // for floating-point types. }; +namespace detail { + +// Helper function for getting the internal representation of a bfloat16. +inline Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value) { + return Value.value; +} + +// Helper function for creating a float16 from a value with the same type as the +// internal representation. +inline bfloat16 bitsToBfloat16(const Bfloat16StorageT Value) { + bfloat16 res; + res.value = Value; + return res; +} + +} // namespace detail + } // namespace oneapi } // namespace ext diff --git a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp index e69de29bb2d1d..8bce9d045eb59 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp @@ -0,0 +1,208 @@ +//==-------- bfloat16_math.hpp - SYCL bloat16 math functions ---------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext { +namespace oneapi { +namespace experimental { + +namespace detail { +template +uint32_t to_uint32_t(sycl::marray x, size_t start) { + uint32_t res; + std::memcpy(&res, &x[start], sizeof(uint32_t)); + return res; +} +} // namespace detail + +template +std::enable_if_t::value, T> fabs(T x) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits)); +#else + std::ignore = x; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +sycl::marray fabs(sycl::marray x) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::marray res; + + for (size_t i = 0; i < N / 2; i++) { + auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2)); + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); + } + + if (N % 2) { + oneapi::detail::Bfloat16StorageT XBits = + oneapi::detail::bfloat16ToBits(x[N - 1]); + res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits)); + } + return res; +#else + std::ignore = x; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +std::enable_if_t::value, T> fmin(T x, T y) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); + return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits)); +#else + std::ignore = x; + std::ignore = y; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +sycl::marray fmin(sycl::marray x, + sycl::marray y) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::marray res; + + for (size_t i = 0; i < N / 2; i++) { + auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2), + detail::to_uint32_t(y, i * 2)); + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); + } + + if (N % 2) { + oneapi::detail::Bfloat16StorageT XBits = + oneapi::detail::bfloat16ToBits(x[N - 1]); + oneapi::detail::Bfloat16StorageT YBits = + oneapi::detail::bfloat16ToBits(y[N - 1]); + res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits)); + } + + return res; +#else + std::ignore = x; + std::ignore = y; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +std::enable_if_t::value, T> fmax(T x, T y) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); + return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits)); +#else + std::ignore = x; + std::ignore = y; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +sycl::marray fmax(sycl::marray x, + sycl::marray y) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::marray res; + + for (size_t i = 0; i < N / 2; i++) { + auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2), + detail::to_uint32_t(y, i * 2)); + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); + } + + if (N % 2) { + oneapi::detail::Bfloat16StorageT XBits = + oneapi::detail::bfloat16ToBits(x[N - 1]); + oneapi::detail::Bfloat16StorageT YBits = + oneapi::detail::bfloat16ToBits(y[N - 1]); + res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits)); + } + return res; +#else + std::ignore = x; + std::ignore = y; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +std::enable_if_t::value, T> fma(T x, T y, T z) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); + oneapi::detail::Bfloat16StorageT ZBits = oneapi::detail::bfloat16ToBits(z); + return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits)); +#else + std::ignore = x; + std::ignore = y; + std::ignore = z; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +sycl::marray fma(sycl::marray x, + sycl::marray y, + sycl::marray z) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::marray res; + + for (size_t i = 0; i < N / 2; i++) { + auto partial_res = + __clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2), + detail::to_uint32_t(z, i * 2)); + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); + } + + if (N % 2) { + oneapi::detail::Bfloat16StorageT XBits = + oneapi::detail::bfloat16ToBits(x[N - 1]); + oneapi::detail::Bfloat16StorageT YBits = + oneapi::detail::bfloat16ToBits(y[N - 1]); + oneapi::detail::Bfloat16StorageT ZBits = + oneapi::detail::bfloat16ToBits(z[N - 1]); + res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits)); + } + return res; +#else + std::ignore = x; + std::ignore = y; + std::ignore = z; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +} // namespace experimental +} // namespace oneapi +} // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp index e676cf1013295..7a7105bf5519a 100755 --- a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp @@ -32,15 +32,6 @@ namespace ext { namespace oneapi { namespace experimental { -namespace detail { -template -uint32_t to_uint32_t(sycl::marray x, size_t start) { - uint32_t res; - std::memcpy(&res, &x[start], sizeof(uint32_t)); - return res; -} -} // namespace detail - // Provides functionality to print data from kernels in a C way: // - On non-host devices this function is directly mapped to printf from // OpenCL C diff --git a/sycl/include/sycl/ext/oneapi/sub_group.hpp b/sycl/include/sycl/ext/oneapi/sub_group.hpp index e0f33486749e5..3be48fb447d76 100644 --- a/sycl/include/sycl/ext/oneapi/sub_group.hpp +++ b/sycl/include/sycl/ext/oneapi/sub_group.hpp @@ -205,8 +205,7 @@ struct sub_group { template using EnableIfIsScalarArithmetic = - sycl::detail::enable_if_t::value, - T>; + std::enable_if_t::value, T>; /* --- one-input shuffles --- */ /* indices in [0 , sub_group size) */ @@ -260,7 +259,7 @@ struct sub_group { #ifdef __SYCL_DEVICE_ONLY__ // Method for decorated pointer template > - detail::enable_if_t, T>::value, T> + std::enable_if_t, T>::value, T> load(CVT *cv_src) const { T *src = const_cast(cv_src); return load(sycl::multi_ptr, @@ -270,7 +269,7 @@ struct sub_group { // Method for raw pointer template > - detail::enable_if_t, T>::value, T> + std::enable_if_t, T>::value, T> load(CVT *cv_src) const { T *src = const_cast(cv_src); @@ -300,10 +299,11 @@ struct sub_group { template > - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value, T> load(const multi_ptr cv_src) const { - multi_ptr src = detail::GetUnqualMultiPtr(cv_src); + multi_ptr src = + sycl::detail::GetUnqualMultiPtr(cv_src); #ifdef __SYCL_DEVICE_ONLY__ #ifdef __NVPTX__ return src.get()[get_local_id()[0]]; @@ -319,10 +319,11 @@ struct sub_group { template > - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForLocalLoadStore::value, T> load(const multi_ptr cv_src) const { - multi_ptr src = detail::GetUnqualMultiPtr(cv_src); + multi_ptr src = + sycl::detail::GetUnqualMultiPtr(cv_src); #ifdef __SYCL_DEVICE_ONLY__ return src.get()[get_local_id()[0]]; #else @@ -335,11 +336,12 @@ struct sub_group { #ifdef __NVPTX__ template > - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value, vec> load(const multi_ptr cv_src) const { - multi_ptr src = detail::GetUnqualMultiPtr(cv_src); + multi_ptr src = + sycl::detail::GetUnqualMultiPtr(cv_src); vec res; for (int i = 0; i < N; ++i) { res[i] = *(src.get() + i * get_max_local_range()[0] + get_local_id()[0]); @@ -349,23 +351,25 @@ struct sub_group { #else // __NVPTX__ template > - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && N != 1 && N != 3 && N != 16, vec> load(const multi_ptr cv_src) const { - multi_ptr src = detail::GetUnqualMultiPtr(cv_src); + multi_ptr src = + sycl::detail::GetUnqualMultiPtr(cv_src); return sycl::detail::sub_group::load(src); } template > - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && N == 16, vec> load(const multi_ptr cv_src) const { - multi_ptr src = detail::GetUnqualMultiPtr(cv_src); + multi_ptr src = + sycl::detail::GetUnqualMultiPtr(cv_src); return {sycl::detail::sub_group::load<8, T>(src), sycl::detail::sub_group::load<8, T>(src + 8 * get_max_local_range()[0])}; @@ -373,12 +377,13 @@ struct sub_group { template > - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && N == 3, vec> load(const multi_ptr cv_src) const { - multi_ptr src = detail::GetUnqualMultiPtr(cv_src); + multi_ptr src = + sycl::detail::GetUnqualMultiPtr(cv_src); return { sycl::detail::sub_group::load<1, T>(src), sycl::detail::sub_group::load<2, T>(src + get_max_local_range()[0])}; @@ -386,19 +391,20 @@ struct sub_group { template > - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && N == 1, vec> load(const multi_ptr cv_src) const { - multi_ptr src = detail::GetUnqualMultiPtr(cv_src); + multi_ptr src = + sycl::detail::GetUnqualMultiPtr(cv_src); return sycl::detail::sub_group::load(src); } #endif // ___NVPTX___ #else // __SYCL_DEVICE_ONLY__ template > - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value, vec> load(const multi_ptr src) const { @@ -410,11 +416,12 @@ struct sub_group { template > - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForLocalLoadStore::value, vec> load(const multi_ptr cv_src) const { - multi_ptr src = detail::GetUnqualMultiPtr(cv_src); + multi_ptr src = + sycl::detail::GetUnqualMultiPtr(cv_src); #ifdef __SYCL_DEVICE_ONLY__ vec res; for (int i = 0; i < N; ++i) { @@ -431,7 +438,7 @@ struct sub_group { #ifdef __SYCL_DEVICE_ONLY__ // Method for decorated pointer template - detail::enable_if_t, T>::value> + std::enable_if_t, T>::value> store(T *dst, const remove_decoration_t &x) const { store(sycl::multi_ptr, sycl::detail::deduce_AS::value, @@ -441,7 +448,7 @@ struct sub_group { // Method for raw pointer template - detail::enable_if_t, T>::value> + std::enable_if_t, T>::value> store(T *dst, const remove_decoration_t &x) const { #ifdef __NVPTX__ @@ -475,7 +482,7 @@ struct sub_group { template - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value> store(multi_ptr dst, const T &x) const { #ifdef __SYCL_DEVICE_ONLY__ @@ -494,7 +501,7 @@ struct sub_group { template - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForLocalLoadStore::value> store(multi_ptr dst, const T &x) const { #ifdef __SYCL_DEVICE_ONLY__ @@ -511,7 +518,7 @@ struct sub_group { #ifdef __NVPTX__ template - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value> store(multi_ptr dst, const vec &x) const { for (int i = 0; i < N; ++i) { @@ -521,7 +528,7 @@ struct sub_group { #else // __NVPTX__ template - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && N != 1 && N != 3 && N != 16> store(multi_ptr dst, const vec &x) const { @@ -530,7 +537,7 @@ struct sub_group { template - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && N == 1> store(multi_ptr dst, const vec &x) const { @@ -539,7 +546,7 @@ struct sub_group { template - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && N == 3> store(multi_ptr dst, const vec &x) const { @@ -550,7 +557,7 @@ struct sub_group { template - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && N == 16> store(multi_ptr dst, const vec &x) const { @@ -563,7 +570,7 @@ struct sub_group { #else // __SYCL_DEVICE_ONLY__ template - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value> store(multi_ptr dst, const vec &x) const { (void)dst; @@ -575,7 +582,7 @@ struct sub_group { template - sycl::detail::enable_if_t< + std::enable_if_t< sycl::detail::sub_group::AcceptableForLocalLoadStore::value> store(multi_ptr dst, const vec &x) const { #ifdef __SYCL_DEVICE_ONLY__