From 4733e39abe0ae08c0bf0f7e5489eeeed29f8c9b9 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 20 Mar 2025 11:00:06 +0800 Subject: [PATCH 1/6] wip --- sumcheck/src/prover.rs | 112 +++++++++++++++++++++++--------------- sumcheck/src/structs.rs | 2 + sumcheck_macro/src/lib.rs | 3 +- 3 files changed, 72 insertions(+), 45 deletions(-) diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 6859f38e9..f5c9ea679 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -329,12 +329,15 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { end_timer!(start); let max_degree = polynomial.aux_info.max_degree; + let num_polys = polynomial.flattened_ml_extensions.len(); assert!(extrapolation_aux.len() == max_degree - 1); Self { + max_num_variables: polynomial.aux_info.max_num_variables, challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, extrapolation_aux, + poly_index_is_bind: vec![false; num_polys], } } @@ -381,30 +384,44 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { self.challenges.push(chal); let r = self.challenges[self.round - 1]; - if self.challenges.len() == 1 { - self.poly.flattened_ml_extensions.iter_mut().for_each(|f| { - if f.num_vars() > 0 { - *f = Arc::new(f.fix_variables(&[r.elements])); - } else { - panic!("calling sumcheck on constant") - } - }); - } else { - self.poly - .flattened_ml_extensions - .iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, - // which can be +5% overhead rust docs doen't explain the - // reason - .map(Arc::get_mut) - .for_each(|f| { + let expected_numvars_at_round = self.expected_numvars_at_round(); + println!("num poly {}", self.poly.flattened_ml_extensions.len()); + self.poly_index_is_bind + .par_iter_mut() + .zip(self.poly.flattened_ml_extensions.par_iter_mut()) + .for_each(|(has_bind, poly)| { + if *has_bind { + // in place + let f = Arc::get_mut(poly); + if let Some(f) = f { - if f.num_vars() > 0 { - f.fix_variables_in_place(&[r.elements]); + debug_assert!(f.num_vars() <= expected_numvars_at_round); + debug_assert!(f.num_vars() > 0); + println!("binded f.num_vars() {}", f.num_vars()); + if f.num_vars() == expected_numvars_at_round { + println!( + "has bind at round {}, num_var {}", + self.round, + f.num_vars() + ); + f.fix_variables_in_place_parallel(&[r.elements]) } + }; + } else { + debug_assert!(poly.num_vars() <= expected_numvars_at_round); + debug_assert!(poly.num_vars() > 0); + println!("f.num_vars() {}", poly.num_vars()); + if expected_numvars_at_round == poly.num_vars() { + println!( + "what bind at round {}, num_var {}", + self.round, + poly.num_vars() + ); + *poly = Arc::new(poly.fix_variables_parallel(&[r.elements])); + *has_bind = true; } - }); - } + } + }); } exit_span!(span); // end_timer!(fix_argument); @@ -485,6 +502,12 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }) .collect() } + pub fn expected_numvars_at_round(&self) -> usize { + // first round start from 1 + let num_vars = self.max_num_variables + 1 - self.round; + debug_assert!(num_vars > 0, "make sumcheck work on constant"); + num_vars + } } /// parallel version @@ -588,10 +611,13 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { ); let max_degree = polynomial.aux_info.max_degree; + let num_polys = polynomial.flattened_ml_extensions.len(); let prover_state = Self { + max_num_variables: polynomial.aux_info.max_num_variables, challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, + poly_index_is_bind: vec![false; num_polys], extrapolation_aux: (1..max_degree) .map(|degree| { let points = (0..1 + degree as u64).map(E::from_u64).collect::>(); @@ -644,33 +670,31 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { self.challenges.push(chal); let r = self.challenges[self.round - 1]; - if self.challenges.len() == 1 { - self.poly - .flattened_ml_extensions - .par_iter_mut() - .for_each(|f| { - if f.num_vars() > 0 { - *f = Arc::new(f.fix_variables_parallel(&[r.elements])); - } else { - panic!("calling sumcheck on constant") - } - }); - } else { - self.poly - .flattened_ml_extensions - .par_iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, - // which can be +5% overhead rust docs doen't explain the - // reason - .map(Arc::get_mut) - .for_each(|f| { + let expected_numvars_at_round = self.expected_numvars_at_round(); + + self.poly_index_is_bind + .par_iter_mut() + .zip(self.poly.flattened_ml_extensions.par_iter_mut()) + .for_each(|(has_bind, poly)| { + if *has_bind { + // in place + let f = Arc::get_mut(poly); if let Some(f) = f { - if f.num_vars() > 0 { + debug_assert!(f.num_vars() <= expected_numvars_at_round); + debug_assert!(f.num_vars() > 0); + if f.num_vars() == expected_numvars_at_round { f.fix_variables_in_place_parallel(&[r.elements]) } + }; + } else { + debug_assert!(poly.num_vars() <= expected_numvars_at_round); + debug_assert!(poly.num_vars() > 0); + if expected_numvars_at_round == poly.num_vars() { + *poly = Arc::new(poly.fix_variables_parallel(&[r.elements])); + *has_bind = true; } - }); - } + } + }); } exit_span!(span); // end_timer!(fix_argument); diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index a910f368f..b4794a283 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -44,6 +44,8 @@ pub struct IOPProverState<'a, E: ExtensionField> { /// points with precomputed barycentric weights for extrapolating smaller /// degree uni-polys to `max_degree + 1` evaluations. pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, + pub(crate) max_num_variables: usize, + pub(crate) poly_index_is_bind: Vec, } /// Prover State of a PolyIOP diff --git a/sumcheck_macro/src/lib.rs b/sumcheck_macro/src/lib.rs index 428500261..a62f6997c 100644 --- a/sumcheck_macro/src/lib.rs +++ b/sumcheck_macro/src/lib.rs @@ -239,7 +239,8 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr } else { res }; - let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(v1.len()).max(1) + self.round - 1); + println!("self.expected_numvars_in_current_round() {}, ceil_log2(v1.len()) {}", self.expected_numvars_at_round(), ceil_log2(v1.len())); + let num_vars_multiplicity = self.expected_numvars_at_round() - ceil_log2(v1.len()); // +1 due to we already bind, thus if num_vars_multiplicity > 0 { AdditiveArray(res.0.map(|e| e * E::BaseField::from_u64(1 << num_vars_multiplicity))) } else { From 387d23147077b7b56db98ffd8c3570a22ff62e70 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 20 Mar 2025 11:27:30 +0800 Subject: [PATCH 2/6] refactor sumcheck boilerplate code by fix_var --- multilinear_extensions/src/virtual_polys.rs | 1 - sumcheck/src/prover.rs | 183 +++++++------------- sumcheck/src/structs.rs | 2 + sumcheck/src/test.rs | 27 +-- 4 files changed, 61 insertions(+), 152 deletions(-) diff --git a/multilinear_extensions/src/virtual_polys.rs b/multilinear_extensions/src/virtual_polys.rs index 80dba2fcf..46d6dad20 100644 --- a/multilinear_extensions/src/virtual_polys.rs +++ b/multilinear_extensions/src/virtual_polys.rs @@ -18,7 +18,6 @@ pub struct VirtualPolynomials<'a, E: ExtensionField> { impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { pub fn new(num_threads: usize, max_num_variables: usize) -> Self { - println!("ceil_log2(num_threads) {}", ceil_log2(num_threads)); VirtualPolynomials { num_threads, polys: (0..num_threads) diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 6859f38e9..13bfeba15 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -5,10 +5,7 @@ use crossbeam_channel::bounded; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ - mle::{DenseMultilinearExtension, FieldType, MultilinearExtension}, - op_mle, - util::largest_even_below, - virtual_poly::VirtualPolynomial, + mle::FieldType, op_mle, util::largest_even_below, virtual_poly::VirtualPolynomial, }; use rayon::{ Scope, @@ -114,16 +111,8 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { if let Some(p) = challenge { prover_state.challenges.push(p); // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - let mle = Arc::get_mut(mle).unwrap(); - if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]); - } - }); + prover_state.fix_var(p.elements); + tx_prover_state .send(Some((thread_id, prover_state))) .unwrap(); @@ -183,29 +172,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { if let Some(p) = challenge { prover_state.challenges.push(p); // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - if num_variables == 1 { - // first time fix variable should be create new instance - if mle.num_vars() > 0 { - *mle = mle.fix_variables(&[p.elements]).into(); - } else { - *mle = - Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - mle.get_base_field_vec().to_vec(), - )) - } - } else { - let mle = Arc::get_mut(mle).unwrap(); - if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]); - } - } - }); + prover_state.fix_var(p.elements); tx_prover_state .send(Some((main_thread_id, prover_state))) .unwrap(); @@ -280,21 +247,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { if let Some(p) = challenge { prover_state.challenges.push(p); // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each( - |mle: &mut Arc< - dyn MultilinearExtension>, - >| { - if mle.num_vars() > 0 { - Arc::get_mut(mle) - .unwrap() - .fix_variables_in_place(&[p.elements]); - } - }, - ); + prover_state.fix_var(p.elements); }; exit_span!(span); @@ -330,11 +283,13 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let max_degree = polynomial.aux_info.max_degree; assert!(extrapolation_aux.len() == max_degree - 1); + let num_polys = polynomial.flattened_ml_extensions.len(); Self { challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, extrapolation_aux, + poly_index_fixvar_in_place: vec![false; num_polys], } } @@ -381,30 +336,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { self.challenges.push(chal); let r = self.challenges[self.round - 1]; - if self.challenges.len() == 1 { - self.poly.flattened_ml_extensions.iter_mut().for_each(|f| { - if f.num_vars() > 0 { - *f = Arc::new(f.fix_variables(&[r.elements])); - } else { - panic!("calling sumcheck on constant") - } - }); - } else { - self.poly - .flattened_ml_extensions - .iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, - // which can be +5% overhead rust docs doen't explain the - // reason - .map(Arc::get_mut) - .for_each(|f| { - if let Some(f) = f { - if f.num_vars() > 0 { - f.fix_variables_in_place(&[r.elements]); - } - } - }); - } + self.fix_var(r.elements); } exit_span!(span); // end_timer!(fix_argument); @@ -485,6 +417,29 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }) .collect() } + + /// fix_var + pub fn fix_var(&mut self, r: E) { + self.poly_index_fixvar_in_place + .iter_mut() + .zip_eq(self.poly.flattened_ml_extensions.iter_mut()) + .for_each(|(has_fixvar_in_place, poly)| { + if *has_fixvar_in_place { + // in place + let poly = Arc::get_mut(poly); + if let Some(f) = poly { + if f.num_vars() > 0 { + f.fix_variables_in_place(&[r]) + } + }; + } else if poly.num_vars() > 0 { + *poly = Arc::new(poly.fix_variables(&[r])); + *has_fixvar_in_place = true; + } else { + panic!("calling sumcheck on constant") + } + }); + } } /// parallel version @@ -538,28 +493,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { if let Some(p) = challenge { prover_state.challenges.push(p); // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .par_iter_mut() - .for_each(|mle| { - if num_variables == 1 { - // first time fix variable should be create new instance - if mle.num_vars() > 0 { - *mle = mle.fix_variables(&[p.elements]).into(); - } else { - *mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - mle.get_base_field_vec().to_vec(), - )) - } - } else { - let mle = Arc::get_mut(mle).unwrap(); - if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]); - } - } - }); + prover_state.fix_var_parallel(p.elements); }; exit_span!(span); @@ -588,6 +522,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { ); let max_degree = polynomial.aux_info.max_degree; + let num_polys = polynomial.flattened_ml_extensions.len(); let prover_state = Self { challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, @@ -599,6 +534,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { (points, weights) }) .collect(), + poly_index_fixvar_in_place: vec![false; num_polys], }; end_timer!(start); @@ -644,33 +580,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { self.challenges.push(chal); let r = self.challenges[self.round - 1]; - if self.challenges.len() == 1 { - self.poly - .flattened_ml_extensions - .par_iter_mut() - .for_each(|f| { - if f.num_vars() > 0 { - *f = Arc::new(f.fix_variables_parallel(&[r.elements])); - } else { - panic!("calling sumcheck on constant") - } - }); - } else { - self.poly - .flattened_ml_extensions - .par_iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, - // which can be +5% overhead rust docs doen't explain the - // reason - .map(Arc::get_mut) - .for_each(|f| { - if let Some(f) = f { - if f.num_vars() > 0 { - f.fix_variables_in_place_parallel(&[r.elements]) - } - } - }); - } + self.fix_var(r.elements); } exit_span!(span); // end_timer!(fix_argument); @@ -728,4 +638,27 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { evaluations: products_sum, } } + + /// fix_var + pub fn fix_var_parallel(&mut self, r: E) { + self.poly_index_fixvar_in_place + .par_iter_mut() + .zip_eq(self.poly.flattened_ml_extensions.par_iter_mut()) + .for_each(|(has_fixvar_in_place, poly)| { + if *has_fixvar_in_place { + // in place + let poly = Arc::get_mut(poly); + if let Some(f) = poly { + if f.num_vars() > 0 { + f.fix_variables_in_place_parallel(&[r]) + } + }; + } else if poly.num_vars() > 0 { + *poly = Arc::new(poly.fix_variables_parallel(&[r])); + *has_fixvar_in_place = true; + } else { + panic!("calling sumcheck on constant") + } + }); + } } diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index a910f368f..959eb2766 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -44,6 +44,8 @@ pub struct IOPProverState<'a, E: ExtensionField> { /// points with precomputed barycentric weights for extrapolating smaller /// degree uni-polys to `max_degree + 1` evaluations. pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, + /// record poly should fix variable in place or not + pub(crate) poly_index_fixvar_in_place: Vec, } /// Prover State of a PolyIOP diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 87a0f4e79..1638bf4d9 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::{ structs::{IOPProverState, IOPVerifierState}, util::interpolate_uni_poly, @@ -7,12 +5,10 @@ use crate::{ use ark_std::{rand::RngCore, test_rng}; use ff_ext::{ExtensionField, FromUniformBytes, GoldilocksExt2}; use multilinear_extensions::{ - mle::DenseMultilinearExtension, virtual_poly::{VPAuxInfo, VirtualPolynomial}, virtual_polys::VirtualPolynomials, }; use p3_field::PrimeCharacteristicRing; -use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use transcript::{BasicTranscript, Transcript}; #[test] @@ -144,28 +140,7 @@ fn test_sumcheck_internal( if let Some(p) = challenge { prover_state.challenges.push(p); // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .par_iter_mut() - .for_each(|mle| { - if num_variables == 1 { - // first time fix variable should be create new instance - if mle.num_vars() > 0 { - *mle = mle.fix_variables(&[p.elements]).into(); - } else { - *mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - mle.get_base_field_vec().to_vec(), - )) - } - } else { - let mle = Arc::get_mut(mle).unwrap(); - if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]); - } - } - }); + prover_state.fix_var(p.elements); }; let subclaim = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &asserted_sum); assert!( From 758cc8935e64f390ee2187f1040661295add67c1 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 20 Mar 2025 15:11:24 +0800 Subject: [PATCH 3/6] rename --- sumcheck/src/prover.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index aac041912..30b07a0cd 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -423,8 +423,8 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { self.poly_index_fixvar_in_place .iter_mut() .zip_eq(self.poly.flattened_ml_extensions.iter_mut()) - .for_each(|(has_fixvar_in_place, poly)| { - if *has_fixvar_in_place { + .for_each(|(can_fixvar_in_place, poly)| { + if *can_fixvar_in_place { // in place let poly = Arc::get_mut(poly); if let Some(f) = poly { @@ -434,7 +434,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; } else if poly.num_vars() > 0 { *poly = Arc::new(poly.fix_variables(&[r])); - *has_fixvar_in_place = true; + *can_fixvar_in_place = true; } else { panic!("calling sumcheck on constant") } @@ -644,8 +644,8 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { self.poly_index_fixvar_in_place .par_iter_mut() .zip_eq(self.poly.flattened_ml_extensions.par_iter_mut()) - .for_each(|(has_fixvar_in_place, poly)| { - if *has_fixvar_in_place { + .for_each(|(can_fixvar_in_place, poly)| { + if *can_fixvar_in_place { // in place let poly = Arc::get_mut(poly); if let Some(f) = poly { @@ -655,7 +655,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; } else if poly.num_vars() > 0 { *poly = Arc::new(poly.fix_variables_parallel(&[r])); - *has_fixvar_in_place = true; + *can_fixvar_in_place = true; } else { panic!("calling sumcheck on constant") } From fe0207efe647fbd385d9c53be98e05ced009c662 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 20 Mar 2025 22:12:19 +0800 Subject: [PATCH 4/6] work and test pass --- Cargo.lock | 1 + multilinear_extensions/src/virtual_polys.rs | 1 + sumcheck/src/prover.rs | 21 --------- sumcheck/src/test.rs | 16 ++++--- sumcheck_macro/Cargo.toml | 1 + sumcheck_macro/examples/expand.rs | 17 +++++-- sumcheck_macro/src/lib.rs | 52 ++++++++++++++++++--- 7 files changed, 72 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 47fcbec86..380bb0cdd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2701,6 +2701,7 @@ dependencies = [ "p3", "proc-macro2", "quote", + "rand", "sumcheck", "syn 2.0.98", ] diff --git a/multilinear_extensions/src/virtual_polys.rs b/multilinear_extensions/src/virtual_polys.rs index 46d6dad20..13cd70e06 100644 --- a/multilinear_extensions/src/virtual_polys.rs +++ b/multilinear_extensions/src/virtual_polys.rs @@ -18,6 +18,7 @@ pub struct VirtualPolynomials<'a, E: ExtensionField> { impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { pub fn new(num_threads: usize, max_num_variables: usize) -> Self { + debug_assert!(num_threads > 0); VirtualPolynomials { num_threads, polys: (0..num_threads) diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index a8f5cc6ee..0518fc016 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -345,22 +345,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) - // - // To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars, - // for it evaluation value we need to times 2^(max_num_vars - num_vars) - // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n - // For i round univariate poly, f^i(x) - // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds - // = \sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n' - // = 2^(|b| - |b1|) * \sum_b1 f_1(r, 0, b1) + \sum_b f_2(r, 0, b) - // same applied on f^i[1] - // It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value let span = entered_span!("products_sum"); let AdditiveVec(products_sum) = self.poly.products.iter().fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), |mut products_sum, (coefficient, products)| { let span = entered_span!("sum"); - let f = &self.poly.flattened_ml_extensions; let mut sum: Vec = match products.len() { 1 => sumcheck_code_gen!(1, false, |i| &f[products[i]]).to_vec(), @@ -444,11 +433,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; } else if poly.num_vars() > 0 { if expected_numvars_at_round == poly.num_vars() { - println!( - "what bind at round {}, num_var {}", - self.round, - poly.num_vars() - ); *poly = Arc::new(poly.fix_variables(&[r])); *can_fixvar_in_place = true; } @@ -673,11 +657,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; } else if poly.num_vars() > 0 { if expected_numvars_at_round == poly.num_vars() { - println!( - "what bind at round {}, num_var {}", - self.round, - poly.num_vars() - ); *poly = Arc::new(poly.fix_variables_parallel(&[r])); *can_fixvar_in_place = true; } diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index f3dd5afe4..4e6cbfac1 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -5,6 +5,7 @@ use crate::{ use ark_std::{rand::RngCore, test_rng}; use ff_ext::{ExtensionField, FromUniformBytes, GoldilocksExt2}; use multilinear_extensions::{ + util::max_usable_threads, virtual_poly::{VPAuxInfo, VirtualPolynomial}, virtual_polys::VirtualPolynomials, }; @@ -13,17 +14,19 @@ use transcript::{BasicTranscript, Transcript}; #[test] fn test_sumcheck_with_different_degree() { - let nv = vec![4, 5]; // test polynomial mixed with different num_var - test_sumcheck_with_different_degree_helper::(nv); + // test polynomial mixed with different num_var + let nv = vec![3, 4, 5]; + let num_polys = nv.len(); + for num_threads in 1..num_polys.min(max_usable_threads()) { + test_sumcheck_with_different_degree_helper::(num_threads, &nv); + } } -fn test_sumcheck_with_different_degree_helper(nv: Vec) { +fn test_sumcheck_with_different_degree_helper(num_threads: usize, nv: &[usize]) { let mut rng = test_rng(); let degree = 2; let num_multiplicands_range = (degree, degree + 1); let num_products = 1; - // TODO investigate error when num_threads > 1 - let num_threads = 1; let mut transcript = BasicTranscript::::new(b"test"); let max_num_variables = *nv.iter().max().unwrap(); @@ -69,10 +72,11 @@ fn test_sumcheck_with_different_degree_helper(nv: Vec) .map(|c| c.elements) .collect::>(); assert_eq!(r.len(), max_num_variables); + // r are right alignment assert!( input_polys .iter() - .map(|(poly, _)| { poly.evaluate(&r[..poly.aux_info.max_num_variables]) }) + .map(|(poly, _)| { poly.evaluate(&r[r.len() - poly.aux_info.max_num_variables..]) }) .sum::() == subclaim.expected_evaluation, "wrong subclaim" diff --git a/sumcheck_macro/Cargo.toml b/sumcheck_macro/Cargo.toml index e1f842693..214ed0d14 100644 --- a/sumcheck_macro/Cargo.toml +++ b/sumcheck_macro/Cargo.toml @@ -18,6 +18,7 @@ p3 = { path = "../p3" } proc-macro2 = "1.0.92" quote = "1.0" syn = { version = "2.0", features = ["full"] } +rand.workspace = true [dev-dependencies] ff_ext = { path = "../ff_ext" } diff --git a/sumcheck_macro/examples/expand.rs b/sumcheck_macro/examples/expand.rs index 97e3412a3..80641be5e 100644 --- a/sumcheck_macro/examples/expand.rs +++ b/sumcheck_macro/examples/expand.rs @@ -2,22 +2,21 @@ /// ```sh /// cargo expand --example expand /// ``` -use ff_ext::ExtensionField; -use ff_ext::GoldilocksExt2; +use ff_ext::{ExtensionField, GoldilocksExt2}; use multilinear_extensions::{ mle::FieldType, util::largest_even_below, virtual_poly::VirtualPolynomial, }; use p3::field::PrimeCharacteristicRing; +use rand::rngs::OsRng; use sumcheck::util::{AdditiveArray, ceil_log2}; #[derive(Default)] struct Container<'a, E: ExtensionField> { poly: VirtualPolynomial<'a, E>, - round: usize, } fn main() { - let c = Container::::default(); + let c = Container::::new(); c.run(); } @@ -26,4 +25,14 @@ impl Container<'_, E> { let _result: AdditiveArray<_, 4> = sumcheck_macro::sumcheck_code_gen!(3, false, |_| &self.poly.flattened_ml_extensions[0]); } + + pub fn expected_numvars_at_round(&self) -> usize { + 1 + } + + pub fn new() -> Self { + Self { + poly: VirtualPolynomial::random(3, (4, 5), 2, &mut OsRng).0, + } + } } diff --git a/sumcheck_macro/src/lib.rs b/sumcheck_macro/src/lib.rs index 821eb8b3a..1afb84742 100644 --- a/sumcheck_macro/src/lib.rs +++ b/sumcheck_macro/src/lib.rs @@ -219,28 +219,68 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr }; let iter = if parallalize { - quote! {.into_par_iter().step_by(2).with_min_len(64)} + quote! {.into_par_iter().step_by(2).rev().with_min_len(64)} } else { quote! {.step_by(2).rev()} }; // Generate the final AdditiveArray expression. + + // special case: generate product for polynomial num_var less than current expected num_var + // which happened when we batching sumcheck with different num_vars + let product = mul_exprs( + (1..=degree) + .map(|j: u32| { + let v = ident(format!("v{j}")); + quote! {#v[b]} + }) + .collect(), + ); + let degree_plus_one = (degree + 1) as usize; quote! { - let res = (0..largest_even_below(v1.len())) + // To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars + // we actually need to have a full sum, times 2^(bh_num_vars - num_vars) to accumulate into univariate computation + // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n + // For i < n - n', to compute univariate poly, f^i(x), b is i-th round boolean hypercube + // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds + // = \sum_b f_1(b) + f_2(r, 0, b) + // = 2^(|b| - |b1|) * \sum_b1 f_1(b1) + \sum_b f_2(r, 0, b) + // b1 is suffix alignment with b + // same applied on f^i[1], f^i[2], ... f^i[degree + 1] + // It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value + + // NOTE: current method work in suffix alignment order + let num_var = ceil_log2(v1.len()); + let expected_numvars_at_round = self.expected_numvars_at_round(); + let res = if num_var < expected_numvars_at_round { + // TODO optimize by caching computed result for later round reuse + // need to figure out how to cache in one place to support base/extension field + let sum = (0..largest_even_below(v1.len())).map( + |b| { + #product + }, + ).sum(); + AdditiveArray::<_, #degree_plus_one>([sum; #degree_plus_one]) + } else { + (0..largest_even_below(v1.len())) #iter .map(|b| { #additive_array_items }) - .sum::>(); + .sum::>() + }; let res = if v1.len() == 1 { let b = 0; AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one]) } else { res }; - println!("self.expected_numvars_in_current_round() {}, ceil_log2(v1.len()) {}", self.expected_numvars_at_round(), ceil_log2(v1.len())); - let num_vars_multiplicity = self.expected_numvars_at_round() - ceil_log2(v1.len()); + + // calculate multiplicity term + // minus one because when expected num of var is i, the boolean hypercube dimension only i-1 + let num_vars_multiplicity = self.expected_numvars_at_round().saturating_sub(1).saturating_sub(num_var); + if num_vars_multiplicity > 0 { AdditiveArray(res.0.map(|e| e * E::BaseField::from_u64(1 << num_vars_multiplicity))) } else { @@ -315,7 +355,7 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr // Generate the second match statement that maps f vars to AdditiveArray. out = quote! { { - #out + #out match (#match_input) { #match_arms _ => unreachable!(), From be756bbe33c77aa599425e2b2998dd55ec4c73b2 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 21 Mar 2025 00:33:31 +0800 Subject: [PATCH 5/6] code cosmetics --- sumcheck_macro/src/lib.rs | 45 +++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/sumcheck_macro/src/lib.rs b/sumcheck_macro/src/lib.rs index 1afb84742..1e090c2dc 100644 --- a/sumcheck_macro/src/lib.rs +++ b/sumcheck_macro/src/lib.rs @@ -253,40 +253,35 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr // NOTE: current method work in suffix alignment order let num_var = ceil_log2(v1.len()); let expected_numvars_at_round = self.expected_numvars_at_round(); - let res = if num_var < expected_numvars_at_round { + if num_var < expected_numvars_at_round { // TODO optimize by caching computed result for later round reuse // need to figure out how to cache in one place to support base/extension field - let sum = (0..largest_even_below(v1.len())).map( + let mut sum = (0..largest_even_below(v1.len())).map( |b| { #product }, ).sum(); + // calculate multiplicity term + // minus one because when expected num of var is n_i, the boolean hypercube dimension only n_i-1 + let num_vars_multiplicity = self.expected_numvars_at_round().saturating_sub(1).saturating_sub(num_var); + if num_vars_multiplicity > 0 { + sum *= E::BaseField::from_u64(1 << num_vars_multiplicity); + } AdditiveArray::<_, #degree_plus_one>([sum; #degree_plus_one]) } else { - (0..largest_even_below(v1.len())) - #iter - .map(|b| { - #additive_array_items - }) - .sum::>() - }; - let res = if v1.len() == 1 { - let b = 0; - AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one]) - } else { - res - }; - - // calculate multiplicity term - // minus one because when expected num of var is i, the boolean hypercube dimension only i-1 - let num_vars_multiplicity = self.expected_numvars_at_round().saturating_sub(1).saturating_sub(num_var); - - if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from_u64(1 << num_vars_multiplicity))) - } else { - res + let res = (0..largest_even_below(v1.len())) + #iter + .map(|b| { + #additive_array_items + }) + .sum::>(); + if v1.len() == 1 { + let b = 0; + AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one]) + } else { + res + } } - } }; From 474b9854aa2d1e3ae1ff0ebd073f98788ac46bd9 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 21 Mar 2025 10:10:56 +0800 Subject: [PATCH 6/6] address review feedback --- sumcheck_macro/src/lib.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sumcheck_macro/src/lib.rs b/sumcheck_macro/src/lib.rs index 1e090c2dc..df77f9b37 100644 --- a/sumcheck_macro/src/lib.rs +++ b/sumcheck_macro/src/lib.rs @@ -243,7 +243,7 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr // we actually need to have a full sum, times 2^(bh_num_vars - num_vars) to accumulate into univariate computation // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n // For i < n - n', to compute univariate poly, f^i(x), b is i-th round boolean hypercube - // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds + // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} challenge get from prev rounds // = \sum_b f_1(b) + f_2(r, 0, b) // = 2^(|b| - |b1|) * \sum_b1 f_1(b1) + \sum_b f_2(r, 0, b) // b1 is suffix alignment with b @@ -269,17 +269,16 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr } AdditiveArray::<_, #degree_plus_one>([sum; #degree_plus_one]) } else { - let res = (0..largest_even_below(v1.len())) - #iter - .map(|b| { - #additive_array_items - }) - .sum::>(); if v1.len() == 1 { let b = 0; AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one]) } else { - res + (0..largest_even_below(v1.len())) + #iter + .map(|b| { + #additive_array_items + }) + .sum::>() } } }