diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index 921f231aa1..a26af2f51c 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -829,7 +829,7 @@ class usm_ndarray : public py::object char *get_data() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetData_(raw_ar); @@ -842,7 +842,7 @@ class usm_ndarray : public py::object int get_ndim() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetNDim_(raw_ar); @@ -850,12 +850,21 @@ class usm_ndarray : public py::object const py::ssize_t *get_shape_raw() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetShape_(raw_ar); } + std::vector get_shape_vector() const + { + auto raw_sh = get_shape_raw(); + auto nd = get_ndim(); + + std::vector shape_vector(raw_sh, raw_sh + nd); + return shape_vector; + } + py::ssize_t get_shape(int i) const { auto shape_ptr = get_shape_raw(); @@ -864,15 +873,43 @@ class usm_ndarray : public py::object const py::ssize_t *get_strides_raw() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetStrides_(raw_ar); } + std::vector get_strides_vector() const + { + auto raw_st = get_strides_raw(); + auto nd = get_ndim(); + + if (raw_st == nullptr) { + auto is_c_contig = is_c_contiguous(); + auto is_f_contig = is_f_contiguous(); + auto raw_sh = get_shape_raw(); + if (is_c_contig) { + const auto &contig_strides = c_contiguous_strides(nd, raw_sh); + return contig_strides; + } + else if (is_f_contig) { + const auto &contig_strides = f_contiguous_strides(nd, raw_sh); + return contig_strides; + } + else { + throw std::runtime_error("Invalid array encountered when " + "building strides"); + } + } + else { + std::vector st_vec(raw_st, raw_st + nd); + return st_vec; + } + } + py::ssize_t get_size() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); int ndim = api.UsmNDArray_GetNDim_(raw_ar); @@ -889,7 +926,7 @@ class usm_ndarray : public py::object std::pair get_minmax_offsets() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); int nd = api.UsmNDArray_GetNDim_(raw_ar); @@ -923,7 +960,7 @@ class usm_ndarray : public py::object sycl::queue get_queue() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar); @@ -932,7 +969,7 @@ class usm_ndarray : public py::object int get_typenum() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetTypenum_(raw_ar); @@ -940,7 +977,7 @@ class usm_ndarray : public py::object int get_flags() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetFlags_(raw_ar); @@ -948,7 +985,7 @@ class usm_ndarray : public py::object int get_elemsize() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetElementSize_(raw_ar); @@ -956,21 +993,21 @@ class usm_ndarray : public py::object bool is_c_contiguous() const { - int flags = this->get_flags(); + int flags = get_flags(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return static_cast(flags & api.USM_ARRAY_C_CONTIGUOUS_); } bool is_f_contiguous() const { - int flags = this->get_flags(); + int flags = get_flags(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return static_cast(flags & api.USM_ARRAY_F_CONTIGUOUS_); } bool is_writable() const { - int flags = this->get_flags(); + int flags = get_flags(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return static_cast(flags & api.USM_ARRAY_WRITABLE_); }