@@ -856,6 +856,16 @@ class usm_ndarray : public py::object
856
856
return api .UsmNDArray_GetShape_ (raw_ar );
857
857
}
858
858
859
+ const std ::vector < py ::ssize_t > get_shape_vector () const
860
+ {
861
+ auto raw_sh = this -> get_shape_raw ();
862
+ auto nd = this -> get_ndim ();
863
+
864
+ std ::vector < py ::ssize_t >
865
+ shape_vector (raw_sh , raw_sh + nd );
866
+ return shape_vector ;
867
+ }
868
+
859
869
py ::ssize_t get_shape (int i ) const
860
870
{
861
871
auto shape_ptr = get_shape_raw ();
@@ -870,6 +880,39 @@ class usm_ndarray : public py::object
870
880
return api .UsmNDArray_GetStrides_ (raw_ar );
871
881
}
872
882
883
+ const std ::vector < py ::ssize_t > get_strides_vector () const
884
+ {
885
+ auto raw_st = this -> get_strides_raw ();
886
+ auto nd = this -> get_ndim ();
887
+
888
+ if (raw_st == nullptr) {
889
+ auto is_c_contig = this -> is_c_contiguous ();
890
+ auto is_f_contig = this -> is_f_contiguous ();
891
+ auto raw_sh = this -> get_shape_raw ();
892
+ if (is_c_contig ) {
893
+ const auto & contig_strides =
894
+ c_contiguous_strides (nd , raw_sh );
895
+ return contig_strides ;
896
+ }
897
+ else if (is_f_contig ) {
898
+ const auto & contig_strides =
899
+ f_contiguous_strides (nd , raw_sh );
900
+ return contig_strides ;
901
+ }
902
+ else {
903
+ throw std ::runtime_error (
904
+ "Invalid array encountered when "
905
+ "building strides"
906
+ );
907
+ }
908
+ }
909
+ else {
910
+ std ::vector < py ::ssize_t >
911
+ st_vec (raw_st , raw_st + nd );
912
+ return st_vec ;
913
+ }
914
+ }
915
+
873
916
py ::ssize_t get_size () const
874
917
{
875
918
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
0 commit comments