diff --git a/dpctl/tensor/libtensor/include/kernels/constructors.hpp b/dpctl/tensor/libtensor/include/kernels/constructors.hpp index 9b77b47a84..999ddec50d 100644 --- a/dpctl/tensor/libtensor/include/kernels/constructors.hpp +++ b/dpctl/tensor/libtensor/include/kernels/constructors.hpp @@ -230,13 +230,28 @@ template class LinearSequenceAffineFunctor wTy w = wTy(n - i) / n; using dpctl::tensor::type_utils::is_complex; if constexpr (is_complex::value) { - auto _w = static_cast(w); - auto _wc = static_cast(wc); - auto re_comb = start_v.real() * _w + end_v.real() * _wc; - auto im_comb = start_v.imag() * _w + end_v.imag() * _wc; + using reT = typename Ty::value_type; + auto _w = static_cast(w); + auto _wc = static_cast(wc); + auto re_comb = sycl::fma(start_v.real(), _w, reT(0)); + re_comb = + sycl::fma(end_v.real(), _wc, + re_comb); // start_v.real() * _w + end_v.real() * _wc; + auto im_comb = + sycl::fma(start_v.imag(), _w, + reT(0)); // start_v.imag() * _w + end_v.imag() * _wc; + im_comb = sycl::fma(end_v.imag(), _wc, im_comb); Ty affine_comb = Ty{re_comb, im_comb}; p[i] = affine_comb; } + else if constexpr (std::is_floating_point::value) { + Ty _w = static_cast(w); + Ty _wc = static_cast(wc); + auto affine_comb = + sycl::fma(start_v, _w, Ty(0)); // start_v * w + end_v * wc; + affine_comb = sycl::fma(end_v, _wc, affine_comb); + p[i] = affine_comb; + } else { using dpctl::tensor::type_utils::convert_impl; auto affine_comb = start_v * w + end_v * wc; diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 0b1764d86b..46f4a23ca9 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1168,6 +1168,21 @@ def test_linspace_fp(): assert X.strides == (1,) +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_linspace_fp_max(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + n = 16 + dt = dpt.dtype(dtype) + max_ = dpt.finfo(dt).max + X = dpt.linspace(max_, max_, endpoint=True, num=n, dtype=dt, sycl_queue=q) + assert X.shape == (n,) + assert X.strides == (1,) + assert np.allclose( + dpt.asnumpy(X), np.linspace(max_, max_, endpoint=True, num=n, dtype=dt) + ) + + @pytest.mark.parametrize( "dt", _all_dtypes,