diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index cc90480b50..cc9bfa3171 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -28,7 +28,9 @@ #include "dpctl_capi.h" #include #include +#include #include +#include namespace py = pybind11; @@ -497,4 +499,38 @@ class usm_ndarray : public py::object }; } // end namespace tensor + +namespace utils +{ + +template +sycl::event keep_args_alive(sycl::queue q, + const py::object (&py_objs)[num], + const std::vector &depends = {}) +{ + sycl::event host_task_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + std::array, num> shp_arr; + for (std::size_t i = 0; i < num; ++i) { + shp_arr[i] = std::make_shared(py_objs[i]); + shp_arr[i]->inc_ref(); + } + cgh.host_task([=]() { + bool guard = (Py_IsInitialized() && !_Py_IsFinalizing()); + if (guard) { + PyGILState_STATE gstate; + gstate = PyGILState_Ensure(); + for (std::size_t i = 0; i < num; ++i) { + shp_arr[i]->dec_ref(); + } + PyGILState_Release(gstate); + } + }); + }); + + return host_task_ev; +} + +} // end namespace utils + } // end namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 4b378ae306..454d83d6ea 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -342,33 +342,7 @@ std::vector f_contiguous_strides(int nd, } } -template -sycl::event keep_args_alive(sycl::queue q, - const py::object (&py_objs)[num], - const std::vector &depends = {}) -{ - sycl::event host_task_ev = q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - std::array, num> shp_arr; - for (std::size_t i = 0; i < num; ++i) { - shp_arr[i] = std::make_shared(py_objs[i]); - shp_arr[i]->inc_ref(); - } - cgh.host_task([=]() { - bool guard = (Py_IsInitialized() && !_Py_IsFinalizing()); - if (guard) { - PyGILState_STATE gstate; - gstate = PyGILState_Ensure(); - for (std::size_t i = 0; i < num; ++i) { - shp_arr[i]->dec_ref(); - } - PyGILState_Release(gstate); - } - }); - }); - - return host_task_ev; -} +using dpctl::utils::keep_args_alive; void simplify_iteration_space(int &nd, const py::ssize_t *&shape, diff --git a/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp b/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp index 4a3df41f55..569428004f 100644 --- a/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp +++ b/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp @@ -8,34 +8,7 @@ namespace py = pybind11; -sycl::event keep_args_alive(sycl::queue q, - py::object o1, - py::object o2, - py::object o3, - const std::vector &depends = {}) -{ - sycl::event ht_event = q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - std::shared_ptr shp1 = std::make_shared(o1); - std::shared_ptr shp2 = std::make_shared(o2); - std::shared_ptr shp3 = std::make_shared(o3); - shp1->inc_ref(); - shp2->inc_ref(); - shp3->inc_ref(); - cgh.host_task([=]() { - bool guard = (Py_IsInitialized() && !_Py_IsFinalizing()); - if (guard) { - PyGILState_STATE gstate; - gstate = PyGILState_Ensure(); - shp1->dec_ref(); - shp2->dec_ref(); - shp3->dec_ref(); - PyGILState_Release(gstate); - } - }); - }); - return ht_event; -} +using dpctl::utils::keep_args_alive; std::pair gemv(sycl::queue q, @@ -131,7 +104,8 @@ gemv(sycl::queue q, throw std::runtime_error("Type dispatch ran into trouble."); } - sycl::event ht_event = keep_args_alive(q, matrix, vector, result, {res_ev}); + sycl::event ht_event = + keep_args_alive(q, {matrix, vector, result}, {res_ev}); return std::make_pair(ht_event, res_ev); }