diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 7d5c7e9a5..d43a9e633 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -16,17 +16,9 @@ use imp_prelude::*; use arraytraits; use dimension; -use iterators; use error::{self, ShapeError, ErrorKind}; use dimension::IntoDimension; use dimension::{abs_index, axes_of, Axes, do_slice, merge_axes, stride_offset}; -use iterators::{ - new_lanes, - new_lanes_mut, - exact_chunks_of, - exact_chunks_mut_of, - windows -}; use zip::Zip; use { @@ -676,7 +668,7 @@ impl ArrayBase where S: Data, D: Dimension pub fn genrows(&self) -> Lanes { let mut n = self.ndim(); if n == 0 { n += 1; } - new_lanes(self.view(), Axis(n - 1)) + Lanes::new(self.view(), Axis(n - 1)) } /// Return a producer and iterable that traverses over the *generalized* @@ -688,7 +680,7 @@ impl ArrayBase where S: Data, D: Dimension { let mut n = self.ndim(); if n == 0 { n += 1; } - new_lanes_mut(self.view_mut(), Axis(n - 1)) + LanesMut::new(self.view_mut(), Axis(n - 1)) } /// Return a producer and iterable that traverses over the *generalized* @@ -718,7 +710,7 @@ impl ArrayBase where S: Data, D: Dimension /// } /// ``` pub fn gencolumns(&self) -> Lanes { - new_lanes(self.view(), Axis(0)) + Lanes::new(self.view(), Axis(0)) } /// Return a producer and iterable that traverses over the *generalized* @@ -728,7 +720,7 @@ impl ArrayBase where S: Data, D: Dimension pub fn gencolumns_mut(&mut self) -> LanesMut where S: DataMut { - new_lanes_mut(self.view_mut(), Axis(0)) + LanesMut::new(self.view_mut(), Axis(0)) } /// Return a producer and iterable that traverses over all 1D lanes @@ -760,7 +752,7 @@ impl ArrayBase where S: Data, D: Dimension /// assert_eq!(inner2.into_iter().next().unwrap(), aview1(&[0, 1, 2])); /// ``` pub fn lanes(&self, axis: Axis) -> Lanes { - new_lanes(self.view(), axis) + Lanes::new(self.view(), axis) } /// Return a producer and iterable that traverses over all 1D lanes @@ -770,7 +762,7 @@ impl ArrayBase where S: Data, D: Dimension pub fn lanes_mut(&mut self, axis: Axis) -> LanesMut where S: DataMut { - new_lanes_mut(self.view_mut(), axis) + LanesMut::new(self.view_mut(), axis) } @@ -819,7 +811,7 @@ impl ArrayBase where S: Data, D: Dimension pub fn axis_iter(&self, axis: Axis) -> AxisIter where D: RemoveAxis, { - iterators::new_axis_iter(self.view(), axis.index()) + AxisIter::new(self.view(), axis) } @@ -834,7 +826,7 @@ impl ArrayBase where S: Data, D: Dimension where S: DataMut, D: RemoveAxis, { - iterators::new_axis_iter_mut(self.view_mut(), axis.index()) + AxisIterMut::new(self.view_mut(), axis) } @@ -865,7 +857,7 @@ impl ArrayBase where S: Data, D: Dimension /// [[26, 27]]])); /// ``` pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter { - iterators::new_chunk_iter(self.view(), axis.index(), size) + AxisChunksIter::new(self.view(), axis, size) } /// Return an iterator that traverses over `axis` by chunks of `size`, @@ -878,7 +870,7 @@ impl ArrayBase where S: Data, D: Dimension -> AxisChunksIterMut where S: DataMut { - iterators::new_chunk_iter_mut(self.view_mut(), axis.index(), size) + AxisChunksIterMut::new(self.view_mut(), axis, size) } /// Return an exact chunks producer (and iterable). @@ -895,7 +887,7 @@ impl ArrayBase where S: Data, D: Dimension pub fn exact_chunks(&self, chunk_size: E) -> ExactChunks where E: IntoDimension, { - exact_chunks_of(self.view(), chunk_size) + ExactChunks::new(self.view(), chunk_size) } /// Return an exact chunks producer (and iterable). @@ -934,7 +926,7 @@ impl ArrayBase where S: Data, D: Dimension where E: IntoDimension, S: DataMut { - exact_chunks_mut_of(self.view_mut(), chunk_size) + ExactChunksMut::new(self.view_mut(), chunk_size) } /// Return a window producer and iterable. @@ -954,7 +946,7 @@ impl ArrayBase where S: Data, D: Dimension pub fn windows(&self, window_size: E) -> Windows where E: IntoDimension { - windows(self.view(), window_size) + Windows::new(self.view(), window_size) } // Return (length, stride) for diagonal @@ -1595,8 +1587,8 @@ impl ArrayBase where S: Data, D: Dimension // break the arrays up into their inner rows let n = self.ndim(); let dim = self.raw_dim(); - Zip::from(new_lanes_mut(self.view_mut(), Axis(n - 1))) - .and(new_lanes(rhs.broadcast_assume(dim), Axis(n - 1))) + Zip::from(LanesMut::new(self.view_mut(), Axis(n - 1))) + .and(Lanes::new(rhs.broadcast_assume(dim), Axis(n - 1))) .apply(move |s_row, r_row| { Zip::from(s_row).and(r_row).apply(|a, b| f(a, b)) }); diff --git a/src/impl_views.rs b/src/impl_views.rs index 7bd96ffeb..e1cd76466 100644 --- a/src/impl_views.rs +++ b/src/impl_views.rs @@ -23,8 +23,7 @@ use { Baseiter, }; -use iter; -use iterators; +use iter::{self, AxisIter, AxisIterMut}; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> @@ -469,7 +468,7 @@ impl<'a, A, D> ArrayView<'a, A, D> } #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter<'a, A, D> { + pub(crate) fn into_base_iter(self) -> Baseiter { unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } @@ -477,7 +476,7 @@ impl<'a, A, D> ArrayView<'a, A, D> #[inline] pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> { - ElementsBase { inner: self.into_base_iter() } + ElementsBase::new(self) } pub(crate) fn into_iter_(self) -> Iter<'a, A, D> { @@ -490,7 +489,7 @@ impl<'a, A, D> ArrayView<'a, A, D> pub fn into_outer_iter(self) -> iter::AxisIter<'a, A, D::Smaller> where D: RemoveAxis, { - iterators::new_outer_iter(self) + AxisIter::new(self, Axis(0)) } } @@ -519,7 +518,7 @@ impl<'a, A, D> ArrayViewMut<'a, A, D> } #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter<'a, A, D> { + pub(crate) fn into_base_iter(self) -> Baseiter { unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } @@ -527,7 +526,7 @@ impl<'a, A, D> ArrayViewMut<'a, A, D> #[inline] pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> { - ElementsBaseMut { inner: self.into_base_iter() } + ElementsBaseMut::new(self) } pub(crate) fn into_slice_(self) -> Result<&'a mut [A], Self> { @@ -550,7 +549,7 @@ impl<'a, A, D> ArrayViewMut<'a, A, D> pub fn into_outer_iter(self) -> iter::AxisIterMut<'a, A, D::Smaller> where D: RemoveAxis, { - iterators::new_outer_iter_mut(self) + AxisIterMut::new(self, Axis(0)) } } diff --git a/src/iterators/chunks.rs b/src/iterators/chunks.rs index 4fcaa8637..8dc03b795 100644 --- a/src/iterators/chunks.rs +++ b/src/iterators/chunks.rs @@ -38,26 +38,30 @@ pub struct ExactChunks<'a, A: 'a, D> { inner_strides: D, } -/// **Panics** if any chunk dimension is zero
-pub fn exact_chunks_of(mut a: ArrayView, chunk: E) -> ExactChunks - where D: Dimension, - E: IntoDimension, -{ - let chunk = chunk.into_dimension(); - ndassert!(a.ndim() == chunk.ndim(), - concat!("Chunk dimension {} does not match array dimension {} ", - "(with array of shape {:?})"), - chunk.ndim(), a.ndim(), a.shape()); - for i in 0..a.ndim() { - a.dim[i] /= chunk[i]; - } - let inner_strides = a.raw_strides(); - a.strides *= &chunk; +impl<'a, A, D: Dimension> ExactChunks<'a, A, D> { + /// Creates a new exact chunks producer. + /// + /// **Panics** if any chunk dimension is zero + pub(crate) fn new(mut a: ArrayView<'a, A, D>, chunk: E) -> Self + where + E: IntoDimension, + { + let chunk = chunk.into_dimension(); + ndassert!(a.ndim() == chunk.ndim(), + concat!("Chunk dimension {} does not match array dimension {} ", + "(with array of shape {:?})"), + chunk.ndim(), a.ndim(), a.shape()); + for i in 0..a.ndim() { + a.dim[i] /= chunk[i]; + } + let inner_strides = a.raw_strides(); + a.strides *= &chunk; - ExactChunks { - base: a, - chunk: chunk, - inner_strides: inner_strides, + ExactChunks { + base: a, + chunk: chunk, + inner_strides: inner_strides, + } } } @@ -117,27 +121,30 @@ pub struct ExactChunksMut<'a, A: 'a, D> { inner_strides: D, } -/// **Panics** if any chunk dimension is zero
-pub fn exact_chunks_mut_of(mut a: ArrayViewMut, chunk: E) - -> ExactChunksMut - where D: Dimension, - E: IntoDimension, -{ - let chunk = chunk.into_dimension(); - ndassert!(a.ndim() == chunk.ndim(), - concat!("Chunk dimension {} does not match array dimension {} ", - "(with array of shape {:?})"), - chunk.ndim(), a.ndim(), a.shape()); - for i in 0..a.ndim() { - a.dim[i] /= chunk[i]; - } - let inner_strides = a.raw_strides(); - a.strides *= &chunk; +impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> { + /// Creates a new exact chunks producer. + /// + /// **Panics** if any chunk dimension is zero + pub(crate) fn new(mut a: ArrayViewMut<'a, A, D>, chunk: E) -> Self + where + E: IntoDimension, + { + let chunk = chunk.into_dimension(); + ndassert!(a.ndim() == chunk.ndim(), + concat!("Chunk dimension {} does not match array dimension {} ", + "(with array of shape {:?})"), + chunk.ndim(), a.ndim(), a.shape()); + for i in 0..a.ndim() { + a.dim[i] /= chunk[i]; + } + let inner_strides = a.raw_strides(); + a.strides *= &chunk; - ExactChunksMut { - base: a, - chunk: chunk, - inner_strides: inner_strides, + ExactChunksMut { + base: a, + chunk: chunk, + inner_strides: inner_strides, + } } } diff --git a/src/iterators/lanes.rs b/src/iterators/lanes.rs index 05b079a80..4dc6b0a7d 100644 --- a/src/iterators/lanes.rs +++ b/src/iterators/lanes.rs @@ -1,3 +1,4 @@ +use std::marker::PhantomData; use imp_prelude::*; use {NdProducer, Layout}; @@ -30,29 +31,30 @@ pub struct Lanes<'a, A: 'a, D> { inner_stride: Ixs, } - -pub fn new_lanes(v: ArrayView, axis: Axis) - -> Lanes - where D: Dimension -{ - let ndim = v.ndim(); - let len; - let stride; - let iter_v; - if ndim == 0 { - len = 1; - stride = 1; - iter_v = v.try_remove_axis(Axis(0)) - } else { - let i = axis.index(); - len = v.dim[i]; - stride = v.strides[i] as isize; - iter_v = v.try_remove_axis(axis) - } - Lanes { - inner_len: len, - inner_stride: stride, - base: iter_v, +impl<'a, A, D: Dimension> Lanes<'a, A, D> { + pub(crate) fn new(v: ArrayView<'a, A, Di>, axis: Axis) -> Self + where + Di: Dimension, + { + let ndim = v.ndim(); + let len; + let stride; + let iter_v; + if ndim == 0 { + len = 1; + stride = 1; + iter_v = v.try_remove_axis(Axis(0)) + } else { + let i = axis.index(); + len = v.dim[i]; + stride = v.strides[i] as isize; + iter_v = v.try_remove_axis(axis) + } + Lanes { + inner_len: len, + inner_stride: stride, + base: iter_v, + } } } @@ -84,6 +86,7 @@ impl<'a, A, D> IntoIterator for Lanes<'a, A, D> iter: self.base.into_base_iter(), inner_len: self.inner_len, inner_stride: self.inner_stride, + life: PhantomData, } } } @@ -96,29 +99,30 @@ pub struct LanesMut<'a, A: 'a, D> { inner_stride: Ixs, } - -pub fn new_lanes_mut(v: ArrayViewMut, axis: Axis) - -> LanesMut - where D: Dimension -{ - let ndim = v.ndim(); - let len; - let stride; - let iter_v; - if ndim == 0 { - len = 1; - stride = 1; - iter_v = v.try_remove_axis(Axis(0)) - } else { - let i = axis.index(); - len = v.dim[i]; - stride = v.strides[i] as isize; - iter_v = v.try_remove_axis(axis) - } - LanesMut { - inner_len: len, - inner_stride: stride, - base: iter_v, +impl<'a, A, D: Dimension> LanesMut<'a, A, D> { + pub(crate) fn new(v: ArrayViewMut<'a, A, Di>, axis: Axis) -> Self + where + Di: Dimension, + { + let ndim = v.ndim(); + let len; + let stride; + let iter_v; + if ndim == 0 { + len = 1; + stride = 1; + iter_v = v.try_remove_axis(Axis(0)) + } else { + let i = axis.index(); + len = v.dim[i]; + stride = v.strides[i] as isize; + iter_v = v.try_remove_axis(axis) + } + LanesMut { + inner_len: len, + inner_stride: stride, + base: iter_v, + } } } @@ -132,6 +136,7 @@ impl<'a, A, D> IntoIterator for LanesMut<'a, A, D> iter: self.base.into_base_iter(), inner_len: self.inner_len, inner_stride: self.inner_stride, + life: PhantomData, } } } diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index c16e9f432..0e33c869a 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -29,64 +29,51 @@ use super::{ NdProducer, }; -pub use self::windows::{ - Windows, - windows -}; +pub use self::windows::Windows; pub use self::chunks::{ ExactChunks, ExactChunksIter, - exact_chunks_of, ExactChunksMut, ExactChunksIterMut, - exact_chunks_mut_of, }; pub use self::lanes::{ - new_lanes, - new_lanes_mut, Lanes, LanesMut, }; use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; -/// Base for array iterators +/// Base for iterators over all axes. /// -/// Iterator element type is `&'a A`. -pub struct Baseiter<'a, A: 'a, D> { - // Can have pub fields because it is not itself pub. - pub ptr: *mut A, - pub dim: D, - pub strides: D, - pub index: Option, - pub life: PhantomData<&'a A>, +/// Iterator element type is `*mut A`. +pub struct Baseiter { + ptr: *mut A, + dim: D, + strides: D, + index: Option, } -impl<'a, A, D: Dimension> Baseiter<'a, A, D> { - /// Creating a Baseiter is unsafe, because it can - /// have any lifetime, be immut or mut, and the - /// boundary and stride parameters need to be correct to - /// avoid memory unsafety. - /// - /// It must be placed in the correct mother iterator to be safe. - /// - /// NOTE: Mind the lifetime, it's arbitrary +impl Baseiter { + /// Creating a Baseiter is unsafe because shape and stride parameters need + /// to be correct to avoid performing an unsafe pointer offset while + /// iterating. #[inline] - pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<'a, A, D> { + pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter { Baseiter { ptr: ptr, index: len.first_index(), dim: len, strides: stride, - life: PhantomData, } } } -impl<'a, A, D: Dimension> Baseiter<'a, A, D> { +impl Iterator for Baseiter { + type Item = *mut A; + #[inline] - pub fn next(&mut self) -> Option<*mut A> { + fn next(&mut self) -> Option<*mut A> { let index = match self.index { None => return None, Some(ref ix) => ix.clone(), @@ -96,29 +83,9 @@ impl<'a, A, D: Dimension> Baseiter<'a, A, D> { unsafe { Some(self.ptr.offset(offset)) } } - #[inline] - fn next_ref(&mut self) -> Option<&'a A> { - unsafe { self.next().map(|p| &*p) } - } - - #[inline] - fn next_ref_mut(&mut self) -> Option<&'a mut A> { - unsafe { self.next().map(|p| &mut *p) } - } - - fn len(&self) -> usize { - match self.index { - None => 0, - Some(ref ix) => { - let gone = self.dim - .default_strides() - .slice() - .iter() - .zip(ix.slice().iter()) - .fold(0, |s, (&a, &b)| s + a as usize * b as usize); - self.dim.size() - gone - } - } + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) } fn fold(mut self, init: Acc, mut g: G) -> Acc @@ -149,7 +116,24 @@ impl<'a, A, D: Dimension> Baseiter<'a, A, D> { } } -impl<'a, A> Baseiter<'a, A, Ix1> { +impl<'a, A, D: Dimension> ExactSizeIterator for Baseiter { + fn len(&self) -> usize { + match self.index { + None => 0, + Some(ref ix) => { + let gone = self.dim + .default_strides() + .slice() + .iter() + .zip(ix.slice().iter()) + .fold(0, |s, (&a, &b)| s + a as usize * b as usize); + self.dim.size() - gone + } + } + } +} + +impl DoubleEndedIterator for Baseiter { #[inline] fn next_back(&mut self) -> Option<*mut A> { let index = match self.index { @@ -164,24 +148,13 @@ impl<'a, A> Baseiter<'a, A, Ix1> { unsafe { Some(self.ptr.offset(offset)) } } - - #[inline] - fn next_back_ref(&mut self) -> Option<&'a A> { - unsafe { self.next_back().map(|p| &*p) } - } - - #[inline] - fn next_back_ref_mut(&mut self) -> Option<&'a mut A> { - unsafe { self.next_back().map(|p| &mut *p) } - } } clone_bounds!( - ['a, A, D: Clone] - Baseiter['a, A, D] { + [A, D: Clone] + Baseiter[A, D] { @copy { ptr, - life, } dim, strides, @@ -193,21 +166,30 @@ clone_bounds!( ['a, A, D: Clone] ElementsBase['a, A, D] { @copy { + life, } inner, } ); +impl<'a, A, D: Dimension> ElementsBase<'a, A, D> { + pub fn new(v: ArrayView<'a, A, D>) -> Self { + ElementsBase { + inner: v.into_base_iter(), + life: PhantomData, + } + } +} + impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D> { type Item = &'a A; #[inline] fn next(&mut self) -> Option<&'a A> { - self.inner.next_ref() + self.inner.next().map(|p| unsafe { &*p }) } fn size_hint(&self) -> (usize, Option) { - let len = self.inner.len(); - (len, Some(len)) + self.inner.size_hint() } fn fold(self, init: Acc, mut g: G) -> Acc @@ -222,7 +204,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D> { impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1> { #[inline] fn next_back(&mut self) -> Option<&'a A> { - self.inner.next_back_ref() + self.inner.next_back().map(|p| unsafe { &*p }) } } @@ -308,7 +290,8 @@ pub struct Iter<'a, A: 'a, D> { /// Counted read only iterator pub struct ElementsBase<'a, A: 'a, D> { - pub inner: Baseiter<'a, A, D>, + inner: Baseiter, + life: PhantomData<&'a A>, } /// An iterator over the elements of an array (mutable). @@ -324,9 +307,18 @@ pub struct IterMut<'a, A: 'a, D> { /// /// Iterator element type is `&'a mut A`. pub struct ElementsBaseMut<'a, A: 'a, D> { - pub inner: Baseiter<'a, A, D>, + inner: Baseiter, + life: PhantomData<&'a mut A>, } +impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> { + pub fn new(v: ArrayViewMut<'a, A, D>) -> Self { + ElementsBaseMut { + inner: v.into_base_iter(), + life: PhantomData, + } + } +} /// An iterator over the indexes and elements of an array. /// @@ -397,15 +389,14 @@ impl<'a, A, D: Dimension> Iterator for IndexedIter<'a, A, D> { None => return None, Some(ref ix) => ix.clone(), }; - match self.0.inner.next_ref() { + match self.0.next() { None => None, - Some(p) => Some((index.into_pattern(), p)), + Some(elem) => Some((index.into_pattern(), elem)), } } fn size_hint(&self) -> (usize, Option) { - let len = self.0.inner.len(); - (len, Some(len)) + self.0.size_hint() } } @@ -454,12 +445,11 @@ impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> { type Item = &'a mut A; #[inline] fn next(&mut self) -> Option<&'a mut A> { - self.inner.next_ref_mut() + self.inner.next().map(|p| unsafe { &mut *p }) } fn size_hint(&self) -> (usize, Option) { - let len = self.inner.len(); - (len, Some(len)) + self.inner.size_hint() } fn fold(self, init: Acc, mut g: G) -> Acc @@ -474,7 +464,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> { impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> { #[inline] fn next_back(&mut self) -> Option<&'a mut A> { - self.inner.next_back_ref_mut() + self.inner.next_back().map(|p| unsafe { &mut *p }) } } @@ -495,15 +485,14 @@ impl<'a, A, D: Dimension> Iterator for IndexedIterMut<'a, A, D> { None => return None, Some(ref ix) => ix.clone(), }; - match self.0.inner.next_ref_mut() { + match self.0.next() { None => None, - Some(p) => Some((index.into_pattern(), p)), + Some(elem) => Some((index.into_pattern(), elem)), } } fn size_hint(&self) -> (usize, Option) { - let len = self.0.inner.len(); - (len, Some(len)) + self.0.size_hint() } } @@ -515,16 +504,29 @@ impl<'a, A, D> ExactSizeIterator for IndexedIterMut<'a, A, D> } } -/// An iterator that traverses over all dimensions but the innermost, -/// and yields each inner row. +/// An iterator that traverses over all axes but one, and yields a view for +/// each lane along that axis. /// /// See [`.lanes()`](../struct.ArrayBase.html#method.lanes) for more information. pub struct LanesIter<'a, A: 'a, D> { inner_len: Ix, inner_stride: Ixs, - iter: Baseiter<'a, A, D>, + iter: Baseiter, + life: PhantomData<&'a A>, } +clone_bounds!( + ['a, A, D: Clone] + LanesIter['a, A, D] { + @copy { + inner_len, + inner_stride, + life, + } + iter, + } +); + impl<'a, A, D> Iterator for LanesIter<'a, A, D> where D: Dimension { @@ -536,8 +538,7 @@ impl<'a, A, D> Iterator for LanesIter<'a, A, D> } fn size_hint(&self) -> (usize, Option) { - let len = self.iter.len(); - (len, Some(len)) + self.iter.size_hint() } } @@ -560,7 +561,8 @@ impl<'a, A, D> ExactSizeIterator for LanesIter<'a, A, D> pub struct LanesIterMut<'a, A: 'a, D> { inner_len: Ix, inner_stride: Ixs, - iter: Baseiter<'a, A, D>, + iter: Baseiter, + life: PhantomData<&'a mut A>, } impl<'a, A, D> Iterator for LanesIterMut<'a, A, D> @@ -576,8 +578,7 @@ impl<'a, A, D> Iterator for LanesIterMut<'a, A, D> } fn size_hint(&self) -> (usize, Option) { - let len = self.iter.len(); - (len, Some(len)) + self.iter.size_hint() } } @@ -590,7 +591,7 @@ impl<'a, A, D> ExactSizeIterator for LanesIterMut<'a, A, D> } #[derive(Debug)] -pub struct OuterIterCore { +pub struct AxisIterCore { index: Ix, len: Ix, stride: Ixs, @@ -601,7 +602,7 @@ pub struct OuterIterCore { clone_bounds!( [A, D: Clone] - OuterIterCore[A, D] { + AxisIterCore[A, D] { @copy { index, len, @@ -613,33 +614,62 @@ clone_bounds!( } ); -fn new_outer_core(v: ArrayBase, axis: usize) - -> OuterIterCore - where D: RemoveAxis, - S: Data -{ - let shape = v.shape()[axis]; - let stride = v.strides()[axis]; - - OuterIterCore { - index: 0, - len: shape, - stride: stride, - inner_dim: v.dim.remove_axis(Axis(axis)), - inner_strides: v.strides.remove_axis(Axis(axis)), - ptr: v.ptr, +impl AxisIterCore { + /// Constructs a new iterator over the specified axis. + fn new(v: ArrayBase, axis: Axis) -> Self + where + Di: RemoveAxis, + S: Data, + { + let shape = v.shape()[axis.index()]; + let stride = v.strides()[axis.index()]; + AxisIterCore { + index: 0, + len: shape, + stride: stride, + inner_dim: v.dim.remove_axis(axis), + inner_strides: v.strides.remove_axis(axis), + ptr: v.ptr, + } } -} -impl OuterIterCore { unsafe fn offset(&self, index: usize) -> *mut A { debug_assert!(index <= self.len, "index={}, len={}, stride={}", index, self.len, self.stride); self.ptr.offset(index as isize * self.stride) } + + /// Split the iterator at index, yielding two disjoint iterators. + /// + /// **Panics** if `index` is strictly greater than the iterator's length + fn split_at(self, index: usize) -> (Self, Self) { + assert!(index <= self.len); + let right_ptr = if index != self.len { + unsafe { self.offset(index) } + } else { + self.ptr + }; + let left = AxisIterCore { + index: 0, + len: index, + stride: self.stride, + inner_dim: self.inner_dim.clone(), + inner_strides: self.inner_strides.clone(), + ptr: self.ptr, + }; + let right = AxisIterCore { + index: 0, + len: self.len - index, + stride: self.stride, + inner_dim: self.inner_dim, + inner_strides: self.inner_strides, + ptr: right_ptr, + }; + (left, right) + } } -impl Iterator for OuterIterCore +impl Iterator for AxisIterCore where D: Dimension, { type Item = *mut A; @@ -660,7 +690,7 @@ impl Iterator for OuterIterCore } } -impl DoubleEndedIterator for OuterIterCore +impl DoubleEndedIterator for AxisIterCore where D: Dimension, { fn next_back(&mut self) -> Option { @@ -690,7 +720,7 @@ impl DoubleEndedIterator for OuterIterCore /// for more information. #[derive(Debug)] pub struct AxisIter<'a, A: 'a, D> { - iter: OuterIterCore, + iter: AxisIterCore, life: PhantomData<&'a A>, } @@ -704,54 +734,35 @@ clone_bounds!( } ); - -macro_rules! outer_iter_split_at_impl { - ($iter: ident) => ( - impl<'a, A, D> $iter<'a, A, D> - where D: Dimension - { - /// Split the iterator at index, yielding two disjoint iterators. - /// - /// *panics* if `index` is strictly greater than the iterator's length - pub fn split_at(self, index: Ix) - -> ($iter<'a, A, D>, $iter<'a, A, D>) - { - assert!(index <= self.iter.len); - let right_ptr = if index != self.iter.len { - unsafe { self.iter.offset(index) } - } - else { - self.iter.ptr - }; - let left = $iter { - iter: OuterIterCore { - index: 0, - len: index, - stride: self.iter.stride, - inner_dim: self.iter.inner_dim.clone(), - inner_strides: self.iter.inner_strides.clone(), - ptr: self.iter.ptr, - }, - life: self.life, - }; - let right = $iter { - iter: OuterIterCore { - index: 0, - len: self.iter.len - index, - stride: self.iter.stride, - inner_dim: self.iter.inner_dim, - inner_strides: self.iter.inner_strides, - ptr: right_ptr, - }, - life: self.life, - }; - (left, right) - } +impl<'a, A, D: Dimension> AxisIter<'a, A, D> { + /// Creates a new iterator over the specified axis. + pub(crate) fn new(v: ArrayView<'a, A, Di>, axis: Axis) -> Self + where + Di: RemoveAxis, + { + AxisIter { + iter: AxisIterCore::new(v, axis), + life: PhantomData, } - ) -} + } -outer_iter_split_at_impl!(AxisIter); + /// Split the iterator at index, yielding two disjoint iterators. + /// + /// **Panics** if `index` is strictly greater than the iterator's length + pub fn split_at(self, index: usize) -> (Self, Self) { + let (left, right) = self.iter.split_at(index); + ( + AxisIter { + iter: left, + life: self.life, + }, + AxisIter { + iter: right, + life: self.life, + }, + ) + } +} impl<'a, A, D> Iterator for AxisIter<'a, A, D> where D: Dimension @@ -791,26 +802,6 @@ impl<'a, A, D> ExactSizeIterator for AxisIter<'a, A, D> } } -pub fn new_outer_iter(v: ArrayView) -> AxisIter - where D: RemoveAxis -{ - AxisIter { - iter: new_outer_core(v, 0), - life: PhantomData, - } -} - -pub fn new_axis_iter(v: ArrayView, axis: usize) - -> AxisIter - where D: RemoveAxis -{ - AxisIter { - iter: new_outer_core(v, axis), - life: PhantomData, - } -} - - /// An iterator that traverses over an axis and /// and yields each subview (mutable) /// @@ -826,11 +817,39 @@ pub fn new_axis_iter(v: ArrayView, axis: usize) /// or [`.axis_iter_mut()`](../struct.ArrayBase.html#method.axis_iter_mut) /// for more information. pub struct AxisIterMut<'a, A: 'a, D> { - iter: OuterIterCore, + iter: AxisIterCore, life: PhantomData<&'a mut A>, } -outer_iter_split_at_impl!(AxisIterMut); +impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> { + /// Creates a new iterator over the specified axis. + pub(crate) fn new(v: ArrayViewMut<'a, A, Di>, axis: Axis) -> Self + where + Di: RemoveAxis, + { + AxisIterMut { + iter: AxisIterCore::new(v, axis), + life: PhantomData, + } + } + + /// Split the iterator at index, yielding two disjoint iterators. + /// + /// **Panics** if `index` is strictly greater than the iterator's length + pub fn split_at(self, index: usize) -> (Self, Self) { + let (left, right) = self.iter.split_at(index); + ( + AxisIterMut { + iter: left, + life: self.life, + }, + AxisIterMut { + iter: right, + life: self.life, + }, + ) + } +} impl<'a, A, D> Iterator for AxisIterMut<'a, A, D> where D: Dimension @@ -870,25 +889,6 @@ impl<'a, A, D> ExactSizeIterator for AxisIterMut<'a, A, D> } } -pub fn new_outer_iter_mut(v: ArrayViewMut) -> AxisIterMut - where D: RemoveAxis -{ - AxisIterMut { - iter: new_outer_core(v, 0), - life: PhantomData, - } -} - -pub fn new_axis_iter_mut(v: ArrayViewMut, axis: usize) - -> AxisIterMut - where D: RemoveAxis -{ - AxisIterMut { - iter: new_outer_core(v, axis), - life: PhantomData, - } -} - impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> { type Item = ::Item; @@ -994,7 +994,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> /// /// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information. pub struct AxisChunksIter<'a, A: 'a, D> { - iter: OuterIterCore, + iter: AxisIterCore, last_ptr: *mut A, last_dim: D, life: PhantomData<&'a A>, @@ -1012,9 +1012,10 @@ clone_bounds!( } ); -fn chunk_iter_parts(v: ArrayView, axis: usize, size: usize) - -> (OuterIterCore, *mut A, D) +fn chunk_iter_parts(v: ArrayView, axis: Axis, size: usize) + -> (AxisIterCore, *mut A, D) { + let axis = axis.index(); let axis_len = v.shape()[axis]; let size = if size > axis_len { axis_len } else { size }; let last_index = axis_len / size; @@ -1036,7 +1037,7 @@ fn chunk_iter_parts(v: ArrayView, axis: usize, size: usiz else { v.ptr }; - let iter = OuterIterCore { + let iter = AxisIterCore { index: 0, len: shape, stride: stride, @@ -1048,17 +1049,15 @@ fn chunk_iter_parts(v: ArrayView, axis: usize, size: usiz (iter, last_ptr, last_dim) } -pub fn new_chunk_iter(v: ArrayView, axis: usize, size: usize) - -> AxisChunksIter - where D: Dimension -{ - let (iter, last_ptr, last_dim) = chunk_iter_parts(v, axis, size); - - AxisChunksIter { - iter: iter, - last_ptr: last_ptr, - last_dim: last_dim, - life: PhantomData, +impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> { + pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self { + let (iter, last_ptr, last_dim) = chunk_iter_parts(v, axis, size); + AxisChunksIter { + iter: iter, + last_ptr: last_ptr, + last_dim: last_dim, + life: PhantomData, + } } } @@ -1131,23 +1130,21 @@ macro_rules! chunk_iter_impl { /// See [`.axis_chunks_iter_mut()`](../struct.ArrayBase.html#method.axis_chunks_iter_mut) /// for more information. pub struct AxisChunksIterMut<'a, A: 'a, D> { - iter: OuterIterCore, + iter: AxisIterCore, last_ptr: *mut A, last_dim: D, life: PhantomData<&'a mut A>, } -pub fn new_chunk_iter_mut(v: ArrayViewMut, axis: usize, size: usize) - -> AxisChunksIterMut - where D: Dimension -{ - let (iter, last_ptr, last_dim) = chunk_iter_parts(v.into_view(), axis, size); - - AxisChunksIterMut { - iter: iter, - last_ptr: last_ptr, - last_dim: last_dim, - life: PhantomData, +impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> { + pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self { + let (iter, last_ptr, last_dim) = chunk_iter_parts(v.into_view(), axis, size); + AxisChunksIterMut { + iter: iter, + last_ptr: last_ptr, + last_dim: last_dim, + life: PhantomData, + } } } diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index 28e922c60..418a4d1d2 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -15,30 +15,31 @@ pub struct Windows<'a, A: 'a, D> { strides: D, } -pub fn windows(a: ArrayView, window_size: E) -> Windows - where D: Dimension, - E: IntoDimension, -{ - let window = window_size.into_dimension(); - ndassert!(a.ndim() == window.ndim(), - concat!("Window dimension {} does not match array dimension {} ", - "(with array of shape {:?})"), - window.ndim(), a.ndim(), a.shape()); - let mut size = a.dim; - for (sz, &ws) in size.slice_mut().iter_mut().zip(window.slice()) +impl<'a, A, D: Dimension> Windows<'a, A, D> { + pub(crate) fn new(a: ArrayView<'a, A, D>, window_size: E) -> Self + where + E: IntoDimension, { - if ws == 0 { panic!("window-size must not be zero!"); } - // cannot use std::cmp::max(0, ..) since arithmetic underflow panics - *sz = if *sz < ws { 0 } else { *sz - ws + 1 }; - } + let window = window_size.into_dimension(); + ndassert!(a.ndim() == window.ndim(), + concat!("Window dimension {} does not match array dimension {} ", + "(with array of shape {:?})"), + window.ndim(), a.ndim(), a.shape()); + let mut size = a.dim; + for (sz, &ws) in size.slice_mut().iter_mut().zip(window.slice()) { + assert_ne!(ws, 0, "window-size must not be zero!"); + // cannot use std::cmp::max(0, ..) since arithmetic underflow panics + *sz = if *sz < ws { 0 } else { *sz - ws + 1 }; + } - let window_strides = a.strides.clone(); + let window_strides = a.strides.clone(); - unsafe { - Windows { - base: ArrayView::from_shape_ptr(size.clone().strides(a.strides), a.ptr), - window: window, - strides: window_strides, + unsafe { + Windows { + base: ArrayView::from_shape_ptr(size.clone().strides(a.strides), a.ptr), + window: window, + strides: window_strides, + } } } } diff --git a/src/lib.rs b/src/lib.rs index 9868b0c53..2550f1c9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,7 +122,7 @@ pub use error::{ShapeError, ErrorKind}; pub use slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex}; use iterators::Baseiter; -use iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut}; +use iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, LanesMut}; pub use arraytraits::AsArray; pub use linalg_traits::{LinalgScalar, NdFloat}; @@ -877,7 +877,7 @@ impl ArrayBase fn inner_rows(&self) -> iterators::Lanes { let n = self.ndim(); - iterators::new_lanes(self.view(), Axis(n.saturating_sub(1))) + Lanes::new(self.view(), Axis(n.saturating_sub(1))) } /// n-d generalization of rows, just like inner iter @@ -885,7 +885,7 @@ impl ArrayBase where S: DataMut { let n = self.ndim(); - iterators::new_lanes_mut(self.view_mut(), Axis(n.saturating_sub(1))) + LanesMut::new(self.view_mut(), Axis(n.saturating_sub(1))) } }