diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index cbcefd0ae..e8f024271 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -222,7 +222,7 @@ fn verify(proof_file: &str, vk_file: &str) { let transcript = TranscriptWithStat::new(&stat_recorder, b"riscv"); assert!( verifier - .verify_proof_halt(zkvm_proof.clone(), transcript, zkvm_proof.has_halt()) + .verify_proof_halt(zkvm_proof.clone(), transcript, zkvm_proof.has_halt(&vk)) .is_ok() ); println!("e2e proof stat: {}", zkvm_proof); diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index d47ed5108..6b38e2b8b 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -12,7 +12,7 @@ use sumcheck::structs::IOPProverMessage; use crate::{ instructions::{Instruction, riscv::ecall::HaltInstruction}, - structs::TowerProofs, + structs::{TowerProofs, ZKVMVerifyingKey}, }; pub mod constants; @@ -30,9 +30,6 @@ mod tests; deserialize = "E::BaseField: DeserializeOwned" ))] pub struct ZKVMOpcodeProof> { - // TODO support >1 opcodes - pub num_instances: usize, - // product constraints pub record_r_out_evals: Vec, pub record_w_out_evals: Vec, @@ -73,9 +70,6 @@ pub struct ZKVMTableProof> pub tower_proof: TowerProofs, - // num_vars hint for rw dynamic address to work - pub rw_hints_num_vars: Vec, - pub fixed_in_evals: Vec, pub fixed_opening_proof: Option, pub wits_commit: PCS::Commitment, @@ -144,8 +138,10 @@ pub struct ZKVMProof> { pub raw_pi: Vec>, // the evaluation of raw_pi. pub pi_evals: Vec, - opcode_proofs: BTreeMap)>, - table_proofs: BTreeMap)>, + // circuit size -> instance mapping + pub num_instances: Vec<(usize, usize)>, + opcode_proofs: BTreeMap>, + table_proofs: BTreeMap>, } impl> ZKVMProof { @@ -167,6 +163,7 @@ impl> ZKVMProof { Self { raw_pi, pi_evals, + num_instances: vec![], opcode_proofs: BTreeMap::new(), table_proofs: BTreeMap::new(), } @@ -180,11 +177,19 @@ impl> ZKVMProof { self.opcode_proofs.len() + self.table_proofs.len() } - pub fn has_halt(&self) -> bool { + pub fn has_halt(&self, vk: &ZKVMVerifyingKey) -> bool { let halt_instance_count = self - .opcode_proofs - .get(&HaltInstruction::::name()) - .map(|(_, p)| p.num_instances) + .num_instances + .iter() + .find_map(|(circuit_index, num_instances)| { + (*circuit_index + == vk + .circuit_vks + .keys() + .position(|circuit_name| *circuit_name == HaltInstruction::::name()) + .expect("halt circuit not exist")) + .then_some(*num_instances) + }) .unwrap_or(0); if halt_instance_count > 0 { assert_eq!( @@ -208,10 +213,10 @@ impl + Serialize> fmt::Dis let mpcs_opcode_commitment = self .opcode_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.wits_commit); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -221,10 +226,10 @@ impl + Serialize> fmt::Dis let mpcs_opcode_opening = self .opcode_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.wits_opening_proof); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -235,10 +240,10 @@ impl + Serialize> fmt::Dis let tower_proof_opcode = self .opcode_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.tower_proof); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -249,10 +254,10 @@ impl + Serialize> fmt::Dis let main_sumcheck_opcode = self .opcode_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.main_sel_sumcheck_proofs); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -263,10 +268,10 @@ impl + Serialize> fmt::Dis let mpcs_table_commitment = self .table_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.wits_commit); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -276,10 +281,10 @@ impl + Serialize> fmt::Dis let mpcs_table_opening = self .table_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.wits_opening_proof); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -289,10 +294,10 @@ impl + Serialize> fmt::Dis let mpcs_table_fixed_opening = self .table_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.fixed_opening_proof); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -303,10 +308,10 @@ impl + Serialize> fmt::Dis let tower_proof_table = self .table_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.tower_proof); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -317,10 +322,10 @@ impl + Serialize> fmt::Dis let same_r_sumcheck_table = self .table_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.same_r_sumcheck_proofs); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 1ba59c4cc..56ba0c0e7 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -64,13 +64,14 @@ impl> ZKVMProver { pi: PublicValues, mut transcript: impl Transcript, ) -> Result, ZKVMError> { - let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); let mut vm_proof = ZKVMProof::empty(pi); + let span = entered_span!("commit_to_pi", profiling_1 = true); // including raw public input to transcript for v in vm_proof.raw_pi.iter().flatten() { transcript.append_field_element(v); } + exit_span!(span); let pi: Vec> = vm_proof .raw_pi @@ -82,6 +83,7 @@ impl> ZKVMProver { .collect(); // commit to fixed commitment + let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); for pk in self.pk.circuit_pks.values() { if let Some(fixed_commit) = &pk.vk.fixed_commit { PCS::write_commitment(fixed_commit, &mut transcript) @@ -91,10 +93,48 @@ impl> ZKVMProver { exit_span!(span); // commit to main traces + let circuit_name_index_mapping = self + .pk + .circuit_pks + .keys() + .enumerate() + .map(|(k, v)| (v, k)) + .collect::>(); let mut commitments = BTreeMap::new(); let mut wits = BTreeMap::new(); let mut structural_wits = BTreeMap::new(); + let mut num_instances = Vec::with_capacity(self.pk.circuit_pks.len()); + for (index, (circuit_name, _)) in self.pk.circuit_pks.iter().enumerate() { + if let Some(num_instance) = witnesses + .get_opcode_witness(circuit_name) + .or_else(|| { + witnesses + .get_table_witness(circuit_name) + .map(|rmms| &rmms[0]) + }) + .map(|rmm| rmm.num_instances()) + .and_then(|num_instance| { + if num_instance > 0 { + Some(num_instance) + } else { + None + } + }) + { + num_instances.push((index, num_instance)); + } + } + + // verifier need this information from prover to achieve non-uniform design. + vm_proof.num_instances = num_instances; + + // write (circuit_size, num_var) to transcript + for (circuit_size, num_var) in &vm_proof.num_instances { + transcript.append_message(&circuit_size.to_le_bytes()); + transcript.append_message(&num_var.to_le_bytes()); + } + let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); // commit to opcode circuits first and then commit to table circuits, sorted by name for (circuit_name, mut rmm) in witnesses.into_iter_sorted() { @@ -120,7 +160,7 @@ impl> ZKVMProver { PCS::batch_commit_and_write(&self.pk.pp, witness_rmm, &mut transcript) .map_err(ZKVMError::PCSError)?; let witness = PCS::get_arc_mle_witness_from_commitment(&commit); - commitments.insert(circuit_name.clone(), commit); + commitments.insert(circuit_name_index_mapping[&circuit_name], commit); (witness, structural_witness) } }; @@ -155,7 +195,7 @@ impl> ZKVMProver { continue; } transcript.append_field_element(&E::BaseField::from_u64(index as u64)); - let wits_commit = commitments.remove(circuit_name).unwrap(); + let wits_commit = commitments.remove(&index).unwrap(); // TODO: add an enum for circuit type either in constraint_system or vk let cs = pk.get_cs(); let is_opcode_circuit = cs.lk_table_expressions.is_empty() @@ -187,9 +227,7 @@ impl> ZKVMProver { circuit_name, num_instances ); - vm_proof - .opcode_proofs - .insert(circuit_name.clone(), (index, opcode_proof)); + vm_proof.opcode_proofs.insert(index, opcode_proof); } else { let (structural_witness, structural_num_instances) = structural_wits .remove(circuit_name) @@ -211,9 +249,7 @@ impl> ZKVMProver { num_instances, structural_num_instances ); - vm_proof - .table_proofs - .insert(circuit_name.clone(), (index, table_proof)); + vm_proof.table_proofs.insert(index, table_proof); for (idx, eval) in pi_in_evals { vm_proof.update_pi_eval(idx, eval); } @@ -652,7 +688,6 @@ impl> ZKVMProver { let wits_commit = PCS::get_pure_commitment(&wits_commit); Ok(ZKVMOpcodeProof { - num_instances, record_r_out_evals, record_w_out_evals, lk_p1_out_eval, @@ -919,15 +954,6 @@ impl> ZKVMProver { }) .collect_vec(); - // (non uniform) collect dynamic address hints as witness for verifier - let rw_hints_num_vars = structural_witnesses - .iter() - .map(|mle| mle.num_vars()) - .collect_vec(); - for var in rw_hints_num_vars.iter() { - transcript.append_message(&var.to_le_bytes()); - } - let (rt_tower, tower_proof) = TowerProver::create_proof( // pattern [r1, w1, r2, w2, ...] same pair are chain together r_wit_layers @@ -1130,7 +1156,6 @@ impl> ZKVMProver { tower_proof, fixed_in_evals, fixed_opening_proof, - rw_hints_num_vars, wits_in_evals, wits_commit, wits_opening_proof, diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 190ab3ede..996a6bd72 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -166,6 +166,7 @@ fn test_rw_lk_expression_combination() { &vk.vp, verifier.vk.circuit_vks.get(&name).unwrap(), &proof, + num_instances, &[], &mut v_transcript, NUM_FANIN, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 00b99a58f..a1b0d71fa 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -11,6 +11,7 @@ use multilinear_extensions::{ virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, }; use p3::field::PrimeCharacteristicRing; +use std::collections::HashSet; use sumcheck::structs::{IOPProof, IOPVerifierState}; use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; @@ -57,7 +58,7 @@ impl> ZKVMVerifier expect_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. - let has_halt = vm_proof.has_halt(); + let has_halt = vm_proof.has_halt(&self.vk); if has_halt != expect_halt { return Err(ZKVMError::VerifyError(format!( "ecall/halt mismatch: expected {expect_halt} != {has_halt}", @@ -79,6 +80,45 @@ impl> ZKVMVerifier let pi_evals = &vm_proof.pi_evals; + // make sure circuit index are + // 1. unique + // 2. less than self.vk.circuit_vks.len() + assert!( + vm_proof + .num_instances + .iter() + .fold(None, |prev, &(circuit_index, _)| { + (circuit_index < self.vk.circuit_vks.len() + && prev.is_none_or(|p| p < circuit_index)) + .then_some(circuit_index) + }) + .is_some(), + "num_instances validity check failed" + ); + + assert_eq!( + vm_proof + .num_instances + .iter() + .map(|(x, _)| x) + .collect::>(), + vm_proof + .opcode_proofs + .keys() + .chain(vm_proof.table_proofs.keys()) + .collect::>(), + "num_instance circuit index exactly equal with provided proofs" + ); + + assert!( + vm_proof + .opcode_proofs + .keys() + .collect::>() + .is_disjoint(&vm_proof.table_proofs.keys().collect::>()), + "there is duplicated circuit index" + ); + // TODO fix soundness: construct raw public input by ourself and trustless from proof // including raw public input to transcript vm_proof @@ -108,17 +148,26 @@ impl> ZKVMVerifier } } - for (circuit_name, _) in self.vk.circuit_vks.iter() { - if let Some((_, opcode_proof)) = vm_proof.opcode_proofs.get(circuit_name) { + // write (circuit_size, num_var) to transcript + for (circuit_size, num_var) in &vm_proof.num_instances { + transcript.append_message(&circuit_size.to_le_bytes()); + transcript.append_message(&num_var.to_le_bytes()); + } + + let circuit_vks: Vec<&VerifyingKey> = self.vk.circuit_vks.values().collect_vec(); + let circuit_names: Vec<&String> = self.vk.circuit_vks.keys().collect_vec(); + for (index, _) in &vm_proof.num_instances { + let circuit_name = circuit_names[*index]; + if let Some(opcode_proof) = vm_proof.opcode_proofs.get(index) { tracing::debug!("read {}'s commit", circuit_name); PCS::write_commitment(&opcode_proof.wits_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; - } else if let Some((_, table_proof)) = vm_proof.table_proofs.get(circuit_name) { + } else if let Some(table_proof) = vm_proof.table_proofs.get(index) { tracing::debug!("read {}'s commit", circuit_name); PCS::write_commitment(&table_proof.wits_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; } else { - // all proof are optional + unreachable!("respective proof of index {} should exist", index) } } @@ -132,15 +181,17 @@ impl> ZKVMVerifier let dummy_table_item = challenges[0]; let mut dummy_table_item_multiplicity = 0; let point_eval = PointAndEval::default(); - for (index, (circuit_name, circuit_vk)) in self.vk.circuit_vks.iter().enumerate() { - if let Some((_, opcode_proof)) = vm_proof.opcode_proofs.get(circuit_name) { - transcript.append_field_element(&E::BaseField::from_u64(index as u64)); - let name = circuit_name; + for (index, num_instances) in &vm_proof.num_instances { + let circuit_vk = circuit_vks[*index]; + let name = circuit_names[*index]; + if let Some(opcode_proof) = vm_proof.opcode_proofs.get(index) { + transcript.append_field_element(&E::BaseField::from_u64(*index as u64)); let _rand_point = self.verify_opcode_proof( name, &self.vk.vp, circuit_vk, opcode_proof, + *num_instances, pi_evals, &mut transcript, NUM_FANIN, @@ -152,10 +203,9 @@ impl> ZKVMVerifier // getting the number of dummy padding item that we used in this opcode circuit let num_lks = circuit_vk.get_cs().lk_expressions.len(); let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; - let num_padded_instance = next_pow2_instance_padding(opcode_proof.num_instances) - - opcode_proof.num_instances; - dummy_table_item_multiplicity += num_padded_lks_per_instance - * opcode_proof.num_instances + let num_padded_instance = + next_pow2_instance_padding(*num_instances) - num_instances; + dummy_table_item_multiplicity += num_padded_lks_per_instance * num_instances + num_lks.next_power_of_two() * num_padded_instance; prod_r *= opcode_proof @@ -171,14 +221,14 @@ impl> ZKVMVerifier logup_sum += opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.inverse(); logup_sum += opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.inverse(); - } else if let Some((_, table_proof)) = vm_proof.table_proofs.get(circuit_name) { - transcript.append_field_element(&E::BaseField::from_u64(index as u64)); - let name = circuit_name; + } else if let Some(table_proof) = vm_proof.table_proofs.get(index) { + transcript.append_field_element(&E::BaseField::from_u64(*index as u64)); let _rand_point = self.verify_table_proof( name, &self.vk.vp, circuit_vk, table_proof, + *num_instances, &vm_proof.raw_pi, &vm_proof.pi_evals, &mut transcript, @@ -208,7 +258,7 @@ impl> ZKVMVerifier .copied() .product::(); } else { - // all proof are optional + unreachable!("respective proof of index {} should exist", index) } } logup_sum -= E::from_u64(dummy_table_item_multiplicity as u64) * dummy_table_item.inverse(); @@ -255,6 +305,7 @@ impl> ZKVMVerifier vp: &PCS::VerifierParam, circuit_vk: &VerifyingKey, proof: &ZKVMOpcodeProof, + num_instances: usize, pi: &[E], transcript: &mut impl Transcript, num_product_fanin: usize, @@ -274,7 +325,6 @@ impl> ZKVMVerifier ); let (chip_record_alpha, _) = (challenges[0], challenges[1]); - let num_instances = proof.num_instances; let next_pow2_instance = next_pow2_instance_padding(num_instances); let log2_num_instances = ceil_log2(next_pow2_instance); @@ -501,6 +551,7 @@ impl> ZKVMVerifier vp: &PCS::VerifierParam, circuit_vk: &VerifyingKey, proof: &ZKVMTableProof, + num_instances: usize, raw_pi: &[Vec], pi: &[E], transcript: &mut impl Transcript, @@ -515,6 +566,9 @@ impl> ZKVMVerifier .zip_eq(cs.w_table_expressions.iter()) .all(|(r, w)| r.table_spec.len == w.table_spec.len) ); + + let log2_num_instances = ceil_log2(num_instances); + // in table proof, we always skip same point sumcheck for now // as tower sumcheck batch product argument/logup in same length let is_skip_same_point_sumcheck = true; @@ -522,6 +576,7 @@ impl> ZKVMVerifier // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; + // NOTE: for all structural witness within same constrain system should got same hints num variable via `log2_num_instances` let expected_rounds = cs // only iterate r set, as read/write set round should match .r_table_expressions @@ -532,14 +587,15 @@ impl> ZKVMVerifier r.table_spec .structural_witins .iter() - .map(|StructuralWitIn { id, max_len, .. }| { - let hint_num_vars = proof.rw_hints_num_vars[*id as usize]; + .map(|StructuralWitIn { max_len, .. }| { + let hint_num_vars = log2_num_instances; assert!((1 << hint_num_vars) <= *max_len); hint_num_vars }) .max() .unwrap() }); + assert_eq!(num_vars, log2_num_instances); [num_vars, num_vars] // format: [read_round, write_round] }) .chain(cs.lk_table_expressions.iter().map(|l| { @@ -548,22 +604,19 @@ impl> ZKVMVerifier l.table_spec .structural_witins .iter() - .map(|StructuralWitIn { id, max_len, .. }| { - let hint_num_vars = proof.rw_hints_num_vars[*id as usize]; + .map(|StructuralWitIn { max_len, .. }| { + let hint_num_vars = log2_num_instances; assert!((1 << hint_num_vars) <= *max_len); hint_num_vars }) .max() .unwrap() }); + assert_eq!(num_vars, log2_num_instances); num_vars })) .collect_vec(); - for var in proof.rw_hints_num_vars.iter() { - transcript.append_message(&var.to_le_bytes()); - } - let expected_max_rounds = expected_rounds.iter().cloned().max().unwrap(); let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = TowerVerify::verify(