Skip to content

Commit df6114d

Browse files
committed
Added get_strides_vector and get_shape_vector
1 parent c32423c commit df6114d

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,15 @@ class usm_ndarray : public py::object
856856
return api.UsmNDArray_GetShape_(raw_ar);
857857
}
858858

859+
const std::vector<py::ssize_t> get_shape_vector() const
860+
{
861+
auto raw_sh = this->get_shape_raw();
862+
auto nd = this->get_ndim();
863+
864+
std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
865+
return shape_vector;
866+
}
867+
859868
py::ssize_t get_shape(int i) const
860869
{
861870
auto shape_ptr = get_shape_raw();
@@ -870,6 +879,34 @@ class usm_ndarray : public py::object
870879
return api.UsmNDArray_GetStrides_(raw_ar);
871880
}
872881

882+
const std::vector<py::ssize_t> get_strides_vector() const
883+
{
884+
auto raw_st = this->get_strides_raw();
885+
auto nd = this->get_ndim();
886+
887+
if (raw_st == nullptr) {
888+
auto is_c_contig = this->is_c_contiguous();
889+
auto is_f_contig = this->is_f_contiguous();
890+
auto raw_sh = this->get_shape_raw();
891+
if (is_c_contig) {
892+
const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
893+
return contig_strides;
894+
}
895+
else if (is_f_contig) {
896+
const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
897+
return contig_strides;
898+
}
899+
else {
900+
throw std::runtime_error("Invalid array encountered when "
901+
"building strides");
902+
}
903+
}
904+
else {
905+
std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
906+
return st_vec;
907+
}
908+
}
909+
873910
py::ssize_t get_size() const
874911
{
875912
PyUSMArrayObject *raw_ar = this->usm_array_ptr();

0 commit comments

Comments
 (0)