diff --git a/sycl/include/sycl/ext/intel/math.hpp b/sycl/include/sycl/ext/intel/math.hpp index d0c74a8319c7f..8a6913945083b 100644 --- a/sycl/include/sycl/ext/intel/math.hpp +++ b/sycl/include/sycl/ext/intel/math.hpp @@ -36,6 +36,9 @@ double __imf_floor(double); _iml_half_internal __imf_floorf16(_iml_half_internal); float __imf_rintf(float); double __imf_rint(double); +_iml_half_internal __imf_invf16(_iml_half_internal); +float __imf_invf(float); +double __imf_inv(double); _iml_half_internal __imf_rintf16(_iml_half_internal); float __imf_sqrtf(float); double __imf_sqrt(double); @@ -118,6 +121,24 @@ sycl::half2 floor(sycl::half2 x) { return sycl::half2{floor(x.s0()), floor(x.s1())}; } +template +std::enable_if_t, float> inv(Tp x) { + return __imf_invf(x); +} + +template +std::enable_if_t, double> inv(Tp x) { + return __imf_inv(x); +} + +template +std::enable_if_t, sycl::half> inv(Tp x) { + _iml_half_internal xi = __builtin_bit_cast(_iml_half_internal, x); + return __builtin_bit_cast(sycl::half, __imf_invf16(xi)); +} + +sycl::half2 inv(sycl::half2 x) { return sycl::half2{inv(x.s0()), inv(x.s1())}; } + template std::enable_if_t, float> rint(Tp x) { return __imf_rintf(x);