Skip to content

Commit 3d8b8b9

Browse files
author
Kunming Jiang
committed
Added parallel version for interleave
1 parent 1c504a5 commit 3d8b8b9

File tree

1 file changed

+83
-2
lines changed

1 file changed

+83
-2
lines changed

mpcs/src/lib.rs

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#![deny(clippy::cargo)]
22
use ff_ext::ExtensionField;
3-
use itertools::{interleave, Either, Itertools};
3+
use itertools::{Either, Itertools};
44
use multilinear_extensions::{mle::{DenseMultilinearExtension, FieldType, MultilinearExtension}, virtual_poly::{build_eq_x_r, eq_eval, VPAuxInfo}};
55
use serde::{Serialize, de::DeserializeOwned};
66
use std::fmt::Debug;
@@ -10,6 +10,8 @@ use p3_field::PrimeCharacteristicRing;
1010
use multilinear_extensions::virtual_poly::VirtualPolynomial;
1111
use sumcheck::structs::{IOPProof, IOPProverState, IOPVerifierState};
1212
use witness::RowMajorMatrix;
13+
#[cfg(feature = "parallel")]
14+
use rayon::prelude::*;
1315

1416
pub mod sum_check;
1517
pub mod util;
@@ -172,7 +174,8 @@ fn interleave_polys<E: ExtensionField>(
172174
// Interleave the polys give their position on the binary tree
173175
// Assume the polys are sorted by decreasing size
174176
// Denote: N - size of the interleaved poly; M - num of polys
175-
// This function performs interleave in O(M) + O(N) time and is *potentially* parallelizable (maybe? idk)
177+
// This function performs interleave in O(M) + O(N) time
178+
#[cfg(not(feature = "parallel"))]
176179
fn interleave_polys<E: ExtensionField>(
177180
polys: Vec<&DenseMultilinearExtension<E>>,
178181
comps: &Vec<Vec<bool>>,
@@ -223,6 +226,84 @@ fn interleave_polys<E: ExtensionField>(
223226
DenseMultilinearExtension { num_vars: interleaved_num_vars, evaluations: interleaved_evaluations }
224227
}
225228

229+
// Parallel version: divide interleaved_evaluation into chunks
230+
#[cfg(feature = "parallel")]
231+
fn interleave_polys<E: ExtensionField>(
232+
polys: Vec<&DenseMultilinearExtension<E>>,
233+
comps: &Vec<Vec<bool>>,
234+
) -> DenseMultilinearExtension<E> {
235+
use std::cmp::min;
236+
237+
assert!(polys.len() > 0);
238+
let sizes: Vec<usize> = polys.iter().map(|p| p.evaluations.len()).collect();
239+
let interleaved_size = sizes.iter().sum::<usize>().next_power_of_two();
240+
let interleaved_num_vars = interleaved_size.ilog2() as usize;
241+
242+
// Compute Start and Gap for each poly
243+
// * Start: where's its first entry in the interleaved poly?
244+
// * Gap: how many entires are between its consecutive entries in the interleaved poly?
245+
let start_list: Vec<usize> = comps.iter().map(|comp| {
246+
let mut start = 0;
247+
let mut pow_2 = 1;
248+
for b in comp {
249+
start += if *b { pow_2 } else { 0 };
250+
pow_2 *= 2;
251+
}
252+
start
253+
}).collect();
254+
let gap_list: Vec<usize> = polys.iter().map(|poly|
255+
1 << (interleaved_num_vars - poly.num_vars)
256+
).collect();
257+
// Minimally each chunk needs one entry from the smallest poly
258+
let num_chunks = min(rayon::current_num_threads().next_power_of_two(), sizes[sizes.len() - 1]);
259+
let interleaved_chunk_size = interleaved_size / num_chunks;
260+
// Length of the poly each thread processes
261+
let poly_chunk_size: Vec<usize> = sizes.iter().map(|s| s / num_chunks).collect();
262+
263+
// Initialize the interleaved poly
264+
// Is there a better way to deal with field types?
265+
let interleaved_evaluations = match polys[0].evaluations {
266+
FieldType::Base(_) => {
267+
let mut interleaved_eval = vec![E::BaseField::ZERO; interleaved_size];
268+
interleaved_eval.par_chunks_exact_mut(interleaved_chunk_size).enumerate().for_each(|(i, chunk)| {
269+
for (p, poly) in polys.iter().enumerate() {
270+
match &poly.evaluations {
271+
FieldType::Base(pe) => {
272+
// Each thread processes a chunk of pe
273+
for (j, e) in pe[i * poly_chunk_size[p]..(i+1) * poly_chunk_size[p]].iter().enumerate() {
274+
chunk[start_list[p] + gap_list[p] * j] = *e;
275+
}
276+
}
277+
b => panic!("do not support merge BASE field type with b: {:?}", b)
278+
}
279+
}
280+
});
281+
FieldType::Base(interleaved_eval)
282+
}
283+
FieldType::Ext(_) => {
284+
let mut interleaved_eval = vec![E::ZERO; interleaved_size];
285+
interleaved_eval.par_chunks_exact_mut(num_chunks).enumerate().for_each(|(i, chunk)| {
286+
for (p, poly) in polys.iter().enumerate() {
287+
match &poly.evaluations {
288+
FieldType::Ext(pe) => {
289+
// Each thread processes a chunk of pe
290+
for (j, e) in pe[i * poly_chunk_size[p]..(i+1) * poly_chunk_size[p]].iter().enumerate() {
291+
chunk[start_list[p] + gap_list[p] * j] = *e;
292+
}
293+
}
294+
b => panic!("do not support merge EXT field type with b: {:?}", b)
295+
}
296+
}
297+
});
298+
FieldType::Ext(interleaved_eval)
299+
}
300+
_ => unreachable!()
301+
};
302+
303+
DenseMultilinearExtension { num_vars: interleaved_num_vars, evaluations: interleaved_evaluations }
304+
}
305+
306+
226307
// Pack polynomials of different sizes into the same, returns
227308
// 0: A list of packed polys
228309
// 1: The final packed poly, if of different size

0 commit comments

Comments
 (0)