@@ -829,7 +829,7 @@ class usm_ndarray : public py::object
829
829
830
830
char * get_data () const
831
831
{
832
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
832
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
833
833
834
834
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
835
835
return api .UsmNDArray_GetData_ (raw_ar );
@@ -842,20 +842,29 @@ class usm_ndarray : public py::object
842
842
843
843
int get_ndim () const
844
844
{
845
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
845
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
846
846
847
847
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
848
848
return api .UsmNDArray_GetNDim_ (raw_ar );
849
849
}
850
850
851
851
const py ::ssize_t * get_shape_raw () const
852
852
{
853
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
853
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
854
854
855
855
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
856
856
return api .UsmNDArray_GetShape_ (raw_ar );
857
857
}
858
858
859
+ std ::vector < py ::ssize_t > get_shape_vector () const
860
+ {
861
+ auto raw_sh = get_shape_raw ();
862
+ auto nd = get_ndim ();
863
+
864
+ std ::vector < py ::ssize_t > shape_vector (raw_sh , raw_sh + nd );
865
+ return shape_vector ;
866
+ }
867
+
859
868
py ::ssize_t get_shape (int i ) const
860
869
{
861
870
auto shape_ptr = get_shape_raw ();
@@ -864,15 +873,43 @@ class usm_ndarray : public py::object
864
873
865
874
const py ::ssize_t * get_strides_raw () const
866
875
{
867
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
876
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
868
877
869
878
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
870
879
return api .UsmNDArray_GetStrides_ (raw_ar );
871
880
}
872
881
882
+ std ::vector < py ::ssize_t > get_strides_vector () const
883
+ {
884
+ auto raw_st = get_strides_raw ();
885
+ auto nd = get_ndim ();
886
+
887
+ if (raw_st == nullptr ) {
888
+ auto is_c_contig = is_c_contiguous ();
889
+ auto is_f_contig = is_f_contiguous ();
890
+ auto raw_sh = get_shape_raw ();
891
+ if (is_c_contig ) {
892
+ const auto & contig_strides = c_contiguous_strides (nd , raw_sh );
893
+ return contig_strides ;
894
+ }
895
+ else if (is_f_contig ) {
896
+ const auto & contig_strides = f_contiguous_strides (nd , raw_sh );
897
+ return contig_strides ;
898
+ }
899
+ else {
900
+ throw std ::runtime_error ("Invalid array encountered when "
901
+ "building strides" );
902
+ }
903
+ }
904
+ else {
905
+ std ::vector < py ::ssize_t > st_vec (raw_st , raw_st + nd );
906
+ return st_vec ;
907
+ }
908
+ }
909
+
873
910
py ::ssize_t get_size () const
874
911
{
875
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
912
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
876
913
877
914
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
878
915
int ndim = api .UsmNDArray_GetNDim_ (raw_ar );
@@ -889,7 +926,7 @@ class usm_ndarray : public py::object
889
926
890
927
std ::pair < py ::ssize_t , py ::ssize_t > get_minmax_offsets () const
891
928
{
892
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
929
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
893
930
894
931
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
895
932
int nd = api .UsmNDArray_GetNDim_ (raw_ar );
@@ -923,7 +960,7 @@ class usm_ndarray : public py::object
923
960
924
961
sycl ::queue get_queue () const
925
962
{
926
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
963
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
927
964
928
965
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
929
966
DPCTLSyclQueueRef QRef = api .UsmNDArray_GetQueueRef_ (raw_ar );
@@ -932,45 +969,45 @@ class usm_ndarray : public py::object
932
969
933
970
int get_typenum () const
934
971
{
935
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
972
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
936
973
937
974
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
938
975
return api .UsmNDArray_GetTypenum_ (raw_ar );
939
976
}
940
977
941
978
int get_flags () const
942
979
{
943
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
980
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
944
981
945
982
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
946
983
return api .UsmNDArray_GetFlags_ (raw_ar );
947
984
}
948
985
949
986
int get_elemsize () const
950
987
{
951
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
988
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
952
989
953
990
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
954
991
return api .UsmNDArray_GetElementSize_ (raw_ar );
955
992
}
956
993
957
994
bool is_c_contiguous () const
958
995
{
959
- int flags = this -> get_flags ();
996
+ int flags = get_flags ();
960
997
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
961
998
return static_cast < bool > (flags & api .USM_ARRAY_C_CONTIGUOUS_ );
962
999
}
963
1000
964
1001
bool is_f_contiguous () const
965
1002
{
966
- int flags = this -> get_flags ();
1003
+ int flags = get_flags ();
967
1004
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
968
1005
return static_cast < bool > (flags & api .USM_ARRAY_F_CONTIGUOUS_ );
969
1006
}
970
1007
971
1008
bool is_writable () const
972
1009
{
973
- int flags = this -> get_flags ();
1010
+ int flags = get_flags ();
974
1011
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
975
1012
return static_cast < bool > (flags & api .USM_ARRAY_WRITABLE_ );
976
1013
}
0 commit comments