From d59cb83487e96116afe9105cbcdbc8d587ee89be Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 22 May 2023 09:37:42 -0500 Subject: [PATCH 1/9] Added missing license headers, updated license year to 2023 --- dpctl/tensor/libtensor/source/tensor_py.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 3a72c205fb..5f1539e56b 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -44,6 +44,7 @@ #include "full_ctor.hpp" #include "integer_advanced_indexing.hpp" #include "linear_sequences.hpp" +#include "reductions.hpp" #include "simplify_iteration_space.hpp" #include "triul_ctor.hpp" #include "utils/memory_overlap.hpp" @@ -351,4 +352,5 @@ PYBIND11_MODULE(_tensor_impl, m) dpctl::tensor::py_internal::init_elementwise_functions(m); dpctl::tensor::py_internal::init_boolean_reduction_functions(m); + dpctl::tensor::py_internal::init_reduction_functions(m); } From 55711fd2c2c5525d3552e8a2dcc7492c3ce93a8a Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 22 May 2023 09:42:11 -0500 Subject: [PATCH 2/9] Implement `dpctl.tensor.sum` reduction operation --- dpctl/tensor/CMakeLists.txt | 1 + dpctl/tensor/__init__.py | 2 + dpctl/tensor/_reduction.py | 167 +++ .../libtensor/include/kernels/reductions.hpp | 965 ++++++++++++++++++ dpctl/tensor/libtensor/source/reductions.cpp | 469 +++++++++ dpctl/tensor/libtensor/source/reductions.hpp | 40 + dpctl/tests/test_tensor_sum.py | 100 ++ 7 files changed, 1744 insertions(+) create mode 100644 dpctl/tensor/_reduction.py create mode 100644 dpctl/tensor/libtensor/include/kernels/reductions.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions.hpp create mode 100644 dpctl/tests/test_tensor_sum.py diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 9e22280f37..cc122524e8 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -47,6 +47,7 @@ pybind11_add_module(${python_module_name} MODULE ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions.cpp ) set(_clang_prefix "") if (WIN32) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index be9426a834..d22b31e9d0 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -102,6 +102,7 @@ isnan, sqrt, ) +from ._reduction import sum __all__ = [ "Device", @@ -187,4 +188,5 @@ "sqrt", "divide", "equal", + "sum", ] diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py new file mode 100644 index 0000000000..7be651bbfd --- /dev/null +++ b/dpctl/tensor/_reduction.py @@ -0,0 +1,167 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from numpy.core.numeric import normalize_axis_tuple + +import dpctl +import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as ti + +from ._type_utils import _to_device_supported_dtype + + +def _default_reduction_dtype(inp_dt, q): + """Gives default output data type for given input data + type `inp_dt` when reduction is performed on queue `q` + """ + inp_kind = inp_dt.kind + if inp_kind in "bi": + res_dt = dpt.dtype(ti.default_device_int_type(q)) + if inp_dt.itemsize > res_dt.itemsize: + res_dt = inp_dt + elif inp_kind in "u": + res_dt = dpt.dtype(ti.default_device_int_type(q).upper()) + res_ii = dpt.iinfo(res_dt) + inp_ii = dpt.iinfo(inp_dt) + if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max: + pass + else: + res_dt = inp_dt + elif inp_kind in "f": + res_dt = dpt.dtype(ti.default_device_fp_type(q)) + if res_dt.itemsize < inp_dt.itemsize: + res_dt = inp_dt + elif inp_kind in "c": + res_dt = dpt.dtype(ti.default_device_complex_type(q)) + if res_dt.itemsize < inp_dt.itemsize: + res_dt = inp_dt + + return res_dt + + +def sum(arr, axis=None, dtype=None, keepdims=False): + """sum(x, axis=None, dtype=None, keepdims=False) + + Calculates the sum of the input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int,...]]): + axis or axes along which sums must be computed. If a tuple + of unique integers, sums are computed over multiple axes. + If `None`, the sum if computed over the entire array. + Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the default data + type is inferred from the "kind" of the input array data type. + * If `x` has a real-valued floating-point data type, + the returned array will have the default real-valued + floating-point data type for the device where input + array `x` is allocated. + * If x` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array `x` is allocated. + * If `x` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array `x` is allocated. + * If `x` has a complex-valued floating-point data typee, + the returned array will have the default complex-valued + floating-pointer data type for the device where input + array `x` is allocated. + * If `x` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array `x` is allocated. + If the data type (either specified or resolved) differs from the + data type of `x`, the input array elements are cast to the + specified data type before computing the sum. Default: `None`. + keepdims (Optional[bool]): + if `True`, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if `False`, the reduced axes are not included in + the returned array. Default: `False`. + Returns: + usm_ndarray: + an array containing the sums. If the sum was computed over the + entire array, a zero-dimensional array is returned. The returned + array has the data type as described in the `dtype` parameter + description above. + """ + if not isinstance(arr, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(arr)}") + nd = arr.ndim + if axis is None: + axis = tuple(range(nd)) + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + red_nd = len(axis) + perm = [i for i in range(nd) if i not in axis] + list(axis) + arr2 = dpt.permute_dims(arr, perm) + res_shape = arr2.shape[: nd - red_nd] + q = arr.sycl_queue + inp_dt = arr.dtype + if dtype is None: + res_dt = _default_reduction_dtype(inp_dt, q) + else: + res_dt = dpt.dtype(dtype) + res_dt = _to_device_supported_dtype(res_dt, q.sycl_device) + + res_usm_type = arr.usm_type + if red_nd == 0: + return dpt.zeros( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + + host_tasks_list = [] + if ti._sum_over_axis_dtype_supported(inp_dt, res_dt, res_usm_type, q): + res = dpt.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e, _ = ti._sum_over_axis( + src=arr2, trailing_dims_to_reduce=red_nd, dst=res, sycl_queue=q + ) + host_tasks_list.append(ht_e) + else: + if dtype is None: + raise RuntimeError( + "Automatically determined reduction data type does not " + "have direct implementation" + ) + tmp_dt = _default_reduction_dtype(inp_dt, q) + tmp = dpt.empty( + res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_tmp, r_e = ti._sum_over_axis( + src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q + ) + host_tasks_list.append(ht_e_tmp) + res = dpt.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=tmp, dst=res, sycl_queue=q, depends=[r_e] + ) + host_tasks_list.append(ht_e) + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) + dpctl.SyclEvent.wait_for(host_tasks_list) + + return res diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp new file mode 100644 index 0000000000..3da09c32bf --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -0,0 +1,965 @@ +//=== reductions.hpp - Implementation of reduction kernels ------- *-C++-*/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor reduction along axis. +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +/* ================ Reduction, using sycl::reduce_over_group, and + * sycl::atomic_ref ================ */ + +/* + This kernel only works for outT with sizeof(outT) == 4, or sizeof(outT) == 8 + if the device has aspect atomic64 and only with those supported by + sycl::atomic_ref +*/ +template +struct ReductionOverGroupWithAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t reductions_per_wi = 16; + +public: + ReductionOverGroupWithAtomicFunctor( + const argT *data, + outT *res, + ReductionOp reduction_op, + const outT &identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<2> it) const + { + + size_t iter_gid = it.get_global_id(0); + size_t reduction_batch_id = it.get_group(1); + size_t reduction_lid = it.get_local_id(1); + size_t wg = it.get_local_range(1); // 0 <= reduction_lid < wg + + // work-items sums over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl(inp_[inp_offset]); + + local_red_val = reduction_op_(local_red_val, val); + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, identity_, reduction_op_); + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_iter_offset]); + if constexpr (std::is_same_v> || + std::is_same_v>) + { + res_ref += red_val_over_wg; + } + else { + outT read_val = res_ref.load(); + outT new_val{}; + do { + new_val = reduction_op_(read_val, red_val_over_wg); + } while (!res_ref.compare_exchange_strong(read_val, new_val)); + } + } + } +}; + +template +size_t choose_workgroup_size(const size_t reduction_nelems, + const std::vector &sg_sizes) +{ + std::vector wg_choices; + wg_choices.reserve(f * sg_sizes.size()); + + for (const auto &sg_size : sg_sizes) { +#pragma unroll + for (size_t i = 1; i <= f; ++i) { + wg_choices.push_back(sg_size * i); + } + } + std::sort(std::begin(wg_choices), std::end(wg_choices)); + + size_t wg = 1; + for (size_t i = 0; i < wg_choices.size(); ++i) { + if (wg_choices[i] == wg) { + continue; + } + wg = wg_choices[i]; + size_t n_groups = ((reduction_nelems + wg - 1) / wg); + if (n_groups == 1) + break; + } + + return wg; +} + +typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)( + sycl::queue, + size_t, + size_t, + const char *, + char *, + int, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + int, + const py::ssize_t *, + py::ssize_t, + const std::vector &); + +template +class sum_reduction_over_group_with_atomics_krn; + +template +class sum_reduction_over_group_with_atomics_1d_krn; + +template +sycl::event sum_reduction_over_group_with_atomics_strided_impl( + sycl::queue exec_q, + size_t iter_nelems, // number of reductions (num. of rows in a matrix + // when reducing over rows) + size_t reduction_nelems, // size of each reduction (length of rows, i.e. + // number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const py::ssize_t *iter_shape_and_strides, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + int red_nd, + const py::ssize_t *reduction_shape_stride, + py::ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = resTy{0}; + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const py::ssize_t *const &res_shape = iter_shape_and_strides; + const py::ssize_t *const &res_strides = + iter_shape_and_strides + 2 * iter_nd; + IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, res_strides); + + cgh.depends_on(depends); + + cgh.parallel_for(sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = identity_val; + }); + }); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + constexpr size_t preferrered_reductions_per_wi = 4; + size_t reductions_per_wi = + (reduction_nelems < preferrered_reductions_per_wi * wg) + ? ((reduction_nelems + wg - 1) / wg) + : preferrered_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto globalRange = sycl::range<2>{iter_nelems, reduction_groups * wg}; + auto localRange = sycl::range<2>{1, wg}; + + using KernelName = class sum_reduction_over_group_with_atomics_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + + cgh.parallel_for( + sycl::nd_range<2>(globalRange, localRange), + ReductionOverGroupWithAtomicFunctor( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + reductions_per_wi)); + }); + + return comp_ev; +} + +// Contig + +typedef sycl::event (*sum_reduction_contig_impl_fn_ptr)( + sycl::queue, + size_t, + size_t, + const char *, + char *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + const std::vector &); + +/* @brief Reduce rows in a matrix */ +template +sycl::event sum_reduction_over_group_with_atomics_contig_impl( + sycl::queue exec_q, + size_t iter_nelems, // number of reductions (num. of rows in a matrix + // when reducing over rows) + size_t reduction_nelems, // size of each reduction (length of rows, i.e. + // number of columns) + const char *arg_cp, + char *res_cp, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + py::ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = resTy{0}; + + sycl::event res_init_ev = + exec_q.fill(res_tp, resTy(identity_val), iter_nelems, depends); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<2>(reduction_nelems, sg_sizes); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using RowsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + RowsIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + RowsIndexerT columns_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_nelems)}; + NoOpIndexerT result_indexer{}; + InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + ReductionIndexerT reduction_indexer{}; + + constexpr size_t preferrered_reductions_per_wi = 8; + size_t reductions_per_wi = + (reduction_nelems < preferrered_reductions_per_wi * wg) + ? ((reduction_nelems + wg - 1) / wg) + : preferrered_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto globalRange = sycl::range<2>{iter_nelems, reduction_groups * wg}; + auto localRange = sycl::range<2>{1, wg}; + + using KernelName = class sum_reduction_over_group_with_atomics_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + + cgh.parallel_for( + sycl::nd_range<2>(globalRange, localRange), + ReductionOverGroupWithAtomicFunctor( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + reductions_per_wi)); + }); + + return comp_ev; +} + +/* ======================= Reduction, using sycl::reduce_over_group, but not + * using atomic_ref ========================= */ + +template +struct ReductionOverGroupNoAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t reductions_per_wi = 16; + +public: + ReductionOverGroupNoAtomicFunctor( + const argT *data, + outT *res, + ReductionOp reduction_op, + const outT &identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<2> it) const + { + + size_t iter_gid = it.get_global_id(0); + size_t reduction_batch_id = it.get_group(1); + size_t reduction_lid = it.get_local_id(1); + size_t wg = it.get_local_range(1); // 0 <= reduction_lid < wg + + // work-items sums over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (size_t m = 0; m < reductions_per_wi; ++m) { + size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl(inp_[inp_offset]); + + local_red_val = reduction_op_(local_red_val, val); + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, identity_, reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_iter_offset * it.get_group_range(1) + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +template +class sum_reduction_over_group_temps_krn; + +template +sycl::event sum_reduction_over_group_temps_strided_impl( + sycl::queue exec_q, + size_t iter_nelems, // number of reductions (num. of rows in a matrix + // when reducing over rows) + size_t reduction_nelems, // size of each reduction (length of rows, i.e. + // number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const py::ssize_t *iter_shape_and_strides, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + int red_nd, + const py::ssize_t *reduction_shape_stride, + py::ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = resTy{0}; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferrered_reductions_per_wi = 4; + size_t max_wg = d.get_info(); + + size_t reductions_per_wi(preferrered_reductions_per_wi); + if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { + // reduction only requries 1 work-group, can output directly to res + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, + iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + wg = max_wg; + reductions_per_wi = (reduction_nelems + wg - 1) / wg; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<2>{iter_nelems, reduction_groups * wg}; + auto localRange = sycl::range<2>{1, wg}; + + using KernelName = class sum_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<2>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + reductions_per_wi)); + }); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = + sycl::malloc_device(iter_nelems * reduction_groups, exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unabled to allocate device_memory"); + } + + sycl::event first_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + // Only 2*iter_nd entries describing shape and strides of iterated + // dimensions of input array from iter_shape_and_strides are going + // to be accessed by inp_indexer + InputIndexerT inp_indexer(iter_nd, iter_arg_offset, + iter_shape_and_strides); + ResIndexerT noop_tmp_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + noop_tmp_indexer}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + auto globalRange = + sycl::range<2>{iter_nelems, reduction_groups * wg}; + auto localRange = sycl::range<2>{1, wg}; + + using KernelName = class sum_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<2>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + arg_tp, partially_reduced_tmp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + preferrered_reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferrered_reductions_per_wi * max_wg) { + size_t reduction_groups_ = + (remaining_reduction_nelems + + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + if (partially_reduced_tmp2 == nullptr) { + partially_reduced_tmp2 = sycl::malloc_device( + iter_nelems * reduction_groups_, exec_q); + + if (partially_reduced_tmp2 == nullptr) { + dependent_ev.wait(); + sycl::free(partially_reduced_tmp, exec_q); + + throw std::runtime_error( + "Unable to allocate device memory"); + } + + temp2_arg = partially_reduced_tmp2; + } + + // keep reducing + sycl::event partial_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{ + inp_indexer, res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<2>{iter_nelems, reduction_groups_ * wg}; + auto localRange = sycl::range<2>{1, wg}; + + using KernelName = class sum_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<2>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, + preferrered_reductions_per_wi)); + }); + // FIXME: may be unnecessary + partial_reduction_ev.wait(); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = partial_reduction_ev; + } + + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{iter_nd, iter_res_offset, + /* shape */ iter_shape_and_strides, + /*s trides */ iter_shape_and_strides + + 2 * iter_nd}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = (remaining_reduction_nelems + wg - 1) / wg; + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<2>{iter_nelems, reduction_groups * wg}; + auto localRange = sycl::range<2>{1, wg}; + + using KernelName = class sum_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<2>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, reductions_per_wi)); + }); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + sycl::context ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, partially_reduced_tmp, partially_reduced_tmp2] { + sycl::free(partially_reduced_tmp, ctx); + if (partially_reduced_tmp2) { + sycl::free(partially_reduced_tmp2, ctx); + } + }); + }); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForSumReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForSumReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-throug + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct SumOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + return dpctl::tensor::kernels:: + sum_reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + return dpctl::tensor::kernels:: + sum_reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxisAtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + return dpctl::tensor::kernels:: + sum_reduction_over_group_with_atomics_contig_impl; + } + else { + return nullptr; + } + } +}; + +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions.cpp b/dpctl/tensor/libtensor/source/reductions.cpp new file mode 100644 index 0000000000..d0c21a39fa --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions.cpp @@ -0,0 +1,469 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#include "dpctl4pybind11.hpp" +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "reductions.hpp" + +#include "simplify_iteration_space.hpp" +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +bool check_atomic_support(const sycl::queue &exec_q, + sycl::usm::alloc usm_alloc_type, + bool require_atomic64 = false) +{ + bool supports_atomics = false; + + const sycl::device &dev = exec_q.get_device(); + if (require_atomic64) { + if (!dev.has(sycl::aspect::atomic64)) + return false; + } + + switch (usm_alloc_type) { + case sycl::usm::alloc::shared: + supports_atomics = dev.has(sycl::aspect::usm_atomic_shared_allocations); + break; + case sycl::usm::alloc::host: + supports_atomics = dev.has(sycl::aspect::usm_atomic_host_allocations); + break; + case sycl::usm::alloc::device: + supports_atomics = true; + break; + default: + supports_atomics = false; + } + + return supports_atomics; +} + +using dpctl::tensor::kernels::sum_reduction_strided_impl_fn_ptr; +static sum_reduction_strided_impl_fn_ptr + sum_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static sum_reduction_strided_impl_fn_ptr + sum_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::sum_reduction_contig_impl_fn_ptr; +static sum_reduction_contig_impl_fn_ptr + sum_over_axis_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +std::pair py_sum_over_axis( + dpctl::tensor::usm_ndarray src, + int trailing_dims_to_reduce, // sum over this many trailing indexes + dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + size_t dst_nelems = dst.get_size(); + + size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // FIXME: check that dst and src do not overlap + // check that dst is ample enough (memory span is sufficient + // to accommodate all elements) + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + int dst_itemsize = dst.get_elemsize(); + bool supports_atomics = false; + + switch (dst_itemsize) { + case sizeof(float): + { + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + supports_atomics = check_atomic_support(exec_q, usm_type); + } break; + case sizeof(double): + { + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + + constexpr bool require_atomic64 = true; + supports_atomics = + check_atomic_support(exec_q, usm_type, require_atomic64); + } break; + } + + // handle special case when both reduction and iteration are 1D contiguous + // and can be done with atomics + if (supports_atomics) { + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 1)) + { + auto fn = sum_over_axis_contig_atomic_dispatch_table[src_typeid] + [dst_typeid]; + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event sum_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {sum_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, sum_over_axis_contig_ev); + } + } + } + + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_1; + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_reduction_shape; + shT simplified_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + simplify_iteration_space_1( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + simplified_reduction_shape, simplified_reduction_src_strides, + reduction_src_offset); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if (supports_atomics && (reduction_nd == 1) && + (simplified_reduction_src_strides[0] == 1) && (iteration_nd == 1) && + ((simplified_iteration_shape[0] == 1) || + ((simplified_iteration_dst_strides[0] == 1) && + (static_cast(simplified_iteration_src_strides[0]) == + reduction_nelems)))) + { + auto fn = + sum_over_axis_contig_atomic_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + sycl::event sum_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, iteration_dst_offset, + reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {sum_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, sum_over_axis_contig_ev); + } + } + + using dpctl::tensor::kernels::sum_reduction_strided_impl_fn_ptr; + sum_reduction_strided_impl_fn_ptr fn = nullptr; + + if (supports_atomics) { + fn = + sum_over_axis_strided_atomic_dispatch_table[src_typeid][dst_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = sum_over_axis_strided_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + } + + std::vector host_task_events{}; + const auto &iter_src_dst_metadata_packing_triple_ = + dpctl::tensor::offset_utils::device_allocate_and_pack( + exec_q, host_task_events, simplified_iteration_shape, + simplified_iteration_src_strides, simplified_iteration_dst_strides); + py::ssize_t *iter_shape_and_strides = + std::get<0>(iter_src_dst_metadata_packing_triple_); + if (iter_shape_and_strides == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_iter_metadata_ev = + std::get<2>(iter_src_dst_metadata_packing_triple_); + + const auto &reduction_metadata_packing_triple_ = + dpctl::tensor::offset_utils::device_allocate_and_pack( + exec_q, host_task_events, simplified_reduction_shape, + simplified_reduction_src_strides); + py::ssize_t *reduction_shape_stride = + std::get<0>(reduction_metadata_packing_triple_); + if (reduction_shape_stride == nullptr) { + sycl::event::wait(host_task_events); + sycl::free(iter_shape_and_strides, exec_q); + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_reduction_metadata_ev = + std::get<2>(reduction_metadata_packing_triple_); + + std::vector all_deps; + all_deps.reserve(depends.size() + 2); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_iter_metadata_ev); + all_deps.push_back(copy_reduction_metadata_ev); + + auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_nd, iter_shape_and_strides, + iteration_src_offset, iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(comp_ev); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, iter_shape_and_strides, reduction_shape_stride] { + sycl::free(iter_shape_and_strides, ctx); + sycl::free(reduction_shape_stride, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, comp_ev); +} + +bool py_sum_over_axis_dtype_supported(py::dtype input_dtype, + py::dtype output_dtype, + const std::string &dst_usm_type, + sycl::queue q) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + using dpctl::tensor::kernels::sum_reduction_strided_impl_fn_ptr; + sum_reduction_strided_impl_fn_ptr fn = nullptr; + + sycl::usm::alloc kind = sycl::usm::alloc::unknown; + + if (dst_usm_type == "device") { + kind = sycl::usm::alloc::device; + } + else if (dst_usm_type == "shared") { + kind = sycl::usm::alloc::shared; + } + else if (dst_usm_type == "host") { + kind = sycl::usm::alloc::host; + } + else { + throw py::value_error("Unrecognized `dst_usm_type` argument."); + } + + bool supports_atomics = false; + + switch (output_dtype.itemsize()) { + case sizeof(float): + { + supports_atomics = check_atomic_support(q, kind); + } break; + case sizeof(double): + { + constexpr bool require_atomic64 = true; + supports_atomics = check_atomic_support(q, kind, require_atomic64); + } break; + } + + if (supports_atomics) { + fn = + sum_over_axis_strided_atomic_dispatch_table[arg_typeid][out_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = sum_over_axis_strided_temps_dispatch_table[arg_typeid][out_typeid]; + } + + return (fn != nullptr); +} + +void populate_sum_over_axis_dispatch_table(void) +{ + using dpctl::tensor::kernels::sum_reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::sum_reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::SumOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxisAtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(sum_over_axis_contig_atomic_dispatch_table); +} + +namespace py = pybind11; + +void init_reduction_functions(py::module_ m) +{ + populate_sum_over_axis_dispatch_table(); + + m.def("_sum_over_axis", &py_sum_over_axis, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + m.def("_sum_over_axis_dtype_supported", &py_sum_over_axis_dtype_supported, + "", py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions.hpp b/dpctl/tensor/libtensor/source/reductions.hpp new file mode 100644 index 0000000000..ac612ec1f7 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions.hpp @@ -0,0 +1,40 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include +#include + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_reduction_functions(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py new file mode 100644 index 0000000000..5cd74f3309 --- /dev/null +++ b/dpctl/tests/test_tensor_sum.py @@ -0,0 +1,100 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +_all_dtypes = [ + "?", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] +_usm_types = ["device", "shared", "host"] + + +@pytest.mark.parametrize("arg_dtype", _all_dtypes) +def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.sum(m) + + assert isinstance(r, dpt.usm_ndarray) + if m.dtype.kind == "i": + assert r.dtype.kind == "i" + elif m.dtype.kind == "u": + assert r.dtype.kind == "u" + elif m.dtype.kind == "f": + assert r.dtype.kind == "f" + elif m.dtype.kind == "c": + assert r.dtype.kind == "c" + assert (dpt.asnumpy(r) == 100).all() + + m = dpt.ones(200, dtype=arg_dtype)[:1:-2] + r = dpt.sum(m) + assert (dpt.asnumpy(r) == 99).all() + + +@pytest.mark.parametrize("arg_dtype", _all_dtypes) +@pytest.mark.parametrize("out_dtype", _all_dtypes[1:]) +def test_sum_arg_out_dtype_matrix(arg_dtype, out_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + skip_if_dtype_not_supported(out_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.sum(m, dtype=out_dtype) + + assert isinstance(r, dpt.usm_ndarray) + assert r.dtype == dpt.dtype(out_dtype) + assert (dpt.asnumpy(r) == 100).all() + + +def test_sum_axis(): + get_queue_or_skip() + + m = dpt.ones((3, 4, 5, 6, 7), dtype="i4") + s = dpt.sum(m, axis=(1, 2, -1)) + + assert isinstance(s, dpt.usm_ndarray) + assert s.shape == (3, 6) + assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all() + + +def test_sum_keepdims(): + get_queue_or_skip() + + m = dpt.ones((3, 4, 5, 6, 7), dtype="i4") + s = dpt.sum(m, axis=(1, 2, -1), keepdims=True) + + assert isinstance(s, dpt.usm_ndarray) + assert s.shape == (3, 1, 1, 6, 1) + assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all() From 27d314e871e40a5dbe5dd5a0f26c7f4ecf0c29f2 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 22 May 2023 20:08:33 -0500 Subject: [PATCH 3/9] Make sure reduce_per_wi is non-zero --- dpctl/tensor/libtensor/include/kernels/reductions.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index 3da09c32bf..b5ddbe6dd8 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -259,7 +259,7 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( constexpr size_t preferrered_reductions_per_wi = 4; size_t reductions_per_wi = (reduction_nelems < preferrered_reductions_per_wi * wg) - ? ((reduction_nelems + wg - 1) / wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) : preferrered_reductions_per_wi; size_t reduction_groups = @@ -349,7 +349,7 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl( constexpr size_t preferrered_reductions_per_wi = 8; size_t reductions_per_wi = (reduction_nelems < preferrered_reductions_per_wi * wg) - ? ((reduction_nelems + wg - 1) / wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) : preferrered_reductions_per_wi; size_t reduction_groups = @@ -514,7 +514,8 @@ sycl::event sum_reduction_over_group_temps_strided_impl( reduction_shape_stride}; wg = max_wg; - reductions_per_wi = (reduction_nelems + wg - 1) / wg; + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); size_t reduction_groups = (reduction_nelems + reductions_per_wi * wg - 1) / @@ -698,7 +699,8 @@ sycl::event sum_reduction_over_group_temps_strided_impl( ReductionIndexerT reduction_indexer{}; wg = max_wg; - reductions_per_wi = (remaining_reduction_nelems + wg - 1) / wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); size_t reduction_groups = (remaining_reduction_nelems + reductions_per_wi * wg - 1) / From 0327cb168577b06212c61fc61436062a10662014 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 22 May 2023 20:08:57 -0500 Subject: [PATCH 4/9] Add a test for sum of an empty array --- dpctl/tests/test_tensor_sum.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index 5cd74f3309..e53c51766d 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -78,6 +78,14 @@ def test_sum_arg_out_dtype_matrix(arg_dtype, out_dtype): assert (dpt.asnumpy(r) == 100).all() +def test_sum_empty(): + get_queue_or_skip() + x = dpt.empty((0,), dtype="u1") + y = dpt.sum(x) + assert y.shape == tuple() + assert int(y) == 0 + + def test_sum_axis(): get_queue_or_skip() From dd821c5400e462ebdef2be1c6a5570b269f9c5e6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 23 May 2023 07:43:46 -0500 Subject: [PATCH 5/9] Removed unneeded synchronization point. --- dpctl/tensor/libtensor/include/kernels/reductions.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index b5ddbe6dd8..2ab697bd35 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -666,8 +666,6 @@ sycl::event sum_reduction_over_group_temps_strided_impl( remaining_reduction_nelems, preferrered_reductions_per_wi)); }); - // FIXME: may be unnecessary - partial_reduction_ev.wait(); remaining_reduction_nelems = reduction_groups_; std::swap(temp_arg, temp2_arg); From 02909541e984c277cfe5ce92ac33f44ab9744a52 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 24 May 2023 17:36:04 -0500 Subject: [PATCH 6/9] Implemented PR feedback --- .../libtensor/include/kernels/reductions.hpp | 70 +++++------------- .../libtensor/include/utils/sycl_utils.hpp | 71 +++++++++++++++++++ 2 files changed, 88 insertions(+), 53 deletions(-) create mode 100644 dpctl/tensor/libtensor/include/utils/sycl_utils.hpp diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index 2ab697bd35..4e7dbedd9d 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -32,6 +32,7 @@ #include "pybind11/pybind11.h" #include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" @@ -150,35 +151,6 @@ struct ReductionOverGroupWithAtomicFunctor } }; -template -size_t choose_workgroup_size(const size_t reduction_nelems, - const std::vector &sg_sizes) -{ - std::vector wg_choices; - wg_choices.reserve(f * sg_sizes.size()); - - for (const auto &sg_size : sg_sizes) { -#pragma unroll - for (size_t i = 1; i <= f; ++i) { - wg_choices.push_back(sg_size * i); - } - } - std::sort(std::begin(wg_choices), std::end(wg_choices)); - - size_t wg = 1; - for (size_t i = 0; i < wg_choices.size(); ++i) { - if (wg_choices[i] == wg) { - continue; - } - wg = wg_choices[i]; - size_t n_groups = ((reduction_nelems + wg - 1) / wg); - if (n_groups == 1) - break; - } - - return wg; -} - typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)( sycl::queue, size_t, @@ -200,6 +172,8 @@ class sum_reduction_over_group_with_atomics_krn; template class sum_reduction_over_group_with_atomics_1d_krn; +using dpctl::tensor::sycl_utils::choose_workgroup_size; + template sycl::event sum_reduction_over_group_with_atomics_strided_impl( sycl::queue exec_q, @@ -548,13 +522,22 @@ sycl::event sum_reduction_over_group_temps_strided_impl( (preferrered_reductions_per_wi * wg); assert(reduction_groups > 1); - resTy *partially_reduced_tmp = - sycl::malloc_device(iter_nelems * reduction_groups, exec_q); + size_t second_iter_reduction_groups_ = + (reduction_groups + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (reduction_groups + second_iter_reduction_groups_), + exec_q); resTy *partially_reduced_tmp2 = nullptr; if (partially_reduced_tmp == nullptr) { throw std::runtime_error("Unabled to allocate device_memory"); } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + } sycl::event first_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -610,21 +593,6 @@ sycl::event sum_reduction_over_group_temps_strided_impl( (preferrered_reductions_per_wi * wg); assert(reduction_groups_ > 1); - if (partially_reduced_tmp2 == nullptr) { - partially_reduced_tmp2 = sycl::malloc_device( - iter_nelems * reduction_groups_, exec_q); - - if (partially_reduced_tmp2 == nullptr) { - dependent_ev.wait(); - sycl::free(partially_reduced_tmp, exec_q); - - throw std::runtime_error( - "Unable to allocate device memory"); - } - - temp2_arg = partially_reduced_tmp2; - } - // keep reducing sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { @@ -727,13 +695,9 @@ sycl::event sum_reduction_over_group_temps_strided_impl( cgh.depends_on(final_reduction_ev); sycl::context ctx = exec_q.get_context(); - cgh.host_task( - [ctx, partially_reduced_tmp, partially_reduced_tmp2] { - sycl::free(partially_reduced_tmp, ctx); - if (partially_reduced_tmp2) { - sycl::free(partially_reduced_tmp2, ctx); - } - }); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); }); // FIXME: do not return host-task event diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp new file mode 100644 index 0000000000..2fc7b02efa --- /dev/null +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -0,0 +1,71 @@ +//=== sycl_utils.hpp - Implementation of utilities ------- *-C++-*/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines utilities used for kernel submission. +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include + +namespace dpctl +{ +namespace tensor +{ +namespace sycl_utils +{ + +/*! @brief Find the smallest multiple of supported sub-group size larger than + * nelems */ +template +size_t choose_workgroup_size(const size_t nelems, + const std::vector &sg_sizes) +{ + std::vector wg_choices; + wg_choices.reserve(f * sg_sizes.size()); + + for (const auto &sg_size : sg_sizes) { +#pragma unroll + for (size_t i = 1; i <= f; ++i) { + wg_choices.push_back(sg_size * i); + } + } + std::sort(std::begin(wg_choices), std::end(wg_choices)); + + size_t wg = 1; + for (size_t i = 0; i < wg_choices.size(); ++i) { + if (wg_choices[i] == wg) { + continue; + } + wg = wg_choices[i]; + size_t n_groups = ((nelems + wg - 1) / wg); + if (n_groups == 1) + break; + } + + return wg; +} + +} // namespace sycl_utils +} // namespace tensor +} // namespace dpctl From f9ddea7b06a5cd5624327b9defdc78ab1ff14738 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 25 May 2023 14:04:01 -0500 Subject: [PATCH 7/9] Renamed source/reductions.*pp to sum_reductions.*pp Added MemoryOverap check, and the array range check per FIXME note and PR review feedback. Also consolidated transfer of iteration/reduction metadata into a single operation to improve test stability on CPU and improve overall host submission overhead time. --- dpctl/tensor/CMakeLists.txt | 2 +- .../{reductions.cpp => sum_reductions.cpp} | 71 +++++++++++-------- .../{reductions.hpp => sum_reductions.hpp} | 0 dpctl/tensor/libtensor/source/tensor_py.cpp | 2 +- 4 files changed, 42 insertions(+), 33 deletions(-) rename dpctl/tensor/libtensor/source/{reductions.cpp => sum_reductions.cpp} (90%) rename dpctl/tensor/libtensor/source/{reductions.hpp => sum_reductions.hpp} (100%) diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index cc122524e8..a331ab6b74 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -47,7 +47,7 @@ pybind11_add_module(${python_module_name} MODULE ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp ) set(_clang_prefix "") if (WIN32) diff --git a/dpctl/tensor/libtensor/source/reductions.cpp b/dpctl/tensor/libtensor/source/sum_reductions.cpp similarity index 90% rename from dpctl/tensor/libtensor/source/reductions.cpp rename to dpctl/tensor/libtensor/source/sum_reductions.cpp index d0c21a39fa..1907bcac2e 100644 --- a/dpctl/tensor/libtensor/source/reductions.cpp +++ b/dpctl/tensor/libtensor/source/sum_reductions.cpp @@ -35,9 +35,10 @@ #include #include "kernels/reductions.hpp" -#include "reductions.hpp" +#include "sum_reductions.hpp" #include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -135,9 +136,23 @@ std::pair py_sum_over_axis( reduction_nelems *= static_cast(src_shape_ptr[i]); } - // FIXME: check that dst and src do not overlap - // check that dst is ample enough (memory span is sufficient - // to accommodate all elements) + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + // destination must be ample enough to accomodate all elements + { + auto dst_offsets = dst.get_minmax_offsets(); + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Destination array can not accomodate all the " + "elements of source array."); + } + } int src_typenum = src.get_typenum(); int dst_typenum = dst.get_typenum(); @@ -297,38 +312,33 @@ std::pair py_sum_over_axis( } std::vector host_task_events{}; - const auto &iter_src_dst_metadata_packing_triple_ = - dpctl::tensor::offset_utils::device_allocate_and_pack( - exec_q, host_task_events, simplified_iteration_shape, - simplified_iteration_src_strides, simplified_iteration_dst_strides); - py::ssize_t *iter_shape_and_strides = - std::get<0>(iter_src_dst_metadata_packing_triple_); - if (iter_shape_and_strides == nullptr) { + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + simplified_reduction_shape, simplified_reduction_src_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { throw std::runtime_error("Unable to allocate memory on device"); } - const auto ©_iter_metadata_ev = - std::get<2>(iter_src_dst_metadata_packing_triple_); + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); - const auto &reduction_metadata_packing_triple_ = - dpctl::tensor::offset_utils::device_allocate_and_pack( - exec_q, host_task_events, simplified_reduction_shape, - simplified_reduction_src_strides); + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; py::ssize_t *reduction_shape_stride = - std::get<0>(reduction_metadata_packing_triple_); - if (reduction_shape_stride == nullptr) { - sycl::event::wait(host_task_events); - sycl::free(iter_shape_and_strides, exec_q); - throw std::runtime_error("Unable to allocate memory on device"); - } - const auto ©_reduction_metadata_ev = - std::get<2>(reduction_metadata_packing_triple_); + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); std::vector all_deps; - all_deps.reserve(depends.size() + 2); + all_deps.reserve(depends.size() + 1); all_deps.resize(depends.size()); std::copy(depends.begin(), depends.end(), all_deps.begin()); - all_deps.push_back(copy_iter_metadata_ev); - all_deps.push_back(copy_reduction_metadata_ev); + all_deps.push_back(copy_metadata_ev); auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), iteration_nd, iter_shape_and_strides, @@ -339,9 +349,8 @@ std::pair py_sum_over_axis( sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(comp_ev); auto ctx = exec_q.get_context(); - cgh.host_task([ctx, iter_shape_and_strides, reduction_shape_stride] { - sycl::free(iter_shape_and_strides, ctx); - sycl::free(reduction_shape_stride, ctx); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); }); }); host_task_events.push_back(temp_cleanup_ev); diff --git a/dpctl/tensor/libtensor/source/reductions.hpp b/dpctl/tensor/libtensor/source/sum_reductions.hpp similarity index 100% rename from dpctl/tensor/libtensor/source/reductions.hpp rename to dpctl/tensor/libtensor/source/sum_reductions.hpp diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 5f1539e56b..4b36dea534 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -44,8 +44,8 @@ #include "full_ctor.hpp" #include "integer_advanced_indexing.hpp" #include "linear_sequences.hpp" -#include "reductions.hpp" #include "simplify_iteration_space.hpp" +#include "sum_reductions.hpp" #include "triul_ctor.hpp" #include "utils/memory_overlap.hpp" #include "utils/strided_iters.hpp" From de05f4577a7205fdc6ae5923c9232a075f476336 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 25 May 2023 16:12:38 -0500 Subject: [PATCH 8/9] Added use of sequential reduction --- .../libtensor/include/kernels/reductions.hpp | 358 +++++++++++++----- 1 file changed, 254 insertions(+), 104 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index 4e7dbedd9d..cab8e85540 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -46,8 +46,61 @@ namespace tensor namespace kernels { -/* ================ Reduction, using sycl::reduce_over_group, and - * sycl::atomic_ref ================ */ +template +struct SequentialReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + +public: + SequentialReduction(const argT *inp, + outT *res, + ReductionOp reduction_op, + const outT &identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + size_t reduction_size) + : inp_(inp), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); + const py::ssize_t &inp_iter_offset = + inp_out_iter_offsets_.get_first_offset(); + const py::ssize_t &out_iter_offset = + inp_out_iter_offsets_.get_second_offset(); + + outT red_val(identity_); + for (size_t m = 0; m < reduction_max_gid_; ++m) { + const py::ssize_t inp_reduction_offset = + inp_reduced_dims_indexer_(m); + const py::ssize_t inp_offset = + inp_iter_offset + inp_reduction_offset; + + red_val = reduction_op_(red_val, inp_[inp_offset]); + } + + out_[out_iter_offset] = red_val; + } +}; + +/* === Reduction, using sycl::reduce_over_group, and sycl::atomic_ref === */ /* This kernel only works for outT with sizeof(outT) == 4, or sizeof(outT) == 8 @@ -169,8 +222,11 @@ typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)( template class sum_reduction_over_group_with_atomics_krn; -template -class sum_reduction_over_group_with_atomics_1d_krn; +template +class sum_reduction_seq_strided_krn; + +template +class sum_reduction_seq_contig_krn; using dpctl::tensor::sycl_utils::choose_workgroup_size; @@ -198,66 +254,114 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( using ReductionOpT = sycl::plus; constexpr resTy identity_val = resTy{0}; - sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { - using IndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - const py::ssize_t *const &res_shape = iter_shape_and_strides; - const py::ssize_t *const &res_strides = - iter_shape_and_strides + 2 * iter_nd; - IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, res_strides); + if (reduction_nelems < wg) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - cgh.depends_on(depends); + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, + iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; - cgh.parallel_for(sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { - auto res_offset = res_indexer(id[0]); - res_tp[res_offset] = identity_val; + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); }); - }); - const sycl::device &d = exec_q.get_device(); - const auto &sg_sizes = d.get_info(); - size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + return comp_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + const py::ssize_t *const &res_shape = iter_shape_and_strides; + const py::ssize_t *const &res_strides = + iter_shape_and_strides + 2 * iter_nd; + IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, + res_strides); - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; - using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + cgh.depends_on(depends); - InputOutputIterIndexerT in_out_iter_indexer{ - iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; - ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, - reduction_shape_stride}; + cgh.parallel_for(sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = identity_val; + }); + }); - constexpr size_t preferrered_reductions_per_wi = 4; - size_t reductions_per_wi = - (reduction_nelems < preferrered_reductions_per_wi * wg) - ? std::max(1, (reduction_nelems + wg - 1) / wg) - : preferrered_reductions_per_wi; + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); - size_t reduction_groups = - (reduction_nelems + reductions_per_wi * wg - 1) / - (reductions_per_wi * wg); - - auto globalRange = sycl::range<2>{iter_nelems, reduction_groups * wg}; - auto localRange = sycl::range<2>{1, wg}; - - using KernelName = class sum_reduction_over_group_with_atomics_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; - - cgh.parallel_for( - sycl::nd_range<2>(globalRange, localRange), - ReductionOverGroupWithAtomicFunctor( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems, - reductions_per_wi)); - }); - - return comp_ev; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, + iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + constexpr size_t preferrered_reductions_per_wi = 4; + size_t reductions_per_wi = + (reduction_nelems < preferrered_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferrered_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + if (reduction_groups > 1) { + const size_t &max_wg = + d.get_info(); + + if (reduction_nelems < preferrered_reductions_per_wi * max_wg) { + wg = max_wg; + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + } + } + + auto globalRange = + sycl::range<2>{iter_nelems, reduction_groups * wg}; + auto localRange = sycl::range<2>{1, wg}; + + using KernelName = class sum_reduction_over_group_with_atomics_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + + cgh.parallel_for( + sycl::nd_range<2>(globalRange, localRange), + ReductionOverGroupWithAtomicFunctor( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + reductions_per_wi)); + }); + + return comp_ev; + } } // Contig @@ -295,63 +399,109 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl( using ReductionOpT = sycl::plus; constexpr resTy identity_val = resTy{0}; - sycl::event res_init_ev = - exec_q.fill(res_tp, resTy(identity_val), iter_nelems, depends); - const sycl::device &d = exec_q.get_device(); const auto &sg_sizes = d.get_info(); size_t wg = choose_workgroup_size<2>(reduction_nelems, sg_sizes); - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using RowsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - RowsIndexerT, NoOpIndexerT>; - using ReductionIndexerT = NoOpIndexerT; - - RowsIndexerT columns_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_nelems)}; - NoOpIndexerT result_indexer{}; - InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, - result_indexer}; - ReductionIndexerT reduction_indexer{}; - - constexpr size_t preferrered_reductions_per_wi = 8; - size_t reductions_per_wi = - (reduction_nelems < preferrered_reductions_per_wi * wg) - ? std::max(1, (reduction_nelems + wg - 1) / wg) - : preferrered_reductions_per_wi; + if (reduction_nelems < wg) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - size_t reduction_groups = - (reduction_nelems + reductions_per_wi * wg - 1) / - (reductions_per_wi * wg); - - auto globalRange = sycl::range<2>{iter_nelems, reduction_groups * wg}; - auto localRange = sycl::range<2>{1, wg}; - - using KernelName = class sum_reduction_over_group_with_atomics_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; - - cgh.parallel_for( - sycl::nd_range<2>(globalRange, localRange), - ReductionOverGroupWithAtomicFunctor( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems, - reductions_per_wi)); - }); - - return comp_ev; + using InputIterIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{}; + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); + }); + + return comp_ev; + } + else { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using RowsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + RowsIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + RowsIndexerT columns_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_nelems)}; + NoOpIndexerT result_indexer{}; + InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + ReductionIndexerT reduction_indexer{}; + + constexpr size_t preferrered_reductions_per_wi = 8; + size_t reductions_per_wi = + (reduction_nelems < preferrered_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferrered_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + if (reduction_groups > 1) { + const size_t &max_wg = + d.get_info(); + + if (reduction_nelems < preferrered_reductions_per_wi * max_wg) { + wg = max_wg; + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + } + } + + auto globalRange = + sycl::range<2>{iter_nelems, reduction_groups * wg}; + auto localRange = sycl::range<2>{1, wg}; + + using KernelName = class sum_reduction_over_group_with_atomics_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + + cgh.parallel_for( + sycl::nd_range<2>(globalRange, localRange), + ReductionOverGroupWithAtomicFunctor( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + reductions_per_wi)); + }); + + return comp_ev; + } } -/* ======================= Reduction, using sycl::reduce_over_group, but not - * using atomic_ref ========================= */ +/* = Reduction, using sycl::reduce_over_group, but not using atomic_ref = */ template Date: Thu, 25 May 2023 16:13:09 -0500 Subject: [PATCH 9/9] Minor optimization to boolean_reductions.hpp --- dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp index 8418fca83c..a81a385f14 100644 --- a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp @@ -183,7 +183,7 @@ struct SequentialBooleanReduction void operator()(sycl::id<1> id) const { - auto inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); + auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); const py::ssize_t &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); const py::ssize_t &out_iter_offset =