@@ -232,6 +232,10 @@ extern "C" __DPCPP_SYCL_EXTERNAL void
232
232
__devicelib_ConvertBF16ToFINTELVec16 (const uint16_t *, float *) noexcept ;
233
233
#endif
234
234
235
+ // / \brief Converts a vector of bfloat16 to a vector of floats.
236
+ // / \tparam N The size of the vector. Supported sizes are 1, 2, 3, 4, 8, and 16.
237
+ // / \param src The source vector of bfloat16.
238
+ // / \param dst The destination vector of floats.
235
239
template <int N>
236
240
inline void BF16VecToFloatVec (const bfloat16 src[N], float dst[N]) {
237
241
static_assert (N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16 ,
@@ -273,6 +277,10 @@ extern "C" __DPCPP_SYCL_EXTERNAL void
273
277
__devicelib_ConvertFToBF16INTELVec16 (const float *, uint16_t *) noexcept ;
274
278
#endif
275
279
280
+ // / \brief Converts a vector of floats to a vector of bfloat16.
281
+ // / \tparam N The size of the vector.
282
+ // / \param src The source vector of floats.
283
+ // / \param dst The destination vector of bfloat16.
276
284
template <int N> inline void FloatVecToBF16Vec (float src[N], bfloat16 dst[N]) {
277
285
static_assert (N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16 ,
278
286
" Unsupported vector size" );
@@ -292,8 +300,8 @@ template <int N> inline void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) {
292
300
__devicelib_ConvertFToBF16INTELVec16 (src, dst_i16);
293
301
#else
294
302
for (int i = 0 ; i < N; ++i) {
295
- // No need to cast as bfloat16 has a assignment op overload that takes
296
- // a float.
303
+ // No need to cast as bfloat16 has an assignment operator overload that
304
+ // takes a float.
297
305
dst[i] = src[i];
298
306
}
299
307
#endif
@@ -450,10 +458,10 @@ template <typename Ty> inline size_t get_msb_pos(const Ty &x) {
450
458
return (sizeof (Ty) * 8 - 1 - idx);
451
459
}
452
460
453
- // Helper function to get BF16 from unsigned integral data types
454
- // with different rounding modes.
455
- // Reference:
456
- // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L302
461
+ // Helper function to get BF16 from unsigned integral data types
462
+ // with different rounding modes.
463
+ // Reference:
464
+ // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L302
457
465
template <typename T>
458
466
inline bfloat16
459
467
getBFloat16FromUIntegralWithRoundingMode (T &u, SYCLRoundingMode roundingMode) {
@@ -505,9 +513,9 @@ getBFloat16FromUIntegralWithRoundingMode(T &u, SYCLRoundingMode roundingMode) {
505
513
return bit_cast<bfloat16, uint16_t >((b_exp << 7 ) | b_mant);
506
514
}
507
515
508
- // Helper function to get BF16 from signed integral data types.
509
- // Reference:
510
- // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L353
516
+ // Helper function to get BF16 from signed integral data types.
517
+ // Reference:
518
+ // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L353
511
519
template <typename T>
512
520
inline bfloat16
513
521
getBFloat16FromSIntegralWithRoundingMode (T &i, SYCLRoundingMode roundingMode) {
@@ -557,6 +565,10 @@ getBFloat16FromSIntegralWithRoundingMode(T &i, SYCLRoundingMode roundingMode) {
557
565
return bit_cast<bfloat16, uint16_t >(b_sign | (b_exp << 7 ) | b_mant);
558
566
}
559
567
568
+ // / \brief Converts a given value to bfloat16 with a specified rounding mode.
569
+ // / \tparam rm The rounding mode to be used for conversion.
570
+ // / \param a The input value to be converted.
571
+ // / \return The converted bfloat16 value.
560
572
template <typename Ty, int rm>
561
573
inline bfloat16 getBfloat16WithRoundingMode (const Ty &a) {
562
574
if (a == 0 )
0 commit comments