From b98e30b2d59d30df7de7238bc29ba36903026e32 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sat, 2 Feb 2019 22:31:07 -0500 Subject: [PATCH 1/5] Add more benchmarks of sum/sum_axis --- benches/numeric.rs | 46 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/benches/numeric.rs b/benches/numeric.rs index 76d07f1e4..5a7c111a5 100644 --- a/benches/numeric.rs +++ b/benches/numeric.rs @@ -1,7 +1,7 @@ #![feature(test)] extern crate test; -use test::Bencher; +use test::{black_box, Bencher}; extern crate ndarray; use ndarray::prelude::*; @@ -65,6 +65,38 @@ fn contiguous_sum_1e2(bench: &mut Bencher) }); } +#[bench] +fn contiguous_sum_ix3_1e2(bench: &mut Bencher) +{ + let n = 1e2 as usize; + let a = Array::linspace(-1e6, 1e6, n * n * n) + .into_shape([n, n, n]) + .unwrap(); + bench.iter(|| black_box(&a).sum()); +} + +#[bench] +fn inner_discontiguous_sum_ix3_1e2(bench: &mut Bencher) +{ + let n = 1e2 as usize; + let a = Array::linspace(-1e6, 1e6, n * n * 2*n) + .into_shape([n, n, 2*n]) + .unwrap(); + let v = a.slice(s![.., .., ..;2]); + bench.iter(|| black_box(&v).sum()); +} + +#[bench] +fn middle_discontiguous_sum_ix3_1e2(bench: &mut Bencher) +{ + let n = 1e2 as usize; + let a = Array::linspace(-1e6, 1e6, n * 2*n * n) + .into_shape([n, 2*n, n]) + .unwrap(); + let v = a.slice(s![.., ..;2, ..]); + bench.iter(|| black_box(&v).sum()); +} + #[bench] fn sum_by_row_1e4(bench: &mut Bencher) { @@ -88,3 +120,15 @@ fn sum_by_col_1e4(bench: &mut Bencher) a.sum_axis(Axis(1)) }); } + +#[bench] +fn sum_by_middle_1e2(bench: &mut Bencher) +{ + let n = 1e2 as usize; + let a = Array::linspace(-1e6, 1e6, n * n * n) + .into_shape([n, n, n]) + .unwrap(); + bench.iter(|| { + a.sum_axis(Axis(1)) + }); +} From ed88e2e9a8445c1f9dccf1d408f6a79ea1ace17b Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sat, 2 Feb 2019 23:32:56 -0500 Subject: [PATCH 2/5] Improve performance of iterator_pairwise_sum --- src/numeric_util.rs | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/src/numeric_util.rs b/src/numeric_util.rs index bcd1080e8..23966b9a9 100644 --- a/src/numeric_util.rs +++ b/src/numeric_util.rs @@ -56,16 +56,18 @@ where I: Iterator, A: Clone + Add + Zero, { - let mut partial_sums = vec![]; - let mut partial_sum = A::zero(); - for (i, x) in iter.enumerate() { - partial_sum = partial_sum + x.clone(); - if i % NAIVE_SUM_THRESHOLD == NAIVE_SUM_THRESHOLD - 1 { + let (len, _) = iter.size_hint(); + let cap = len.saturating_sub(1) / NAIVE_SUM_THRESHOLD + 1; // ceiling of division + let mut partial_sums = Vec::with_capacity(cap); + let (_, last_sum) = iter.fold((0, A::zero()), |(count, partial_sum), x| { + if count < NAIVE_SUM_THRESHOLD { + (count + 1, partial_sum + x.clone()) + } else { partial_sums.push(partial_sum); - partial_sum = A::zero(); + (1, x.clone()) } - } - partial_sums.push(partial_sum); + }); + partial_sums.push(last_sum); pure_pairwise_sum(&partial_sums) } @@ -205,3 +207,17 @@ pub fn unrolled_eq(xs: &[A], ys: &[A]) -> bool true } + +#[cfg(test)] +mod tests { + use quickcheck::quickcheck; + use std::num::Wrapping; + use super::iterator_pairwise_sum; + + quickcheck! { + fn iterator_pairwise_sum_is_correct(xs: Vec) -> bool { + let xs: Vec<_> = xs.into_iter().map(|x| Wrapping(x)).collect(); + iterator_pairwise_sum(xs.iter()) == xs.iter().sum() + } + } +} From e7835ee671eb547bab955251eeda9bcc6a75f38f Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sat, 2 Feb 2019 23:40:02 -0500 Subject: [PATCH 3/5] Make sum pairwise over all dimensions --- src/numeric/impl_numeric.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 15ddd789e..7c417a1aa 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -33,17 +33,10 @@ impl ArrayBase where A: Clone + Add + num_traits::Zero, { if let Some(slc) = self.as_slice_memory_order() { - return numeric_util::pairwise_sum(&slc) - } - let mut sum = A::zero(); - for row in self.inner_rows() { - if let Some(slc) = row.as_slice() { - sum = sum + numeric_util::pairwise_sum(&slc); - } else { - sum = sum + numeric_util::iterator_pairwise_sum(row.iter()); - } + numeric_util::pairwise_sum(&slc) + } else { + numeric_util::iterator_pairwise_sum(self.iter()) } - sum } /// Return the sum of all elements in the array. From 8301c25bb242bd788bbaf6359a30d2e286c3753a Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sat, 2 Feb 2019 23:59:30 -0500 Subject: [PATCH 4/5] Implement contiguous sum_axis in terms of Zip This is slightly faster and is easier to understand. --- src/numeric/impl_numeric.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 7c417a1aa..0dd273525 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -8,7 +8,6 @@ use std::ops::{Add, Div, Mul}; use num_traits::{self, Zero, Float, FromPrimitive}; -use itertools::free::enumerate; use crate::imp_prelude::*; use crate::numeric_util; @@ -97,14 +96,12 @@ impl ArrayBase D: RemoveAxis, { let n = self.len_of(axis); - let stride = self.strides()[axis.index()]; - if self.ndim() == 2 && stride == 1 { + if self.stride_of(axis) == 1 { // contiguous along the axis we are summing let mut res = Array::zeros(self.raw_dim().remove_axis(axis)); - let ax = axis.index(); - for (i, elt) in enumerate(&mut res) { - *elt = self.index_axis(Axis(1 - ax), i).sum(); - } + Zip::from(&mut res) + .and(self.lanes(axis)) + .apply(|sum, lane| *sum = lane.sum()); res } else if self.len_of(axis) <= numeric_util::NAIVE_SUM_THRESHOLD { self.fold_axis(axis, A::zero(), |acc, x| acc.clone() + x.clone()) From 82453dfc73a217acc24fa85aa7f9108e28da5713 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 3 Feb 2019 00:02:24 -0500 Subject: [PATCH 5/5] Remove redundant len_of call --- src/numeric/impl_numeric.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 0dd273525..0be4a0ce0 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -103,7 +103,7 @@ impl ArrayBase .and(self.lanes(axis)) .apply(|sum, lane| *sum = lane.sum()); res - } else if self.len_of(axis) <= numeric_util::NAIVE_SUM_THRESHOLD { + } else if n <= numeric_util::NAIVE_SUM_THRESHOLD { self.fold_axis(axis, A::zero(), |acc, x| acc.clone() + x.clone()) } else { let (v1, v2) = self.view().split_at(axis, n / 2);