Skip to content

Commit 03552e2

Browse files
authored
Merge pull request #574 from jturner314/optimize-fold
Improve performance of .fold()
2 parents 55aca3b + 1157763 commit 03552e2

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

src/dimension/axes.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ impl<'a, D> Iterator for Axes<'a, D>
7575
}
7676
}
7777

78+
fn fold<B, F>(self, init: B, f: F) -> B
79+
where
80+
F: FnMut(B, AxisDescription) -> B,
81+
{
82+
(self.start..self.end)
83+
.map(move |i| AxisDescription(Axis(i), self.dim[i], self.strides[i] as isize))
84+
.fold(init, f)
85+
}
86+
7887
fn size_hint(&self) -> (usize, Option<usize>) {
7988
let len = self.end - self.start;
8089
(len, Some(len))

src/impl_methods.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,13 +1854,25 @@ where
18541854
} else {
18551855
let mut v = self.view();
18561856
// put the narrowest axis at the last position
1857-
if v.ndim() > 1 {
1858-
let last = v.ndim() - 1;
1859-
let narrow_axis = v.axes()
1860-
.filter(|ax| ax.len() > 1)
1861-
.min_by_key(|ax| ax.stride().abs())
1862-
.map_or(last, |ax| ax.axis().index());
1863-
v.swap_axes(last, narrow_axis);
1857+
match v.ndim() {
1858+
0 | 1 => {}
1859+
2 => {
1860+
if self.len_of(Axis(1)) <= 1
1861+
|| self.len_of(Axis(0)) > 1
1862+
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
1863+
{
1864+
v.swap_axes(0, 1);
1865+
}
1866+
}
1867+
n => {
1868+
let last = n - 1;
1869+
let narrow_axis = v
1870+
.axes()
1871+
.filter(|ax| ax.len() > 1)
1872+
.min_by_key(|ax| ax.stride().abs())
1873+
.map_or(last, |ax| ax.axis().index());
1874+
v.swap_axes(last, narrow_axis);
1875+
}
18641876
}
18651877
v.into_elements_base().fold(init, f)
18661878
}

0 commit comments

Comments
 (0)