12
12
#include " src/__support/CPP/bit.h"
13
13
#include " src/__support/CPP/limits.h"
14
14
#include " src/__support/CPP/type_traits.h"
15
+ #include " src/__support/FPUtil/BasicOperations.h"
15
16
#include " src/__support/FPUtil/FPBits.h"
17
+ #include " src/__support/FPUtil/dyadic_float.h"
16
18
#include " src/__support/FPUtil/rounding_mode.h"
17
19
#include " src/__support/big_int.h"
18
20
#include " src/__support/macros/attributes.h" // LIBC_INLINE
@@ -106,20 +108,52 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
106
108
sizeof (OutType) <= sizeof (InType),
107
109
OutType>
108
110
fma (InType x, InType y, InType z) {
109
- using OutFPBits = fputil:: FPBits<OutType>;
111
+ using OutFPBits = FPBits<OutType>;
110
112
using OutStorageType = typename OutFPBits::StorageType;
111
- using InFPBits = fputil:: FPBits<InType>;
113
+ using InFPBits = FPBits<InType>;
112
114
using InStorageType = typename InFPBits::StorageType;
113
115
114
116
constexpr int IN_EXPLICIT_MANT_LEN = InFPBits::FRACTION_LEN + 1 ;
115
117
constexpr size_t PROD_LEN = 2 * IN_EXPLICIT_MANT_LEN;
116
118
constexpr size_t TMP_RESULT_LEN = cpp::bit_ceil (PROD_LEN + 1 );
117
119
using TmpResultType = UInt<TMP_RESULT_LEN>;
120
+ using DyadicFloat = DyadicFloat<TMP_RESULT_LEN>;
118
121
119
- constexpr size_t EXTRA_FRACTION_LEN =
120
- TMP_RESULT_LEN - 1 - OutFPBits::FRACTION_LEN;
121
- constexpr TmpResultType EXTRA_FRACTION_STICKY_MASK =
122
- (TmpResultType (1 ) << (EXTRA_FRACTION_LEN - 1 )) - 1 ;
122
+ InFPBits x_bits (x), y_bits (y), z_bits (z);
123
+
124
+ if (LIBC_UNLIKELY (x_bits.is_nan () || y_bits.is_nan () || z_bits.is_nan ())) {
125
+ if (x_bits.is_nan () || y_bits.is_nan ()) {
126
+ if (x_bits.is_signaling_nan () || y_bits.is_signaling_nan () ||
127
+ z_bits.is_signaling_nan ())
128
+ raise_except_if_required (FE_INVALID);
129
+
130
+ if (x_bits.is_quiet_nan ()) {
131
+ InStorageType x_payload = static_cast <InStorageType>(getpayload (x));
132
+ if ((x_payload & ~(OutFPBits::FRACTION_MASK >> 1 )) == 0 )
133
+ return OutFPBits::quiet_nan (x_bits.sign (),
134
+ static_cast <OutStorageType>(x_payload))
135
+ .get_val ();
136
+ }
137
+
138
+ if (y_bits.is_quiet_nan ()) {
139
+ InStorageType y_payload = static_cast <InStorageType>(getpayload (y));
140
+ if ((y_payload & ~(OutFPBits::FRACTION_MASK >> 1 )) == 0 )
141
+ return OutFPBits::quiet_nan (y_bits.sign (),
142
+ static_cast <OutStorageType>(y_payload))
143
+ .get_val ();
144
+ }
145
+
146
+ if (z_bits.is_quiet_nan ()) {
147
+ InStorageType z_payload = static_cast <InStorageType>(getpayload (z));
148
+ if ((z_payload & ~(OutFPBits::FRACTION_MASK >> 1 )) == 0 )
149
+ return OutFPBits::quiet_nan (z_bits.sign (),
150
+ static_cast <OutStorageType>(z_payload))
151
+ .get_val ();
152
+ }
153
+
154
+ return OutFPBits::quiet_nan ().get_val ();
155
+ }
156
+ }
123
157
124
158
if (LIBC_UNLIKELY (x == 0 || y == 0 || z == 0 ))
125
159
return static_cast <OutType>(x * y + z);
@@ -142,7 +176,9 @@ fma(InType x, InType y, InType z) {
142
176
z *= InType (InStorageType (1 ) << InFPBits::FRACTION_LEN);
143
177
}
144
178
145
- InFPBits x_bits (x), y_bits (y), z_bits (z);
179
+ x_bits = InFPBits (x);
180
+ y_bits = InFPBits (y);
181
+ z_bits = InFPBits (z);
146
182
const Sign z_sign = z_bits.sign ();
147
183
Sign prod_sign = (x_bits.sign () == y_bits.sign ()) ? Sign::POS : Sign::NEG;
148
184
x_exp += x_bits.get_biased_exponent ();
@@ -182,7 +218,6 @@ fma(InType x, InType y, InType z) {
182
218
constexpr int RESULT_MIN_LEN = PROD_LEN - InFPBits::FRACTION_LEN;
183
219
z_mant <<= RESULT_MIN_LEN;
184
220
int z_lsb_exp = z_exp - (InFPBits::FRACTION_LEN + RESULT_MIN_LEN);
185
- bool round_bit = false ;
186
221
bool sticky_bits = false ;
187
222
bool z_shifted = false ;
188
223
@@ -221,85 +256,18 @@ fma(InType x, InType y, InType z) {
221
256
}
222
257
}
223
258
224
- OutStorageType result = 0 ;
225
- int r_exp = 0 ; // Unbiased exponent of the result
226
-
227
- int round_mode = fputil::quick_get_round ();
228
-
229
- // Normalize the result.
230
- if (prod_mant != 0 ) {
231
- int lead_zeros = cpp::countl_zero (prod_mant);
232
- // Move the leading 1 to the most significant bit.
233
- prod_mant <<= lead_zeros;
234
- prod_lsb_exp -= lead_zeros;
235
- r_exp = prod_lsb_exp + (cpp::numeric_limits<TmpResultType>::digits - 1 ) -
236
- InFPBits::EXP_BIAS + OutFPBits::EXP_BIAS;
237
-
238
- if (r_exp > 0 ) {
239
- // The result is normal. We will shift the mantissa to the right by the
240
- // amount of extra bits compared to the length of the explicit mantissa in
241
- // the output type. The rounding bit then becomes the highest bit that is
242
- // shifted out, and the following lower bits are merged into sticky bits.
243
- round_bit =
244
- (prod_mant & (TmpResultType (1 ) << (EXTRA_FRACTION_LEN - 1 ))) != 0 ;
245
- sticky_bits |= (prod_mant & EXTRA_FRACTION_STICKY_MASK) != 0 ;
246
- result = static_cast <OutStorageType>(prod_mant >> EXTRA_FRACTION_LEN);
247
- } else {
248
- if (r_exp < -OutFPBits::FRACTION_LEN) {
249
- // The result is smaller than 1/2 of the smallest denormal number.
250
- sticky_bits = true ; // since the result is non-zero.
251
- result = 0 ;
252
- } else {
253
- // The result is denormal.
254
- TmpResultType mask = TmpResultType (1 ) << (EXTRA_FRACTION_LEN - r_exp);
255
- round_bit = (prod_mant & mask) != 0 ;
256
- sticky_bits |= (prod_mant & (mask - 1 )) != 0 ;
257
- if (r_exp > -OutFPBits::FRACTION_LEN)
258
- result = static_cast <OutStorageType>(
259
- prod_mant >> (EXTRA_FRACTION_LEN + 1 - r_exp));
260
- else
261
- result = 0 ;
262
- }
263
-
264
- r_exp = 0 ;
265
- }
266
- } else {
259
+ if (prod_mant == 0 ) {
267
260
// When there is exact cancellation, i.e., x*y == -z exactly, return -0.0 if
268
261
// rounding downward and +0.0 for other rounding modes.
269
- if (round_mode == FE_DOWNWARD)
262
+ if (quick_get_round () == FE_DOWNWARD)
270
263
prod_sign = Sign::NEG;
271
264
else
272
265
prod_sign = Sign::POS;
273
266
}
274
267
275
- // Finalize the result.
276
- if (LIBC_UNLIKELY (r_exp >= OutFPBits::MAX_BIASED_EXPONENT)) {
277
- if ((round_mode == FE_TOWARDZERO) ||
278
- (round_mode == FE_UPWARD && prod_sign.is_neg ()) ||
279
- (round_mode == FE_DOWNWARD && prod_sign.is_pos ())) {
280
- return OutFPBits::max_normal (prod_sign).get_val ();
281
- }
282
- return OutFPBits::inf (prod_sign).get_val ();
283
- }
284
-
285
- // Remove hidden bit and append the exponent field and sign bit.
286
- result = static_cast <OutStorageType>(
287
- (result & OutFPBits::FRACTION_MASK) |
288
- (static_cast <OutStorageType>(r_exp) << OutFPBits::FRACTION_LEN));
289
- if (prod_sign.is_neg ())
290
- result |= OutFPBits::SIGN_MASK;
291
-
292
- // Rounding.
293
- if (round_mode == FE_TONEAREST) {
294
- if (round_bit && (sticky_bits || ((result & 1 ) != 0 )))
295
- ++result;
296
- } else if ((round_mode == FE_UPWARD && prod_sign.is_pos ()) ||
297
- (round_mode == FE_DOWNWARD && prod_sign.is_neg ())) {
298
- if (round_bit || sticky_bits)
299
- ++result;
300
- }
301
-
302
- return cpp::bit_cast<OutType>(result);
268
+ DyadicFloat result (prod_sign, prod_lsb_exp - InFPBits::EXP_BIAS, prod_mant);
269
+ result.mantissa |= sticky_bits;
270
+ return result.template as <OutType, /* ShouldSignalExceptions=*/ true >();
303
271
}
304
272
305
273
} // namespace generic
0 commit comments