diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index e34cd94bb9..1025072c08 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -43,6 +43,7 @@ pybind11_add_module(${python_module_name} MODULE ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp ) target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index e970af98f5..ec9f27617b 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -86,6 +86,7 @@ usm_ndarray_str, ) from dpctl.tensor._reshape import reshape +from dpctl.tensor._search_functions import where from dpctl.tensor._usmarray import usm_ndarray from ._constants import e, inf, nan, newaxis, pi @@ -128,6 +129,7 @@ "from_dlpack", "tril", "triu", + "where", "dtype", "isdtype", "bool", diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py new file mode 100644 index 0000000000..d3503f60b5 --- /dev/null +++ b/dpctl/tensor/_search_functions.py @@ -0,0 +1,177 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dpctl +import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as ti +from dpctl.tensor._manipulation_functions import _broadcast_shapes + +from ._type_utils import _all_data_types, _can_cast + + +def _where_result_type(dt1, dt2, dev): + res_dtype = dpt.result_type(dt1, dt2) + fp16 = dev.has_aspect_fp16 + fp64 = dev.has_aspect_fp64 + + all_dts = _all_data_types(fp16, fp64) + if res_dtype in all_dts: + return res_dtype + else: + for res_dtype_ in all_dts: + if _can_cast(dt1, res_dtype_, fp16, fp64) and _can_cast( + dt2, res_dtype_, fp16, fp64 + ): + return res_dtype_ + return None + + +def where(condition, x1, x2): + """where(condition, x1, x2) + + Returns :class:`dpctl.tensor.usm_ndarray` with elements chosen + from `x1` or `x2` depending on `condition`. + + Args: + condition (usm_ndarray): When True yields from `x1`, + and otherwise yields from `x2`. + Must be compatible with `x1` and `x2` according + to broadcasting rules. + x1 (usm_ndarray): Array from which values are chosen when + `condition` is True. + Must be compatible with `condition` and `x2` according + to broadcasting rules. + x2 (usm_ndarray): Array from which values are chosen when + `condition` is not True. + Must be compatible with `condition` and `x2` according + to broadcasting rules. + + Returns: + usm_ndarray: + An array with elements from `x1` where `condition` is True, + and elements from `x2` elsewhere. + + The data type of the returned array is determined by applying + the Type Promotion Rules to `x1` and `x2`. + + The memory layout of the returned array is + F-contiguous (column-major) when all inputs are F-contiguous, + and C-contiguous (row-major) otherwise. + """ + if not isinstance(condition, dpt.usm_ndarray): + raise TypeError( + "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}" + ) + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError( + "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x1)}" + ) + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError( + "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}" + ) + exec_q = dpctl.utils.get_execution_queue( + ( + condition.sycl_queue, + x1.sycl_queue, + x2.sycl_queue, + ) + ) + if exec_q is None: + raise dpctl.utils.ExecutionPlacementError + dst_usm_type = dpctl.utils.get_coerced_usm_type( + ( + condition.usm_type, + x1.usm_type, + x2.usm_type, + ) + ) + + x1_dtype = x1.dtype + x2_dtype = x2.dtype + dst_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device) + if dst_dtype is None: + raise TypeError( + "function 'where' does not support input " + f"types ({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced " + "to any supported types according to the casting rule ''safe''." + ) + + res_shape = _broadcast_shapes(condition, x1, x2) + + if condition.size == 0: + return dpt.empty( + res_shape, dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q + ) + + deps = [] + wait_list = [] + if x1_dtype != dst_dtype: + _x1 = dpt.empty_like(x1, dtype=dst_dtype) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=_x1, sycl_queue=exec_q + ) + x1 = _x1 + deps.append(copy1_ev) + wait_list.append(ht_copy1_ev) + + if x2_dtype != dst_dtype: + _x2 = dpt.empty_like(x2, dtype=dst_dtype) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=_x2, sycl_queue=exec_q + ) + x2 = _x2 + deps.append(copy2_ev) + wait_list.append(ht_copy2_ev) + + condition = dpt.broadcast_to(condition, res_shape) + x1 = dpt.broadcast_to(x1, res_shape) + x2 = dpt.broadcast_to(x2, res_shape) + + # dst is F-contiguous when all inputs are F contiguous + # otherwise, defaults to C-contiguous + if all( + ( + condition.flags.fnc, + x1.flags.fnc, + x2.flags.fnc, + ) + ): + order = "F" + else: + order = "C" + + dst = dpt.empty( + res_shape, + dtype=dst_dtype, + order=order, + usm_type=dst_usm_type, + sycl_queue=exec_q, + ) + + hev, _ = ti._where( + condition=condition, + x1=x1, + x2=x2, + dst=dst, + sycl_queue=exec_q, + depends=deps, + ) + dpctl.SyclEvent.wait_for(wait_list) + hev.wait() + + return dst diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py new file mode 100644 index 0000000000..3ea6875fce --- /dev/null +++ b/dpctl/tensor/_type_utils.py @@ -0,0 +1,113 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dpctl.tensor as dpt + + +def _all_data_types(_fp16, _fp64): + if _fp64: + if _fp16: + return [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float16, + dpt.float32, + dpt.float64, + dpt.complex64, + dpt.complex128, + ] + else: + return [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float32, + dpt.float64, + dpt.complex64, + dpt.complex128, + ] + else: + if _fp16: + return [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float16, + dpt.float32, + dpt.complex64, + ] + else: + return [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float32, + dpt.complex64, + ] + + +def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool): + """ + Return True if data type `dt` is the + maximal size inexact data type + """ + if _fp64: + return dt in [dpt.float64, dpt.complex128] + return dt in [dpt.float32, dpt.complex64] + + +def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool): + """ + Can `from_` be cast to `to_` safely on a device with + fp16 and fp64 aspects as given? + """ + can_cast_v = dpt.can_cast(from_, to_) # ask NumPy + if _fp16 and _fp64: + return can_cast_v + if not can_cast_v: + if ( + from_.kind in "biu" + and to_.kind in "fc" + and _is_maximal_inexact_type(to_, _fp16, _fp64) + ): + return True + + return can_cast_v diff --git a/dpctl/tensor/libtensor/include/kernels/where.hpp b/dpctl/tensor/libtensor/include/kernels/where.hpp new file mode 100644 index 0000000000..912df8ad24 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/where.hpp @@ -0,0 +1,288 @@ +//=== where.hpp - Implementation of where kernels ---*-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for dpctl.tensor.where. +//===----------------------------------------------------------------------===// + +#pragma once +#include "pybind11/numpy.h" +#include "pybind11/stl.h" +#include "utils/offset_utils.hpp" +#include "utils/type_utils.hpp" +#include +#include +#include +#include +#include +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace search +{ + +namespace py = pybind11; + +using namespace dpctl::tensor::offset_utils; + +template +class where_strided_kernel; +template +class where_contig_kernel; + +template +class WhereContigFunctor +{ +private: + size_t nelems = 0; + const condT *cond_p = nullptr; + const T *x1_p = nullptr; + const T *x2_p = nullptr; + T *dst_p = nullptr; + +public: + WhereContigFunctor(size_t nelems_, + const condT *cond_p_, + const T *x1_p_, + const T *x2_p_, + T *dst_p_) + : nelems(nelems_), cond_p(cond_p_), x1_p(x1_p_), x2_p(x2_p_), + dst_p(dst_p_) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value || is_complex::value) { + std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; + size_t base = ndit.get_global_linear_id(); + + base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); + for (size_t offset = base; + offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz)); + offset += sgSize) + { + using dpctl::tensor::type_utils::convert_impl; + bool check = convert_impl(cond_p[offset]); + dst_p[offset] = check ? x1_p[offset] : x2_p[offset]; + } + } + else { + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t max_sgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * max_sgSize); + + if (base + n_vecs * vec_sz * sgSize < nelems && + sgSize == max_sgSize) { + using dst_ptrT = + sycl::multi_ptr; + using x_ptrT = + sycl::multi_ptr; + using cond_ptrT = + sycl::multi_ptr; + sycl::vec dst_vec; + sycl::vec x1_vec; + sycl::vec x2_vec; + sycl::vec cond_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + auto idx = base + it * sgSize; + x1_vec = sg.load(x_ptrT(&x1_p[idx])); + x2_vec = sg.load(x_ptrT(&x2_p[idx])); + cond_vec = sg.load(cond_ptrT(&cond_p[idx])); +#pragma unroll + for (std::uint8_t k = 0; k < vec_sz; ++k) { + dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k]; + } + sg.store(dst_ptrT(&dst_p[idx]), dst_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems; + k += sgSize) { + dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k]; + } + } + } + } +}; + +typedef sycl::event (*where_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + const char *, + const char *, + char *, + const std::vector &); + +template +sycl::event where_contig_impl(sycl::queue q, + size_t nelems, + const char *cond_cp, + const char *x1_cp, + const char *x2_cp, + char *dst_cp, + const std::vector &depends) +{ + const condT *cond_tp = reinterpret_cast(cond_cp); + const T *x1_tp = reinterpret_cast(x1_cp); + const T *x2_tp = reinterpret_cast(x2_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + sycl::event where_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + const size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + cgh.parallel_for>( + sycl::nd_range<1>(gws_range, lws_range), + WhereContigFunctor(nelems, cond_tp, x1_tp, + x2_tp, dst_tp)); + }); + + return where_ev; +} + +template +class WhereStridedFunctor +{ +private: + const T *x1_p = nullptr; + const T *x2_p = nullptr; + T *dst_p = nullptr; + const condT *cond_p = nullptr; + IndexerT indexer; + +public: + WhereStridedFunctor(const condT *cond_p_, + const T *x1_p_, + const T *x2_p_, + T *dst_p_, + IndexerT indexer_) + : x1_p(x1_p_), x2_p(x2_p_), dst_p(dst_p_), cond_p(cond_p_), + indexer(indexer_) + { + } + + void operator()(sycl::id<1> id) const + { + size_t gid = id[0]; + auto offsets = indexer(static_cast(gid)); + + using dpctl::tensor::type_utils::convert_impl; + bool check = + convert_impl(cond_p[offsets.get_first_offset()]); + + dst_p[gid] = check ? x1_p[offsets.get_second_offset()] + : x2_p[offsets.get_third_offset()]; + } +}; + +typedef sycl::event (*where_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const char *, + const char *, + const char *, + char *, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + const std::vector &); + +template +sycl::event where_strided_impl(sycl::queue q, + size_t nelems, + int nd, + const char *cond_cp, + const char *x1_cp, + const char *x2_cp, + char *dst_cp, + const py::ssize_t *shape_strides, + py::ssize_t x1_offset, + py::ssize_t x2_offset, + py::ssize_t cond_offset, + const std::vector &depends) +{ + const condT *cond_tp = reinterpret_cast(cond_cp); + const T *x1_tp = reinterpret_cast(x1_cp); + const T *x2_tp = reinterpret_cast(x2_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + sycl::event where_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + ThreeOffsets_StridedIndexer indexer{nd, cond_offset, x1_offset, + x2_offset, shape_strides}; + + cgh.parallel_for< + where_strided_kernel>( + sycl::range<1>(nelems), + WhereStridedFunctor( + cond_tp, x1_tp, x2_tp, dst_tp, indexer)); + }); + + return where_ev; +} + +template struct WhereStridedFactory +{ + fnT get() + { + fnT fn = where_strided_impl; + return fn; + } +}; + +template struct WhereContigFactory +{ + fnT get() + { + fnT fn = where_contig_impl; + return fn; + } +}; + +} // namespace search +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 7a5886088c..cf66ede984 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -45,6 +45,7 @@ #include "triul_ctor.hpp" #include "utils/memory_overlap.hpp" #include "utils/strided_iters.hpp" +#include "where.hpp" namespace py = pybind11; @@ -92,6 +93,10 @@ using dpctl::tensor::py_internal::usm_ndarray_eye; using dpctl::tensor::py_internal::usm_ndarray_triul; +/* =========================== Where ============================== */ + +using dpctl::tensor::py_internal::py_where; + // populate dispatch tables void init_dispatch_tables(void) { @@ -100,6 +105,7 @@ void init_dispatch_tables(void) init_copy_and_cast_usm_to_usm_dispatch_tables(); init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables(); init_advanced_indexing_dispatch_tables(); + init_where_dispatch_tables(); return; } @@ -293,4 +299,8 @@ PYBIND11_MODULE(_tensor_impl, m) m.def("_nonzero", &py_nonzero, "", py::arg("cumsum"), py::arg("indexes"), py::arg("mask_shape"), py::arg("sycl_queue"), py::arg("depends") = py::list()); + + m.def("_where", &py_where, "", py::arg("condition"), py::arg("x1"), + py::arg("x2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); } diff --git a/dpctl/tensor/libtensor/source/where.cpp b/dpctl/tensor/libtensor/source/where.cpp new file mode 100644 index 0000000000..134ea684b4 --- /dev/null +++ b/dpctl/tensor/libtensor/source/where.cpp @@ -0,0 +1,258 @@ +//===-- where.cpp - Implementation of where --*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines Python API for implementation functions of +/// dpctl.tensor.where +//===----------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include "kernels/where.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "where.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace _ns = dpctl::tensor::detail; + +using dpctl::tensor::kernels::search::where_contig_impl_fn_ptr_t; +using dpctl::tensor::kernels::search::where_strided_impl_fn_ptr_t; + +static where_contig_impl_fn_ptr_t where_contig_dispatch_table[_ns::num_types] + [_ns::num_types]; +static where_strided_impl_fn_ptr_t where_strided_dispatch_table[_ns::num_types] + [_ns::num_types]; + +using dpctl::utils::keep_args_alive; + +std::pair +py_where(dpctl::tensor::usm_ndarray condition, + dpctl::tensor::usm_ndarray x1, + dpctl::tensor::usm_ndarray x2, + dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends) +{ + + if (!dpctl::utils::queues_are_compatible(exec_q, {x1, x2, condition, dst})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + int nd = condition.get_ndim(); + int x1_nd = x1.get_ndim(); + int x2_nd = x2.get_ndim(); + int dst_nd = dst.get_ndim(); + + if (nd != x1_nd || nd != x2_nd) { + throw py::value_error( + "Input arrays are not of appropriate dimension for where kernel."); + } + + if (nd != dst_nd) { + throw py::value_error( + "Destination is not of appropriate dimension for where kernel."); + } + + const py::ssize_t *x1_shape = x1.get_shape_raw(); + const py::ssize_t *x2_shape = x2.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + const py::ssize_t *cond_shape = condition.get_shape_raw(); + + bool shapes_equal(true); + size_t nelems(1); + for (int i = 0; i < nd; ++i) { + const auto &sh_i = dst_shape[i]; + nelems *= static_cast(sh_i); + shapes_equal = shapes_equal && (x1_shape[i] == sh_i) && + (x2_shape[i] == sh_i) && (cond_shape[i] == sh_i); + } + + if (!shapes_equal) { + throw py::value_error("Axes are not of matching shapes."); + } + + if (nelems == 0) { + return std::make_pair(sycl::event{}, sycl::event{}); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(dst, condition) || overlap(dst, x1) || overlap(dst, x2)) { + throw py::value_error("Destination array overlaps with input."); + } + + int x1_typenum = x1.get_typenum(); + int x2_typenum = x2.get_typenum(); + int cond_typenum = condition.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = dpctl::tensor::detail::usm_ndarray_types(); + int cond_typeid = array_types.typenum_to_lookup_id(cond_typenum); + int x1_typeid = array_types.typenum_to_lookup_id(x1_typenum); + int x2_typeid = array_types.typenum_to_lookup_id(x2_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (x1_typeid != x2_typeid || x1_typeid != dst_typeid) { + throw py::value_error("Value arrays must have the same data type"); + } + + // ensure that dst is sufficiently ample + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accomodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < static_cast(nelems)) { + throw py::value_error( + "Memory addressed by the destination array can not " + "accomodate all the " + "array elements."); + } + } + + char *cond_data = condition.get_data(); + char *x1_data = x1.get_data(); + char *x2_data = x2.get_data(); + char *dst_data = dst.get_data(); + + bool is_x1_c_contig = x1.is_c_contiguous(); + bool is_x1_f_contig = x1.is_f_contiguous(); + + bool is_x2_c_contig = x2.is_c_contiguous(); + bool is_x2_f_contig = x2.is_f_contiguous(); + + bool is_cond_c_contig = condition.is_c_contiguous(); + bool is_cond_f_contig = condition.is_f_contiguous(); + + bool all_c_contig = (is_x1_c_contig && is_x2_c_contig && is_cond_c_contig); + bool all_f_contig = (is_x1_f_contig && is_x2_f_contig && is_cond_f_contig); + + if (all_c_contig || all_f_contig) { + auto contig_fn = where_contig_dispatch_table[x1_typeid][cond_typeid]; + + auto where_ev = contig_fn(exec_q, nelems, cond_data, x1_data, x2_data, + dst_data, depends); + sycl::event ht_ev = + keep_args_alive(exec_q, {x1, x2, dst, condition}, {where_ev}); + + return std::make_pair(ht_ev, where_ev); + } + + const py::ssize_t *cond_strides = condition.get_strides_raw(); + const py::ssize_t *x1_strides = x1.get_strides_raw(); + const py::ssize_t *x2_strides = x2.get_strides_raw(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_cond_strides; + shT simplified_x1_strides; + shT simplified_x2_strides; + py::ssize_t cond_offset(0); + py::ssize_t x1_offset(0); + py::ssize_t x2_offset(0); + + const py::ssize_t *_shape = x1_shape; + + constexpr py::ssize_t _itemsize = 1; + + dpctl::tensor::py_internal::simplify_iteration_space_3( + nd, _shape, cond_strides, _itemsize, is_cond_c_contig, is_cond_f_contig, + x1_strides, _itemsize, is_x1_c_contig, is_x1_f_contig, x2_strides, + _itemsize, is_x2_c_contig, is_x2_f_contig, simplified_shape, + simplified_cond_strides, simplified_x1_strides, simplified_x2_strides, + cond_offset, x1_offset, x2_offset); + + auto fn = where_strided_dispatch_table[x1_typeid][cond_typeid]; + + std::vector host_task_events; + host_task_events.reserve(2); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, simplified_shape, simplified_cond_strides, + simplified_x1_strides, simplified_x2_strides); + py::ssize_t *packed_shape_strides = std::get<0>(ptr_size_event_tuple); + sycl::event copy_shape_strides_ev = std::get<2>(ptr_size_event_tuple); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + sycl::event where_ev = + fn(exec_q, nelems, nd, cond_data, x1_data, x2_data, dst_data, + packed_shape_strides, cond_offset, x1_offset, x2_offset, all_deps); + + // free packed temporaries + sycl::event temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(where_ev); + auto ctx = exec_q.get_context(); + cgh.host_task([packed_shape_strides, ctx]() { + sycl::free(packed_shape_strides, ctx); + }); + }); + + host_task_events.push_back(temporaries_cleanup_ev); + + sycl::event arg_cleanup_ev = + keep_args_alive(exec_q, {x1, x2, condition, dst}, host_task_events); + + return std::make_pair(arg_cleanup_ev, temporaries_cleanup_ev); +} + +void init_where_dispatch_tables(void) +{ + using dpctl::tensor::kernels::search::WhereContigFactory; + dpctl::tensor::detail::DispatchTableBuilder< + where_contig_impl_fn_ptr_t, WhereContigFactory, + dpctl::tensor::detail::num_types> + dtb1; + dtb1.populate_dispatch_table(where_contig_dispatch_table); + + using dpctl::tensor::kernels::search::WhereStridedFactory; + dpctl::tensor::detail::DispatchTableBuilder< + where_strided_impl_fn_ptr_t, WhereStridedFactory, + dpctl::tensor::detail::num_types> + dtb2; + dtb2.populate_dispatch_table(where_strided_dispatch_table); +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/where.hpp b/dpctl/tensor/libtensor/source/where.hpp new file mode 100644 index 0000000000..38d40b8550 --- /dev/null +++ b/dpctl/tensor/libtensor/source/where.hpp @@ -0,0 +1,52 @@ +//===-- where.hpp - --*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file declares Python API for implementation functions of +/// dpctl.tensor.where +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include +#include + +#include "dpctl4pybind11.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern std::pair +py_where(dpctl::tensor::usm_ndarray, + dpctl::tensor::usm_ndarray, + dpctl::tensor::usm_ndarray, + dpctl::tensor::usm_ndarray, + sycl::queue, + const std::vector &); + +extern void init_where_dispatch_tables(void); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tests/test_type_utils.py b/dpctl/tests/test_type_utils.py new file mode 100644 index 0000000000..882478a2ce --- /dev/null +++ b/dpctl/tests/test_type_utils.py @@ -0,0 +1,68 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import dpctl.tensor as dpt +from dpctl.tensor._type_utils import ( + _all_data_types, + _can_cast, + _is_maximal_inexact_type, +) + + +def test_all_data_types(): + fp16_fp64_types = set([dpt.float16, dpt.float64, dpt.complex128]) + fp64_types = set([dpt.float64, dpt.complex128]) + + all_dts = _all_data_types(True, True) + assert fp16_fp64_types.issubset(all_dts) + + all_dts = _all_data_types(True, False) + assert dpt.float16 in all_dts + assert not fp64_types.issubset(all_dts) + + all_dts = _all_data_types(False, True) + assert dpt.float16 not in all_dts + assert fp64_types.issubset(all_dts) + + all_dts = _all_data_types(False, False) + assert not fp16_fp64_types.issubset(all_dts) + + +@pytest.mark.parametrize("fp16", [True, False]) +@pytest.mark.parametrize("fp64", [True, False]) +def test_maximal_inexact_types(fp16, fp64): + assert not _is_maximal_inexact_type(dpt.int32, fp16, fp64) + assert fp64 == _is_maximal_inexact_type(dpt.float64, fp16, fp64) + assert fp64 == _is_maximal_inexact_type(dpt.complex128, fp16, fp64) + assert fp64 != _is_maximal_inexact_type(dpt.float32, fp16, fp64) + assert fp64 != _is_maximal_inexact_type(dpt.complex64, fp16, fp64) + + +def test_can_cast_device(): + assert _can_cast(dpt.int64, dpt.float64, True, True) + # if f8 is available, can't cast i8 to f4 + assert not _can_cast(dpt.int64, dpt.float32, True, True) + assert not _can_cast(dpt.int64, dpt.float32, False, True) + # should be able to cast to f8 when f2 unavailable + assert _can_cast(dpt.int64, dpt.float64, False, True) + # casting to f4 acceptable when f8 unavailable + assert _can_cast(dpt.int64, dpt.float32, True, False) + assert _can_cast(dpt.int64, dpt.float32, False, False) + # can't safely cast inexact type to inexact type of lesser precision + assert not _can_cast(dpt.float32, dpt.float16, True, False) + assert not _can_cast(dpt.float64, dpt.float32, False, True) diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py new file mode 100644 index 0000000000..486350589e --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -0,0 +1,329 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from helper import get_queue_or_skip, skip_if_dtype_not_supported +from numpy.testing import assert_array_equal + +import dpctl.tensor as dpt +from dpctl.tensor._search_functions import _where_result_type +from dpctl.tensor._type_utils import _all_data_types +from dpctl.utils import ExecutionPlacementError + +_all_dtypes = [ + "?", + "u1", + "i1", + "u2", + "i2", + "u4", + "i4", + "u8", + "i8", + "e", + "f", + "d", + "F", + "D", +] + + +class mock_device: + def __init__(self, fp16, fp64): + self.has_aspect_fp16 = fp16 + self.has_aspect_fp64 = fp64 + + +def test_where_basic(): + get_queue_or_skip() + + cond = dpt.asarray( + [ + [True, False, False], + [False, True, False], + [False, False, True], + [False, False, False], + [True, True, True], + ] + ) + out = dpt.where(cond, dpt.asarray(1), dpt.asarray(0)) + out_expected = dpt.asarray( + [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]] + ) + assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all() + + out = dpt.where(cond, dpt.ones(cond.shape), dpt.zeros(cond.shape)) + assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all() + + out = dpt.where( + cond, + dpt.ones(cond.shape[0], dtype="i4")[:, dpt.newaxis], + dpt.zeros(cond.shape[0], dtype="i4")[:, dpt.newaxis], + ) + assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all() + + +def _dtype_all_close(x1, x2): + if np.issubdtype(x2.dtype, np.floating) or np.issubdtype( + x2.dtype, np.complexfloating + ): + x2_dtype = x2.dtype + return np.allclose( + x1, x2, atol=np.finfo(x2_dtype).eps, rtol=np.finfo(x2_dtype).eps + ) + else: + return np.allclose(x1, x2) + + +@pytest.mark.parametrize("dt1", _all_dtypes) +@pytest.mark.parametrize("dt2", _all_dtypes) +@pytest.mark.parametrize("fp16", [True, False]) +@pytest.mark.parametrize("fp64", [True, False]) +def test_where_result_types(dt1, dt2, fp16, fp64): + dev = mock_device(fp16, fp64) + + dt1 = dpt.dtype(dt1) + dt2 = dpt.dtype(dt2) + res_t = _where_result_type(dt1, dt2, dev) + + if fp16 and fp64: + assert res_t == dpt.result_type(dt1, dt2) + else: + if res_t: + assert res_t.kind == dpt.result_type(dt1, dt2).kind + else: + # some illegal cases are covered above, but + # this guarantees that _where_result_type + # produces None only when one of the dtypes + # is illegal given fp aspects of device + all_dts = _all_data_types(fp16, fp64) + assert dt1 not in all_dts or dt2 not in all_dts + + +@pytest.mark.parametrize("dt", _all_dtypes) +def test_where_all_dtypes(dt): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt, q) + + # mask dtype changes + cond = dpt.asarray([0, 1, 3, 0, 10], dtype=dt, sycl_queue=q) + x1 = dpt.asarray(0, dtype="f4", sycl_queue=q) + x2 = dpt.asarray(1, dtype="f4", sycl_queue=q) + res = dpt.where(cond, x1, x2) + + res_check = np.asarray([1, 0, 0, 1, 0], dtype=res.dtype) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + + # contiguous cases + x1 = dpt.full(cond.shape, 0, dtype="f4", sycl_queue=q) + x2 = dpt.full(cond.shape, 1, dtype="f4", sycl_queue=q) + res = dpt.where(cond, x1, x2) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + + # input array dtype changes + cond = dpt.asarray([False, True, True, False, True], sycl_queue=q) + x1 = dpt.asarray(0, dtype=dt, sycl_queue=q) + x2 = dpt.asarray(1, dtype=dt, sycl_queue=q) + res = dpt.where(cond, x1, x2) + + res_check = np.asarray([1, 0, 0, 1, 0], dtype=res.dtype) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + + # contiguous cases + x1 = dpt.full(cond.shape, 0, dtype=dt, sycl_queue=q) + x2 = dpt.full(cond.shape, 1, dtype=dt, sycl_queue=q) + res = dpt.where(cond, x1, x2) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + + +def test_where_asymmetric_dtypes(): + q = get_queue_or_skip() + + cond = dpt.asarray([0, 1, 3, 0, 10], dtype="?", sycl_queue=q) + x1 = dpt.asarray(2, dtype="i4", sycl_queue=q) + x2 = dpt.asarray(3, dtype="i8", sycl_queue=q) + + res = dpt.where(cond, x1, x2) + res_check = np.asarray([3, 2, 2, 3, 2], dtype=res.dtype) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + + # flip order + + res = dpt.where(cond, x2, x1) + res_check = np.asarray([2, 3, 3, 2, 3], dtype=res.dtype) + assert _dtype_all_close(dpt.asnumpy(res), res_check) + + +def test_where_nan_inf(): + get_queue_or_skip() + + cond = dpt.asarray([True, False, True, False], dtype="?") + x1 = dpt.asarray([np.nan, 2.0, np.inf, 3.0], dtype="f4") + x2 = dpt.asarray([2.0, np.nan, 3.0, np.inf], dtype="f4") + + cond_np = dpt.asnumpy(cond) + x1_np = dpt.asnumpy(x1) + x2_np = dpt.asnumpy(x2) + + res = dpt.where(cond, x1, x2) + res_np = np.where(cond_np, x1_np, x2_np) + + assert np.allclose(dpt.asnumpy(res), res_np, equal_nan=True) + + res = dpt.where(x1, cond, x2) + res_np = np.where(x1_np, cond_np, x2_np) + assert _dtype_all_close(dpt.asnumpy(res), res_np) + + +def test_where_empty(): + # check that numpy returns same results when + # handling empty arrays + get_queue_or_skip() + + empty = dpt.empty(0, dtype="i2") + m = dpt.asarray(True) + x1 = dpt.asarray(1, dtype="i2") + x2 = dpt.asarray(2, dtype="i2") + res = dpt.where(empty, x1, x2) + + empty_np = np.empty(0, dtype="i2") + m_np = dpt.asnumpy(m) + x1_np = dpt.asnumpy(x1) + x2_np = dpt.asnumpy(x2) + res_np = np.where(empty_np, x1_np, x2_np) + + assert_array_equal(dpt.asnumpy(res), res_np) + + res = dpt.where(m, empty, x2) + res_np = np.where(m_np, empty_np, x2_np) + + assert_array_equal(dpt.asnumpy(res), res_np) + + # check that broadcasting is performed + with pytest.raises(ValueError): + dpt.where(empty, x1, dpt.empty((1, 2))) + + +@pytest.mark.parametrize("order", ["C", "F"]) +def test_where_contiguous(order): + get_queue_or_skip() + + cond = dpt.asarray( + [ + [[True, False, False], [False, True, True]], + [[False, True, False], [True, False, True]], + [[False, False, True], [False, False, True]], + [[False, False, False], [True, False, True]], + [[True, True, True], [True, False, True]], + ], + order=order, + ) + + x1 = dpt.full(cond.shape, 2, dtype="i4", order=order) + x2 = dpt.full(cond.shape, 3, dtype="i4", order=order) + expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2)) + res = dpt.where(cond, x1, x2) + + assert _dtype_all_close(dpt.asnumpy(res), expected) + + +def test_where_contiguous1D(): + get_queue_or_skip() + + cond = dpt.asarray([True, False, True, False, False, True]) + + x1 = dpt.full(cond.shape, 2, dtype="i4") + x2 = dpt.full(cond.shape, 3, dtype="i4") + expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2)) + res = dpt.where(cond, x1, x2) + assert_array_equal(dpt.asnumpy(res), expected) + + # test with complex dtype (branch in kernel) + x1 = dpt.astype(x1, dpt.complex64) + x2 = dpt.astype(x2, dpt.complex64) + expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2)) + res = dpt.where(cond, x1, x2) + assert _dtype_all_close(dpt.asnumpy(res), expected) + + +def test_where_strided(): + get_queue_or_skip() + + s0, s1 = 4, 9 + cond = dpt.reshape( + dpt.asarray( + [True, False, False, False, True, True, False, True, False] * s0 + ), + (s0, s1), + )[:, ::3] + + x1 = dpt.reshape( + dpt.arange(cond.shape[0] * cond.shape[1] * 2, dtype="i4"), + (cond.shape[0], cond.shape[1] * 2), + )[:, ::2] + x2 = dpt.reshape( + dpt.arange(cond.shape[0] * cond.shape[1] * 3, dtype="i4"), + (cond.shape[0], cond.shape[1] * 3), + )[:, ::3] + expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2)) + res = dpt.where(cond, x1, x2) + + assert_array_equal(dpt.asnumpy(res), expected) + + # negative strides + res = dpt.where(cond, dpt.flip(x1), x2) + expected = np.where( + dpt.asnumpy(cond), np.flip(dpt.asnumpy(x1)), dpt.asnumpy(x2) + ) + assert_array_equal(dpt.asnumpy(res), expected) + + res = dpt.where(dpt.flip(cond), x1, x2) + expected = np.where( + np.flip(dpt.asnumpy(cond)), dpt.asnumpy(x1), dpt.asnumpy(x2) + ) + assert_array_equal(dpt.asnumpy(res), expected) + + +def test_where_arg_validation(): + get_queue_or_skip() + + check = dict() + x1 = dpt.empty((1,), dtype="i4") + x2 = dpt.empty((1,), dtype="i4") + + with pytest.raises(TypeError): + dpt.where(check, x1, x2) + with pytest.raises(TypeError): + dpt.where(x1, check, x2) + with pytest.raises(TypeError): + dpt.where(x1, x2, check) + + +def test_where_compute_follows_data(): + q1 = get_queue_or_skip() + q2 = get_queue_or_skip() + q3 = get_queue_or_skip() + + x1 = dpt.empty((1,), dtype="i4", sycl_queue=q1) + x2 = dpt.empty((1,), dtype="i4", sycl_queue=q2) + + with pytest.raises(ExecutionPlacementError): + dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q1), x1, x2) + with pytest.raises(ExecutionPlacementError): + dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q3), x1, x2) + with pytest.raises(ExecutionPlacementError): + dpt.where(x1, x1, x2)