From fd621a59d9bee9afdbbd4926f99a968f853e8abb Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 8 Jun 2023 14:49:37 +0200 Subject: [PATCH 1/5] Impementation of dpctl.tensor.less function --- dpctl/tensor/__init__.py | 2 + dpctl/tensor/_elementwise_funcs.py | 28 +- .../kernels/elementwise_functions/less.hpp | 314 ++++++++++++++++++ .../source/elementwise_functions.cpp | 73 +++- 4 files changed, 414 insertions(+), 3 deletions(-) create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 2e720cba92..946303892c 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -104,6 +104,7 @@ isfinite, isinf, isnan, + less, log, log1p, multiply, @@ -200,6 +201,7 @@ "isinf", "isnan", "isfinite", + "less", "log", "log1p", "proj", diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index cb24929b76..ae48180a19 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -381,7 +381,33 @@ ) # B13: ==== LESS (x1, x2) -# FIXME: implement B13 +_less_docstring_ = """ +divide(x1, x2, out=None, order='K') + +Computes the less-than test results for each element `x1_i` of +the input array `x1` the respective element `x2_i` of the input array `x2`. + +Args: + x1 (usm_ndarray): + First input array, expected to have numeric data type. + x2 (usm_ndarray): + Second input array, also expected to have numeric data type. + out ({None, usm_ndarray}, optional): + Output array to populate. + Array have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the newly output array, if parameter `out` is `None`. + Default: "K". +Returns: + usm_narray: + An array containing the result of element-wise less-than comparison. + The data type of the returned array is determined by the + Type Promotion Rules. +""" + +less = BinaryElementwiseFunc( + "less", ti._less_result_type, ti._less, _less_docstring_ +) # B14: ==== LESS_EQUAL (x1, x2) # FIXME: implement B14 diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp new file mode 100644 index 0000000000..413938bb76 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp @@ -0,0 +1,314 @@ +//=== less.hpp - Binary function LESS ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain in1 copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of comparison of +/// tensor elements. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +#include "kernels/elementwise_functions/common.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace less +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; + +template struct LessFunctor +{ + static_assert(std::is_same_v); + + using supports_complex = + std::conjunction, tu_ns::is_complex>; + using supports_sg_loadstore = std::negation< + std::disjunction, tu_ns::is_complex>>; + using supports_vec = std::conjunction< + std::is_same, + std::negation, + tu_ns::is_complex>>>; + + resT operator()(const argT1 &in1, const argT2 &in2) + { + if constexpr (supports_complex::value) { + + if (std::real(in1) < std::real(in2)) { + return (std::imag(in1) == in1.imag() || + std::imag(in2) != std::imag(in2)); + } + else if (std::real(in1) > std::real(in2)) { + return (std::imag(in2) != std::imag(in2) && + in1.imag() == in1.imag()); + } + else if (std::real(in1) == std::real(in2) || + (std::real(in1) != std::real(in1) && + std::real(in2) != std::real(in2))) + { + return (in1.imag() < std::imag(in2) || + (std::imag(in2) != std::imag(in2) && + in1.imag() == in1.imag())); + } + else { + return (std::real(in2) != std::real(in2)); + } + } + + else { + return (in1 < in2); + } + } + + template + sycl::vec operator()(const sycl::vec &in1, + const sycl::vec &in2) + { + + auto tmp = (in1 < in2); + + if constexpr (std::is_same_v) { + return tmp; + } + else { + using dpctl::tensor::type_utils::vec_cast; + + return vec_cast( + tmp); + } + } +}; + +template +using LessContigFunctor = + elementwise_common::BinaryContigFunctor, + vec_sz, + n_vecs>; + +template +using LessStridedFunctor = + elementwise_common::BinaryStridedFunctor>; + +template struct LessOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + bool>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + bool>, + td_ns::DefaultResultEntry>::result_type; +}; + +template +class less_contig_kernel; + +template +sycl::event less_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg1_p, + py::ssize_t arg1_offset, + const char *arg2_p, + py::ssize_t arg2_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + const size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + using resTy = typename LessOutputType::value_type; + + const argTy1 *arg1_tp = + reinterpret_cast(arg1_p) + arg1_offset; + const argTy2 *arg2_tp = + reinterpret_cast(arg2_p) + arg2_offset; + resTy *res_tp = reinterpret_cast(res_p) + res_offset; + + cgh.parallel_for< + less_contig_kernel>( + sycl::nd_range<1>(gws_range, lws_range), + LessContigFunctor( + arg1_tp, arg2_tp, res_tp, nelems)); + }); + return comp_ev; +} + +template struct LessContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename LessOutputType::value_type, void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = less_contig_impl; + return fn; + } + } +}; + +template struct LessTypeMapFactory +{ + /*! @brief get typeid for output type of operator()>(x, y), always bool */ + std::enable_if_t::value, int> get() + { + using rT = typename LessOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class less_strided_strided_kernel; + +template +sycl::event +less_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg1_p, + py::ssize_t arg1_offset, + const char *arg2_p, + py::ssize_t arg2_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename LessOutputType::value_type; + + using IndexerT = + typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + + IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset, + shape_and_strides}; + + const argTy1 *arg1_tp = reinterpret_cast(arg1_p); + const argTy2 *arg2_tp = reinterpret_cast(arg2_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for< + less_strided_strided_kernel>( + {nelems}, LessStridedFunctor( + arg1_tp, arg2_tp, res_tp, indexer)); + }); + return comp_ev; +} + +template struct LessStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename LessOutputType::value_type, void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = less_strided_impl; + return fn; + } + } +}; + +} // namespace less +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 88597512bc..413aa7f224 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -43,6 +43,7 @@ #include "kernels/elementwise_functions/isfinite.hpp" #include "kernels/elementwise_functions/isinf.hpp" #include "kernels/elementwise_functions/isnan.hpp" +#include "kernels/elementwise_functions/less.hpp" #include "kernels/elementwise_functions/log.hpp" #include "kernels/elementwise_functions/log1p.hpp" #include "kernels/elementwise_functions/multiply.hpp" @@ -723,7 +724,39 @@ void populate_isnan_dispatch_vectors(void) // B13: ==== LESS (x1, x2) namespace impl { -// FIXME: add code for B13 +namespace less_fn_ns = dpctl::tensor::kernels::less; + +static binary_contig_impl_fn_ptr_t less_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static int less_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + less_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_less_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = less_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::LessTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(less_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::LessStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(less_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::LessContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(less_contig_dispatch_table); +}; } // namespace impl // B14: ==== LESS_EQUAL (x1, x2) @@ -1587,7 +1620,43 @@ void init_elementwise_functions(py::module_ m) } // B13: ==== LESS (x1, x2) - // FIXME: + { + impl::populate_less_dispatch_tables(); + using impl::less_contig_dispatch_table; + using impl::less_output_id_table; + using impl::less_strided_dispatch_table; + + auto less_pyapi = [&](dpctl::tensor::usm_ndarray src1, + dpctl::tensor::usm_ndarray src2, + dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, less_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + less_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + less_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto less_result_type_pyapi = [&](py::dtype dtype1, py::dtype dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + less_output_id_table); + }; + m.def("_less", less_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_less_result_type", less_result_type_pyapi, ""); + } // B14: ==== LESS_EQUAL (x1, x2) // FIXME: From 54b3b7b5559a45f5eca127ce93c05bede3ec1567 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 8 Jun 2023 14:50:20 +0200 Subject: [PATCH 2/5] Add tests for dpctl.tensor.less --- dpctl/tests/elementwise/test_less.py | 212 +++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 dpctl/tests/elementwise/test_less.py diff --git a/dpctl/tests/elementwise/test_less.py b/dpctl/tests/elementwise/test_less.py new file mode 100644 index 0000000000..55c3abe35c --- /dev/null +++ b/dpctl/tests/elementwise/test_less.py @@ -0,0 +1,212 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ctypes + +import numpy as np +import pytest + +import dpctl +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes, _compare_dtypes, _usm_types + + +@pytest.mark.parametrize("op1_dtype", _all_dtypes) +@pytest.mark.parametrize("op2_dtype", _all_dtypes) +def test_less_dtype_matrix(op1_dtype, op2_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(op1_dtype, q) + skip_if_dtype_not_supported(op2_dtype, q) + + sz = 127 + ar1 = dpt.ones(sz, dtype=op1_dtype) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype) + + r = dpt.less(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected = np.less(np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + assert r.shape == ar1.shape + assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() + assert r.sycl_queue == ar1.sycl_queue + + ar3 = dpt.ones(sz, dtype=op1_dtype) + ar4 = dpt.ones(2 * sz, dtype=op2_dtype) + + r = dpt.less(ar3[::-1], ar4[::2]) + assert isinstance(r, dpt.usm_ndarray) + expected = np.less(np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + assert r.shape == ar3.shape + assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() + + +@pytest.mark.parametrize("op_dtype", ["c8", "c16"]) +def test_less_complex_matrix(op_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(op_dtype, q) + + sz = 127 + ar1_np_real = np.random.randint(0, 10, sz) + ar1_np_imag = np.random.randint(0, 10, sz) + ar1 = dpt.asarray(ar1_np_real + 1j * ar1_np_imag, dtype=op_dtype) + + ar2_np_real = np.random.randint(0, 10, sz) + ar2_np_imag = np.random.randint(0, 10, sz) + ar2 = dpt.asarray(ar2_np_real + 1j * ar2_np_imag, dtype=op_dtype) + + r = dpt.less(ar1, ar2) + expected = np.less(dpt.asnumpy(ar1), dpt.asnumpy(ar2)) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + assert r.shape == expected.shape + assert (dpt.asnumpy(r) == expected).all() + + r1 = dpt.less(ar1[::-2], ar2[::2]) + expected1 = np.less(dpt.asnumpy(ar1[::-2]), dpt.asnumpy(ar2[::2])) + assert _compare_dtypes(r.dtype, expected1.dtype, sycl_queue=q) + assert r1.shape == expected1.shape + assert (dpt.asnumpy(r1) == expected1).all() + + +@pytest.mark.parametrize("op1_usm_type", _usm_types) +@pytest.mark.parametrize("op2_usm_type", _usm_types) +def test_less_usm_type_matrix(op1_usm_type, op2_usm_type): + get_queue_or_skip() + + sz = 128 + ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type) + ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type) + + r = dpt.less(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected_usm_type = dpctl.utils.get_coerced_usm_type( + (op1_usm_type, op2_usm_type) + ) + assert r.usm_type == expected_usm_type + + +def test_less_order(): + get_queue_or_skip() + + ar1 = dpt.ones((20, 20), dtype="i4", order="C") + ar2 = dpt.ones((20, 20), dtype="i4", order="C") + r1 = dpt.less(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.less(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.less(ar1, ar2, order="A") + assert r3.flags.c_contiguous + r4 = dpt.less(ar1, ar2, order="K") + assert r4.flags.c_contiguous + + ar1 = dpt.ones((20, 20), dtype="i4", order="F") + ar2 = dpt.ones((20, 20), dtype="i4", order="F") + r1 = dpt.less(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.less(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.less(ar1, ar2, order="A") + assert r3.flags.f_contiguous + r4 = dpt.less(ar1, ar2, order="K") + assert r4.flags.f_contiguous + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + r4 = dpt.less(ar1, ar2, order="K") + assert r4.strides == (20, -1) + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + r4 = dpt.less(ar1, ar2, order="K") + assert r4.strides == (-1, 20) + + +def test_less_broadcasting(): + get_queue_or_skip() + + m = dpt.ones((100, 5), dtype="i4") + v = dpt.arange(1, 6, dtype="i4") + + r = dpt.less(m, v) + + expected = np.less( + np.ones((100, 5), dtype="i4"), np.arange(1, 6, dtype="i4") + ) + assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() + + r2 = dpt.less(v, m) + expected2 = np.less( + np.arange(1, 6, dtype="i4"), np.ones((100, 5), dtype="i4") + ) + assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all() + + +@pytest.mark.parametrize("arr_dt", _all_dtypes) +def test_less_python_scalar(arr_dt): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arr_dt, q) + + X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q) + py_ones = ( + bool(1), + int(1), + float(1), + complex(1), + np.float32(1), + ctypes.c_int(1), + ) + for sc in py_ones: + R = dpt.less(X, sc) + assert isinstance(R, dpt.usm_ndarray) + R = dpt.less(sc, X) + assert isinstance(R, dpt.usm_ndarray) + + +class MockArray: + def __init__(self, arr): + self.data_ = arr + + @property + def __sycl_usm_array_interface__(self): + return self.data_.__sycl_usm_array_interface__ + + +def test_less_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + b = dpt.ones(10) + c = MockArray(b) + r = dpt.less(a, c) + assert isinstance(r, dpt.usm_ndarray) + + +def test_less_canary_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + + class Canary: + def __init__(self): + pass + + @property + def __sycl_usm_array_interface__(self): + return None + + c = Canary() + with pytest.raises(ValueError): + dpt.less(a, c) From 54fb83af0fa9181ac74fc8a943c4b90d7562e95c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 8 Jun 2023 15:14:14 +0200 Subject: [PATCH 3/5] Replace branching with ternary operator --- .../kernels/elementwise_functions/less.hpp | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp index 413938bb76..67eed5c0b2 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp @@ -65,28 +65,20 @@ template struct LessFunctor resT operator()(const argT1 &in1, const argT2 &in2) { if constexpr (supports_complex::value) { - - if (std::real(in1) < std::real(in2)) { - return (std::imag(in1) == in1.imag() || - std::imag(in2) != std::imag(in2)); - } - else if (std::real(in1) > std::real(in2)) { - return (std::imag(in2) != std::imag(in2) && - in1.imag() == in1.imag()); - } - else if (std::real(in1) == std::real(in2) || - (std::real(in1) != std::real(in1) && - std::real(in2) != std::real(in2))) - { - return (in1.imag() < std::imag(in2) || - (std::imag(in2) != std::imag(in2) && - in1.imag() == in1.imag())); - } - else { - return (std::real(in2) != std::real(in2)); - } + return (std::real(in1) < std::real(in2)) + ? (std::imag(in1) == in1.imag() || + std::imag(in2) != std::imag(in2)) + : (std::real(in1) > std::real(in2)) + ? (std::imag(in2) != std::imag(in2) && + in1.imag() == in1.imag()) + : (std::real(in1) == std::real(in2) || + (std::real(in1) != std::real(in1) && + std::real(in2) != std::real(in2))) + ? (in1.imag() < std::imag(in2) || + (std::imag(in2) != std::imag(in2) && + in1.imag() == in1.imag())) + : (std::real(in2) != std::real(in2)); } - else { return (in1 < in2); } From 6e4d5bc14c0b85c4b9256ea1d1debb2bf455fdf5 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 9 Jun 2023 20:21:32 +0200 Subject: [PATCH 4/5] Fix remarks and extend support for complex and float --- dpctl/tensor/_elementwise_funcs.py | 2 +- .../kernels/elementwise_functions/less.hpp | 42 ++++++++++++------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index ae48180a19..578f70e671 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -382,7 +382,7 @@ # B13: ==== LESS (x1, x2) _less_docstring_ = """ -divide(x1, x2, out=None, order='K') +less(x1, x2, out=None, order='K') Computes the less-than test results for each element `x1_i` of the input array `x1` the respective element `x2_i` of the input array `x2`. diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp index 67eed5c0b2..4c123009da 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp @@ -53,8 +53,6 @@ template struct LessFunctor { static_assert(std::is_same_v); - using supports_complex = - std::conjunction, tu_ns::is_complex>; using supports_sg_loadstore = std::negation< std::disjunction, tu_ns::is_complex>>; using supports_vec = std::conjunction< @@ -64,20 +62,28 @@ template struct LessFunctor resT operator()(const argT1 &in1, const argT2 &in2) { - if constexpr (supports_complex::value) { - return (std::real(in1) < std::real(in2)) - ? (std::imag(in1) == in1.imag() || - std::imag(in2) != std::imag(in2)) - : (std::real(in1) > std::real(in2)) - ? (std::imag(in2) != std::imag(in2) && - in1.imag() == in1.imag()) - : (std::real(in1) == std::real(in2) || - (std::real(in1) != std::real(in1) && - std::real(in2) != std::real(in2))) - ? (in1.imag() < std::imag(in2) || - (std::imag(in2) != std::imag(in2) && - in1.imag() == in1.imag())) - : (std::real(in2) != std::real(in2)); + if constexpr (std::is_same_v> && + std::is_same_v) + { + float real1 = std::real(in1); + return (real1 == in2) ? (std::imag(in1) < 0.0f) : real1 < in2; + } + else if constexpr (std::is_same_v && + std::is_same_v>) + { + float real2 = std::real(in2); + return (in1 == real2) ? (0.0f < std::imag(in2)) : in1 < real2; + } + else if constexpr (tu_ns::is_complex::value || + tu_ns::is_complex::value) + { + static_assert(std::is_same_v); + using realT = typename argT1::value_type; + realT real1 = std::real(in1); + realT real2 = std::real(in2); + + return (real1 == real2) ? (std::imag(in1) < std::imag(in2)) + : real1 < real2; } else { return (in1 < in2); @@ -167,6 +173,10 @@ template struct LessOutputType T2, std::complex, bool>, + td_ns:: + BinaryTypeMapResultEntry, bool>, + td_ns:: + BinaryTypeMapResultEntry, T2, float, bool>, td_ns::DefaultResultEntry>::result_type; }; From 14e32b70663c03b05f7ddf36c693355a93d65fb6 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 9 Jun 2023 20:22:35 +0200 Subject: [PATCH 5/5] Update tests for less function --- dpctl/tests/elementwise/test_less.py | 50 +++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/dpctl/tests/elementwise/test_less.py b/dpctl/tests/elementwise/test_less.py index 55c3abe35c..767a390614 100644 --- a/dpctl/tests/elementwise/test_less.py +++ b/dpctl/tests/elementwise/test_less.py @@ -34,23 +34,27 @@ def test_less_dtype_matrix(op1_dtype, op2_dtype): skip_if_dtype_not_supported(op2_dtype, q) sz = 127 - ar1 = dpt.ones(sz, dtype=op1_dtype) + ar1 = dpt.zeros(sz, dtype=op1_dtype) ar2 = dpt.ones_like(ar1, dtype=op2_dtype) r = dpt.less(ar1, ar2) assert isinstance(r, dpt.usm_ndarray) - expected = np.less(np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)) + expected = np.less( + np.zeros(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype) + ) assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) assert r.shape == ar1.shape assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() assert r.sycl_queue == ar1.sycl_queue - ar3 = dpt.ones(sz, dtype=op1_dtype) + ar3 = dpt.zeros(sz, dtype=op1_dtype) ar4 = dpt.ones(2 * sz, dtype=op2_dtype) r = dpt.less(ar3[::-1], ar4[::2]) assert isinstance(r, dpt.usm_ndarray) - expected = np.less(np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)) + expected = np.less( + np.zeros(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype) + ) assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) assert r.shape == ar3.shape assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() @@ -82,6 +86,44 @@ def test_less_complex_matrix(op_dtype): assert r1.shape == expected1.shape assert (dpt.asnumpy(r1) == expected1).all() + ar3 = dpt.asarray([1.0 + 9j, 2.0 + 0j, 2.0 + 1j, 2.0 + 2j], dtype=op_dtype) + ar4 = dpt.asarray([2.0 + 0j, dpt.nan, dpt.inf, -dpt.inf], dtype=op_dtype) + r2 = dpt.less(ar3, ar4) + with np.errstate(invalid="ignore"): + expected2 = np.less(dpt.asnumpy(ar3), dpt.asnumpy(ar4)) + assert (dpt.asnumpy(r2) == expected2).all() + + r3 = dpt.less(ar4, ar4) + with np.errstate(invalid="ignore"): + expected3 = np.less(dpt.asnumpy(ar4), dpt.asnumpy(ar4)) + assert (dpt.asnumpy(r3) == expected3).all() + + +def test_less_complex_float(): + get_queue_or_skip() + + ar1 = dpt.asarray([1.0 + 9j, 2.0 + 0j, 2.0 + 1j, 2.0 + 2j], dtype="c8") + ar2 = dpt.full((4,), 2, dtype="f4") + + r = dpt.less(ar1, ar2) + expected = np.less(dpt.asnumpy(ar1), dpt.asnumpy(ar2)) + assert (dpt.asnumpy(r) == expected).all() + + r1 = dpt.less(ar2, ar1) + expected1 = np.less(dpt.asnumpy(ar2), dpt.asnumpy(ar1)) + assert (dpt.asnumpy(r1) == expected1).all() + with np.errstate(invalid="ignore"): + for tp in [dpt.nan, dpt.inf, -dpt.inf]: + + ar3 = dpt.full((4,), tp) + r2 = dpt.less(ar1, ar3) + expected2 = np.less(dpt.asnumpy(ar1), dpt.asnumpy(ar3)) + assert (dpt.asnumpy(r2) == expected2).all() + + r3 = dpt.less(ar3, ar1) + expected3 = np.less(dpt.asnumpy(ar3), dpt.asnumpy(ar1)) + assert (dpt.asnumpy(r3) == expected3).all() + @pytest.mark.parametrize("op1_usm_type", _usm_types) @pytest.mark.parametrize("op2_usm_type", _usm_types)