Skip to content

Commit 2bf459c

Browse files
tamirdphimuemue
authored andcommitted
GroupMap: add fold_with
This is a generalization of `fold` which takes a function rather than a value, which removes the need for a `Clone` bound. `fold` is implemented in terms of `fold_with`.
1 parent f00e3ae commit 2bf459c

File tree

2 files changed

+75
-5
lines changed

2 files changed

+75
-5
lines changed

src/grouping_map.rs

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,50 @@ where
115115
destination_map
116116
}
117117

118+
/// Groups elements from the `GroupingMap` source by key and applies `operation` to the elements
119+
/// of each group sequentially, passing the previously accumulated value, a reference to the key
120+
/// and the current element as arguments, and stores the results in a new map.
121+
///
122+
/// `init` is called to obtain the initial value of each accumulator.
123+
///
124+
/// `operation` is a function that is invoked on each element with the following parameters:
125+
/// - the current value of the accumulator of the group;
126+
/// - a reference to the key of the group this element belongs to;
127+
/// - the element from the source being accumulated.
128+
///
129+
/// Return a `HashMap` associating the key of each group with the result of folding that group's elements.
130+
///
131+
/// ```
132+
/// use itertools::Itertools;
133+
///
134+
/// #[derive(Debug, Default)]
135+
/// struct Accumulator {
136+
/// acc: usize,
137+
/// }
138+
///
139+
/// let lookup = (1..=7)
140+
/// .into_grouping_map_by(|&n| n % 3)
141+
/// .fold_with(|_key| Default::default(), |Accumulator { acc }, _key, val| {
142+
/// let acc = acc + val;
143+
/// Accumulator { acc }
144+
/// });
145+
///
146+
/// assert_eq!(lookup[&0].acc, 3 + 6);
147+
/// assert_eq!(lookup[&1].acc, 1 + 4 + 7);
148+
/// assert_eq!(lookup[&2].acc, 2 + 5);
149+
/// assert_eq!(lookup.len(), 3);
150+
/// ```
151+
pub fn fold_with<FI, FO, R>(self, mut init: FI, mut operation: FO) -> HashMap<K, R>
152+
where
153+
FI: FnMut(&K) -> R,
154+
FO: FnMut(R, &K, V) -> R,
155+
{
156+
self.aggregate(|acc, key, val| {
157+
let acc = acc.unwrap_or_else(|| init(key));
158+
Some(operation(acc, key, val))
159+
})
160+
}
161+
118162
/// Groups elements from the `GroupingMap` source by key and applies `operation` to the elements
119163
/// of each group sequentially, passing the previously accumulated value, a reference to the key
120164
/// and the current element as arguments, and stores the results in a new map.
@@ -140,15 +184,12 @@ where
140184
/// assert_eq!(lookup[&2], 2 + 5);
141185
/// assert_eq!(lookup.len(), 3);
142186
/// ```
143-
pub fn fold<FO, R>(self, init: R, mut operation: FO) -> HashMap<K, R>
187+
pub fn fold<FO, R>(self, init: R, operation: FO) -> HashMap<K, R>
144188
where
145189
R: Clone,
146190
FO: FnMut(R, &K, V) -> R,
147191
{
148-
self.aggregate(|acc, key, val| {
149-
let acc = acc.unwrap_or_else(|| init.clone());
150-
Some(operation(acc, key, val))
151-
})
192+
self.fold_with(|_: &K| init.clone(), operation)
152193
}
153194

154195
/// Groups elements from the `GroupingMap` source by key and applies `operation` to the elements

tests/quick.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,35 @@ quickcheck! {
14721472
}
14731473
}
14741474

1475+
fn correct_grouping_map_by_fold_with_modulo_key(a: Vec<u8>, modulo: u8) -> () {
1476+
#[derive(Debug, Default, PartialEq)]
1477+
struct Accumulator {
1478+
acc: u64,
1479+
}
1480+
1481+
let modulo = if modulo == 0 { 1 } else { modulo } as u64; // Avoid `% 0`
1482+
let lookup = a.iter().map(|&b| b as u64) // Avoid overflows
1483+
.into_grouping_map_by(|i| i % modulo)
1484+
.fold_with(|_key| Default::default(), |Accumulator { acc }, &key, val| {
1485+
assert!(val % modulo == key);
1486+
let acc = acc + val;
1487+
Accumulator { acc }
1488+
});
1489+
1490+
let group_map_lookup = a.iter()
1491+
.map(|&b| b as u64)
1492+
.map(|i| (i % modulo, i))
1493+
.into_group_map()
1494+
.into_iter()
1495+
.map(|(key, vals)| (key, vals.into_iter().sum())).map(|(key, acc)| (key,Accumulator { acc }))
1496+
.collect::<HashMap<_,_>>();
1497+
assert_eq!(lookup, group_map_lookup);
1498+
1499+
for (&key, &Accumulator { acc: sum }) in lookup.iter() {
1500+
assert_eq!(sum, a.iter().map(|&b| b as u64).filter(|&val| val % modulo == key).sum::<u64>());
1501+
}
1502+
}
1503+
14751504
fn correct_grouping_map_by_fold_modulo_key(a: Vec<u8>, modulo: u8) -> () {
14761505
let modulo = if modulo == 0 { 1 } else { modulo } as u64; // Avoid `% 0`
14771506
let lookup = a.iter().map(|&b| b as u64) // Avoid overflows

0 commit comments

Comments
 (0)