diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 48c8c045a8..c427dc99ab 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -90,7 +90,7 @@ from dpctl.tensor._usmarray import usm_ndarray from ._constants import e, inf, nan, newaxis, pi -from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan +from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan, sqrt __all__ = [ "Device", @@ -171,4 +171,5 @@ "isinf", "isnan", "isfinite", + "sqrt", ] diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 566fea9658..97b873773b 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -71,3 +71,13 @@ isinf = UnaryElementwiseFunc( "isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_ ) + +# SQRT + +_sqrt_docstring_ = """ +Computes sqrt for each element `x_i` for input array `x`. +""" + +sqrt = UnaryElementwiseFunc( + "sqrt", ti._sqrt_result_type, ti._sqrt, _sqrt_docstring_ +) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp new file mode 100644 index 0000000000..719670cea0 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -0,0 +1,207 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "kernels/elementwise_functions/common.hpp" + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace sqrt +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::type_utils::is_complex; + +template struct SqrtFunctor +{ + + // is function constant for given argT + using is_constant = typename std::false_type; + // constant value, if constant + // constexpr resT constant_value = resT{}; + // is function defined for sycl::vec + using supports_vec = typename std::false_type; + // do both argTy and resTy support sugroup store/load operation + using supports_sg_loadstore = typename std::negation< + std::disjunction, is_complex>>; + + resT operator()(const argT &in) + { + return std::sqrt(in); + } +}; + +template +using SqrtContigFunctor = elementwise_common:: + UnaryContigFunctor, vec_sz, n_vecs>; + +template +using SqrtStridedFunctor = elementwise_common:: + UnaryStridedFunctor>; + +template struct SqrtOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, std::complex>, + td_ns::TypeMapEntry, std::complex>, + td_ns::DefaultEntry>::result_type; +}; + +typedef sycl::event (*sqrt_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + char *, + const std::vector &); + +template +class sqrt_contig_kernel; + +template +sycl::event sqrt_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg_p, + char *res_p, + const std::vector &depends = {}) +{ + sycl::event sqrt_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + constexpr size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + static_assert(lws % vec_sz == 0); + auto gws_range = sycl::range<1>( + ((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) * + lws); + auto lws_range = sycl::range<1>(lws); + + using resTy = typename SqrtOutputType::value_type; + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for< + class sqrt_contig_kernel>( + sycl::nd_range<1>(gws_range, lws_range), + SqrtContigFunctor(arg_tp, res_tp, + nelems)); + }); + return sqrt_ev; +} + +template struct SqrtContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = sqrt_contig_impl; + return fn; + } + } +}; + +template struct SqrtTypeMapFactory +{ + /*! @brief get typeid for output type of std::sqrt(T x) */ + std::enable_if_t::value, int> get() + { + using rT = typename SqrtOutputType::value_type; + ; + return td_ns::GetTypeid{}.get(); + } +}; + +template class sqrt_strided_kernel; + +typedef sycl::event (*sqrt_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event +sqrt_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename SqrtOutputType::value_type; + using IndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides); + + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + sycl::range<1> gRange{nelems}; + + cgh.parallel_for>( + gRange, SqrtStridedFunctor( + arg_tp, res_tp, arg_res_indexer)); + }); + return comp_ev; +} + +template struct SqrtStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = sqrt_strided_impl; + return fn; + } + } +}; + +} // namespace sqrt +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 65b2b1d9a1..acae73541b 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -38,6 +38,7 @@ #include "kernels/elementwise_functions/isfinite.hpp" #include "kernels/elementwise_functions/isinf.hpp" #include "kernels/elementwise_functions/isnan.hpp" +#include "kernels/elementwise_functions/sqrt.hpp" namespace dpctl { @@ -325,6 +326,43 @@ void populate_add_dispatch_tables(void) } // namespace impl +// SQRT +namespace impl +{ + +namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt; +using sqrt_fn_ns::sqrt_contig_impl_fn_ptr_t; +using sqrt_fn_ns::sqrt_strided_impl_fn_ptr_t; + +static sqrt_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types]; +static int sqrt_output_typeid_vector[td_ns::num_types]; +static sqrt_strided_impl_fn_ptr_t + sqrt_strided_dispatch_vector[td_ns::num_types]; + +void populate_sqrt_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = sqrt_fn_ns; + + using fn_ns::SqrtContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector); + + using fn_ns::SqrtStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector); + + using fn_ns::SqrtTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(sqrt_output_typeid_vector); +} + +} // namespace impl + namespace py = pybind11; void init_elementwise_functions(py::module_ m) @@ -628,7 +666,26 @@ void init_elementwise_functions(py::module_ m) // FIXME: // U33: ==== SQRT (x) - // FIXME: + { + impl::populate_sqrt_dispatch_vectors(); + using impl::sqrt_contig_dispatch_vector; + using impl::sqrt_output_typeid_vector; + using impl::sqrt_strided_dispatch_vector; + + auto sqrt_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, sqrt_output_typeid_vector, + sqrt_contig_dispatch_vector, sqrt_strided_dispatch_vector); + }; + m.def("_sqrt", sqrt_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sqrt_result_type_pyapi = [&](py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, sqrt_output_typeid_vector); + }; + m.def("_sqrt_result_type", sqrt_result_type_pyapi); + } // B23: ==== SUBTRACT (x1, x2) // FIXME: diff --git a/dpctl/tests/test_tensor_sqrt.py b/dpctl/tests/test_tensor_sqrt.py new file mode 100644 index 0000000000..2f924027dc --- /dev/null +++ b/dpctl/tests/test_tensor_sqrt.py @@ -0,0 +1,133 @@ +import itertools + +import numpy as np +import pytest +from numpy.testing import assert_equal + +import dpctl.tensor as dpt +import dpctl.tensor._type_utils as tu +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +_all_dtypes = [ + "b1", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] +_usm_types = ["device", "shared", "host"] + + +def _map_to_device_dtype(dt, dev): + return tu._to_device_supported_dtype(dt, dev) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_sqrt_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + expected_dtype = np.sqrt(np.array(0, dtype=dtype)).dtype + expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) + assert dpt.sqrt(X).dtype == expected_dtype + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) +def test_sqrt_output_contig(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n_seq = 1027 + + X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q) + Xnp = dpt.asnumpy(X) + + Y = dpt.sqrt(X) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) +def test_sqrt_output_strided(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n_seq = 2054 + + X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q)[::-2] + Xnp = dpt.asnumpy(X) + + Y = dpt.sqrt(X) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol) + + +@pytest.mark.parametrize("usm_type", _usm_types) +def test_sqrt_usm_type(usm_type): + q = get_queue_or_skip() + + arg_dt = np.dtype("f4") + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) + X[..., 0::2] = 16.0 + X[..., 1::2] = 23.0 + + Y = dpt.sqrt(X) + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == X.sycl_queue + assert Y.flags.c_contiguous + + expected_Y = np.empty(input_shape, dtype=arg_dt) + expected_Y[..., 0::2] = np.sqrt(np.float32(16.0)) + expected_Y[..., 1::2] = np.sqrt(np.float32(23.0)) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_sqrt_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + X[..., 0::2] = 16.0 + X[..., 1::2] = 23.0 + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) + Y = dpt.sqrt(U, order=ord) + expected_Y = np.sqrt(dpt.asnumpy(U)) + tol = 8 * max( + dpt.finfo(Y.dtype).resolution, + np.finfo(expected_Y.dtype).resolution, + ) + np.testing.assert_allclose( + dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol + ) + + +def test_sqrt_special_cases(): + q = get_queue_or_skip() + + X = dpt.asarray( + [dpt.nan, -1.0, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q + ) + Xnp = dpt.asnumpy(X) + + assert_equal(dpt.asnumpy(dpt.sqrt(X)), np.sqrt(Xnp))