diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index a0a1a0b4df..98fe465add 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -126,6 +126,7 @@ pow, proj, real, + round, sin, sqrt, square, @@ -243,6 +244,7 @@ "pow", "proj", "real", + "round", "sin", "sqrt", "square", diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 06a5080a5e..2fd00c01c2 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -956,7 +956,30 @@ # FIXME: implement B22 # U28: ==== ROUND (x) -# FIXME: implement U28 +_round_docstring = """ +round(x, out=None, order='K') + +Rounds each element `x_i` of the input array `x` to +the nearest integer-valued number. + +Args: + x (usm_ndarray): + Input array, expected to have numeric data type. + out ({None, usm_ndarray}, optional): + Output array to populate. + Array have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the newly output array, if parameter `out` is `None`. + Default: "K". +Returns: + usm_narray: + An array containing the element-wise rounded value. The data type + of the returned array is determined by the Type Promotion Rules. +""" + +round = UnaryElementwiseFunc( + "round", ti._round_result_type, ti._round, _round_docstring +) # U29: ==== SIGN (x) # FIXME: implement U29 diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp new file mode 100644 index 0000000000..ac61829a1d --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -0,0 +1,205 @@ +//=== round.hpp - Unary function ROUND ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of ROUND(x) function. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include +#include + +#include "kernels/elementwise_functions/common.hpp" + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace round +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::type_utils::is_complex; + +template struct RoundFunctor +{ + + // is function constant for given argT + using is_constant = typename std::false_type; + // constant value, if constant + // constexpr resT constant_value = resT{}; + // is function defined for sycl::vec + using supports_vec = typename std::false_type; + // do both argTy and resTy support sugroup store/load operation + using supports_sg_loadstore = typename std::negation< + std::disjunction, is_complex>>; + + resT operator()(const argT &in) + { + + if constexpr (std::is_integral_v) { + return in; + } + else if constexpr (is_complex::value) { + using realT = typename argT::value_type; + return resT{round_func(std::real(in)), + round_func(std::imag(in))}; + } + else { + return round_func(in); + } + } + +private: + template T round_func(const T &input) const + { + return std::rint(input); + } +}; + +template +using RoundContigFunctor = + elementwise_common::UnaryContigFunctor, + vec_sz, + n_vecs>; + +template +using RoundStridedFunctor = elementwise_common:: + UnaryStridedFunctor>; + +template struct RoundOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::DefaultResultEntry>::result_type; +}; + +template +class round_contig_kernel; + +template +sycl::event round_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg_p, + char *res_p, + const std::vector &depends = {}) +{ + return elementwise_common::unary_contig_impl< + argTy, RoundOutputType, RoundContigFunctor, round_contig_kernel>( + exec_q, nelems, arg_p, res_p, depends); +} + +template struct RoundContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = round_contig_impl; + return fn; + } + } +}; + +template struct RoundTypeMapFactory +{ + /*! @brief get typeid for output type of sycl::round(T x) */ + std::enable_if_t::value, int> get() + { + using rT = typename RoundOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template class round_strided_kernel; + +template +sycl::event +round_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::unary_strided_impl< + argTy, RoundOutputType, RoundStridedFunctor, round_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, + res_offset, depends, additional_depends); +} + +template struct RoundStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = round_strided_impl; + return fn; + } + } +}; + +} // namespace round +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 7246236d37..6491e6a294 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -65,6 +65,7 @@ #include "kernels/elementwise_functions/pow.hpp" #include "kernels/elementwise_functions/proj.hpp" #include "kernels/elementwise_functions/real.hpp" +#include "kernels/elementwise_functions/round.hpp" #include "kernels/elementwise_functions/sin.hpp" #include "kernels/elementwise_functions/sqrt.hpp" #include "kernels/elementwise_functions/square.hpp" @@ -1627,7 +1628,37 @@ namespace impl // U28: ==== ROUND (x) namespace impl { -// FIXME: add code for U28 + +namespace round_fn_ns = dpctl::tensor::kernels::round; + +static unary_contig_impl_fn_ptr_t + round_contig_dispatch_vector[td_ns::num_types]; +static int round_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + round_strided_dispatch_vector[td_ns::num_types]; + +void populate_round_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = round_fn_ns; + + using fn_ns::RoundContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(round_contig_dispatch_vector); + + using fn_ns::RoundStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(round_strided_dispatch_vector); + + using fn_ns::RoundTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(round_output_typeid_vector); +} + } // namespace impl // U29: ==== SIGN (x) @@ -3029,7 +3060,27 @@ void init_elementwise_functions(py::module_ m) // FIXME: // U28: ==== ROUND (x) - // FIXME: + { + impl::populate_round_dispatch_vectors(); + using impl::round_contig_dispatch_vector; + using impl::round_output_typeid_vector; + using impl::round_strided_dispatch_vector; + + auto round_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, round_output_typeid_vector, + round_contig_dispatch_vector, round_strided_dispatch_vector); + }; + m.def("_round", round_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto round_result_type_pyapi = [&](py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, + round_output_typeid_vector); + }; + m.def("_round_result_type", round_result_type_pyapi); + } // U29: ==== SIGN (x) // FIXME: diff --git a/dpctl/tests/elementwise/test_round.py b/dpctl/tests/elementwise/test_round.py new file mode 100644 index 0000000000..fb2b104bb1 --- /dev/null +++ b/dpctl/tests/elementwise/test_round.py @@ -0,0 +1,222 @@ +import itertools + +import numpy as np +import pytest +from numpy.testing import assert_allclose, assert_array_equal + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes, _map_to_device_dtype, _usm_types + + +@pytest.mark.parametrize("dtype", _all_dtypes[1:]) +def test_round_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0.1, dtype=dtype, sycl_queue=q) + expected_dtype = np.round(np.array(0, dtype=dtype)).dtype + expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) + assert dpt.round(X).dtype == expected_dtype + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_round_real_contig(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n_seq = 100 + n_rep = 137 + Xnp = np.linspace(0.01, 88.1, num=n_seq, dtype=dtype) + X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q) + Y = dpt.round(X) + Ynp = np.round(Xnp) + + tol = 8 * dpt.finfo(dtype).resolution + assert_allclose(dpt.asnumpy(Y), np.repeat(Ynp, n_rep), atol=tol, rtol=tol) + + Z = dpt.empty_like(X, dtype=dtype) + dpt.round(X, out=Z) + + assert_allclose(dpt.asnumpy(Z), np.repeat(Ynp, n_rep), atol=tol, rtol=tol) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_round_complex_contig(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n_seq = 100 + n_rep = 137 + low = -88.0 + high = 88.0 + x1 = np.random.uniform(low=low, high=high, size=n_seq) + x2 = np.random.uniform(low=low, high=high, size=n_seq) + Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype) + + X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q) + Y = dpt.round(X) + + tol = 8 * dpt.finfo(dtype).resolution + assert_allclose( + dpt.asnumpy(Y), np.repeat(np.round(Xnp), n_rep), atol=tol, rtol=tol + ) + + Z = dpt.empty_like(X, dtype=dtype) + dpt.round(X, out=Z) + + assert_allclose( + dpt.asnumpy(Z), np.repeat(np.round(Xnp), n_rep), atol=tol, rtol=tol + ) + + +@pytest.mark.parametrize("usm_type", _usm_types) +def test_round_usm_type(usm_type): + q = get_queue_or_skip() + + arg_dt = np.dtype("f4") + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) + X[..., 0::2] = 16.2 + X[..., 1::2] = 23.7 + + Y = dpt.round(X) + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == X.sycl_queue + assert Y.flags.c_contiguous + + expected_Y = np.empty(input_shape, dtype=arg_dt) + expected_Y[..., 0::2] = np.round(np.float32(16.2)) + expected_Y[..., 1::2] = np.round(np.float32(23.7)) + tol = 8 * dpt.finfo(Y.dtype).resolution + + assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_round_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + X[..., 0::2] = 8.8 + X[..., 1::2] = 11.3 + + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) + expected_Y = np.round(dpt.asnumpy(U)) + for ord in ["C", "F", "A", "K"]: + Y = dpt.round(U, order=ord) + assert_allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_round_real_special_cases(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + tol = 8 * dpt.finfo(dtype).resolution + x = [np.nan, np.inf, -np.inf, 1.5, 2.5, -1.5, -2.5, 0.0, -0.0] + Xnp = np.array(x, dtype=dtype) + X = dpt.asarray(x, dtype=dtype) + + Y = dpt.asnumpy(dpt.round(X)) + Ynp = np.round(Xnp) + assert_allclose(Y, Ynp, atol=tol, rtol=tol) + assert_array_equal(np.signbit(Y), np.signbit(Ynp)) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_round_real_strided(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + np.random.seed(42) + strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4]) + sizes = [2, 4, 6, 8, 9, 24, 72] + tol = 8 * dpt.finfo(dtype).resolution + + for ii in sizes: + Xnp = np.random.uniform(low=0.01, high=88.1, size=ii) + Xnp.astype(dtype) + X = dpt.asarray(Xnp) + Ynp = np.round(Xnp) + for jj in strides: + assert_allclose( + dpt.asnumpy(dpt.round(X[::jj])), + Ynp[::jj], + atol=tol, + rtol=tol, + ) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_round_complex_strided(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + np.random.seed(42) + strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4]) + sizes = [2, 4, 6, 8, 9, 24, 72] + tol = 8 * dpt.finfo(dtype).resolution + + low = -88.0 + high = 88.0 + for ii in sizes: + x1 = np.random.uniform(low=low, high=high, size=ii) + x2 = np.random.uniform(low=low, high=high, size=ii) + Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype) + X = dpt.asarray(Xnp) + Ynp = np.round(Xnp) + for jj in strides: + assert_allclose( + dpt.asnumpy(dpt.round(X[::jj])), + Ynp[::jj], + atol=tol, + rtol=tol, + ) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_round_complex_special_cases(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + x = [np.nan, np.inf, -np.inf, 1.5, 2.5, -1.5, -2.5, 0.0, -0.0] + xc = [complex(*val) for val in itertools.product(x, repeat=2)] + + Xc_np = np.array(xc, dtype=dtype) + Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q) + + Ynp = np.round(Xc_np) + Y = dpt.round(Xc) + + tol = 8 * dpt.finfo(dtype).resolution + assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol) + assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) +def test_round_out_overlap(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.linspace(0, 1, 15, dtype=dtype, sycl_queue=q) + X = dpt.reshape(X, (3, 5)) + + Xnp = dpt.asnumpy(X) + Ynp = np.round(Xnp, out=Xnp) + + Y = dpt.round(X, out=X) + tol = 8 * dpt.finfo(Y.dtype).resolution + assert Y is X + assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol) + + Ynp = np.round(Xnp, out=Xnp[::-1]) + Y = dpt.round(X, out=X[::-1]) + assert Y is not X + assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol) + assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)