diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 9a79830e47..083f2558c5 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -44,6 +44,7 @@ pybind11_add_module(${python_module_name} MODULE ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp ) target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index ec9f27617b..7364397648 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -88,6 +88,7 @@ from dpctl.tensor._reshape import reshape from dpctl.tensor._search_functions import where from dpctl.tensor._usmarray import usm_ndarray +from dpctl.tensor._utility_functions import all, any from ._constants import e, inf, nan, newaxis, pi @@ -130,6 +131,8 @@ "tril", "triu", "where", + "all", + "any", "dtype", "isdtype", "bool", diff --git a/dpctl/tensor/_utility_functions.py b/dpctl/tensor/_utility_functions.py new file mode 100644 index 0000000000..500c997e8f --- /dev/null +++ b/dpctl/tensor/_utility_functions.py @@ -0,0 +1,125 @@ +from numpy.core.numeric import normalize_axis_tuple + +import dpctl +import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as ti + + +def _boolean_reduction(x, axis, keepdims, func): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + nd = x.ndim + if axis is None: + red_nd = nd + # case of a scalar + if red_nd == 0: + return dpt.astype(x, dpt.bool) + x_tmp = x + res_shape = tuple() + perm = list(range(nd)) + else: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + + red_nd = len(axis) + # check for axis=() + if red_nd == 0: + return dpt.astype(x, dpt.bool) + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt.permute_dims(x, perm) + res_shape = x_tmp.shape[: nd - red_nd] + + exec_q = x.sycl_queue + res_usm_type = x.usm_type + + wait_list = [] + res_tmp = dpt.empty( + res_shape, + dtype=dpt.int32, + usm_type=res_usm_type, + sycl_queue=exec_q, + ) + hev0, ev0 = func( + src=x_tmp, + trailing_dims_to_reduce=red_nd, + dst=res_tmp, + sycl_queue=exec_q, + ) + wait_list.append(hev0) + + # copy to boolean result array + res = dpt.empty( + res_shape, + dtype=dpt.bool, + usm_type=res_usm_type, + sycl_queue=exec_q, + ) + hev1, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=res_tmp, dst=res, sycl_queue=exec_q, depends=[ev0] + ) + wait_list.append(hev1) + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) + dpctl.SyclEvent.wait_for(wait_list) + return res + + +def all(x, axis=None, keepdims=False): + """all(x, axis=None, keepdims=False) + + Tests whether all input array elements evaluate to True along a given axis. + + Args: + x (usm_ndarray): Input array. + axis (Optional[Union[int, Tuple[int,...]]]): Axis (or axes) + along which to perform a logical AND reduction. + When `axis` is `None`, a logical AND reduction + is performed over all dimensions of `x`. + If `axis` is negative, the axis is counted from + the last dimension to the first. + Default: `None`. + keepdims (bool, optional): If `True`, the reduced axes are included + in the result as singleton dimensions, and the result is + broadcastable to the input array shape. + If `False`, the reduced axes are not included in the result. + Default: `False`. + + Returns: + usm_ndarray: + An array with a data type of `bool` + containing the results of the logical AND reduction. + """ + return _boolean_reduction(x, axis, keepdims, ti._all) + + +def any(x, axis=None, keepdims=False): + """any(x, axis=None, keepdims=False) + + Tests whether any input array elements evaluate to True along a given axis. + + Args: + x (usm_ndarray): Input array. + axis (Optional[Union[int, Tuple[int,...]]]): Axis (or axes) + along which to perform a logical OR reduction. + When `axis` is `None`, a logical OR reduction + is performed over all dimensions of `x`. + If `axis` is negative, the axis is counted from + the last dimension to the first. + Default: `None`. + keepdims (bool, optional): If `True`, the reduced axes are included + in the result as singleton dimensions, and the result is + broadcastable to the input array shape. + If `False`, the reduced axes are not included in the result. + Default: `False`. + + Returns: + usm_ndarray: + An array with a data type of `bool` + containing the results of the logical OR reduction. + """ + return _boolean_reduction(x, axis, keepdims, ti._any) diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp new file mode 100644 index 0000000000..dec96fab2a --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp @@ -0,0 +1,636 @@ +//=== boolean_reductions.hpp - Implementation of boolean reduction kernels +//---*-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for dpctl.tensor.any and dpctl.tensor.all +//===----------------------------------------------------------------------===// + +#pragma once +#include + +#include +#include +#include +#include + +#include "pybind11/pybind11.h" + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +template struct boolean_predicate +{ + bool operator()(const T &v) const + { + using dpctl::tensor::type_utils::convert_impl; + return convert_impl(v); + } +}; + +template +struct all_reduce_wg_contig +{ + void operator()(sycl::nd_item &ndit, + outT *out, + size_t &out_idx, + const inpT *start, + const inpT *end) const + { + PredicateT pred{}; + auto wg = ndit.get_group(); + outT red_val_over_wg = + static_cast(sycl::joint_all_of(wg, start, end, pred)); + + if (wg.leader()) { + sycl::atomic_ref + res_ref(out[out_idx]); + res_ref.fetch_and(red_val_over_wg); + } + } +}; + +template +struct any_reduce_wg_contig +{ + void operator()(sycl::nd_item &ndit, + outT *out, + size_t &out_idx, + const inpT *start, + const inpT *end) const + { + PredicateT pred{}; + auto wg = ndit.get_group(); + outT red_val_over_wg = + static_cast(sycl::joint_any_of(wg, start, end, pred)); + + if (wg.leader()) { + sycl::atomic_ref + res_ref(out[out_idx]); + res_ref.fetch_or(red_val_over_wg); + } + } +}; + +template struct all_reduce_wg_strided +{ + void operator()(sycl::nd_item &ndit, + T *out, + const size_t &out_idx, + const T &local_val) const + { + auto wg = ndit.get_group(); + T red_val_over_wg = static_cast(sycl::all_of_group(wg, local_val)); + + if (wg.leader()) { + sycl::atomic_ref + res_ref(out[out_idx]); + res_ref.fetch_and(red_val_over_wg); + } + } +}; + +template struct any_reduce_wg_strided +{ + void operator()(sycl::nd_item &ndit, + T *out, + const size_t &out_idx, + const T &local_val) const + { + auto wg = ndit.get_group(); + T red_val_over_wg = static_cast(sycl::any_of_group(wg, local_val)); + + if (wg.leader()) { + sycl::atomic_ref + res_ref(out[out_idx]); + res_ref.fetch_or(red_val_over_wg); + } + } +}; + +template +struct SequentialBooleanReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + +public: + SequentialBooleanReduction(const argT *inp, + outT *res, + ReductionOp reduction_op, + const outT &identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + size_t reduction_size) + : inp_(inp), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); + const py::ssize_t &inp_iter_offset = + inp_out_iter_offsets_.get_first_offset(); + const py::ssize_t &out_iter_offset = + inp_out_iter_offsets_.get_second_offset(); + + outT red_val(identity_); + for (size_t m = 0; m < reduction_max_gid_; ++m) { + py::ssize_t inp_reduction_offset = + static_cast(inp_reduced_dims_indexer_(m)); + py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; + + // must convert to boolean first to handle nans + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl(inp_[inp_offset]); + + red_val = reduction_op_(red_val, val); + } + + out_[out_iter_offset] = red_val; + } +}; + +template +struct ContigBooleanReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + GroupOp group_op_; + size_t reduction_max_gid_ = 0; + size_t reductions_per_wi = 16; + +public: + ContigBooleanReduction(const argT *inp, + outT *res, + GroupOp group_op, + size_t reduction_size, + size_t reduction_size_per_wi) + : inp_(inp), out_(res), group_op_(group_op), + reduction_max_gid_(reduction_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<2> it) const + { + + size_t reduction_id = it.get_group(0); + size_t reduction_batch_id = it.get_group(1); + size_t wg_size = it.get_local_range(1); + + size_t base = reduction_id * reduction_max_gid_; + size_t start = base + reduction_batch_id * wg_size * reductions_per_wi; + size_t end = std::min((start + (reductions_per_wi * wg_size)), + base + reduction_max_gid_); + // reduction and atomic operations are performed + // in group_op_ + group_op_(it, out_, reduction_id, inp_ + start, inp_ + end); + } +}; + +typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)( + sycl::queue, + size_t, + size_t, + const char *, + char *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + const std::vector &); + +template +class boolean_reduction_contig_krn; + +template +class boolean_reduction_seq_contig_krn; + +template +sycl::event +boolean_reduction_contig_impl(sycl::queue exec_q, + size_t iter_nelems, + size_t reduction_nelems, + const char *arg_cp, + char *res_cp, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + py::ssize_t red_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + red_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + constexpr resTy identity_val = sycl::known_identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = + 4 * (*std::max_element(std::begin(sg_sizes), std::end(sg_sizes))); + + sycl::event red_ev; + if (reduction_nelems < wg) { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputIterIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{}; + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialBooleanReduction( + arg_tp, res_tp, RedOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems)); + }); + } + else { + sycl::event init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + IndexerT res_indexer{}; + + cgh.depends_on(depends); + + cgh.parallel_for(sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = identity_val; + }); + }); + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(init_ev); + + constexpr std::uint8_t group_dim = 2; + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? ((reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto gws = + sycl::range{iter_nelems, reduction_groups * wg}; + auto lws = sycl::range{1, wg}; + + cgh.parallel_for< + class boolean_reduction_contig_krn>( + sycl::nd_range(gws, lws), + ContigBooleanReduction( + arg_tp, res_tp, GroupOpT(), reduction_nelems, + reductions_per_wi)); + }); + } + return red_ev; +} + +template struct AllContigFactory +{ + fnT get() const + { + using resTy = std::int32_t; + using RedOpT = sycl::logical_and; + using GroupOpT = + all_reduce_wg_contig>; + + return dpctl::tensor::kernels::boolean_reduction_contig_impl< + srcTy, resTy, RedOpT, GroupOpT>; + } +}; + +template struct AnyContigFactory +{ + fnT get() const + { + using resTy = std::int32_t; + using RedOpT = sycl::logical_or; + using GroupOpT = + any_reduce_wg_contig>; + + return dpctl::tensor::kernels::boolean_reduction_contig_impl< + srcTy, resTy, RedOpT, GroupOpT>; + } +}; + +template +struct StridedBooleanReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + GroupOp group_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t reductions_per_wi = 16; + +public: + StridedBooleanReduction(const argT *inp, + outT *res, + ReductionOp reduction_op, + GroupOp group_op, + const outT &identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t reduction_size_per_wi) + : inp_(inp), out_(res), reduction_op_(reduction_op), + group_op_(group_op), identity_(identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<2> it) const + { + + size_t reduction_id = it.get_group(0); + size_t reduction_batch_id = it.get_group(1); + size_t reduction_lid = it.get_local_id(1); + size_t wg_size = it.get_local_range(1); + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(reduction_id); + const py::ssize_t &inp_iter_offset = + inp_out_iter_offsets_.get_first_offset(); + const py::ssize_t &out_iter_offset = + inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg_size * reductions_per_wi; + for (size_t m = 0; m < reductions_per_wi; ++m) { + size_t arg_reduce_gid = arg_reduce_gid0 + m * wg_size; + + if (arg_reduce_gid < reduction_max_gid_) { + py::ssize_t inp_reduction_offset = static_cast( + inp_reduced_dims_indexer_(arg_reduce_gid)); + py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; + + // must convert to boolean first to handle nans + using dpctl::tensor::type_utils::convert_impl; + bool val = convert_impl(inp_[inp_offset]); + + local_red_val = + reduction_op_(local_red_val, static_cast(val)); + } + } + // reduction and atomic operations are performed + // in group_op_ + group_op_(it, out_, out_iter_offset, local_red_val); + } +}; + +template +class boolean_reduction_strided_krn; + +template +class boolean_reduction_seq_strided_krn; + +typedef sycl::event (*boolean_reduction_strided_impl_fn_ptr)( + sycl::queue, + size_t, + size_t, + const char *, + char *, + int, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + int, + const py::ssize_t *, + py::ssize_t, + const std::vector &); + +template +sycl::event +boolean_reduction_strided_impl(sycl::queue exec_q, + size_t iter_nelems, + size_t reduction_nelems, + const char *arg_cp, + char *res_cp, + int iter_nd, + const py::ssize_t *iter_shape_and_strides, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + int red_nd, + const py::ssize_t *reduction_shape_stride, + py::ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + constexpr resTy identity_val = sycl::known_identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = + 4 * (*std::max_element(std::begin(sg_sizes), std::end(sg_sizes))); + + sycl::event red_ev; + if (reduction_nelems < wg) { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, + iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialBooleanReduction( + arg_tp, res_tp, RedOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems)); + }); + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const py::ssize_t *const &res_shape = iter_shape_and_strides; + const py::ssize_t *const &res_strides = + iter_shape_and_strides + 2 * iter_nd; + IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, + res_strides); + + cgh.depends_on(depends); + + cgh.parallel_for(sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = identity_val; + }); + }); + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + constexpr std::uint8_t group_dim = 2; + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, + iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? ((reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto gws = + sycl::range{iter_nelems, reduction_groups * wg}; + auto lws = sycl::range{1, wg}; + + cgh.parallel_for>( + sycl::nd_range(gws, lws), + StridedBooleanReduction( + arg_tp, res_tp, RedOpT(), GroupOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + reductions_per_wi)); + }); + } + return red_ev; +} + +template struct AllStridedFactory +{ + fnT get() const + { + using resTy = std::int32_t; + using RedOpT = sycl::logical_and; + using GroupOpT = all_reduce_wg_strided; + + return dpctl::tensor::kernels::boolean_reduction_strided_impl< + srcTy, resTy, RedOpT, GroupOpT>; + } +}; + +template struct AnyStridedFactory +{ + fnT get() const + { + using resTy = std::int32_t; + using RedOpT = sycl::logical_or; + using GroupOpT = any_reduce_wg_strided; + + return dpctl::tensor::kernels::boolean_reduction_strided_impl< + srcTy, resTy, RedOpT, GroupOpT>; + } +}; + +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/boolean_reductions.cpp b/dpctl/tensor/libtensor/source/boolean_reductions.cpp new file mode 100644 index 0000000000..5def6c5158 --- /dev/null +++ b/dpctl/tensor/libtensor/source/boolean_reductions.cpp @@ -0,0 +1,165 @@ +//=-- boolean_reductions.cpp - Implementation of boolean reductions +//-//--*-C++-*-/=// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines Python API for implementation functions of +/// dpctl.tensor.all and dpctl.tensor.any +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +#include "boolean_reductions.hpp" +#include "dpctl4pybind11.hpp" + +#include "kernels/boolean_reductions.hpp" +#include "utils/type_utils.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +// All +namespace impl +{ +using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr; +using dpctl::tensor::kernels::boolean_reduction_strided_impl_fn_ptr; +static boolean_reduction_strided_impl_fn_ptr + all_reduction_strided_dispatch_vector[td_ns::num_types]; +static boolean_reduction_contig_impl_fn_ptr + all_reduction_contig_dispatch_vector[td_ns::num_types]; + +void populate_all_dispatch_vectors(void) +{ + using td_ns::DispatchVectorBuilder; + + using dpctl::tensor::kernels::boolean_reduction_strided_impl_fn_ptr; + + using dpctl::tensor::kernels::AllStridedFactory; + DispatchVectorBuilder + all_dvb1; + all_dvb1.populate_dispatch_vector(all_reduction_strided_dispatch_vector); + + using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr; + + using dpctl::tensor::kernels::AllContigFactory; + DispatchVectorBuilder + all_dvb2; + all_dvb2.populate_dispatch_vector(all_reduction_contig_dispatch_vector); +}; + +} // namespace impl + +// Any +namespace impl +{ +using dpctl::tensor::kernels::boolean_reduction_strided_impl_fn_ptr; +static boolean_reduction_strided_impl_fn_ptr + any_reduction_strided_dispatch_vector[td_ns::num_types]; +using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr; +static boolean_reduction_contig_impl_fn_ptr + any_reduction_contig_dispatch_vector[td_ns::num_types]; + +void populate_any_dispatch_vectors(void) +{ + using td_ns::DispatchVectorBuilder; + + using dpctl::tensor::kernels::boolean_reduction_strided_impl_fn_ptr; + + using dpctl::tensor::kernels::AnyStridedFactory; + DispatchVectorBuilder + any_dvb1; + any_dvb1.populate_dispatch_vector(any_reduction_strided_dispatch_vector); + + using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr; + + using dpctl::tensor::kernels::AnyContigFactory; + DispatchVectorBuilder + any_dvb2; + any_dvb2.populate_dispatch_vector(any_reduction_contig_dispatch_vector); +}; + +} // namespace impl + +void init_boolean_reduction_functions(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + // ALL + { + impl::populate_all_dispatch_vectors(); + using impl::all_reduction_contig_dispatch_vector; + using impl::all_reduction_strided_dispatch_vector; + + auto all_pyapi = [&](arrayT src, int trailing_dims_to_reduce, + arrayT dst, sycl::queue exec_q, + const event_vecT &depends = {}) { + return py_boolean_reduction(src, trailing_dims_to_reduce, dst, + exec_q, depends, + all_reduction_contig_dispatch_vector, + all_reduction_strided_dispatch_vector); + }; + m.def("_all", all_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } + + // ANY + { + impl::populate_any_dispatch_vectors(); + using impl::any_reduction_contig_dispatch_vector; + using impl::any_reduction_strided_dispatch_vector; + + auto any_pyapi = [&](arrayT src, int trailing_dims_to_reduce, + arrayT dst, sycl::queue exec_q, + const event_vecT &depends = {}) { + return py_boolean_reduction(src, trailing_dims_to_reduce, dst, + exec_q, depends, + any_reduction_contig_dispatch_vector, + any_reduction_strided_dispatch_vector); + }; + m.def("_any", any_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/boolean_reductions.hpp b/dpctl/tensor/libtensor/source/boolean_reductions.hpp new file mode 100644 index 0000000000..7b00932a8c --- /dev/null +++ b/dpctl/tensor/libtensor/source/boolean_reductions.hpp @@ -0,0 +1,270 @@ +//===-- boolean_reductions.hpp ---*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file declares Python API for implementation functions of +/// dpctl.tensor.any and dpctl.tensor.all +//===----------------------------------------------------------------------===// + +#pragma once +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +template +std::pair +py_boolean_reduction(dpctl::tensor::usm_ndarray src, + int trailing_dims_to_reduce, + dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends, + const contig_dispatchT &contig_dispatch_vector, + const strided_dispatchT &strided_dispatch_vector) +{ + int src_nd = src.get_ndim(); + int iter_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iter_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iter_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + size_t dst_nelems = dst.get_size(); + + size_t red_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + red_nelems *= static_cast(src_shape_ptr[i]); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(dst, src)) { + throw py::value_error("Arrays are expected to have no memory overlap"); + } + + // ensure that dst is sufficiently ample + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accomodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < static_cast(dst_nelems)) { + throw py::value_error( + "Memory addressed by the destination array can not " + "accomodate all the array elements."); + } + } + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + constexpr int int32_typeid = static_cast(td_ns::typenum_t::INT32); + if (dst_typeid != int32_typeid) { + throw py::value_error( + "Unexpected data type of destination array, expecting 'int32'"); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + bool is_dst_c_contig = dst.is_c_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nd == 0)) { + auto fn = contig_dispatch_vector[src_typeid]; + constexpr py::ssize_t zero_offset = 0; + + auto red_ev = fn(exec_q, dst_nelems, red_nelems, src_data, dst_data, + zero_offset, zero_offset, zero_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + + auto src_shape_vecs = src.get_shape_vector(); + auto src_strides_vecs = src.get_strides_vector(); + auto dst_strides_vecs = dst.get_strides_vector(); + + int simplified_red_nd = trailing_dims_to_reduce; + + using shT = std::vector; + shT red_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_red_shape; + shT simplified_red_src_strides; + py::ssize_t red_src_offset(0); + + using dpctl::tensor::py_internal::simplify_iteration_space_1; + simplify_iteration_space_1( + simplified_red_nd, src_shape_ptr + dst_nd, red_src_strides, + // output + simplified_red_shape, simplified_red_src_strides, red_src_offset); + + shT iter_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iter_nd); + shT const &iter_dst_strides = dst_strides_vecs; + + shT simplified_iter_shape; + shT simplified_iter_src_strides; + shT simplified_iter_dst_strides; + py::ssize_t iter_src_offset(0); + py::ssize_t iter_dst_offset(0); + + if (iter_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iter_nd = 1; + simplified_iter_shape.push_back(1); + simplified_iter_src_strides.push_back(0); + simplified_iter_dst_strides.push_back(0); + } + else { + using dpctl::tensor::py_internal::simplify_iteration_space; + simplify_iteration_space( + iter_nd, src_shape_ptr, iter_src_strides, iter_dst_strides, + // output + simplified_iter_shape, simplified_iter_src_strides, + simplified_iter_dst_strides, iter_src_offset, iter_dst_offset); + } + + if ((simplified_red_nd == 1) && (simplified_red_src_strides[0] == 1) && + (iter_nd == 1) && + ((simplified_iter_shape[0] == 1) || + ((simplified_iter_dst_strides[0] == 1) && + (simplified_iter_src_strides[0] == + static_cast(red_nelems))))) + { + auto fn = contig_dispatch_vector[src_typeid]; + size_t iter_nelems = dst_nelems; + + sycl::event red_ev = + fn(exec_q, iter_nelems, red_nelems, src.get_data(), dst.get_data(), + iter_src_offset, iter_dst_offset, red_src_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + + auto fn = strided_dispatch_vector[src_typeid]; + + // using a single host_task for packing here + // prevents crashes on CPU + std::vector host_task_events{}; + const auto &iter_red_metadata_packing_triple_ = + dpctl::tensor::offset_utils::device_allocate_and_pack( + exec_q, host_task_events, simplified_iter_shape, + simplified_iter_src_strides, simplified_iter_dst_strides, + simplified_red_shape, simplified_red_src_strides); + py::ssize_t *packed_shapes_and_strides = + std::get<0>(iter_red_metadata_packing_triple_); + if (packed_shapes_and_strides == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = + std::get<2>(iter_red_metadata_packing_triple_); + + py::ssize_t *iter_shape_and_strides = packed_shapes_and_strides; + py::ssize_t *red_shape_stride = packed_shapes_and_strides + (3 * iter_nd); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto red_ev = + fn(exec_q, dst_nelems, red_nelems, src_data, dst_data, dst_nd, + iter_shape_and_strides, iter_src_offset, iter_dst_offset, + simplified_red_nd, red_shape_stride, red_src_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_shapes_and_strides] { + sycl::free(packed_shapes_and_strides, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, red_ev); +} + +extern void init_boolean_reduction_functions(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 2cf627be18..e4c8a5e9b1 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -34,6 +34,7 @@ #include "dpctl4pybind11.hpp" #include "boolean_advanced_indexing.hpp" +#include "boolean_reductions.hpp" #include "copy_and_cast_usm_to_usm.hpp" #include "copy_for_reshape.hpp" #include "copy_numpy_ndarray_into_usm_ndarray.hpp" @@ -346,4 +347,6 @@ PYBIND11_MODULE(_tensor_impl, m) m.def("_where", &py_where, "", py::arg("condition"), py::arg("x1"), py::arg("x2"), py::arg("dst"), py::arg("sycl_queue"), py::arg("depends") = py::list()); + + dpctl::tensor::py_internal::init_boolean_reduction_functions(m); } diff --git a/dpctl/tests/test_usm_ndarray_utility_functions.py b/dpctl/tests/test_usm_ndarray_utility_functions.py new file mode 100644 index 0000000000..f99517daef --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_utility_functions.py @@ -0,0 +1,150 @@ +from random import randrange + +import numpy as np +import pytest +from numpy import AxisError +from numpy.testing import assert_array_equal, assert_equal + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +_all_dtypes = [ + "?", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] + + +@pytest.mark.parametrize("func,identity", [(dpt.all, True), (dpt.any, False)]) +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_boolean_reduction_dtypes_contig(func, identity, dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + x = dpt.full(10, identity, dtype=dtype, sycl_queue=q) + res = func(x) + + assert_equal(dpt.asnumpy(res), identity) + + x[randrange(x.size)] = not identity + res = func(x) + assert_equal(dpt.asnumpy(res), not identity) + + # test branch in kernel for large arrays + wg_size = 4 * 32 + x = dpt.full((wg_size + 1), identity, dtype=dtype, sycl_queue=q) + res = func(x) + assert_equal(dpt.asnumpy(res), identity) + + x[randrange(x.size)] = not identity + res = func(x) + assert_equal(dpt.asnumpy(res), not identity) + + +@pytest.mark.parametrize("func,identity", [(dpt.all, True), (dpt.any, False)]) +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_boolean_reduction_dtypes_strided(func, identity, dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + x = dpt.full(20, identity, dtype=dtype, sycl_queue=q)[::-2] + res = func(x) + assert_equal(dpt.asnumpy(res), identity) + + x[randrange(x.size)] = not identity + res = func(x) + assert_equal(dpt.asnumpy(res), not identity) + + +@pytest.mark.parametrize("func,identity", [(dpt.all, True), (dpt.any, False)]) +def test_boolean_reduction_axis(func, identity): + get_queue_or_skip() + + x = dpt.full((2, 3, 4, 5, 6), identity, dtype="i4") + res = func(x, axis=(1, 2, -1)) + + assert res.shape == (2, 5) + assert_array_equal(dpt.asnumpy(res), np.full(res.shape, identity)) + + # make first row of output negation of identity + x[0, 0, 0, ...] = not identity + res = func(x, axis=(1, 2, -1)) + assert_array_equal(dpt.asnumpy(res[0]), np.full(res.shape[1], not identity)) + + +@pytest.mark.parametrize("func", [dpt.all, dpt.any]) +def test_boolean_reduction_keepdims(func): + get_queue_or_skip() + + x = dpt.ones((2, 3, 4, 5, 6), dtype="i4") + res = func(x, axis=(1, 2, -1), keepdims=True) + assert res.shape == (2, 1, 1, 5, 1) + assert_array_equal(dpt.asnumpy(res), np.full(res.shape, True)) + + res = func(x, axis=None, keepdims=True) + assert res.shape == (1,) * x.ndim + + +@pytest.mark.parametrize("func,identity", [(dpt.all, True), (dpt.any, False)]) +def test_boolean_reduction_empty(func, identity): + get_queue_or_skip() + + x = dpt.empty((0,), dtype="i4") + res = func(x) + assert_equal(dpt.asnumpy(res), identity) + + +# nan, inf, and -inf should evaluate to true +@pytest.mark.parametrize("func", [dpt.all, dpt.any]) +def test_boolean_reductions_nan_inf(func): + q = get_queue_or_skip() + + x = dpt.asarray([dpt.nan, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q)[ + :, dpt.newaxis + ] + res = func(x, axis=1) + assert_equal(dpt.asnumpy(res), True) + + +@pytest.mark.parametrize("func", [dpt.all, dpt.any]) +def test_boolean_reduction_scalars(func): + get_queue_or_skip() + + x = dpt.ones((), dtype="i4") + assert_equal(dpt.asnumpy(func(x)), True) + + x = dpt.zeros((), dtype="i4") + assert_equal(dpt.asnumpy(func(x)), False) + + +@pytest.mark.parametrize("func", [dpt.all, dpt.any]) +def test_boolean_reduction_empty_axis(func): + get_queue_or_skip() + + x = dpt.ones((5,), dtype="i4") + res = func(x, axis=()) + assert_array_equal(dpt.asnumpy(res), dpt.asnumpy(x).astype(np.bool_)) + + +@pytest.mark.parametrize("func", [dpt.all, dpt.any]) +def test_arg_validation_boolean_reductions(func): + get_queue_or_skip() + + x = dpt.ones((4, 5), dtype="i4") + d = dict() + + with pytest.raises(TypeError): + func(d) + with pytest.raises(AxisError): + func(x, axis=-3)