From a0e44d2d4fa4b1e446df4c86aa4082d160615d78 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 30 Mar 2023 01:40:48 -0700 Subject: [PATCH 1/6] Implements dpt.where --- dpctl/tensor/CMakeLists.txt | 1 + dpctl/tensor/__init__.py | 2 + dpctl/tensor/_search_functions.py | 103 ++++++ .../libtensor/include/kernels/where.hpp | 293 ++++++++++++++++++ dpctl/tensor/libtensor/source/tensor_py.cpp | 10 + dpctl/tensor/libtensor/source/where.cpp | 258 +++++++++++++++ dpctl/tensor/libtensor/source/where.hpp | 52 ++++ 7 files changed, 719 insertions(+) create mode 100644 dpctl/tensor/_search_functions.py create mode 100644 dpctl/tensor/libtensor/include/kernels/where.hpp create mode 100644 dpctl/tensor/libtensor/source/where.cpp create mode 100644 dpctl/tensor/libtensor/source/where.hpp diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index e34cd94bb9..1025072c08 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -43,6 +43,7 @@ pybind11_add_module(${python_module_name} MODULE ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp ) target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index e970af98f5..ec9f27617b 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -86,6 +86,7 @@ usm_ndarray_str, ) from dpctl.tensor._reshape import reshape +from dpctl.tensor._search_functions import where from dpctl.tensor._usmarray import usm_ndarray from ._constants import e, inf, nan, newaxis, pi @@ -128,6 +129,7 @@ "from_dlpack", "tril", "triu", + "where", "dtype", "isdtype", "bool", diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py new file mode 100644 index 0000000000..613c8c3445 --- /dev/null +++ b/dpctl/tensor/_search_functions.py @@ -0,0 +1,103 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dpctl +import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as ti +from dpctl.tensor._manipulation_functions import _broadcast_shapes + + +def where(condition, x1, x2): + if not isinstance(condition, dpt.usm_ndarray): + raise TypeError( + "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}" + ) + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError( + "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x1)}" + ) + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError( + "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}" + ) + exec_q = dpctl.utils.get_execution_queue( + ( + condition.sycl_queue, + x1.sycl_queue, + x2.sycl_queue, + ) + ) + if exec_q is None: + raise dpctl.utils.ExecutionPlacementError + dst_usm_type = dpctl.utils.get_coerced_usm_type( + ( + condition.usm_type, + x1.usm_type, + x2.usm_type, + ) + ) + + x1_dtype = x1.dtype + x2_dtype = x2.dtype + dst_dtype = dpt.result_type(x1.dtype, x2.dtype) + + if condition.size == 0: + return dpt.asarray( + (), dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q + ) + + res_shape = _broadcast_shapes(condition, x1, x2) + + deps = [] + wait_list = [] + if x1_dtype is not dst_dtype: + _x1 = dpt.empty_like(x1, dtype=dst_dtype) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=_x1, sycl_queue=exec_q + ) + x1 = _x1 + deps.append(copy1_ev) + wait_list.append(ht_copy1_ev) + + if x2_dtype is not dst_dtype: + _x2 = dpt.empty_like(x2, dtype=dst_dtype) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=_x2, sycl_queue=exec_q + ) + x2 = _x2 + deps.append(copy2_ev) + wait_list.append(ht_copy2_ev) + + condition = dpt.broadcast_to(condition, res_shape) + x1 = dpt.broadcast_to(x1, res_shape) + x2 = dpt.broadcast_to(x2, res_shape) + + dst = dpt.empty( + res_shape, dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q + ) + + hev, _ = ti._where( + condition=condition, + x1=x1, + x2=x2, + dst=dst, + sycl_queue=exec_q, + depends=deps, + ) + wait_list.append(hev) + dpctl.SyclEvent.wait_for(wait_list) + + return dst diff --git a/dpctl/tensor/libtensor/include/kernels/where.hpp b/dpctl/tensor/libtensor/include/kernels/where.hpp new file mode 100644 index 0000000000..54e502af0e --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/where.hpp @@ -0,0 +1,293 @@ +//=== where.hpp - Implementation of where kernels ---*-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for dpctl.tensor.where. +//===----------------------------------------------------------------------===// + +#pragma once +#include "pybind11/numpy.h" +#include "pybind11/stl.h" +#include "utils/offset_utils.hpp" +#include "utils/type_utils.hpp" +#include +#include +#include +#include +#include +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace search +{ + +namespace py = pybind11; + +using namespace dpctl::tensor::offset_utils; + +template +class where_strided_kernel; +template +class where_contig_kernel; + +template +class WhereContigFunctor +{ +private: + size_t nelems = 0; + const char *x1_cp = nullptr; + const char *x2_cp = nullptr; + char *dst_cp = nullptr; + const char *cond_cp = nullptr; + +public: + WhereContigFunctor(size_t nelems_, + const char *cond_data_p, + const char *x1_data_p, + const char *x2_data_p, + char *dst_data_p) + : nelems(nelems_), x1_cp(x1_data_p), x2_cp(x2_data_p), + dst_cp(dst_data_p), cond_cp(cond_data_p) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + const T *x1_data = reinterpret_cast(x1_cp); + const T *x2_data = reinterpret_cast(x2_cp); + T *dst_data = reinterpret_cast(dst_cp); + const condT *cond_data = reinterpret_cast(cond_cp); + + using dpctl::tensor::type_utils::convert_impl; + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value || is_complex::value) { + std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; + size_t base = ndit.get_global_linear_id(); + + base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); + for (size_t offset = base; + offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz)); + offset += sgSize) + { + bool check = convert_impl(cond_data[offset]); + dst_data[offset] = check ? x1_data[offset] : x2_data[offset]; + } + } + else { + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t max_sgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * max_sgSize); + + if (base + n_vecs * vec_sz * sgSize < nelems && + sgSize == max_sgSize) { + using dst_ptrT = + sycl::multi_ptr; + using x_ptrT = + sycl::multi_ptr; + using cond_ptrT = + sycl::multi_ptr; + + sycl::vec dst_vec; + sycl::vec x1_vec; + sycl::vec x2_vec; + sycl::vec cond_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + auto idx = base + it * sgSize; + x1_vec = sg.load(x_ptrT(&x1_data[idx])); + x2_vec = sg.load(x_ptrT(&x2_data[idx])); + cond_vec = sg.load(cond_ptrT(&cond_data[idx])); + +#pragma unroll + for (std::uint8_t k = 0; k < vec_sz; ++k) { + bool check = convert_impl(cond_vec[k]); + dst_vec[k] = check ? x1_vec[k] : x2_vec[k]; + } + sg.store(dst_ptrT(&dst_data[idx]), dst_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems; + k += sgSize) { + bool check = convert_impl(cond_data[k]); + dst_data[k] = check ? x1_data[k] : x2_data[k]; + } + } + } + } +}; + +typedef sycl::event (*where_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + const char *, + const char *, + char *, + const std::vector &); + +template +sycl::event where_contig_impl(sycl::queue q, + size_t nelems, + const char *cond_p, + const char *x1_p, + const char *x2_p, + char *dst_p, + const std::vector &depends) +{ + sycl::event where_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + const size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + cgh.parallel_for>( + sycl::nd_range<1>(gws_range, lws_range), + WhereContigFunctor(nelems, cond_p, x1_p, + x2_p, dst_p)); + }); + + return where_ev; +} + +template +class WhereStridedFunctor +{ +private: + const char *x1_cp = nullptr; + const char *x2_cp = nullptr; + char *dst_cp = nullptr; + const char *cond_cp = nullptr; + IndexerT indexer; + +public: + WhereStridedFunctor(const char *cond_data_p, + const char *x1_data_p, + const char *x2_data_p, + char *dst_data_p, + IndexerT indexer_) + : x1_cp(x1_data_p), x2_cp(x2_data_p), dst_cp(dst_data_p), + cond_cp(cond_data_p), indexer(indexer_) + { + } + + void operator()(sycl::id<1> id) const + { + const T *x1_data = reinterpret_cast(x1_cp); + const T *x2_data = reinterpret_cast(x2_cp); + T *dst_data = reinterpret_cast(dst_cp); + const condT *cond_data = reinterpret_cast(cond_cp); + + size_t gid = id[0]; + auto offsets = indexer(static_cast(gid)); + + using dpctl::tensor::type_utils::convert_impl; + bool check = + convert_impl(cond_data[offsets.get_first_offset()]); + + dst_data[gid] = check ? x1_data[offsets.get_second_offset()] + : x2_data[offsets.get_third_offset()]; + } +}; + +typedef sycl::event (*where_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const char *, + const char *, + const char *, + char *, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + const std::vector &); + +template +sycl::event where_strided_impl(sycl::queue q, + size_t nelems, + int nd, + const char *cond_p, + const char *x1_p, + const char *x2_p, + char *dst_p, + const py::ssize_t *shape_strides, + py::ssize_t x1_offset, + py::ssize_t x2_offset, + py::ssize_t cond_offset, + const std::vector &depends) +{ + sycl::event where_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + ThreeOffsets_StridedIndexer indexer{nd, cond_offset, x1_offset, + x2_offset, shape_strides}; + + cgh.parallel_for< + where_strided_kernel>( + sycl::range<1>(nelems), + WhereStridedFunctor( + cond_p, x1_p, x2_p, dst_p, indexer)); + }); + + return where_ev; +} + +template struct WhereStridedFactory +{ + fnT get() + { + fnT fn = where_strided_impl; + return fn; + } +}; + +template struct WhereContigFactory +{ + fnT get() + { + fnT fn = where_contig_impl; + return fn; + } +}; + +} // namespace search +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 7a5886088c..cf66ede984 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -45,6 +45,7 @@ #include "triul_ctor.hpp" #include "utils/memory_overlap.hpp" #include "utils/strided_iters.hpp" +#include "where.hpp" namespace py = pybind11; @@ -92,6 +93,10 @@ using dpctl::tensor::py_internal::usm_ndarray_eye; using dpctl::tensor::py_internal::usm_ndarray_triul; +/* =========================== Where ============================== */ + +using dpctl::tensor::py_internal::py_where; + // populate dispatch tables void init_dispatch_tables(void) { @@ -100,6 +105,7 @@ void init_dispatch_tables(void) init_copy_and_cast_usm_to_usm_dispatch_tables(); init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables(); init_advanced_indexing_dispatch_tables(); + init_where_dispatch_tables(); return; } @@ -293,4 +299,8 @@ PYBIND11_MODULE(_tensor_impl, m) m.def("_nonzero", &py_nonzero, "", py::arg("cumsum"), py::arg("indexes"), py::arg("mask_shape"), py::arg("sycl_queue"), py::arg("depends") = py::list()); + + m.def("_where", &py_where, "", py::arg("condition"), py::arg("x1"), + py::arg("x2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); } diff --git a/dpctl/tensor/libtensor/source/where.cpp b/dpctl/tensor/libtensor/source/where.cpp new file mode 100644 index 0000000000..862877ee9a --- /dev/null +++ b/dpctl/tensor/libtensor/source/where.cpp @@ -0,0 +1,258 @@ +//===-- where.cpp - Implementation of where --*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines Python API for implementation functions of +/// dpctl.tensor.where +//===----------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include "kernels/where.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "where.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace _ns = dpctl::tensor::detail; + +using dpctl::tensor::kernels::search::where_contig_impl_fn_ptr_t; +using dpctl::tensor::kernels::search::where_strided_impl_fn_ptr_t; + +static where_contig_impl_fn_ptr_t where_contig_dispatch_table[_ns::num_types] + [_ns::num_types]; +static where_strided_impl_fn_ptr_t where_strided_dispatch_table[_ns::num_types] + [_ns::num_types]; + +using dpctl::utils::keep_args_alive; + +std::pair +py_where(dpctl::tensor::usm_ndarray condition, + dpctl::tensor::usm_ndarray x1, + dpctl::tensor::usm_ndarray x2, + dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends) +{ + + if (!dpctl::utils::queues_are_compatible(exec_q, {x1, x2, condition, dst})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + int nd = condition.get_ndim(); + int x1_nd = x1.get_ndim(); + int x2_nd = x2.get_ndim(); + int dst_nd = dst.get_ndim(); + + if (nd != x1_nd || nd != x2_nd) { + throw py::value_error( + "Input arrays are not of appropriate dimension for where kernel."); + } + + if (nd != dst_nd) { + throw py::value_error( + "Destination is not of appropriate dimension for where kernel."); + } + + const py::ssize_t *x1_shape = x1.get_shape_raw(); + const py::ssize_t *x2_shape = x2.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + const py::ssize_t *cond_shape = condition.get_shape_raw(); + + bool shapes_equal(true); + size_t nelems(1); + for (int i = 0; i < nd; ++i) { + nelems *= static_cast(dst_shape[i]); + shapes_equal = shapes_equal && (x1_shape[i] == dst_shape[i]) && + (x2_shape[i] == dst_shape[i]) && + (cond_shape[i] == dst_shape[i]); + } + + if (!shapes_equal) { + throw py::value_error("Axes are not of matching shapes."); + } + + if (nelems == 0) { + return std::make_pair(sycl::event{}, sycl::event{}); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(dst, condition) || overlap(dst, x1) || overlap(dst, x2)) { + throw py::value_error("Destination array overlaps with input."); + } + + int x1_typenum = x1.get_typenum(); + int x2_typenum = x2.get_typenum(); + int cond_typenum = condition.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = dpctl::tensor::detail::usm_ndarray_types(); + int cond_typeid = array_types.typenum_to_lookup_id(cond_typenum); + int x1_typeid = array_types.typenum_to_lookup_id(x1_typenum); + int x2_typeid = array_types.typenum_to_lookup_id(x2_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (x1_typeid != x2_typeid || x1_typeid != dst_typeid) { + throw py::value_error("Non-condition are not of same type."); + } + + // ensure that dst is sufficiently ample + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accomodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < static_cast(nelems)) { + throw py::value_error( + "Memory addressed by the destination array can not " + "accomodate all the " + "array elements."); + } + } + + char *cond_data = condition.get_data(); + char *x1_data = x1.get_data(); + char *x2_data = x2.get_data(); + char *dst_data = dst.get_data(); + + bool is_x1_c_contig = x1.is_c_contiguous(); + bool is_x1_f_contig = x1.is_f_contiguous(); + + bool is_x2_c_contig = x2.is_c_contiguous(); + bool is_x2_f_contig = x2.is_f_contiguous(); + + bool is_cond_c_contig = condition.is_c_contiguous(); + bool is_cond_f_contig = condition.is_f_contiguous(); + + bool all_c_contig = (is_x1_c_contig && is_x2_c_contig && is_cond_c_contig); + bool all_f_contig = (is_x1_f_contig && is_x2_f_contig && is_cond_f_contig); + + if (all_c_contig || all_f_contig) { + auto contig_fn = where_contig_dispatch_table[x1_typeid][cond_typeid]; + + auto where_ev = contig_fn(exec_q, nelems, cond_data, x1_data, x2_data, + dst_data, depends); + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst, condition}, {where_ev}); + + return std::make_pair(ht_ev, where_ev); + } + + const py::ssize_t *cond_strides = condition.get_strides_raw(); + const py::ssize_t *x1_strides = x1.get_strides_raw(); + const py::ssize_t *x2_strides = x2.get_strides_raw(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_cond_strides; + shT simplified_x1_strides; + shT simplified_x2_strides; + py::ssize_t cond_offset(0); + py::ssize_t x1_offset(0); + py::ssize_t x2_offset(0); + + const py::ssize_t *_shape = x1_shape; + + constexpr py::ssize_t _itemsize = 1; + + dpctl::tensor::py_internal::simplify_iteration_space_3( + nd, _shape, cond_strides, _itemsize, is_cond_c_contig, is_cond_f_contig, + x1_strides, _itemsize, is_x1_c_contig, is_x1_f_contig, x2_strides, + _itemsize, is_x2_c_contig, is_x2_f_contig, simplified_shape, + simplified_cond_strides, simplified_x1_strides, simplified_x2_strides, + cond_offset, x1_offset, x2_offset); + + auto fn = where_strided_dispatch_table[x1_typeid][cond_typeid]; + + std::vector host_task_events; + host_task_events.reserve(2); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, simplified_shape, simplified_cond_strides, + simplified_x1_strides, simplified_x2_strides); + py::ssize_t *packed_shape_strides = std::get<0>(ptr_size_event_tuple); + sycl::event copy_shape_strides_ev = std::get<2>(ptr_size_event_tuple); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + sycl::event where_ev = + fn(exec_q, nelems, nd, cond_data, x1_data, x2_data, dst_data, + packed_shape_strides, cond_offset, x1_offset, x2_offset, all_deps); + + // free packed temporaries + sycl::event temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(where_ev); + auto ctx = exec_q.get_context(); + cgh.host_task([packed_shape_strides, ctx]() { + sycl::free(packed_shape_strides, ctx); + }); + }); + + host_task_events.push_back(temporaries_cleanup_ev); + + sycl::event arg_cleanup_ev = + keep_args_alive(exec_q, {x1, x2, condition, dst}, host_task_events); + + return std::make_pair(arg_cleanup_ev, temporaries_cleanup_ev); +} + +void init_where_dispatch_tables(void) +{ + using dpctl::tensor::kernels::search::WhereContigFactory; + dpctl::tensor::detail::DispatchTableBuilder< + where_contig_impl_fn_ptr_t, WhereContigFactory, + dpctl::tensor::detail::num_types> + dtb1; + dtb1.populate_dispatch_table(where_contig_dispatch_table); + + using dpctl::tensor::kernels::search::WhereStridedFactory; + dpctl::tensor::detail::DispatchTableBuilder< + where_strided_impl_fn_ptr_t, WhereStridedFactory, + dpctl::tensor::detail::num_types> + dtb2; + dtb2.populate_dispatch_table(where_strided_dispatch_table); +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/where.hpp b/dpctl/tensor/libtensor/source/where.hpp new file mode 100644 index 0000000000..95f70e560e --- /dev/null +++ b/dpctl/tensor/libtensor/source/where.hpp @@ -0,0 +1,52 @@ +//===-- where.hpp - --*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file declares Python API for implementation functions of +/// dpctl.tensor.where +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include +#include + +#include "dpctl4pybind11.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern std::pair +py_where(dpctl::tensor::usm_ndarray, + dpctl::tensor::usm_ndarray, + dpctl::tensor::usm_ndarray, + dpctl::tensor::usm_ndarray, + sycl::queue, + const std::vector &); + +extern void init_where_dispatch_tables(void); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl From 9018a79f47c958be61e1ed3a9e96ab1758b0d919 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 3 Apr 2023 07:49:22 -0700 Subject: [PATCH 2/6] Added utility functions, basic where tests - Utility functions are for finding an output type for universal and binary functions when the device of allocation lacks fp16 or fp64 --- dpctl/tensor/_search_functions.py | 21 +++- dpctl/tensor/_type_utils.py | 113 ++++++++++++++++++ .../test_usm_ndarray_search_functions.py | 84 +++++++++++++ 3 files changed, 217 insertions(+), 1 deletion(-) create mode 100644 dpctl/tensor/_type_utils.py create mode 100644 dpctl/tests/test_usm_ndarray_search_functions.py diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py index 613c8c3445..7179d959dd 100644 --- a/dpctl/tensor/_search_functions.py +++ b/dpctl/tensor/_search_functions.py @@ -19,6 +19,25 @@ import dpctl.tensor._tensor_impl as ti from dpctl.tensor._manipulation_functions import _broadcast_shapes +from ._type_utils import _all_data_types, _can_cast + + +def _where_result_type(dt1, dt2, dev): + res_dtype = dpt.result_type(dt1, dt2) + fp16 = dev.has_aspect_fp16 + fp64 = dev.has_aspect_fp64 + + all_dts = _all_data_types(fp16, fp64) + if res_dtype in all_dts: + return res_dtype + else: + for res_dtype_ in all_dts: + if _can_cast(dt1, res_dtype_, fp16, fp64) and _can_cast( + dt2, res_dtype_, fp16, fp64 + ): + return res_dtype_ + return None + def where(condition, x1, x2): if not isinstance(condition, dpt.usm_ndarray): @@ -52,7 +71,7 @@ def where(condition, x1, x2): x1_dtype = x1.dtype x2_dtype = x2.dtype - dst_dtype = dpt.result_type(x1.dtype, x2.dtype) + dst_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device) if condition.size == 0: return dpt.asarray( diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py new file mode 100644 index 0000000000..f39c9374db --- /dev/null +++ b/dpctl/tensor/_type_utils.py @@ -0,0 +1,113 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dpctl.tensor as dpt + + +def _all_data_types(_fp16, _fp64): + if _fp64: + if _fp16: + return [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float16, + dpt.float32, + dpt.float64, + dpt.complex64, + dpt.complex128, + ] + else: + return [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float32, + dpt.float64, + dpt.complex64, + dpt.complex128, + ] + else: + if _fp16: + return [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float16, + dpt.float32, + dpt.complex64, + ] + else: + return [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float32, + dpt.complex64, + ] + + +def is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool): + """ + Return True if data type `dt` is the + maximal size inexact data type + """ + if _fp64: + return dt in [dpt.float64, dpt.complex128] + return dt in [dpt.float32, dpt.complex64] + + +def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool): + """ + Can `from_` be cast to `to_` safely on a device with + fp16 and fp64 aspects as given? + """ + can_cast_v = dpt.can_cast(from_, to_) # ask NumPy + if _fp16 and _fp64: + return can_cast_v + if not can_cast_v: + if ( + from_.kind in "biu" + and to_.kind in "fc" + and is_maximal_inexact_type(to_, _fp16, _fp64) + ): + return True + + return can_cast_v diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py new file mode 100644 index 0000000000..7f4ec11577 --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -0,0 +1,84 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from helper import get_queue_or_skip, skip_if_dtype_not_supported +from numpy.testing import assert_array_equal + +import dpctl.tensor as dpt + +_all_dtypes = [ + "u1", + "i1", + "u2", + "i2", + "u4", + "i4", + "u8", + "i8", + "e", + "f", + "d", + "F", + "D", +] + + +def test_where_basic(): + get_queue_or_skip + + cond = dpt.asarray( + [ + [True, False, False], + [False, True, False], + [False, False, True], + [False, False, False], + [True, True, True], + ] + ) + out = dpt.where(cond, dpt.asarray(1), dpt.asarray(0)) + out_expected = dpt.asarray( + [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]] + ) + + assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all() + + +@pytest.mark.parametrize("dt1", _all_dtypes) +@pytest.mark.parametrize("dt2", _all_dtypes) +def test_where_all_dtypes(dt1, dt2): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt1, q) + skip_if_dtype_not_supported(dt2, q) + + cond_np = np.arange(5) > 2 + x1_np = np.asarray(2, dtype=dt1) + x2_np = np.asarray(3, dtype=dt2) + + cond = dpt.asarray(cond_np, sycl_queue=q) + x1 = dpt.asarray(x1_np, sycl_queue=q) + x2 = dpt.asarray(x2_np, sycl_queue=q) + + res = dpt.where(cond, x1, x2) + res_np = np.where(cond_np, x1_np, x2_np) + + if res.dtype != res_np.dtype: + assert res.dtype.kind == res_np.dtype.kind + assert_array_equal(dpt.asnumpy(res).astype(res_np.dtype), res_np) + + else: + assert_array_equal(dpt.asnumpy(res), res_np) From b35b083e306b5125291bf56975ce17f172695db2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 3 Apr 2023 22:10:30 -0700 Subject: [PATCH 3/6] Where changed for empty and F-contiguous input - Where now outputs an F-contiguous array when all inputs are F-contiguous - Where now outputs a empty 0D array if any input is a 0D empty array - Added tests for these cases Fixed incorrect logic in where test --- dpctl/tensor/_search_functions.py | 34 +++++-- .../test_usm_ndarray_search_functions.py | 88 ++++++++++++++++--- 2 files changed, 103 insertions(+), 19 deletions(-) diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py index 7179d959dd..9cccfe922a 100644 --- a/dpctl/tensor/_search_functions.py +++ b/dpctl/tensor/_search_functions.py @@ -72,14 +72,21 @@ def where(condition, x1, x2): x1_dtype = x1.dtype x2_dtype = x2.dtype dst_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device) - - if condition.size == 0: - return dpt.asarray( - (), dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q + if dst_dtype is None: + raise TypeError( + "function 'where' does not support input " + f"types ({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced " + "to any supported types according to the casting rule ''safe''." ) res_shape = _broadcast_shapes(condition, x1, x2) + if condition.size == 0: + return dpt.empty( + res_shape, dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q + ) + deps = [] wait_list = [] if x1_dtype is not dst_dtype: @@ -104,8 +111,25 @@ def where(condition, x1, x2): x1 = dpt.broadcast_to(x1, res_shape) x2 = dpt.broadcast_to(x2, res_shape) + # dst is F-contiguous when all inputs are F contiguous + # otherwise, defaults to C-contiguous + if all( + ( + condition.flags.fnc, + x1.flags.fnc, + x2.flags.fnc, + ) + ): + order = "F" + else: + order = "C" + dst = dpt.empty( - res_shape, dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q + res_shape, + dtype=dst_dtype, + order=order, + usm_type=dst_usm_type, + sycl_queue=exec_q, ) hev, _ = ti._where( diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py index 7f4ec11577..ec64991000 100644 --- a/dpctl/tests/test_usm_ndarray_search_functions.py +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -39,7 +39,7 @@ def test_where_basic(): - get_queue_or_skip + get_queue_or_skip() cond = dpt.asarray( [ @@ -58,6 +58,18 @@ def test_where_basic(): assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all() +def _dtype_all_close(x1, x2): + if np.issubdtype(x2.dtype, np.floating) or np.issubdtype( + x2.dtype, np.complexfloating + ): + x2_dtype = x2.dtype + return np.allclose( + x1, x2, atol=np.finfo(x2_dtype).eps, rtol=np.finfo(x2_dtype).eps + ) + else: + return np.allclose(x1, x2) + + @pytest.mark.parametrize("dt1", _all_dtypes) @pytest.mark.parametrize("dt2", _all_dtypes) def test_where_all_dtypes(dt1, dt2): @@ -65,20 +77,68 @@ def test_where_all_dtypes(dt1, dt2): skip_if_dtype_not_supported(dt1, q) skip_if_dtype_not_supported(dt2, q) - cond_np = np.arange(5) > 2 - x1_np = np.asarray(2, dtype=dt1) - x2_np = np.asarray(3, dtype=dt2) - - cond = dpt.asarray(cond_np, sycl_queue=q) - x1 = dpt.asarray(x1_np, sycl_queue=q) - x2 = dpt.asarray(x2_np, sycl_queue=q) + cond = dpt.asarray([False, False, False, True, True], sycl_queue=q) + x1 = dpt.asarray(2, sycl_queue=q) + x2 = dpt.asarray(3, sycl_queue=q) res = dpt.where(cond, x1, x2) - res_np = np.where(cond_np, x1_np, x2_np) + res_check = np.asarray([3, 3, 3, 2, 2], dtype=res.dtype) - if res.dtype != res_np.dtype: - assert res.dtype.kind == res_np.dtype.kind - assert_array_equal(dpt.asnumpy(res).astype(res_np.dtype), res_np) + dev = q.sycl_device - else: - assert_array_equal(dpt.asnumpy(res), res_np) + if not dev.has_aspect_fp16 or not dev.has_aspect_fp64: + assert res.dtype.kind == dpt.result_type(x1.dtype, x2.dtype).kind + + assert _dtype_all_close(dpt.asnumpy(res), res_check) + + +def test_where_empty(): + # check that numpy returns same results when + # handling empty arrays + get_queue_or_skip() + + empty = dpt.empty(0) + m = dpt.asarray(True) + x1 = dpt.asarray(1) + x2 = dpt.asarray(2) + res = dpt.where(empty, x1, x2) + + empty_np = np.empty(0) + m_np = dpt.asnumpy(m) + x1_np = dpt.asnumpy(x1) + x2_np = dpt.asnumpy(x2) + res_np = np.where(empty_np, x1_np, x2_np) + + assert_array_equal(dpt.asnumpy(res), res_np) + + res = dpt.where(m, empty, x2) + res_np = np.where(m_np, empty_np, x2_np) + + assert_array_equal(dpt.asnumpy(res), res_np) + + +@pytest.mark.parametrize("dt", _all_dtypes) +@pytest.mark.parametrize("order", ["C", "F"]) +def test_where_contiguous(dt, order): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt, q) + + cond = dpt.asarray( + [ + [[True, False, False], [False, True, True]], + [[False, True, False], [True, False, True]], + [[False, False, True], [False, False, True]], + [[False, False, False], [True, False, True]], + [[True, True, True], [True, False, True]], + ], + sycl_queue=q, + order=order, + ) + + x1 = dpt.full(cond.shape, 2, dtype=dt, order=order, sycl_queue=q) + x2 = dpt.full(cond.shape, 3, dtype=dt, order=order, sycl_queue=q) + + expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2)) + res = dpt.where(cond, x1, x2) + + assert _dtype_all_close(dpt.asnumpy(res), expected) From be96aec2897ad0e3ceefa5e7aa01f5bebbe722ab Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 7 Apr 2023 09:52:17 -0700 Subject: [PATCH 4/6] Added tests for where and type utility functions --- dpctl/tests/test_type_utils.py | 68 ++++++ .../test_usm_ndarray_search_functions.py | 211 ++++++++++++++++-- 2 files changed, 257 insertions(+), 22 deletions(-) create mode 100644 dpctl/tests/test_type_utils.py diff --git a/dpctl/tests/test_type_utils.py b/dpctl/tests/test_type_utils.py new file mode 100644 index 0000000000..882478a2ce --- /dev/null +++ b/dpctl/tests/test_type_utils.py @@ -0,0 +1,68 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import dpctl.tensor as dpt +from dpctl.tensor._type_utils import ( + _all_data_types, + _can_cast, + _is_maximal_inexact_type, +) + + +def test_all_data_types(): + fp16_fp64_types = set([dpt.float16, dpt.float64, dpt.complex128]) + fp64_types = set([dpt.float64, dpt.complex128]) + + all_dts = _all_data_types(True, True) + assert fp16_fp64_types.issubset(all_dts) + + all_dts = _all_data_types(True, False) + assert dpt.float16 in all_dts + assert not fp64_types.issubset(all_dts) + + all_dts = _all_data_types(False, True) + assert dpt.float16 not in all_dts + assert fp64_types.issubset(all_dts) + + all_dts = _all_data_types(False, False) + assert not fp16_fp64_types.issubset(all_dts) + + +@pytest.mark.parametrize("fp16", [True, False]) +@pytest.mark.parametrize("fp64", [True, False]) +def test_maximal_inexact_types(fp16, fp64): + assert not _is_maximal_inexact_type(dpt.int32, fp16, fp64) + assert fp64 == _is_maximal_inexact_type(dpt.float64, fp16, fp64) + assert fp64 == _is_maximal_inexact_type(dpt.complex128, fp16, fp64) + assert fp64 != _is_maximal_inexact_type(dpt.float32, fp16, fp64) + assert fp64 != _is_maximal_inexact_type(dpt.complex64, fp16, fp64) + + +def test_can_cast_device(): + assert _can_cast(dpt.int64, dpt.float64, True, True) + # if f8 is available, can't cast i8 to f4 + assert not _can_cast(dpt.int64, dpt.float32, True, True) + assert not _can_cast(dpt.int64, dpt.float32, False, True) + # should be able to cast to f8 when f2 unavailable + assert _can_cast(dpt.int64, dpt.float64, False, True) + # casting to f4 acceptable when f8 unavailable + assert _can_cast(dpt.int64, dpt.float32, True, False) + assert _can_cast(dpt.int64, dpt.float32, False, False) + # can't safely cast inexact type to inexact type of lesser precision + assert not _can_cast(dpt.float32, dpt.float16, True, False) + assert not _can_cast(dpt.float64, dpt.float32, False, True) diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py index ec64991000..6e1bfe3135 100644 --- a/dpctl/tests/test_usm_ndarray_search_functions.py +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -1,6 +1,6 @@ # Data Parallel Control (dpctl) # -# Copyright 2020-2022 Intel Corporation +# Copyright 2020-2023 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,12 @@ from numpy.testing import assert_array_equal import dpctl.tensor as dpt +from dpctl.tensor._search_functions import _where_result_type +from dpctl.tensor._type_utils import _all_data_types +from dpctl.utils import ExecutionPlacementError _all_dtypes = [ + "?", "u1", "i1", "u2", @@ -38,6 +42,12 @@ ] +class mock_device: + def __init__(self, fp16, fp64): + self.has_aspect_fp16 = fp16 + self.has_aspect_fp64 = fp64 + + def test_where_basic(): get_queue_or_skip() @@ -54,7 +64,16 @@ def test_where_basic(): out_expected = dpt.asarray( [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]] ) + assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all() + out = dpt.where(cond, dpt.ones(cond.shape), dpt.zeros(cond.shape)) + assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all() + + out = dpt.where( + cond, + dpt.ones(cond.shape[0], dtype="i4")[:, dpt.newaxis], + dpt.zeros(cond.shape[0], dtype="i4")[:, dpt.newaxis], + ) assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all() @@ -72,38 +91,98 @@ def _dtype_all_close(x1, x2): @pytest.mark.parametrize("dt1", _all_dtypes) @pytest.mark.parametrize("dt2", _all_dtypes) -def test_where_all_dtypes(dt1, dt2): +@pytest.mark.parametrize("fp16", [True, False]) +@pytest.mark.parametrize("fp64", [True, False]) +def test_where_result_types(dt1, dt2, fp16, fp64): + dev = mock_device(fp16, fp64) + + dt1 = dpt.dtype(dt1) + dt2 = dpt.dtype(dt2) + res_t = _where_result_type(dt1, dt2, dev) + + if fp16 and fp64: + assert res_t == dpt.result_type(dt1, dt2) + else: + if res_t: + assert res_t.kind == dpt.result_type(dt1, dt2).kind + else: + # some illegal cases are covered above, but + # this guarantees that _where_result_type + # produces None only when one of the dtypes + # is illegal given fp aspects of device + all_dts = _all_data_types(fp16, fp64) + assert dt1 not in all_dts or dt2 not in all_dts + + +@pytest.mark.parametrize("dt", _all_dtypes) +def test_where_all_dtypes(dt): q = get_queue_or_skip() - skip_if_dtype_not_supported(dt1, q) - skip_if_dtype_not_supported(dt2, q) + skip_if_dtype_not_supported(dt, q) - cond = dpt.asarray([False, False, False, True, True], sycl_queue=q) - x1 = dpt.asarray(2, sycl_queue=q) - x2 = dpt.asarray(3, sycl_queue=q) + # mask dtype changes + cond = dpt.asarray([0, 1, 3, 0, 10], dtype=dt, sycl_queue=q) + x1 = dpt.asarray(0, dtype="f", sycl_queue=q) + x2 = dpt.asarray(1, dtype="f", sycl_queue=q) + res = dpt.where(cond, x1, x2) + + res_check = np.asarray([1, 0, 0, 1, 0], dtype=res.dtype) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + # contiguous cases + x1 = dpt.full(cond.shape, 0, dtype="f4", sycl_queue=q) + x2 = dpt.full(cond.shape, 1, dtype="f4", sycl_queue=q) res = dpt.where(cond, x1, x2) - res_check = np.asarray([3, 3, 3, 2, 2], dtype=res.dtype) + assert _dtype_all_close(dpt.asnumpy(res), res_check) - dev = q.sycl_device + # input array dtype changes + cond = dpt.asarray([False, True, True, False, True], sycl_queue=q) + x1 = dpt.asarray(0, dtype=dt, sycl_queue=q) + x2 = dpt.asarray(1, dtype=dt, sycl_queue=q) + res = dpt.where(cond, x1, x2) - if not dev.has_aspect_fp16 or not dev.has_aspect_fp64: - assert res.dtype.kind == dpt.result_type(x1.dtype, x2.dtype).kind + res_check = np.asarray([1, 0, 0, 1, 0], dtype=res.dtype) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + # contiguous cases + x1 = dpt.full(cond.shape, 0, dtype=dt, sycl_queue=q) + x2 = dpt.full(cond.shape, 1, dtype=dt, sycl_queue=q) + res = dpt.where(cond, x1, x2) assert _dtype_all_close(dpt.asnumpy(res), res_check) +def test_where_nan_inf(): + get_queue_or_skip() + + cond = dpt.asarray([True, False, True, False], dtype="?") + x1 = dpt.asarray([np.nan, 2.0, np.inf, 3.0], dtype="f4") + x2 = dpt.asarray([2.0, np.nan, 3.0, np.inf], dtype="f4") + + cond_np = dpt.asnumpy(cond) + x1_np = dpt.asnumpy(x1) + x2_np = dpt.asnumpy(x2) + + res = dpt.where(cond, x1, x2) + res_np = np.where(cond_np, x1_np, x2_np) + + assert np.allclose(dpt.asnumpy(res), res_np, equal_nan=True) + + res = dpt.where(x1, cond, x2) + res_np = np.where(x1_np, cond_np, x2_np) + assert _dtype_all_close(dpt.asnumpy(res), res_np) + + def test_where_empty(): # check that numpy returns same results when # handling empty arrays get_queue_or_skip() - empty = dpt.empty(0) + empty = dpt.empty(0, dtype="i2") m = dpt.asarray(True) - x1 = dpt.asarray(1) - x2 = dpt.asarray(2) + x1 = dpt.asarray(1, dtype="i2") + x2 = dpt.asarray(2, dtype="i2") res = dpt.where(empty, x1, x2) - empty_np = np.empty(0) + empty_np = np.empty(0, dtype="i2") m_np = dpt.asnumpy(m) x1_np = dpt.asnumpy(x1) x2_np = dpt.asnumpy(x2) @@ -116,12 +195,14 @@ def test_where_empty(): assert_array_equal(dpt.asnumpy(res), res_np) + # check that broadcasting is performed + with pytest.raises(ValueError): + dpt.where(empty, x1, dpt.empty((1, 2))) + -@pytest.mark.parametrize("dt", _all_dtypes) @pytest.mark.parametrize("order", ["C", "F"]) -def test_where_contiguous(dt, order): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dt, q) +def test_where_contiguous(order): + get_queue_or_skip() cond = dpt.asarray( [ @@ -131,14 +212,100 @@ def test_where_contiguous(dt, order): [[False, False, False], [True, False, True]], [[True, True, True], [True, False, True]], ], - sycl_queue=q, order=order, ) - x1 = dpt.full(cond.shape, 2, dtype=dt, order=order, sycl_queue=q) - x2 = dpt.full(cond.shape, 3, dtype=dt, order=order, sycl_queue=q) + x1 = dpt.full(cond.shape, 2, dtype="i4", order=order) + x2 = dpt.full(cond.shape, 3, dtype="i4", order=order) + expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2)) + res = dpt.where(cond, x1, x2) + + assert _dtype_all_close(dpt.asnumpy(res), expected) + + +def test_where_contiguous1D(): + get_queue_or_skip() + cond = dpt.asarray([True, False, True, False, False, True]) + + x1 = dpt.full(cond.shape, 2, dtype="i4") + x2 = dpt.full(cond.shape, 3, dtype="i4") expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2)) res = dpt.where(cond, x1, x2) + assert_array_equal(dpt.asnumpy(res), expected) + # test with complex dtype (branch in kernel) + x1 = dpt.astype(x1, dpt.complex64) + x2 = dpt.astype(x2, dpt.complex64) + expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2)) + res = dpt.where(cond, x1, x2) assert _dtype_all_close(dpt.asnumpy(res), expected) + + +def test_where_strided(): + get_queue_or_skip() + + s0, s1 = 4, 9 + cond = dpt.reshape( + dpt.asarray( + [True, False, False, False, True, True, False, True, False] * s0 + ), + (s0, s1), + )[:, ::3] + + x1 = dpt.reshape( + dpt.arange(cond.shape[0] * cond.shape[1] * 2, dtype="i4"), + (cond.shape[0], cond.shape[1] * 2), + )[:, ::2] + x2 = dpt.reshape( + dpt.arange(cond.shape[0] * cond.shape[1] * 3, dtype="i4"), + (cond.shape[0], cond.shape[1] * 3), + )[:, ::3] + expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2)) + res = dpt.where(cond, x1, x2) + + assert_array_equal(dpt.asnumpy(res), expected) + + # negative strides + res = dpt.where(cond, dpt.flip(x1), x2) + expected = np.where( + dpt.asnumpy(cond), np.flip(dpt.asnumpy(x1)), dpt.asnumpy(x2) + ) + assert_array_equal(dpt.asnumpy(res), expected) + + res = dpt.where(dpt.flip(cond), x1, x2) + expected = np.where( + np.flip(dpt.asnumpy(cond)), dpt.asnumpy(x1), dpt.asnumpy(x2) + ) + assert_array_equal(dpt.asnumpy(res), expected) + + +def test_where_arg_validation(): + get_queue_or_skip() + + check = dict() + x1 = dpt.empty((1,), dtype="i4") + x2 = dpt.empty((1,), dtype="i4") + + with pytest.raises(TypeError): + dpt.where(check, x1, x2) + with pytest.raises(TypeError): + dpt.where(x1, check, x2) + with pytest.raises(TypeError): + dpt.where(x1, x2, check) + + +def test_where_compute_follows_data(): + q1 = get_queue_or_skip() + q2 = get_queue_or_skip() + q3 = get_queue_or_skip() + + x1 = dpt.empty((1,), dtype="i4", sycl_queue=q1) + x2 = dpt.empty((1,), dtype="i4", sycl_queue=q2) + + with pytest.raises(ExecutionPlacementError): + dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q1), x1, x2) + with pytest.raises(ExecutionPlacementError): + dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q3), x1, x2) + with pytest.raises(ExecutionPlacementError): + dpt.where(x1, x1, x2) From d50d1c65f658c1fac9ba7d6c4886f343a3635d32 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 7 Apr 2023 10:13:39 -0700 Subject: [PATCH 5/6] Made changes as per PR review by @oleksandr-pavlyk --- dpctl/tensor/_search_functions.py | 39 +++++- dpctl/tensor/_type_utils.py | 6 +- .../libtensor/include/kernels/where.hpp | 113 +++++++++--------- dpctl/tensor/libtensor/source/where.cpp | 16 +-- dpctl/tensor/libtensor/source/where.hpp | 2 +- .../test_usm_ndarray_search_functions.py | 4 +- 6 files changed, 103 insertions(+), 77 deletions(-) diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py index 9cccfe922a..d3503f60b5 100644 --- a/dpctl/tensor/_search_functions.py +++ b/dpctl/tensor/_search_functions.py @@ -1,6 +1,6 @@ # Data Parallel Control (dpctl) # -# Copyright 2020-2022 Intel Corporation +# Copyright 2020-2023 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,6 +40,37 @@ def _where_result_type(dt1, dt2, dev): def where(condition, x1, x2): + """where(condition, x1, x2) + + Returns :class:`dpctl.tensor.usm_ndarray` with elements chosen + from `x1` or `x2` depending on `condition`. + + Args: + condition (usm_ndarray): When True yields from `x1`, + and otherwise yields from `x2`. + Must be compatible with `x1` and `x2` according + to broadcasting rules. + x1 (usm_ndarray): Array from which values are chosen when + `condition` is True. + Must be compatible with `condition` and `x2` according + to broadcasting rules. + x2 (usm_ndarray): Array from which values are chosen when + `condition` is not True. + Must be compatible with `condition` and `x2` according + to broadcasting rules. + + Returns: + usm_ndarray: + An array with elements from `x1` where `condition` is True, + and elements from `x2` elsewhere. + + The data type of the returned array is determined by applying + the Type Promotion Rules to `x1` and `x2`. + + The memory layout of the returned array is + F-contiguous (column-major) when all inputs are F-contiguous, + and C-contiguous (row-major) otherwise. + """ if not isinstance(condition, dpt.usm_ndarray): raise TypeError( "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}" @@ -89,7 +120,7 @@ def where(condition, x1, x2): deps = [] wait_list = [] - if x1_dtype is not dst_dtype: + if x1_dtype != dst_dtype: _x1 = dpt.empty_like(x1, dtype=dst_dtype) ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=x1, dst=_x1, sycl_queue=exec_q @@ -98,7 +129,7 @@ def where(condition, x1, x2): deps.append(copy1_ev) wait_list.append(ht_copy1_ev) - if x2_dtype is not dst_dtype: + if x2_dtype != dst_dtype: _x2 = dpt.empty_like(x2, dtype=dst_dtype) ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=x2, dst=_x2, sycl_queue=exec_q @@ -140,7 +171,7 @@ def where(condition, x1, x2): sycl_queue=exec_q, depends=deps, ) - wait_list.append(hev) dpctl.SyclEvent.wait_for(wait_list) + hev.wait() return dst diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index f39c9374db..3ea6875fce 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -1,6 +1,6 @@ # Data Parallel Control (dpctl) # -# Copyright 2020-2022 Intel Corporation +# Copyright 2020-2023 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ def _all_data_types(_fp16, _fp64): ] -def is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool): +def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool): """ Return True if data type `dt` is the maximal size inexact data type @@ -106,7 +106,7 @@ def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool): if ( from_.kind in "biu" and to_.kind in "fc" - and is_maximal_inexact_type(to_, _fp16, _fp64) + and _is_maximal_inexact_type(to_, _fp16, _fp64) ): return True diff --git a/dpctl/tensor/libtensor/include/kernels/where.hpp b/dpctl/tensor/libtensor/include/kernels/where.hpp index 54e502af0e..912df8ad24 100644 --- a/dpctl/tensor/libtensor/include/kernels/where.hpp +++ b/dpctl/tensor/libtensor/include/kernels/where.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -57,31 +57,24 @@ class WhereContigFunctor { private: size_t nelems = 0; - const char *x1_cp = nullptr; - const char *x2_cp = nullptr; - char *dst_cp = nullptr; - const char *cond_cp = nullptr; + const condT *cond_p = nullptr; + const T *x1_p = nullptr; + const T *x2_p = nullptr; + T *dst_p = nullptr; public: WhereContigFunctor(size_t nelems_, - const char *cond_data_p, - const char *x1_data_p, - const char *x2_data_p, - char *dst_data_p) - : nelems(nelems_), x1_cp(x1_data_p), x2_cp(x2_data_p), - dst_cp(dst_data_p), cond_cp(cond_data_p) + const condT *cond_p_, + const T *x1_p_, + const T *x2_p_, + T *dst_p_) + : nelems(nelems_), cond_p(cond_p_), x1_p(x1_p_), x2_p(x2_p_), + dst_p(dst_p_) { } void operator()(sycl::nd_item<1> ndit) const { - const T *x1_data = reinterpret_cast(x1_cp); - const T *x2_data = reinterpret_cast(x2_cp); - T *dst_data = reinterpret_cast(dst_cp); - const condT *cond_data = reinterpret_cast(cond_cp); - - using dpctl::tensor::type_utils::convert_impl; - using dpctl::tensor::type_utils::is_complex; if constexpr (is_complex::value || is_complex::value) { std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; @@ -92,8 +85,9 @@ class WhereContigFunctor offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz)); offset += sgSize) { - bool check = convert_impl(cond_data[offset]); - dst_data[offset] = check ? x1_data[offset] : x2_data[offset]; + using dpctl::tensor::type_utils::convert_impl; + bool check = convert_impl(cond_p[offset]); + dst_p[offset] = check ? x1_p[offset] : x2_p[offset]; } } else { @@ -115,7 +109,6 @@ class WhereContigFunctor using cond_ptrT = sycl::multi_ptr; - sycl::vec dst_vec; sycl::vec x1_vec; sycl::vec x2_vec; @@ -124,23 +117,20 @@ class WhereContigFunctor #pragma unroll for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { auto idx = base + it * sgSize; - x1_vec = sg.load(x_ptrT(&x1_data[idx])); - x2_vec = sg.load(x_ptrT(&x2_data[idx])); - cond_vec = sg.load(cond_ptrT(&cond_data[idx])); - + x1_vec = sg.load(x_ptrT(&x1_p[idx])); + x2_vec = sg.load(x_ptrT(&x2_p[idx])); + cond_vec = sg.load(cond_ptrT(&cond_p[idx])); #pragma unroll for (std::uint8_t k = 0; k < vec_sz; ++k) { - bool check = convert_impl(cond_vec[k]); - dst_vec[k] = check ? x1_vec[k] : x2_vec[k]; + dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k]; } - sg.store(dst_ptrT(&dst_data[idx]), dst_vec); + sg.store(dst_ptrT(&dst_p[idx]), dst_vec); } } else { for (size_t k = base + sg.get_local_id()[0]; k < nelems; k += sgSize) { - bool check = convert_impl(cond_data[k]); - dst_data[k] = check ? x1_data[k] : x2_data[k]; + dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k]; } } } @@ -159,12 +149,17 @@ typedef sycl::event (*where_contig_impl_fn_ptr_t)( template sycl::event where_contig_impl(sycl::queue q, size_t nelems, - const char *cond_p, - const char *x1_p, - const char *x2_p, - char *dst_p, + const char *cond_cp, + const char *x1_cp, + const char *x2_cp, + char *dst_cp, const std::vector &depends) { + const condT *cond_tp = reinterpret_cast(cond_cp); + const T *x1_tp = reinterpret_cast(x1_cp); + const T *x2_tp = reinterpret_cast(x2_cp); + T *dst_tp = reinterpret_cast(dst_cp); + sycl::event where_ev = q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -178,8 +173,8 @@ sycl::event where_contig_impl(sycl::queue q, cgh.parallel_for>( sycl::nd_range<1>(gws_range, lws_range), - WhereContigFunctor(nelems, cond_p, x1_p, - x2_p, dst_p)); + WhereContigFunctor(nelems, cond_tp, x1_tp, + x2_tp, dst_tp)); }); return where_ev; @@ -189,39 +184,34 @@ template class WhereStridedFunctor { private: - const char *x1_cp = nullptr; - const char *x2_cp = nullptr; - char *dst_cp = nullptr; - const char *cond_cp = nullptr; + const T *x1_p = nullptr; + const T *x2_p = nullptr; + T *dst_p = nullptr; + const condT *cond_p = nullptr; IndexerT indexer; public: - WhereStridedFunctor(const char *cond_data_p, - const char *x1_data_p, - const char *x2_data_p, - char *dst_data_p, + WhereStridedFunctor(const condT *cond_p_, + const T *x1_p_, + const T *x2_p_, + T *dst_p_, IndexerT indexer_) - : x1_cp(x1_data_p), x2_cp(x2_data_p), dst_cp(dst_data_p), - cond_cp(cond_data_p), indexer(indexer_) + : x1_p(x1_p_), x2_p(x2_p_), dst_p(dst_p_), cond_p(cond_p_), + indexer(indexer_) { } void operator()(sycl::id<1> id) const { - const T *x1_data = reinterpret_cast(x1_cp); - const T *x2_data = reinterpret_cast(x2_cp); - T *dst_data = reinterpret_cast(dst_cp); - const condT *cond_data = reinterpret_cast(cond_cp); - size_t gid = id[0]; auto offsets = indexer(static_cast(gid)); using dpctl::tensor::type_utils::convert_impl; bool check = - convert_impl(cond_data[offsets.get_first_offset()]); + convert_impl(cond_p[offsets.get_first_offset()]); - dst_data[gid] = check ? x1_data[offsets.get_second_offset()] - : x2_data[offsets.get_third_offset()]; + dst_p[gid] = check ? x1_p[offsets.get_second_offset()] + : x2_p[offsets.get_third_offset()]; } }; @@ -243,16 +233,21 @@ template sycl::event where_strided_impl(sycl::queue q, size_t nelems, int nd, - const char *cond_p, - const char *x1_p, - const char *x2_p, - char *dst_p, + const char *cond_cp, + const char *x1_cp, + const char *x2_cp, + char *dst_cp, const py::ssize_t *shape_strides, py::ssize_t x1_offset, py::ssize_t x2_offset, py::ssize_t cond_offset, const std::vector &depends) { + const condT *cond_tp = reinterpret_cast(cond_cp); + const T *x1_tp = reinterpret_cast(x1_cp); + const T *x2_tp = reinterpret_cast(x2_cp); + T *dst_tp = reinterpret_cast(dst_cp); + sycl::event where_ev = q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -263,7 +258,7 @@ sycl::event where_strided_impl(sycl::queue q, where_strided_kernel>( sycl::range<1>(nelems), WhereStridedFunctor( - cond_p, x1_p, x2_p, dst_p, indexer)); + cond_tp, x1_tp, x2_tp, dst_tp, indexer)); }); return where_ev; diff --git a/dpctl/tensor/libtensor/source/where.cpp b/dpctl/tensor/libtensor/source/where.cpp index 862877ee9a..134ea684b4 100644 --- a/dpctl/tensor/libtensor/source/where.cpp +++ b/dpctl/tensor/libtensor/source/where.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -96,10 +96,10 @@ py_where(dpctl::tensor::usm_ndarray condition, bool shapes_equal(true); size_t nelems(1); for (int i = 0; i < nd; ++i) { - nelems *= static_cast(dst_shape[i]); - shapes_equal = shapes_equal && (x1_shape[i] == dst_shape[i]) && - (x2_shape[i] == dst_shape[i]) && - (cond_shape[i] == dst_shape[i]); + const auto &sh_i = dst_shape[i]; + nelems *= static_cast(sh_i); + shapes_equal = shapes_equal && (x1_shape[i] == sh_i) && + (x2_shape[i] == sh_i) && (cond_shape[i] == sh_i); } if (!shapes_equal) { @@ -127,7 +127,7 @@ py_where(dpctl::tensor::usm_ndarray condition, int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); if (x1_typeid != x2_typeid || x1_typeid != dst_typeid) { - throw py::value_error("Non-condition are not of same type."); + throw py::value_error("Value arrays must have the same data type"); } // ensure that dst is sufficiently ample @@ -166,8 +166,8 @@ py_where(dpctl::tensor::usm_ndarray condition, auto where_ev = contig_fn(exec_q, nelems, cond_data, x1_data, x2_data, dst_data, depends); - sycl::event ht_ev = dpctl::utils::keep_args_alive( - exec_q, {x1, x2, dst, condition}, {where_ev}); + sycl::event ht_ev = + keep_args_alive(exec_q, {x1, x2, dst, condition}, {where_ev}); return std::make_pair(ht_ev, where_ev); } diff --git a/dpctl/tensor/libtensor/source/where.hpp b/dpctl/tensor/libtensor/source/where.hpp index 95f70e560e..38d40b8550 100644 --- a/dpctl/tensor/libtensor/source/where.hpp +++ b/dpctl/tensor/libtensor/source/where.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py index 6e1bfe3135..ff2e23c076 100644 --- a/dpctl/tests/test_usm_ndarray_search_functions.py +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -121,8 +121,8 @@ def test_where_all_dtypes(dt): # mask dtype changes cond = dpt.asarray([0, 1, 3, 0, 10], dtype=dt, sycl_queue=q) - x1 = dpt.asarray(0, dtype="f", sycl_queue=q) - x2 = dpt.asarray(1, dtype="f", sycl_queue=q) + x1 = dpt.asarray(0, dtype="f4", sycl_queue=q) + x2 = dpt.asarray(1, dtype="f4", sycl_queue=q) res = dpt.where(cond, x1, x2) res_check = np.asarray([1, 0, 0, 1, 0], dtype=res.dtype) From db4304271dd9f96edf47bffbf3167a45648ed11c Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 7 Apr 2023 11:45:36 -0700 Subject: [PATCH 6/6] Added where test - Asymmetric dtype test to improve coverage --- .../tests/test_usm_ndarray_search_functions.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py index ff2e23c076..486350589e 100644 --- a/dpctl/tests/test_usm_ndarray_search_functions.py +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -150,6 +150,24 @@ def test_where_all_dtypes(dt): assert _dtype_all_close(dpt.asnumpy(res), res_check) +def test_where_asymmetric_dtypes(): + q = get_queue_or_skip() + + cond = dpt.asarray([0, 1, 3, 0, 10], dtype="?", sycl_queue=q) + x1 = dpt.asarray(2, dtype="i4", sycl_queue=q) + x2 = dpt.asarray(3, dtype="i8", sycl_queue=q) + + res = dpt.where(cond, x1, x2) + res_check = np.asarray([3, 2, 2, 3, 2], dtype=res.dtype) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + + # flip order + + res = dpt.where(cond, x2, x1) + res_check = np.asarray([2, 3, 3, 2, 3], dtype=res.dtype) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + + def test_where_nan_inf(): get_queue_or_skip()