diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a64a711e..0a3e3c7d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Fixed downcasting ignoring element type and dimensionality ([#265](https://github.com/PyO3/rust-numpy/pull/265)) - `PyArray::new` is now `unsafe`, as it produces uninitialized arrays ([#220](https://github.com/PyO3/rust-numpy/pull/220)) - `PyArray::iter`, `NpySingleIterBuilder::readwrite` and `NpyMultiIterBuilder::add_readwrite` are now `unsafe`, as they allow aliasing mutable references to be created ([#278/](https://github.com/PyO3/rust-numpy/pull/278)) + - The `npyiter` module is deprecated as rust-ndarray's facilities for iteration are more flexible and performant ([#280](https://github.com/PyO3/rust-numpy/pull/280)) - `PyArray::from_exact_iter` does not unsoundly trust `ExactSizeIterator::len` any more ([#262](https://github.com/PyO3/rust-numpy/pull/262)) - `PyArray::as_cell_slice` was removed as it unsoundly interacts with `PyReadonlyArray` allowing safe code to violate aliasing rules ([#260](https://github.com/PyO3/rust-numpy/pull/260)) - `rayon` feature is now removed, and directly specifying the feature via `ndarray` dependency is recommended ([#250](https://github.com/PyO3/rust-numpy/pull/250)) diff --git a/src/array.rs b/src/array.rs index 8fea6243f..565f027da 100644 --- a/src/array.rs +++ b/src/array.rs @@ -23,6 +23,7 @@ use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; use crate::dtype::{Element, PyArrayDescr}; use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError}; use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API}; +#[allow(deprecated)] use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, ReadWrite}; use crate::readonly::PyReadonlyArray; use crate::slice_container::PySliceContainer; @@ -1079,6 +1080,10 @@ impl PyArray { /// /// The iterator will produce mutable references into the array which must not be /// aliased by other references for the life time of the iterator. + #[deprecated( + note = "The wrappers of the array iterator API are deprecated, please use ndarray's `ArrayBase::iter_mut` instead." + )] + #[allow(deprecated)] pub unsafe fn iter<'py>(&'py self) -> PyResult> { NpySingleIterBuilder::readwrite(self).build() } diff --git a/src/lib.rs b/src/lib.rs index 2b1dbe76b..bca8b1276 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,6 +56,7 @@ pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr}; pub use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError}; pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API}; +#[allow(deprecated)] pub use crate::npyiter::{ IterMode, NpyIterFlag, NpyMultiIter, NpyMultiIterBuilder, NpySingleIter, NpySingleIterBuilder, }; diff --git a/src/npyiter.rs b/src/npyiter.rs index 53a787b3e..4d4601972 100644 --- a/src/npyiter.rs +++ b/src/npyiter.rs @@ -1,8 +1,31 @@ -//! Wrapper of [Array Iterator API](https://numpy.org/doc/stable/reference/c-api/iterator.html). +//! Wrapper of the [array iterator API][iterator]. //! -//! This module exposes two iterators: -//! [NpySingleIter](./struct.NpySingleIter.html) and -//! [NpyMultiIter](./struct.NpyMultiIter.html). +//! This module exposes two iterators: [`NpySingleIter`] and [`NpyMultiIter`]. +//! +//! As general recommendation, the usage of ndarray's facilities for iteration should be preferred: +//! +//! * They are more performant due to being transparent to the Rust compiler, using statically known dimensions +//! without dynamic dispatch into NumPy's C implementation, c.f. [`ndarray::iter::Iter`]. +//! * They are more flexible as to which parts of the array iterated in which order, c.f. [`ndarray::iter::Lanes`]. +//! * They can zip up to six arrays together and operate on their elements using multiple threads, c.f. [`ndarray::Zip`]. +//! +//! To safely use these types, extension functions should take [`PyReadonlyArray`] as arguments +//! which provide the [`as_array`][PyReadonlyArray::as_array] method to acquire an [`ndarray::ArrayView`]. +//! +//! [iterator]: https://numpy.org/doc/stable/reference/c-api/iterator.html +#![deprecated( + note = "The wrappers of the array iterator API are deprecated, please use ndarray's iterators like `Lanes` and `Zip` instead." +)] + +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int}; +use std::ptr; + +use ndarray::Dimension; +use pyo3::{PyErr, PyNativeType, PyResult, Python}; + +use crate::array::{PyArray, PyArrayDyn}; +use crate::dtype::Element; use crate::npyffi::{ array::PY_ARRAY_API, npy_intp, npy_uint32, @@ -13,29 +36,16 @@ use crate::npyffi::{ NPY_ITER_READONLY, NPY_ITER_READWRITE, NPY_ITER_REDUCE_OK, NPY_ITER_REFS_OK, NPY_ITER_ZEROSIZE_OK, }; -use crate::{Element, PyArray, PyArrayDyn, PyReadonlyArray}; -use pyo3::{prelude::*, PyNativeType}; - -use std::marker::PhantomData; -use std::os::raw::*; -use std::ptr; +use crate::readonly::PyReadonlyArray; /// Flags for constructing an iterator. -/// For the meanings of each flag, readers can refer to [the numpy document][doc]. /// -/// Note that this enum doesn't provide all flags in the numpy C-API. -/// If you have any inconvenience about that, please file an [issue]. +/// The meanings of these flags are defined in the [the NumPy documentation][iterator]. /// -/// [doc]: https://numpy.org/doc/stable/reference/c-api/iterator.html#c.NpyIter_MultiNew -/// [issue]: https://github.com/PyO3/rust-numpy/issues -// Here's a list of unsupported flags: -// CIndex, -// FIndex, -// MultiIndex, -// ExternalLoop, -// ReadWrite, -// ReadOnly, -// WriteOnly, +/// Note that some flags like `MultiIndex` and `ReadOnly` are directly represented +/// by the iterators types provided here. +/// +/// [iterator]: https://numpy.org/doc/stable/reference/c-api/iterator.html#c.NpyIter_MultiNew #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum NpyIterFlag { CommonDtype, @@ -48,6 +58,13 @@ pub enum NpyIterFlag { DelayBufAlloc, DontNegateStrides, CopyIfOverlap, + // CIndex, + // FIndex, + // MultiIndex, + // ExternalLoop, + // ReadWrite, + // ReadOnly, + // WriteOnly, } impl NpyIterFlag { @@ -68,7 +85,7 @@ impl NpyIterFlag { } } -/// Defines IterMode and MultiIterMode. +/// Defines the sealed traits `IterMode` and `MultiIterMode`. mod itermode { use super::*; @@ -86,15 +103,14 @@ mod itermode { }; } - /// A combinator type that represents the mode of an iterator - /// (E.g., Readonly, ReadWrite, Readonly + ReadWrite). + /// A combinator type that represents the mode of an iterator. pub trait MultiIterMode { private_decl!(); type Pre: MultiIterMode; const FLAG: npy_uint32 = 0; fn flags() -> Vec { if Self::FLAG == 0 { - vec![] + Vec::new() } else { let mut res = Self::Pre::flags(); res.push(Self::FLAG); @@ -151,7 +167,7 @@ pub use itermode::{ IterMode, MultiIterMode, MultiIterModeWithManyArrays, ReadWrite, Readonly, RO, RW, }; -/// Builder of [NpySingleIter](./struct.NpySingleIter.html). +/// Builder of [`NpySingleIter`]. pub struct NpySingleIterBuilder<'py, T, I: IterMode> { flags: npy_uint32, array: &'py PyArrayDyn, @@ -160,9 +176,10 @@ pub struct NpySingleIterBuilder<'py, T, I: IterMode> { } impl<'py, T: Element> NpySingleIterBuilder<'py, T, Readonly> { - /// Makes a new builder for a readonly iterator. - pub fn readonly(array: PyReadonlyArray<'py, T, D>) -> Self { + /// Create a new builder for a readonly iterator. + pub fn readonly(array: PyReadonlyArray<'py, T, D>) -> Self { let (array, was_writable) = array.destruct(); + Self { flags: NPY_ITER_READONLY, array: array.to_dyn(), @@ -173,13 +190,13 @@ impl<'py, T: Element> NpySingleIterBuilder<'py, T, Readonly> { } impl<'py, T: Element> NpySingleIterBuilder<'py, T, ReadWrite> { - /// Makes a new builder for a writable iterator. + /// Create a new builder for a writable iterator. /// /// # Safety /// /// The iterator will produce mutable references into the array which must not be /// aliased by other references for the life time of the iterator. - pub unsafe fn readwrite(array: &'py PyArray) -> Self { + pub unsafe fn readwrite(array: &'py PyArray) -> Self { Self { flags: NPY_ITER_READWRITE, array: array.to_dyn(), @@ -190,7 +207,7 @@ impl<'py, T: Element> NpySingleIterBuilder<'py, T, ReadWrite> { } impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> { - /// Sets a flag to this builder, returning `self`. + /// Applies a flag to this builder, returning `self`. #[must_use] pub fn set(mut self, flag: NpyIterFlag) -> Self { self.flags |= flag.to_c_enum(); @@ -201,6 +218,7 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> { pub fn build(self) -> PyResult> { let array_ptr = self.array.as_array_ptr(); let py = self.array.py(); + let iter_ptr = unsafe { PY_ARRAY_API.NpyIter_New( py, @@ -211,54 +229,54 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> { ptr::null_mut(), ) }; + let readonly_array_ptr = if self.was_writable { Some(array_ptr) } else { None }; + NpySingleIter::new(iter_ptr, readonly_array_ptr, py) } } -/// An iterator over a single array, construced by -/// [NpySingleIterBuilder](./struct.NpySingleIterBuilder.html). -/// This iterator iterates all elements in the array as `&mut T` (in case `readwrite` is used) -/// or `&T` (in case `readonly` is used). +/// An iterator over a single array, construced by [`NpySingleIterBuilder`]. +/// +/// The elements are access `&mut T` in case `readwrite` is used or +/// `&T` in case `readonly` is used. /// /// # Example /// -/// You can use -/// [`NpySingleIterBuilder::readwrite`](./struct.NpySingleIterBuilder.html#method.readwrite) -/// to get a mutable iterator. +/// You can use [`NpySingleIterBuilder::readwrite`] to get a mutable iterator. /// /// ``` -/// use numpy::NpySingleIterBuilder; -/// pyo3::Python::with_gil(|py| { -/// let array = numpy::PyArray::arange(py, 0, 10, 1); +/// use numpy::pyo3::Python; +/// use numpy::{NpySingleIterBuilder, PyArray}; +/// +/// Python::with_gil(|py| { +/// let array = PyArray::arange(py, 0, 10, 1); +/// /// let iter = unsafe { NpySingleIterBuilder::readwrite(array).build().unwrap() }; +/// /// for (i, elem) in iter.enumerate() { /// assert_eq!(*elem, i as i64); -/// *elem = *elem * 2; // elements are mutable -/// } -/// }); -/// ``` -/// Or, as a shorthand, `PyArray::iter` can be also used. -/// ``` -/// # use numpy::NpySingleIterBuilder; -/// # pyo3::Python::with_gil(|py| { -/// # let array = numpy::PyArray::arange(py, 0, 10, 1); -/// for (i, elem) in unsafe { array.iter().unwrap().enumerate() } { -/// assert_eq!(*elem, i as i64); -/// *elem = *elem * 2; // elements are mutable +/// +/// *elem = *elem * 2; // Elements can be mutated. /// } /// }); /// ``` -/// On the other hand, immutable iterator requires [readonly array](../struct.PyReadonlyArray.html). +/// +/// On the other hand, a readonly iterator requires an instance of [`PyReadonlyArray`]. +/// /// ``` -/// use numpy::NpySingleIterBuilder; -/// pyo3::Python::with_gil(|py| { -/// let array = numpy::PyArray::arange(py, 0, 1, 10); +/// use numpy::pyo3::Python; +/// use numpy::{NpySingleIterBuilder, PyArray}; +/// +/// Python::with_gil(|py| { +/// let array = PyArray::arange(py, 0, 1, 10); +/// /// let iter = NpySingleIterBuilder::readonly(array.readonly()).build().unwrap(); +/// /// for (i, elem) in iter.enumerate() { /// assert_eq!(*elem, i as i64); /// } @@ -267,7 +285,6 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> { pub struct NpySingleIter<'py, T, I> { iterator: ptr::NonNull, iternext: unsafe extern "C" fn(*mut NpyIter) -> c_int, - empty: bool, iter_size: npy_intp, dataptr: *mut *mut c_char, return_type: PhantomData, @@ -284,21 +301,17 @@ impl<'py, T, I> NpySingleIter<'py, T, I> { ) -> PyResult { let mut iterator = match ptr::NonNull::new(iterator) { Some(iter) => iter, - None => { - return Err(PyErr::fetch(py)); - } + None => return Err(PyErr::fetch(py)), }; let iternext = match unsafe { PY_ARRAY_API.NpyIter_GetIterNext(py, iterator.as_mut(), ptr::null_mut()) } { Some(ptr) => ptr, - None => { - return Err(PyErr::fetch(py)); - } + None => return Err(PyErr::fetch(py)), }; - let dataptr = unsafe { PY_ARRAY_API.NpyIter_GetDataPtrArray(py, iterator.as_mut()) }; + let dataptr = unsafe { PY_ARRAY_API.NpyIter_GetDataPtrArray(py, iterator.as_mut()) }; if dataptr.is_null() { unsafe { PY_ARRAY_API.NpyIter_Deallocate(py, iterator.as_mut()) }; return Err(PyErr::fetch(py)); @@ -310,7 +323,6 @@ impl<'py, T, I> NpySingleIter<'py, T, I> { iterator, iternext, iter_size, - empty: iter_size == 0, dataptr, return_type: PhantomData, mode: PhantomData, @@ -320,7 +332,7 @@ impl<'py, T, I> NpySingleIter<'py, T, I> { } fn iternext(&mut self) -> Option<*mut T> { - if self.empty { + if self.iter_size == 0 { None } else { // Note: This pointer is correct and doesn't need to be updated, @@ -328,7 +340,10 @@ impl<'py, T, I> NpySingleIter<'py, T, I> { // and then transforming that into a reference, the value that dataptr // points to is being updated by iternext to point to the next value. let ret = unsafe { *self.dataptr as *mut T }; - self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0; + let empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0; + debug_assert_ne!(self.iter_size, 0); + self.iter_size -= 1; + debug_assert!(self.iter_size > 0 || empty); Some(ret) } } @@ -337,6 +352,7 @@ impl<'py, T, I> NpySingleIter<'py, T, I> { impl<'py, T, I> Drop for NpySingleIter<'py, T, I> { fn drop(&mut self) { let _success = unsafe { PY_ARRAY_API.NpyIter_Deallocate(self.py, self.iterator.as_mut()) }; + if let Some(ptr) = self.readonly_array_ptr { unsafe { (*ptr).flags |= NPY_ARRAY_WRITEABLE; @@ -345,7 +361,7 @@ impl<'py, T, I> Drop for NpySingleIter<'py, T, I> { } } -impl<'py, T: 'py> std::iter::Iterator for NpySingleIter<'py, T, Readonly> { +impl<'py, T: 'py> Iterator for NpySingleIter<'py, T, Readonly> { type Item = &'py T; fn next(&mut self) -> Option { @@ -353,11 +369,17 @@ impl<'py, T: 'py> std::iter::Iterator for NpySingleIter<'py, T, Readonly> { } fn size_hint(&self) -> (usize, Option) { - (self.iter_size as usize, Some(self.iter_size as usize)) + (self.len(), Some(self.len())) + } +} + +impl<'py, T: 'py> ExactSizeIterator for NpySingleIter<'py, T, Readonly> { + fn len(&self) -> usize { + self.iter_size as usize } } -impl<'py, T: 'py> std::iter::Iterator for NpySingleIter<'py, T, ReadWrite> { +impl<'py, T: 'py> Iterator for NpySingleIter<'py, T, ReadWrite> { type Item = &'py mut T; fn next(&mut self) -> Option { @@ -365,11 +387,17 @@ impl<'py, T: 'py> std::iter::Iterator for NpySingleIter<'py, T, ReadWrite> { } fn size_hint(&self) -> (usize, Option) { - (self.iter_size as usize, Some(self.iter_size as usize)) + (self.len(), Some(self.len())) } } -/// Builder for [NpyMultiIter](./struct.NpyMultiIter.html). +impl<'py, T: 'py> ExactSizeIterator for NpySingleIter<'py, T, ReadWrite> { + fn len(&self) -> usize { + self.iter_size as usize + } +} + +/// Builder for [`NpyMultiIter`]. pub struct NpyMultiIterBuilder<'py, T, S: MultiIterMode> { flags: npy_uint32, arrays: Vec<&'py PyArrayDyn>, @@ -394,7 +422,7 @@ impl<'py, T: Element> NpyMultiIterBuilder<'py, T, ()> { } } - /// Set a flag to this builder, returning `self`. + /// Applies a flag to this builder, returning `self`. #[must_use] pub fn set(mut self, flag: NpyIterFlag) -> Self { self.flags |= flag.to_c_enum(); @@ -404,13 +432,15 @@ impl<'py, T: Element> NpyMultiIterBuilder<'py, T, ()> { impl<'py, T: Element, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> { /// Add a readonly array to the resulting iterator. - pub fn add_readonly( + pub fn add_readonly( mut self, array: PyReadonlyArray<'py, T, D>, ) -> NpyMultiIterBuilder<'py, T, RO> { let (array, was_writable) = array.destruct(); + self.arrays.push(array.to_dyn()); self.was_writables.push(was_writable); + NpyMultiIterBuilder { flags: self.flags, arrays: self.arrays, @@ -425,12 +455,13 @@ impl<'py, T: Element, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> { /// /// The iterator will produce mutable references into the array which must not be /// aliased by other references for the life time of the iterator. - pub unsafe fn add_readwrite( + pub unsafe fn add_readwrite( mut self, array: &'py PyArray, ) -> NpyMultiIterBuilder<'py, T, RW> { self.arrays.push(array.to_dyn()); self.was_writables.push(false); + NpyMultiIterBuilder { flags: self.flags, arrays: self.arrays, @@ -449,16 +480,15 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T was_writables, .. } = self; - debug_assert!(arrays.len() <= std::i32::MAX as usize); + + debug_assert!(arrays.len() <= i32::MAX as usize); debug_assert!(2 <= arrays.len()); let mut opflags = S::flags(); + let py = arrays[0].py(); - let mut arrays = arrays - .iter() - .map(|x| x.as_array_ptr()) - .collect::>() - .into_boxed_slice(); + + let mut arrays = arrays.iter().map(|x| x.as_array_ptr()).collect::>(); let iter_ptr = unsafe { PY_ARRAY_API.NpyIter_MultiNew( @@ -472,26 +502,27 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T ptr::null_mut(), ) }; + NpyMultiIter::new(iter_ptr, arrays, was_writables, py) } } -/// An iterator over multiple arrays, construced by -/// [NpyMultiIterBuilder](./struct.NpyMultiIterBuilder.html). -/// You can add -/// [`NpyMultiIterBuilder::add_readwrite`](./struct.NpyMultiIterBuilder.html#method.add_readwrite) -/// for adding a mutable component to the iterator, and -/// [`NpyMultiIterBuilder::add_readonly`](./struct.NpyMultiIterBuilder.html#method.add_readonly) -/// for adding a immutable one. +/// An iterator over multiple arrays, construced by [`NpyMultiIterBuilder`]. +/// +/// [`NpyMultiIterBuilder::add_readwrite`] is used for adding a mutable component and +/// [`NpyMultiIterBuilder::add_readonly`] is used for adding an immutable one. /// /// # Example /// /// ``` +/// use numpy::pyo3::Python; /// use numpy::NpyMultiIterBuilder; -/// pyo3::Python::with_gil(|py| { +/// +/// Python::with_gil(|py| { /// let array1 = numpy::PyArray::arange(py, 0, 10, 1); /// let array2 = numpy::PyArray::arange(py, 10, 20, 1); /// let array3 = numpy::PyArray::arange(py, 10, 30, 2); +/// /// let iter = unsafe { /// NpyMultiIterBuilder::new() /// .add_readonly(array1.readonly()) @@ -500,20 +531,20 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T /// .build() /// .unwrap() /// }; +/// /// for (i, j, k) in iter { /// assert_eq!(*i + *j, *k); -/// *j += *i + *k; // The third element is only mutable. +/// *j += *i + *k; // Only the third element can be mutated. /// } /// }); /// ``` pub struct NpyMultiIter<'py, T, S: MultiIterModeWithManyArrays> { iterator: ptr::NonNull, iternext: unsafe extern "C" fn(*mut NpyIter) -> c_int, - empty: bool, iter_size: npy_intp, dataptr: *mut *mut c_char, marker: PhantomData<(T, S)>, - arrays: Box<[*mut PyArrayObject]>, + arrays: Vec<*mut PyArrayObject>, was_writables: Vec, py: Python<'py>, } @@ -521,27 +552,23 @@ pub struct NpyMultiIter<'py, T, S: MultiIterModeWithManyArrays> { impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> { fn new( iterator: *mut NpyIter, - arrays: Box<[*mut PyArrayObject]>, + arrays: Vec<*mut PyArrayObject>, was_writables: Vec, py: Python<'py>, ) -> PyResult { let mut iterator = match ptr::NonNull::new(iterator) { Some(ptr) => ptr, - None => { - return Err(PyErr::fetch(py)); - } + None => return Err(PyErr::fetch(py)), }; let iternext = match unsafe { PY_ARRAY_API.NpyIter_GetIterNext(py, iterator.as_mut(), ptr::null_mut()) } { Some(ptr) => ptr, - None => { - return Err(PyErr::fetch(py)); - } + None => return Err(PyErr::fetch(py)), }; - let dataptr = unsafe { PY_ARRAY_API.NpyIter_GetDataPtrArray(py, iterator.as_mut()) }; + let dataptr = unsafe { PY_ARRAY_API.NpyIter_GetDataPtrArray(py, iterator.as_mut()) }; if dataptr.is_null() { unsafe { PY_ARRAY_API.NpyIter_Deallocate(py, iterator.as_mut()) }; return Err(PyErr::fetch(py)); @@ -553,7 +580,6 @@ impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> { iterator, iternext, iter_size, - empty: iter_size == 0, dataptr, marker: PhantomData, arrays, @@ -566,6 +592,7 @@ impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> { impl<'py, T, S: MultiIterModeWithManyArrays> Drop for NpyMultiIter<'py, T, S> { fn drop(&mut self) { let _success = unsafe { PY_ARRAY_API.NpyIter_Deallocate(self.py, self.iterator.as_mut()) }; + for (array_ptr, &was_writable) in self.arrays.iter().zip(self.was_writables.iter()) { if was_writable { unsafe { (**array_ptr).flags |= NPY_ARRAY_WRITEABLE }; @@ -576,10 +603,11 @@ impl<'py, T, S: MultiIterModeWithManyArrays> Drop for NpyMultiIter<'py, T, S> { macro_rules! impl_multi_iter { ($structure: ty, $($ty: ty)+, $($ptr: ident)+, $expand: ident, $deref: expr) => { - impl<'py, T: 'py> std::iter::Iterator for NpyMultiIter<'py, T, $structure> { + impl<'py, T: 'py> Iterator for NpyMultiIter<'py, T, $structure> { type Item = ($($ty,)+); + fn next(&mut self) -> Option { - if self.empty { + if self.iter_size == 0 { None } else { // Note: This pointer is correct and doesn't need to be updated, @@ -588,19 +616,27 @@ macro_rules! impl_multi_iter { // points to is being updated by iternext to point to the next value. let ($($ptr,)+) = unsafe { $expand::(self.dataptr) }; let retval = Some(unsafe { $deref }); - self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0; + let empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0; + debug_assert_ne!(self.iter_size, 0); + self.iter_size -= 1; + debug_assert!(self.iter_size > 0 || empty); retval } } fn size_hint(&self) -> (usize, Option) { - (self.iter_size as usize, Some(self.iter_size as usize)) + (self.len(), Some(self.len())) + } + } + + impl<'py, T: 'py> ExactSizeIterator for NpyMultiIter<'py, T, $structure> { + fn len(&self) -> usize { + self.iter_size as usize } } }; } -// Helper functions for conversion #[inline(always)] unsafe fn expand2(dataptr: *mut *mut c_char) -> (*mut T, *mut T) { (*dataptr as *mut T, *dataptr.offset(1) as *mut T) diff --git a/src/readonly.rs b/src/readonly.rs index 7a25a0eb6..cce7d9f7a 100644 --- a/src/readonly.rs +++ b/src/readonly.rs @@ -1,9 +1,12 @@ //! Readonly arrays -use crate::npyffi::NPY_ARRAY_WRITEABLE; -use crate::{Element, NotContiguousError, NpyIndex, PyArray}; use ndarray::{ArrayView, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; use pyo3::{prelude::*, types::PyAny, AsPyPointer}; +use crate::npyffi::NPY_ARRAY_WRITEABLE; +#[allow(deprecated)] +use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, Readonly}; +use crate::{Element, NotContiguousError, NpyIndex, PyArray}; + /// Readonly reference of [`PyArray`](../array/struct.PyArray.html). /// /// This struct ensures that the internal array is not writeable while holding `PyReadonlyArray`. @@ -135,8 +138,12 @@ impl<'py, T: Element, D: Dimension> PyReadonlyArray<'py, T, D> { /// Iterates all elements of this array. /// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more. - pub fn iter(self) -> PyResult> { - crate::NpySingleIterBuilder::readonly(self).build() + #[deprecated( + note = "The wrappers of the array iterator API are deprecated, please use ndarray's `ArrayBase::iter` instead." + )] + #[allow(deprecated)] + pub fn iter(self) -> PyResult> { + NpySingleIterBuilder::readonly(self).build() } pub(crate) fn destruct(self) -> (&'py PyArray, bool) { diff --git a/tests/iter.rs b/tests/iter.rs index f0bb82c1a..664060156 100644 --- a/tests/iter.rs +++ b/tests/iter.rs @@ -1,6 +1,8 @@ +#![allow(deprecated)] + use ndarray::array; -use numpy::{NpyMultiIterBuilder, NpySingleIterBuilder, PyArray}; -use pyo3::PyResult; +use numpy::{pyarray, NpyMultiIterBuilder, NpySingleIterBuilder, PyArray}; +use pyo3::{PyResult, Python}; macro_rules! assert_approx_eq { ($x: expr, $y: expr) => { @@ -94,3 +96,48 @@ fn multiiter_rw() -> PyResult<()> { Ok(()) }) } + +#[test] +fn single_iter_size_hint_len() { + Python::with_gil(|py| { + let arr = pyarray![py, [0, 1], [2, 3], [4, 5]]; + + let mut iter = NpySingleIterBuilder::readonly(arr.readonly()) + .build() + .unwrap(); + + for len in (1..=6).rev() { + assert_eq!(iter.len(), len); + assert_eq!(iter.size_hint(), (len, Some(len))); + assert!(iter.next().is_some()); + } + + assert_eq!(iter.len(), 0); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert!(iter.next().is_none()); + }); +} + +#[test] +fn multi_iter_size_hint_len() { + Python::with_gil(|py| { + let arr1 = pyarray![py, [0, 1], [2, 3], [4, 5]]; + let arr2 = pyarray![py, [0, 0], [0, 0], [0, 0]]; + + let mut iter = NpyMultiIterBuilder::new() + .add_readonly(arr1.readonly()) + .add_readonly(arr2.readonly()) + .build() + .unwrap(); + + for len in (1..=6).rev() { + assert_eq!(iter.len(), len); + assert_eq!(iter.size_hint(), (len, Some(len))); + assert!(iter.next().is_some()); + } + + assert_eq!(iter.len(), 0); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert!(iter.next().is_none()); + }); +}