@@ -285,6 +285,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
285
285
assert ! ( extrapolation_aux. len( ) == max_degree - 1 ) ;
286
286
let num_polys = polynomial. flattened_ml_extensions . len ( ) ;
287
287
Self {
288
+ max_num_variables : polynomial. aux_info . max_num_variables ,
288
289
challenges : Vec :: with_capacity ( polynomial. aux_info . max_num_variables ) ,
289
290
round : 0 ,
290
291
poly : polynomial,
@@ -335,7 +336,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
335
336
let chal = challenge. unwrap ( ) ;
336
337
self . challenges . push ( chal) ;
337
338
let r = self . challenges [ self . round - 1 ] ;
338
-
339
339
self . fix_var ( r. elements ) ;
340
340
}
341
341
exit_span ! ( span) ;
@@ -345,22 +345,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
345
345
346
346
// Step 2: generate sum for the partial evaluated polynomial:
347
347
// f(r_1, ... r_m,, x_{m+1}... x_n)
348
- //
349
- // To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars,
350
- // for it evaluation value we need to times 2^(max_num_vars - num_vars)
351
- // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n
352
- // For i round univariate poly, f^i(x)
353
- // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds
354
- // = \sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n'
355
- // = 2^(|b| - |b1|) * \sum_b1 f_1(r, 0, b1) + \sum_b f_2(r, 0, b)
356
- // same applied on f^i[1]
357
- // It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value
358
348
let span = entered_span ! ( "products_sum" ) ;
359
349
let AdditiveVec ( products_sum) = self . poly . products . iter ( ) . fold (
360
350
AdditiveVec :: new ( self . poly . aux_info . max_degree + 1 ) ,
361
351
|mut products_sum, ( coefficient, products) | {
362
352
let span = entered_span ! ( "sum" ) ;
363
-
364
353
let f = & self . poly . flattened_ml_extensions ;
365
354
let mut sum: Vec < E > = match products. len ( ) {
366
355
1 => sumcheck_code_gen ! ( 1 , false , |i| & f[ products[ i] ] ) . to_vec ( ) ,
@@ -418,12 +407,22 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
418
407
. collect ( )
419
408
}
420
409
410
+ pub fn expected_numvars_at_round ( & self ) -> usize {
411
+ // first round start from 1
412
+ let num_vars = self . max_num_variables + 1 - self . round ;
413
+ debug_assert ! ( num_vars > 0 , "make sumcheck work on constant" ) ;
414
+ num_vars
415
+ }
416
+
421
417
/// fix_var
422
418
pub fn fix_var ( & mut self , r : E ) {
419
+ let expected_numvars_at_round = self . expected_numvars_at_round ( ) ;
423
420
self . poly_index_fixvar_in_place
424
421
. iter_mut ( )
425
422
. zip_eq ( self . poly . flattened_ml_extensions . iter_mut ( ) )
426
423
. for_each ( |( can_fixvar_in_place, poly) | {
424
+ debug_assert ! ( poly. num_vars( ) <= expected_numvars_at_round) ;
425
+ debug_assert ! ( poly. num_vars( ) > 0 ) ;
427
426
if * can_fixvar_in_place {
428
427
// in place
429
428
let poly = Arc :: get_mut ( poly) ;
@@ -433,8 +432,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
433
432
}
434
433
} ;
435
434
} else if poly. num_vars ( ) > 0 {
436
- * poly = Arc :: new ( poly. fix_variables ( & [ r] ) ) ;
437
- * can_fixvar_in_place = true ;
435
+ if expected_numvars_at_round == poly. num_vars ( ) {
436
+ * poly = Arc :: new ( poly. fix_variables ( & [ r] ) ) ;
437
+ * can_fixvar_in_place = true ;
438
+ }
438
439
} else {
439
440
panic ! ( "calling sumcheck on constant" )
440
441
}
@@ -524,6 +525,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
524
525
let max_degree = polynomial. aux_info . max_degree ;
525
526
let num_polys = polynomial. flattened_ml_extensions . len ( ) ;
526
527
let prover_state = Self {
528
+ max_num_variables : polynomial. aux_info . max_num_variables ,
527
529
challenges : Vec :: with_capacity ( polynomial. aux_info . max_num_variables ) ,
528
530
round : 0 ,
529
531
poly : polynomial,
@@ -579,7 +581,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
579
581
let chal = challenge. unwrap ( ) ;
580
582
self . challenges . push ( chal) ;
581
583
let r = self . challenges [ self . round - 1 ] ;
582
-
583
584
self . fix_var ( r. elements ) ;
584
585
}
585
586
exit_span ! ( span) ;
@@ -641,6 +642,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
641
642
642
643
/// fix_var
643
644
pub fn fix_var_parallel ( & mut self , r : E ) {
645
+ let expected_numvars_at_round = self . expected_numvars_at_round ( ) ;
644
646
self . poly_index_fixvar_in_place
645
647
. par_iter_mut ( )
646
648
. zip_eq ( self . poly . flattened_ml_extensions . par_iter_mut ( ) )
@@ -654,8 +656,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
654
656
}
655
657
} ;
656
658
} else if poly. num_vars ( ) > 0 {
657
- * poly = Arc :: new ( poly. fix_variables_parallel ( & [ r] ) ) ;
658
- * can_fixvar_in_place = true ;
659
+ if expected_numvars_at_round == poly. num_vars ( ) {
660
+ * poly = Arc :: new ( poly. fix_variables_parallel ( & [ r] ) ) ;
661
+ * can_fixvar_in_place = true ;
662
+ }
659
663
} else {
660
664
panic ! ( "calling sumcheck on constant" )
661
665
}
0 commit comments