Skip to content

Commit 1d8c671

Browse files
Merge branch 'master' into bulk-quantiles
2 parents d5ab45c + d838ee7 commit 1d8c671

File tree

3 files changed

+186
-0
lines changed

3 files changed

+186
-0
lines changed

src/maybe_nan/mod.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,15 @@ where
241241
A: 'a,
242242
F: FnMut(B, &'a A::NotNan) -> B;
243243

244+
/// Traverse the non-NaN elements and their indices and apply a fold,
245+
/// returning the resulting value.
246+
///
247+
/// Elements are visited in arbitrary order.
248+
fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B
249+
where
250+
A: 'a,
251+
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B;
252+
244253
/// Visit each non-NaN element in the array by calling `f` on each element.
245254
///
246255
/// Elements are visited in arbitrary order.
@@ -302,6 +311,20 @@ where
302311
})
303312
}
304313

314+
fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B
315+
where
316+
A: 'a,
317+
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B,
318+
{
319+
self.indexed_iter().fold(init, |acc, (idx, elem)| {
320+
if let Some(not_nan) = elem.try_as_not_nan() {
321+
f(acc, (idx, not_nan))
322+
} else {
323+
acc
324+
}
325+
})
326+
}
327+
305328
fn visit_skipnan<'a, F>(&'a self, mut f: F)
306329
where
307330
A: 'a,

src/quantile/mod.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,33 @@ where
4242
where
4343
A: PartialOrd;
4444

45+
/// Finds the index of the minimum value of the array skipping NaN values.
46+
///
47+
/// Returns `None` if the array is empty or none of the values in the array
48+
/// are non-NaN values.
49+
///
50+
/// Even if there are multiple (equal) elements that are minima, only one
51+
/// index is returned. (Which one is returned is unspecified and may depend
52+
/// on the memory layout of the array.)
53+
///
54+
/// # Example
55+
///
56+
/// ```
57+
/// extern crate ndarray;
58+
/// extern crate ndarray_stats;
59+
///
60+
/// use ndarray::array;
61+
/// use ndarray_stats::QuantileExt;
62+
///
63+
/// let a = array![[::std::f64::NAN, 3., 5.],
64+
/// [2., 0., 6.]];
65+
/// assert_eq!(a.argmin_skipnan(), Some((1, 1)));
66+
/// ```
67+
fn argmin_skipnan(&self) -> Option<D::Pattern>
68+
where
69+
A: MaybeNan,
70+
A::NotNan: Ord;
71+
4572
/// Finds the elementwise minimum of the array.
4673
///
4774
/// Returns `None` if any of the pairwise orderings tested by the function
@@ -100,6 +127,33 @@ where
100127
where
101128
A: PartialOrd;
102129

130+
/// Finds the index of the maximum value of the array skipping NaN values.
131+
///
132+
/// Returns `None` if the array is empty or none of the values in the array
133+
/// are non-NaN values.
134+
///
135+
/// Even if there are multiple (equal) elements that are maxima, only one
136+
/// index is returned. (Which one is returned is unspecified and may depend
137+
/// on the memory layout of the array.)
138+
///
139+
/// # Example
140+
///
141+
/// ```
142+
/// extern crate ndarray;
143+
/// extern crate ndarray_stats;
144+
///
145+
/// use ndarray::array;
146+
/// use ndarray_stats::QuantileExt;
147+
///
148+
/// let a = array![[::std::f64::NAN, 3., 5.],
149+
/// [2., 0., 6.]];
150+
/// assert_eq!(a.argmax_skipnan(), Some((1, 2)));
151+
/// ```
152+
fn argmax_skipnan(&self) -> Option<D::Pattern>
153+
where
154+
A: MaybeNan,
155+
A::NotNan: Ord;
156+
103157
/// Finds the elementwise maximum of the array.
104158
///
105159
/// Returns `None` if any of the pairwise orderings tested by the function
@@ -230,6 +284,28 @@ where
230284
Some(current_pattern_min)
231285
}
232286

287+
fn argmin_skipnan(&self) -> Option<D::Pattern>
288+
where
289+
A: MaybeNan,
290+
A::NotNan: Ord,
291+
{
292+
let mut pattern_min = D::zeros(self.ndim()).into_pattern();
293+
let min = self.indexed_fold_skipnan(None, |current_min, (pattern, elem)| {
294+
Some(match current_min {
295+
Some(m) if (m <= elem) => m,
296+
_ => {
297+
pattern_min = pattern;
298+
elem
299+
}
300+
})
301+
});
302+
if min.is_some() {
303+
Some(pattern_min)
304+
} else {
305+
None
306+
}
307+
}
308+
233309
fn min(&self) -> Option<&A>
234310
where
235311
A: PartialOrd,
@@ -272,6 +348,28 @@ where
272348
Some(current_pattern_max)
273349
}
274350

351+
fn argmax_skipnan(&self) -> Option<D::Pattern>
352+
where
353+
A: MaybeNan,
354+
A::NotNan: Ord,
355+
{
356+
let mut pattern_max = D::zeros(self.ndim()).into_pattern();
357+
let max = self.indexed_fold_skipnan(None, |current_max, (pattern, elem)| {
358+
Some(match current_max {
359+
Some(m) if m >= elem => m,
360+
_ => {
361+
pattern_max = pattern;
362+
elem
363+
}
364+
})
365+
});
366+
if max.is_some() {
367+
Some(pattern_max)
368+
} else {
369+
None
370+
}
371+
}
372+
275373
fn max(&self) -> Option<&A>
276374
where
277375
A: PartialOrd,

tests/quantile.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,37 @@ quickcheck! {
3636
}
3737
}
3838

39+
#[test]
40+
fn test_argmin_skipnan() {
41+
let a = array![[1., 5., 3.], [2., 0., 6.]];
42+
assert_eq!(a.argmin_skipnan(), Some((1, 1)));
43+
44+
let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
45+
assert_eq!(a.argmin_skipnan(), Some((0, 0)));
46+
47+
let a = array![[::std::f64::NAN, 5., 3.], [2., ::std::f64::NAN, 6.]];
48+
assert_eq!(a.argmin_skipnan(), Some((1, 0)));
49+
50+
let a: Array2<f64> = array![[], []];
51+
assert_eq!(a.argmin_skipnan(), None);
52+
53+
let a = arr2(&[[::std::f64::NAN; 2]; 2]);
54+
assert_eq!(a.argmin_skipnan(), None);
55+
}
56+
57+
quickcheck! {
58+
fn argmin_skipnan_matches_min_skipnan(data: Vec<Option<i32>>) -> bool {
59+
let a = Array1::from(data);
60+
let min = a.min_skipnan();
61+
let argmin = a.argmin_skipnan();
62+
if min.is_none() {
63+
argmin == None
64+
} else {
65+
a[argmin.unwrap()] == *min
66+
}
67+
}
68+
}
69+
3970
#[test]
4071
fn test_min() {
4172
let a = array![[1, 5, 3], [2, 0, 6]];
@@ -85,6 +116,40 @@ quickcheck! {
85116
}
86117
}
87118

119+
#[test]
120+
fn test_argmax_skipnan() {
121+
let a = array![[1., 5., 3.], [2., 0., 6.]];
122+
assert_eq!(a.argmax_skipnan(), Some((1, 2)));
123+
124+
let a = array![[1., 5., 3.], [2., ::std::f64::NAN, ::std::f64::NAN]];
125+
assert_eq!(a.argmax_skipnan(), Some((0, 1)));
126+
127+
let a = array![
128+
[::std::f64::NAN, ::std::f64::NAN, 3.],
129+
[2., ::std::f64::NAN, 6.]
130+
];
131+
assert_eq!(a.argmax_skipnan(), Some((1, 2)));
132+
133+
let a: Array2<f64> = array![[], []];
134+
assert_eq!(a.argmax_skipnan(), None);
135+
136+
let a = arr2(&[[::std::f64::NAN; 2]; 2]);
137+
assert_eq!(a.argmax_skipnan(), None);
138+
}
139+
140+
quickcheck! {
141+
fn argmax_skipnan_matches_max_skipnan(data: Vec<Option<i32>>) -> bool {
142+
let a = Array1::from(data);
143+
let max = a.max_skipnan();
144+
let argmax = a.argmax_skipnan();
145+
if max.is_none() {
146+
argmax == None
147+
} else {
148+
a[argmax.unwrap()] == *max
149+
}
150+
}
151+
}
152+
88153
#[test]
89154
fn test_max() {
90155
let a = array![[1, 5, 7], [2, 0, 6]];

0 commit comments

Comments
 (0)