Skip to content

Commit b45e174

Browse files
authored
Merge pull request #1090 from IntelPython/usm-ndarray-shape-stride-vectors
Added get_strides_vector and get_shape_vector
2 parents c32423c + 7082912 commit b45e174

File tree

1 file changed

+50
-13
lines changed

1 file changed

+50
-13
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ class usm_ndarray : public py::object
829829

830830
char *get_data() const
831831
{
832-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
832+
PyUSMArrayObject *raw_ar = usm_array_ptr();
833833

834834
auto const &api = ::dpctl::detail::dpctl_capi::get();
835835
return api.UsmNDArray_GetData_(raw_ar);
@@ -842,20 +842,29 @@ class usm_ndarray : public py::object
842842

843843
int get_ndim() const
844844
{
845-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
845+
PyUSMArrayObject *raw_ar = usm_array_ptr();
846846

847847
auto const &api = ::dpctl::detail::dpctl_capi::get();
848848
return api.UsmNDArray_GetNDim_(raw_ar);
849849
}
850850

851851
const py::ssize_t *get_shape_raw() const
852852
{
853-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
853+
PyUSMArrayObject *raw_ar = usm_array_ptr();
854854

855855
auto const &api = ::dpctl::detail::dpctl_capi::get();
856856
return api.UsmNDArray_GetShape_(raw_ar);
857857
}
858858

859+
std::vector<py::ssize_t> get_shape_vector() const
860+
{
861+
auto raw_sh = get_shape_raw();
862+
auto nd = get_ndim();
863+
864+
std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
865+
return shape_vector;
866+
}
867+
859868
py::ssize_t get_shape(int i) const
860869
{
861870
auto shape_ptr = get_shape_raw();
@@ -864,15 +873,43 @@ class usm_ndarray : public py::object
864873

865874
const py::ssize_t *get_strides_raw() const
866875
{
867-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
876+
PyUSMArrayObject *raw_ar = usm_array_ptr();
868877

869878
auto const &api = ::dpctl::detail::dpctl_capi::get();
870879
return api.UsmNDArray_GetStrides_(raw_ar);
871880
}
872881

882+
std::vector<py::ssize_t> get_strides_vector() const
883+
{
884+
auto raw_st = get_strides_raw();
885+
auto nd = get_ndim();
886+
887+
if (raw_st == nullptr) {
888+
auto is_c_contig = is_c_contiguous();
889+
auto is_f_contig = is_f_contiguous();
890+
auto raw_sh = get_shape_raw();
891+
if (is_c_contig) {
892+
const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
893+
return contig_strides;
894+
}
895+
else if (is_f_contig) {
896+
const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
897+
return contig_strides;
898+
}
899+
else {
900+
throw std::runtime_error("Invalid array encountered when "
901+
"building strides");
902+
}
903+
}
904+
else {
905+
std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
906+
return st_vec;
907+
}
908+
}
909+
873910
py::ssize_t get_size() const
874911
{
875-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
912+
PyUSMArrayObject *raw_ar = usm_array_ptr();
876913

877914
auto const &api = ::dpctl::detail::dpctl_capi::get();
878915
int ndim = api.UsmNDArray_GetNDim_(raw_ar);
@@ -889,7 +926,7 @@ class usm_ndarray : public py::object
889926

890927
std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
891928
{
892-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
929+
PyUSMArrayObject *raw_ar = usm_array_ptr();
893930

894931
auto const &api = ::dpctl::detail::dpctl_capi::get();
895932
int nd = api.UsmNDArray_GetNDim_(raw_ar);
@@ -923,7 +960,7 @@ class usm_ndarray : public py::object
923960

924961
sycl::queue get_queue() const
925962
{
926-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
963+
PyUSMArrayObject *raw_ar = usm_array_ptr();
927964

928965
auto const &api = ::dpctl::detail::dpctl_capi::get();
929966
DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
@@ -932,45 +969,45 @@ class usm_ndarray : public py::object
932969

933970
int get_typenum() const
934971
{
935-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
972+
PyUSMArrayObject *raw_ar = usm_array_ptr();
936973

937974
auto const &api = ::dpctl::detail::dpctl_capi::get();
938975
return api.UsmNDArray_GetTypenum_(raw_ar);
939976
}
940977

941978
int get_flags() const
942979
{
943-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
980+
PyUSMArrayObject *raw_ar = usm_array_ptr();
944981

945982
auto const &api = ::dpctl::detail::dpctl_capi::get();
946983
return api.UsmNDArray_GetFlags_(raw_ar);
947984
}
948985

949986
int get_elemsize() const
950987
{
951-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
988+
PyUSMArrayObject *raw_ar = usm_array_ptr();
952989

953990
auto const &api = ::dpctl::detail::dpctl_capi::get();
954991
return api.UsmNDArray_GetElementSize_(raw_ar);
955992
}
956993

957994
bool is_c_contiguous() const
958995
{
959-
int flags = this->get_flags();
996+
int flags = get_flags();
960997
auto const &api = ::dpctl::detail::dpctl_capi::get();
961998
return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
962999
}
9631000

9641001
bool is_f_contiguous() const
9651002
{
966-
int flags = this->get_flags();
1003+
int flags = get_flags();
9671004
auto const &api = ::dpctl::detail::dpctl_capi::get();
9681005
return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
9691006
}
9701007

9711008
bool is_writable() const
9721009
{
973-
int flags = this->get_flags();
1010+
int flags = get_flags();
9741011
auto const &api = ::dpctl::detail::dpctl_capi::get();
9751012
return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
9761013
}

0 commit comments

Comments
 (0)