diff --git a/src/array.rs b/src/array.rs index e186e30d2..a764d90ec 100644 --- a/src/array.rs +++ b/src/array.rs @@ -9,7 +9,7 @@ use std::{ use ndarray::{ Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, Dim, Dimension, IntoDimension, Ix0, Ix1, - Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, Shape, ShapeBuilder, + Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, RawArrayView, RawArrayViewMut, RawData, ShapeBuilder, StrideShape, }; use num_traits::AsPrimitive; @@ -338,42 +338,19 @@ impl PyArray { } } - /// Calcurates the total number of elements in the array. + /// Calculates the total number of elements in the array. pub fn len(&self) -> usize { self.shape().iter().product() } + /// Returns `true` if the there are no elements in the array. pub fn is_empty(&self) -> bool { - self.len() == 0 + self.shape().iter().any(|dim| *dim == 0) } - /// Returns the pointer to the first element of the inner array. + /// Returns the pointer to the first element of the array. pub(crate) fn data(&self) -> *mut T { - let ptr = self.as_array_ptr(); - unsafe { (*ptr).data as *mut _ } - } -} - -struct InvertedAxes(u32); - -impl InvertedAxes { - fn new(len: usize) -> Self { - assert!(len <= 32, "Only dimensionalities of up to 32 are supported"); - Self(0) - } - - fn push(&mut self, axis: usize) { - debug_assert!(axis < 32); - self.0 |= 1 << axis; - } - - fn invert(mut self, array: &mut ArrayBase) { - while self.0 != 0 { - let axis = self.0.trailing_zeros() as usize; - self.0 &= !(1 << axis); - - array.invert_axis(Axis(axis)); - } + unsafe { (*self.as_array_ptr()).data as *mut _ } } } @@ -384,38 +361,6 @@ impl PyArray { D::from_dimension(&Dim(self.shape())).expect("mismatching dimensions") } - fn ndarray_shape_ptr(&self) -> (StrideShape, *mut T, InvertedAxes) { - let shape = self.shape(); - let strides = self.strides(); - - let mut new_strides = D::zeros(strides.len()); - let mut data_ptr = self.data(); - let mut inverted_axes = InvertedAxes::new(strides.len()); - - for i in 0..strides.len() { - // FIXME(kngwyu): Replace this hacky negative strides support with - // a proper constructor, when it's implemented. - // See https://github.com/rust-ndarray/ndarray/issues/842 for more. - if strides[i] < 0 { - // Move the pointer to the start position - let offset = strides[i] * (shape[i] as isize - 1) / mem::size_of::() as isize; - unsafe { - data_ptr = data_ptr.offset(offset); - } - new_strides[i] = (-strides[i]) as usize / mem::size_of::(); - - inverted_axes.push(i); - } else { - new_strides[i] = strides[i] as usize / mem::size_of::(); - } - } - - let shape = Shape::from(D::from_dimension(&Dim(shape)).expect("mismatching dimensions")); - let new_strides = D::from_dimension(&Dim(new_strides)).expect("mismatching dimensions"); - - (shape.strides(new_strides), data_ptr, inverted_axes) - } - /// Creates a new uninitialized PyArray in python heap. /// /// If `is_fortran == true`, returns Fortran-order array. Else, returns C-order array. @@ -883,6 +828,63 @@ impl PyArray { self.try_readwrite().unwrap() } + fn as_view(&self, from_shape_ptr: F) -> ArrayBase + where + F: FnOnce(StrideShape, *mut T) -> ArrayBase, + { + fn inner( + shape: &[usize], + strides: &[isize], + itemsize: usize, + mut data_ptr: *mut u8, + ) -> (StrideShape, u32, *mut u8) { + let shape = D::from_dimension(&Dim(shape)).expect("mismatching dimensions"); + + assert!( + strides.len() <= 32, + "Only dimensionalities of up to 32 are supported" + ); + + let mut new_strides = D::zeros(strides.len()); + let mut inverted_axes = 0_u32; + + for i in 0..strides.len() { + // FIXME(kngwyu): Replace this hacky negative strides support with + // a proper constructor, when it's implemented. + // See https://github.com/rust-ndarray/ndarray/issues/842 for more. + if strides[i] >= 0 { + new_strides[i] = strides[i] as usize / itemsize; + } else { + // Move the pointer to the start position. + data_ptr = unsafe { data_ptr.offset(strides[i] * (shape[i] as isize - 1)) }; + + new_strides[i] = (-strides[i]) as usize / itemsize; + inverted_axes |= 1 << i; + } + } + + (shape.strides(new_strides), inverted_axes, data_ptr) + } + + let (shape, mut inverted_axes, data_ptr) = inner( + self.shape(), + self.strides(), + mem::size_of::(), + self.data() as _, + ); + + let mut array = from_shape_ptr(shape, data_ptr as _); + + while inverted_axes != 0 { + let axis = inverted_axes.trailing_zeros() as usize; + inverted_axes &= !(1 << axis); + + array.invert_axis(Axis(axis)); + } + + array + } + /// Returns the internal array as [`ArrayView`]. /// /// See also [`PyReadonlyArray::as_array`]. @@ -891,10 +893,7 @@ impl PyArray { /// /// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior. pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> { - let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr(); - let mut res = ArrayView::from_shape_ptr(shape, ptr); - inverted_axes.invert(&mut res); - res + self.as_view(|shape, ptr| ArrayView::from_shape_ptr(shape, ptr)) } /// Returns the internal array as [`ArrayViewMut`]. @@ -905,26 +904,17 @@ impl PyArray { /// /// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior. pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> { - let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr(); - let mut res = ArrayViewMut::from_shape_ptr(shape, ptr); - inverted_axes.invert(&mut res); - res + self.as_view(|shape, ptr| ArrayViewMut::from_shape_ptr(shape, ptr)) } /// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers pub fn as_raw_array(&self) -> RawArrayView { - let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr(); - let mut res = unsafe { RawArrayView::from_shape_ptr(shape, ptr) }; - inverted_axes.invert(&mut res); - res + self.as_view(|shape, ptr| unsafe { RawArrayView::from_shape_ptr(shape, ptr) }) } /// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers pub fn as_raw_array_mut(&self) -> RawArrayViewMut { - let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr(); - let mut res = unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) }; - inverted_axes.invert(&mut res); - res + self.as_view(|shape, ptr| unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) }) } /// Get a copy of `PyArray` as diff --git a/src/convert.rs b/src/convert.rs index c107a1262..9df6a5798 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -144,12 +144,11 @@ where fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray { let len = self.len(); match self.order() { - Some(order) if A::IS_COPY => { + Some(flag) if A::IS_COPY => { // if the array is contiguous, copy it by `copy_nonoverlapping`. let strides = self.npy_strides(); unsafe { - let array = - PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag()); + let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), flag); ptr::copy_nonoverlapping(self.as_ptr(), array.data(), len); array } @@ -157,14 +156,8 @@ where _ => { // if the array is not contiguous, copy all elements by `ArrayBase::iter`. let dim = self.raw_dim(); - let strides = NpyStrides::new::<_, A>( - dim.default_strides() - .slice() - .iter() - .map(|&x| x as npyffi::npy_intp), - ); unsafe { - let array = PyArray::::new_(py, dim, strides.as_ptr(), 0); + let array = PyArray::::new(py, dim, false); let mut data_ptr = array.data(); for item in self.iter() { data_ptr.write(item.clone()); @@ -177,23 +170,9 @@ where } } -pub(crate) enum Order { - Standard, - Fortran, -} - -impl Order { - fn to_flag(&self) -> c_int { - match self { - Order::Standard => 0, - Order::Fortran => 1, - } - } -} - pub(crate) trait ArrayExt { - fn npy_strides(&self) -> NpyStrides; - fn order(&self) -> Option; + fn npy_strides(&self) -> [npyffi::npy_intp; 32]; + fn order(&self) -> Option; } impl ArrayExt for ArrayBase @@ -201,45 +180,35 @@ where S: Data, D: Dimension, { - fn npy_strides(&self) -> NpyStrides { - NpyStrides::new::<_, A>(self.strides().iter().map(|&x| x as npyffi::npy_intp)) + fn npy_strides(&self) -> [npyffi::npy_intp; 32] { + let strides = self.strides(); + let itemsize = mem::size_of::() as isize; + + assert!( + strides.len() <= 32, + "Only dimensionalities of up to 32 are supported" + ); + + let mut new_strides = [0; 32]; + + for i in 0..strides.len() { + new_strides[i] = (strides[i] * itemsize) as npyffi::npy_intp; + } + + new_strides } - fn order(&self) -> Option { + fn order(&self) -> Option { if self.is_standard_layout() { - Some(Order::Standard) + Some(npyffi::NPY_ORDER::NPY_CORDER as _) } else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() { - Some(Order::Fortran) + Some(npyffi::NPY_ORDER::NPY_FORTRANORDER as _) } else { None } } } -/// An array of strides sufficiently large for [any NumPy array][NPY_MAXDIMS] -/// -/// [NPY_MAXDIMS]: https://github.com/numpy/numpy/blob/4c60b3263ac50e5e72f6a909e156314fc3c9cba0/numpy/core/include/numpy/ndarraytypes.h#L40 -pub(crate) struct NpyStrides([npyffi::npy_intp; 32]); - -impl NpyStrides { - pub(crate) fn as_ptr(&self) -> *const npy_intp { - self.0.as_ptr() - } - - fn new(strides: S) -> Self - where - S: Iterator, - { - let type_size = mem::size_of::() as npyffi::npy_intp; - let mut res = [0; 32]; - for (i, s) in strides.enumerate() { - *res.get_mut(i) - .expect("Only dimensionalities of up to 32 are supported") = s * type_size; - } - Self(res) - } -} - /// Utility trait to specify the dimensions of an array. pub trait ToNpyDims: Dimension + Sealed { #[doc(hidden)] diff --git a/tests/array.rs b/tests/array.rs index e0517e8ae..bb1ab75b2 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -90,6 +90,11 @@ fn rank_zero_array_has_invalid_strides_dimensions() { assert_eq!(arr.ndim(), 0); assert_eq!(arr.strides(), &[]); assert_eq!(arr.shape(), &[]); + + assert_eq!(arr.len(), 1); + assert!(!arr.is_empty()); + + assert_eq!(arr.item(), 0.0); }) }