From df6114dd3b67055fb64bd0ff231ac51a7778311a Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 28 Feb 2023 14:43:32 -0800 Subject: [PATCH 1/3] Added get_strides_vector and get_shape_vector --- dpctl/apis/include/dpctl4pybind11.hpp | 37 +++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index 921f231aa1..881451e381 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -856,6 +856,15 @@ class usm_ndarray : public py::object return api.UsmNDArray_GetShape_(raw_ar); } + const std::vector get_shape_vector() const + { + auto raw_sh = this->get_shape_raw(); + auto nd = this->get_ndim(); + + std::vector shape_vector(raw_sh, raw_sh + nd); + return shape_vector; + } + py::ssize_t get_shape(int i) const { auto shape_ptr = get_shape_raw(); @@ -870,6 +879,34 @@ class usm_ndarray : public py::object return api.UsmNDArray_GetStrides_(raw_ar); } + const std::vector get_strides_vector() const + { + auto raw_st = this->get_strides_raw(); + auto nd = this->get_ndim(); + + if (raw_st == nullptr) { + auto is_c_contig = this->is_c_contiguous(); + auto is_f_contig = this->is_f_contiguous(); + auto raw_sh = this->get_shape_raw(); + if (is_c_contig) { + const auto &contig_strides = c_contiguous_strides(nd, raw_sh); + return contig_strides; + } + else if (is_f_contig) { + const auto &contig_strides = f_contiguous_strides(nd, raw_sh); + return contig_strides; + } + else { + throw std::runtime_error("Invalid array encountered when " + "building strides"); + } + } + else { + std::vector st_vec(raw_st, raw_st + nd); + return st_vec; + } + } + py::ssize_t get_size() const { PyUSMArrayObject *raw_ar = this->usm_array_ptr(); From 900b1be1bc94378d585abfdfeb8afaa127129b2e Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 28 Feb 2023 19:55:12 -0800 Subject: [PATCH 2/3] Removed uses of `this->method` in usm_ndarray --- dpctl/apis/include/dpctl4pybind11.hpp | 40 +++++++++++++-------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index 881451e381..65f8467188 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -829,7 +829,7 @@ class usm_ndarray : public py::object char *get_data() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetData_(raw_ar); @@ -842,7 +842,7 @@ class usm_ndarray : public py::object int get_ndim() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetNDim_(raw_ar); @@ -850,7 +850,7 @@ class usm_ndarray : public py::object const py::ssize_t *get_shape_raw() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetShape_(raw_ar); @@ -858,8 +858,8 @@ class usm_ndarray : public py::object const std::vector get_shape_vector() const { - auto raw_sh = this->get_shape_raw(); - auto nd = this->get_ndim(); + auto raw_sh = get_shape_raw(); + auto nd = get_ndim(); std::vector shape_vector(raw_sh, raw_sh + nd); return shape_vector; @@ -873,7 +873,7 @@ class usm_ndarray : public py::object const py::ssize_t *get_strides_raw() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetStrides_(raw_ar); @@ -881,13 +881,13 @@ class usm_ndarray : public py::object const std::vector get_strides_vector() const { - auto raw_st = this->get_strides_raw(); - auto nd = this->get_ndim(); + auto raw_st = get_strides_raw(); + auto nd = get_ndim(); if (raw_st == nullptr) { - auto is_c_contig = this->is_c_contiguous(); - auto is_f_contig = this->is_f_contiguous(); - auto raw_sh = this->get_shape_raw(); + auto is_c_contig = is_c_contiguous(); + auto is_f_contig = is_f_contiguous(); + auto raw_sh = get_shape_raw(); if (is_c_contig) { const auto &contig_strides = c_contiguous_strides(nd, raw_sh); return contig_strides; @@ -909,7 +909,7 @@ class usm_ndarray : public py::object py::ssize_t get_size() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); int ndim = api.UsmNDArray_GetNDim_(raw_ar); @@ -926,7 +926,7 @@ class usm_ndarray : public py::object std::pair get_minmax_offsets() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); int nd = api.UsmNDArray_GetNDim_(raw_ar); @@ -960,7 +960,7 @@ class usm_ndarray : public py::object sycl::queue get_queue() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar); @@ -969,7 +969,7 @@ class usm_ndarray : public py::object int get_typenum() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetTypenum_(raw_ar); @@ -977,7 +977,7 @@ class usm_ndarray : public py::object int get_flags() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetFlags_(raw_ar); @@ -985,7 +985,7 @@ class usm_ndarray : public py::object int get_elemsize() const { - PyUSMArrayObject *raw_ar = this->usm_array_ptr(); + PyUSMArrayObject *raw_ar = usm_array_ptr(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return api.UsmNDArray_GetElementSize_(raw_ar); @@ -993,21 +993,21 @@ class usm_ndarray : public py::object bool is_c_contiguous() const { - int flags = this->get_flags(); + int flags = get_flags(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return static_cast(flags & api.USM_ARRAY_C_CONTIGUOUS_); } bool is_f_contiguous() const { - int flags = this->get_flags(); + int flags = get_flags(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return static_cast(flags & api.USM_ARRAY_F_CONTIGUOUS_); } bool is_writable() const { - int flags = this->get_flags(); + int flags = get_flags(); auto const &api = ::dpctl::detail::dpctl_capi::get(); return static_cast(flags & api.USM_ARRAY_WRITABLE_); } From 7082912149c2c1e700b06d13098ce6f3e486f06e Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 1 Mar 2023 07:30:35 -0600 Subject: [PATCH 3/3] Apply suggestions from code review Removed const qualifier from returned vector. --- dpctl/apis/include/dpctl4pybind11.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index 65f8467188..a26af2f51c 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -856,7 +856,7 @@ class usm_ndarray : public py::object return api.UsmNDArray_GetShape_(raw_ar); } - const std::vector get_shape_vector() const + std::vector get_shape_vector() const { auto raw_sh = get_shape_raw(); auto nd = get_ndim(); @@ -879,7 +879,7 @@ class usm_ndarray : public py::object return api.UsmNDArray_GetStrides_(raw_ar); } - const std::vector get_strides_vector() const + std::vector get_strides_vector() const { auto raw_st = get_strides_raw(); auto nd = get_ndim();