diff --git a/sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp index ed21bf39201bf..3c97bda5b4e90 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp @@ -11,6 +11,10 @@ #include #include +#if !defined(__SYCL_DEVICE_ONLY__) +#include +#endif + namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { namespace ext { @@ -35,9 +39,17 @@ class bfloat16 { return __spirv_ConvertFToBF16INTEL(a); #endif #else - (void)a; - throw exception{errc::feature_not_supported, - "Bfloat16 conversion is not supported on host device"}; + // In case of float value is nan - propagate bfloat16's qnan + if (std::isnan(a)) + return 0xffc1; + union { + uint32_t intStorage; + float floatValue; + }; + floatValue = a; + // Do RNE and truncate + uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF; + return static_cast((intStorage + roundingBias) >> 16); #endif } static float to_float(const storage_t &a) { @@ -51,9 +63,10 @@ class bfloat16 { return __spirv_ConvertBF16ToFINTEL(a); #endif #else - (void)a; - throw exception{errc::feature_not_supported, - "Bfloat16 conversion is not supported on host device"}; + // Shift temporary variable to silence the warning + uint32_t bits = a; + bits <<= 16; + return static_cast(bits); #endif } diff --git a/sycl/test/extensions/bfloat16_host.cpp b/sycl/test/extensions/bfloat16_host.cpp new file mode 100644 index 0000000000000..e3cfb71abb558 --- /dev/null +++ b/sycl/test/extensions/bfloat16_host.cpp @@ -0,0 +1,88 @@ +//==------------ bfloat16_host.cpp - SYCL vectors test ---------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %RUN_ON_HOST %t.out +#include +#include + +#include +#include +#include +#include +#include + +// Helper to convert the expected bits to float value to compare with the result +typedef union { + float Value; + struct { + uint32_t Mantissa : 23; + uint32_t Exponent : 8; + uint32_t Sign : 1; + } RawData; +} floatConvHelper; + +float bitsToFloatConv(std::string Bits) { + floatConvHelper Helper; + Helper.RawData.Sign = static_cast(Bits[0] - '0'); + uint32_t Exponent = 0; + for (size_t I = 1; I != 9; ++I) + Exponent = Exponent + static_cast(Bits[I] - '0') * pow(2, 8 - I); + Helper.RawData.Exponent = Exponent; + uint32_t Mantissa = 0; + for (size_t I = 9; I != 32; ++I) + Mantissa = Mantissa + static_cast(Bits[I] - '0') * pow(2, 31 - I); + Helper.RawData.Mantissa = Mantissa; + return Helper.Value; +} + +bool check_bf16_from_float(float Val, uint16_t Expected) { + uint16_t Result = sycl::ext::oneapi::experimental::bfloat16::from_float(Val); + if (Result != Expected) { + std::cout << "from_float check for Val = " << Val << " failed!\n" + << "Expected " << Expected << " Got " << Result << "\n"; + return false; + } + return true; +} + +bool check_bf16_to_float(uint16_t Val, float Expected) { + float Result = sycl::ext::oneapi::experimental::bfloat16::to_float(Val); + if (Result != Expected) { + std::cout << "to_float check for Val = " << Val << " failed!\n" + << "Expected " << Expected << " Got " << Result << "\n"; + return false; + } + return true; +} + +int main() { + bool Success = + check_bf16_from_float(0.0f, std::stoi("0000000000000000", nullptr, 2)); + Success &= + check_bf16_from_float(42.0f, std::stoi("100001000101000", nullptr, 2)); + Success &= check_bf16_from_float(std::numeric_limits::min(), + std::stoi("0000000010000000", nullptr, 2)); + Success &= check_bf16_from_float(std::numeric_limits::max(), + std::stoi("0111111110000000", nullptr, 2)); + Success &= check_bf16_from_float(std::numeric_limits::quiet_NaN(), + std::stoi("1111111111000001", nullptr, 2)); + + Success &= check_bf16_to_float( + 0, bitsToFloatConv(std::string("00000000000000000000000000000000"))); + Success &= check_bf16_to_float( + 1, bitsToFloatConv(std::string("01000111100000000000000000000000"))); + Success &= check_bf16_to_float( + 42, bitsToFloatConv(std::string("01001010001010000000000000000000"))); + Success &= check_bf16_to_float( + std::numeric_limits::max(), + bitsToFloatConv(std::string("01001111011111111111111100000000"))); + if (!Success) + return -1; + return 0; +}