@@ -856,6 +856,15 @@ class usm_ndarray : public py::object
856
856
return api .UsmNDArray_GetShape_ (raw_ar );
857
857
}
858
858
859
+ const std ::vector < py ::ssize_t > get_shape_vector () const
860
+ {
861
+ auto raw_sh = this -> get_shape_raw ();
862
+ auto nd = this -> get_ndim ();
863
+
864
+ std ::vector < py ::ssize_t > shape_vector (raw_sh , raw_sh + nd );
865
+ return shape_vector ;
866
+ }
867
+
859
868
py ::ssize_t get_shape (int i ) const
860
869
{
861
870
auto shape_ptr = get_shape_raw ();
@@ -870,6 +879,34 @@ class usm_ndarray : public py::object
870
879
return api .UsmNDArray_GetStrides_ (raw_ar );
871
880
}
872
881
882
+ const std ::vector < py ::ssize_t > get_strides_vector () const
883
+ {
884
+ auto raw_st = this -> get_strides_raw ();
885
+ auto nd = this -> get_ndim ();
886
+
887
+ if (raw_st == nullptr) {
888
+ auto is_c_contig = this -> is_c_contiguous ();
889
+ auto is_f_contig = this -> is_f_contiguous ();
890
+ auto raw_sh = this -> get_shape_raw ();
891
+ if (is_c_contig ) {
892
+ const auto & contig_strides = c_contiguous_strides (nd , raw_sh );
893
+ return contig_strides ;
894
+ }
895
+ else if (is_f_contig ) {
896
+ const auto & contig_strides = f_contiguous_strides (nd , raw_sh );
897
+ return contig_strides ;
898
+ }
899
+ else {
900
+ throw std ::runtime_error ("Invalid array encountered when "
901
+ "building strides" );
902
+ }
903
+ }
904
+ else {
905
+ std ::vector < py ::ssize_t > st_vec (raw_st , raw_st + nd );
906
+ return st_vec ;
907
+ }
908
+ }
909
+
873
910
py ::ssize_t get_size () const
874
911
{
875
912
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
0 commit comments