diff --git a/sycl/include/sycl/ext/intel/experimental/bfloat16.hpp b/sycl/include/sycl/ext/intel/experimental/bfloat16.hpp index 34b1d255e6345..5a51f3746e225 100644 --- a/sycl/include/sycl/ext/intel/experimental/bfloat16.hpp +++ b/sycl/include/sycl/ext/intel/experimental/bfloat16.hpp @@ -9,6 +9,7 @@ #pragma once #include +#include __SYCL_INLINE_NAMESPACE(cl) { namespace sycl { @@ -43,8 +44,11 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 { #endif } - // Direct initialization - bfloat16(const storage_t &a) : value(a) {} + static bfloat16 from_bits(const storage_t &a) { + bfloat16 res; + res.value = a; + return res; + } // Implicit conversion from float to bfloat16 bfloat16(const float &a) { value = from_float(a); } @@ -56,9 +60,10 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 { // Implicit conversion from bfloat16 to float operator float() const { return to_float(value); } + operator sycl::half() const { return to_float(value); } // Get raw bits representation of bfloat16 - operator storage_t() const { return value; } + storage_t raw() const { return value; } // Logical operators (!,||,&&) are covered if we can cast to bool explicit operator bool() { return to_float(value) != 0.0f; } diff --git a/sycl/test/extensions/bfloat16.cpp b/sycl/test/extensions/bfloat16.cpp index b548d7ead5c43..f56217712f3d0 100644 --- a/sycl/test/extensions/bfloat16.cpp +++ b/sycl/test/extensions/bfloat16.cpp @@ -8,6 +8,7 @@ using sycl::ext::intel::experimental::bfloat16; SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y); +SYCL_EXTERNAL void foo(long x, sycl::half y); __attribute__((noinline)) float op(float a, float b) { // CHECK: define {{.*}} spir_func float @_Z2opff(float [[a:%.*]], float [[b:%.*]]) @@ -27,11 +28,22 @@ __attribute__((noinline)) float op(float a, float b) { // CHECK-NOT: uitofp // CHECK-NOT: fptoui - bfloat16 D = some_bf16_intrinsic(A, C); + bfloat16 D = bfloat16::from_bits(some_bf16_intrinsic(A.raw(), C.raw())); // CHECK: [[D:%.*]] = tail call spir_func zeroext i16 @_Z19some_bf16_intrinsictt(i16 zeroext [[A]], i16 zeroext [[C]]) // CHECK-NOT: uitofp // CHECK-NOT: fptoui + long L = bfloat16(3.14f); + // CHECK: [[L_bfloat16:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float 0x40091EB860000000) + // CHECK: [[L_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[L_bfloat16]]) + // CHECK: [[L:%.*]] = fptosi float [[L_float]] to i{{32|64}} + + sycl::half H = bfloat16(2.71f); + // CHECK: [[H_bfloat16:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float 0x4005AE1480000000) + // CHECK: [[H_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[H_bfloat16]]) + // CHECK: [[H:%.*]] = fptrunc float [[H_float]] to half + foo(L, H); + return D; // CHECK: [[RetVal:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[D]]) // CHECK: ret float [[RetVal]]