diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index 09ecd066f2..830cc82078 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -671,6 +671,53 @@ struct type_caster> DPCTL_TYPE_CASTER(sycl::kernel_bundle, _("dpctl.program.SyclProgram")); }; + +/* This type caster associates + * ``sycl::half`` C++ class with Python :class:`float` for the purposes + * of generation of Python bindings by pybind11. + */ +template <> struct type_caster +{ +public: + bool load(handle src, bool convert) + { + double py_value; + + if (!src) { + return false; + } + + PyObject *source = src.ptr(); + + if (convert || PyFloat_Check(source)) { + py_value = PyFloat_AsDouble(source); + } + else { + return false; + } + + bool py_err = (py_value == double(-1)) && PyErr_Occurred(); + + if (py_err) { + PyErr_Clear(); + if (convert && (PyNumber_Check(source) != 0)) { + auto tmp = reinterpret_steal(PyNumber_Float(source)); + return load(tmp, false); + } + return false; + } + value = static_cast(py_value); + return true; + } + + static handle cast(sycl::half src, return_value_policy, handle) + { + return PyFloat_FromDouble(static_cast(src)); + } + + PYBIND11_TYPE_CASTER(sycl::half, _("float")); +}; + } // namespace detail } // namespace pybind11 diff --git a/dpctl/tensor/libtensor/source/full_ctor.cpp b/dpctl/tensor/libtensor/source/full_ctor.cpp index 7610ab3a11..41b3093652 100644 --- a/dpctl/tensor/libtensor/source/full_ctor.cpp +++ b/dpctl/tensor/libtensor/source/full_ctor.cpp @@ -36,7 +36,6 @@ #include "utils/type_utils.hpp" #include "full_ctor.hpp" -#include "unboxing_helper.hpp" namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; @@ -79,14 +78,7 @@ sycl::event full_contig_impl(sycl::queue &exec_q, char *dst_p, const std::vector &depends) { - dstTy fill_v; - - PythonObjectUnboxer unboxer{}; - try { - fill_v = unboxer(py_value); - } catch (const py::error_already_set &e) { - throw; - } + dstTy fill_v = py::cast(py_value); using dpctl::tensor::kernels::constructors::full_contig_impl; diff --git a/dpctl/tensor/libtensor/source/linear_sequences.cpp b/dpctl/tensor/libtensor/source/linear_sequences.cpp index aaf3e9e932..1a6b9811fe 100644 --- a/dpctl/tensor/libtensor/source/linear_sequences.cpp +++ b/dpctl/tensor/libtensor/source/linear_sequences.cpp @@ -36,7 +36,6 @@ #include "utils/type_utils.hpp" #include "linear_sequences.hpp" -#include "unboxing_helper.hpp" namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; @@ -86,16 +85,8 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q, char *array_data, const std::vector &depends) { - Ty start_v; - Ty step_v; - - const auto &unboxer = PythonObjectUnboxer{}; - try { - start_v = unboxer(start); - step_v = unboxer(step); - } catch (const py::error_already_set &e) { - throw; - } + Ty start_v = py::cast(start); + Ty step_v = py::cast(step); using dpctl::tensor::kernels::constructors::lin_space_step_impl; @@ -143,14 +134,8 @@ sycl::event lin_space_affine_impl(sycl::queue &exec_q, char *array_data, const std::vector &depends) { - Ty start_v, end_v; - const auto &unboxer = PythonObjectUnboxer{}; - try { - start_v = unboxer(start); - end_v = unboxer(end); - } catch (const py::error_already_set &e) { - throw; - } + Ty start_v = py::cast(start); + Ty end_v = py::cast(end); using dpctl::tensor::kernels::constructors::lin_space_affine_impl; diff --git a/dpctl/tensor/libtensor/source/unboxing_helper.hpp b/dpctl/tensor/libtensor/source/unboxing_helper.hpp deleted file mode 100644 index 36fd85af43..0000000000 --- a/dpctl/tensor/libtensor/source/unboxing_helper.hpp +++ /dev/null @@ -1,53 +0,0 @@ -//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2024 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===--------------------------------------------------------------------===// -/// -/// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions -//===--------------------------------------------------------------------===// - -#pragma once - -#include -#include - -namespace dpctl -{ -namespace tensor -{ -namespace py_internal -{ - -template struct PythonObjectUnboxer -{ - T operator()(const py::object &o) const - { - if constexpr (std::is_same_v) { - float tmp = py::cast(o); - return static_cast(tmp); - } - else { - return py::cast(o); - } - } -}; - -} // namespace py_internal -} // namespace tensor -} // namespace dpctl