Skip to content

Commit 9840fdf

Browse files
committed
simplify design without transcript fork
1 parent d1c7be6 commit 9840fdf

File tree

3 files changed

+109
-118
lines changed

3 files changed

+109
-118
lines changed

ceno_zkvm/src/scheme/prover.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use sumcheck::{
1919
structs::{IOPProverMessage, IOPProverState},
2020
util::optimal_sumcheck_threads,
2121
};
22-
use transcript::{ForkableTranscript, Transcript};
22+
use transcript::Transcript;
2323
use witness::{RowMajorMatrix, next_pow2_instance_padding};
2424

2525
use crate::{
@@ -62,7 +62,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
6262
&self,
6363
witnesses: ZKVMWitnesses<E>,
6464
pi: PublicValues<u32>,
65-
mut transcript: impl ForkableTranscript<E>,
65+
mut transcript: impl Transcript<E>,
6666
) -> Result<ZKVMProof<E, PCS>, ZKVMError> {
6767
let span = entered_span!("commit_to_fixed_commit", profiling_1 = true);
6868
let mut vm_proof = ZKVMProof::empty(pi);
@@ -147,19 +147,14 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
147147
tracing::debug!("challenges in prover: {:?}", challenges);
148148

149149
let main_proofs_span = entered_span!("main_proofs", profiling_1 = true);
150-
let mut transcripts = transcript.fork(self.pk.circuit_pks.len());
151-
for ((circuit_name, pk), (i, transcript)) in self
152-
.pk
153-
.circuit_pks
154-
.iter() // Sorted by key.
155-
.zip_eq(transcripts.iter_mut().enumerate())
156-
{
150+
for (index, (circuit_name, pk)) in self.pk.circuit_pks.iter().enumerate() {
157151
let (witness, num_instances) = wits
158152
.remove(circuit_name)
159153
.ok_or(ZKVMError::WitnessNotFound(circuit_name.to_string()))?;
160154
if witness.is_empty() {
161155
continue;
162156
}
157+
transcript.append_field_element(&E::BaseField::from_u64(index as u64));
163158
let wits_commit = commitments.remove(circuit_name).unwrap();
164159
// TODO: add an enum for circuit type either in constraint_system or vk
165160
let cs = pk.get_cs();
@@ -184,7 +179,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
184179
wits_commit,
185180
&pi,
186181
num_instances,
187-
transcript,
182+
&mut transcript,
188183
&challenges,
189184
)?;
190185
tracing::info!(
@@ -194,7 +189,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
194189
);
195190
vm_proof
196191
.opcode_proofs
197-
.insert(circuit_name.clone(), (i, opcode_proof));
192+
.insert(circuit_name.clone(), (index, opcode_proof));
198193
} else {
199194
let (structural_witness, structural_num_instances) = structural_wits
200195
.remove(circuit_name)
@@ -207,7 +202,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
207202
wits_commit,
208203
structural_witness,
209204
&pi,
210-
transcript,
205+
&mut transcript,
211206
&challenges,
212207
)?;
213208
tracing::info!(
@@ -218,7 +213,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
218213
);
219214
vm_proof
220215
.table_proofs
221-
.insert(circuit_name.clone(), (i, table_proof));
216+
.insert(circuit_name.clone(), (index, table_proof));
222217
for (idx, eval) in pi_in_evals {
223218
vm_proof.update_pi_eval(idx, eval);
224219
}

ceno_zkvm/src/scheme/verifier.rs

Lines changed: 90 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use multilinear_extensions::{
1010
util::ceil_log2,
1111
virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval},
1212
};
13+
use p3::field::PrimeCharacteristicRing;
1314
use sumcheck::structs::{IOPProof, IOPVerifierState};
1415
use transcript::{ForkableTranscript, Transcript};
1516
use witness::next_pow2_instance_padding;
@@ -107,15 +108,18 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
107108
}
108109
}
109110

110-
for (name, (_, proof)) in vm_proof.opcode_proofs.iter() {
111-
tracing::debug!("read {}'s commit", name);
112-
PCS::write_commitment(&proof.wits_commit, &mut transcript)
113-
.map_err(ZKVMError::PCSError)?;
114-
}
115-
for (name, (_, proof)) in vm_proof.table_proofs.iter() {
116-
tracing::debug!("read {}'s commit", name);
117-
PCS::write_commitment(&proof.wits_commit, &mut transcript)
118-
.map_err(ZKVMError::PCSError)?;
111+
for (circuit_name, _) in self.vk.circuit_vks.iter() {
112+
if let Some((_, opcode_proof)) = vm_proof.opcode_proofs.get(circuit_name) {
113+
tracing::debug!("read {}'s commit", circuit_name);
114+
PCS::write_commitment(&opcode_proof.wits_commit, &mut transcript)
115+
.map_err(ZKVMError::PCSError)?;
116+
} else if let Some((_, table_proof)) = vm_proof.table_proofs.get(circuit_name) {
117+
tracing::debug!("read {}'s commit", circuit_name);
118+
PCS::write_commitment(&table_proof.wits_commit, &mut transcript)
119+
.map_err(ZKVMError::PCSError)?;
120+
} else {
121+
// all proof are optional
122+
}
119123
}
120124

121125
// alpha, beta
@@ -128,94 +132,84 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
128132
let dummy_table_item = challenges[0];
129133
let mut dummy_table_item_multiplicity = 0;
130134
let point_eval = PointAndEval::default();
131-
let mut transcripts = transcript.fork(self.vk.circuit_vks.len());
132-
133-
for (name, (i, opcode_proof)) in vm_proof.opcode_proofs {
134-
let transcript = &mut transcripts[i];
135-
136-
let circuit_vk = self
137-
.vk
138-
.circuit_vks
139-
.get(&name)
140-
.ok_or(ZKVMError::VKNotFound(name.clone()))?;
141-
let _rand_point = self.verify_opcode_proof(
142-
&name,
143-
&self.vk.vp,
144-
circuit_vk,
145-
&opcode_proof,
146-
pi_evals,
147-
transcript,
148-
NUM_FANIN,
149-
&point_eval,
150-
&challenges,
151-
)?;
152-
tracing::info!("verified proof for opcode {}", name);
153-
154-
// getting the number of dummy padding item that we used in this opcode circuit
155-
let num_lks = circuit_vk.get_cs().lk_expressions.len();
156-
let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks;
157-
let num_padded_instance =
158-
next_pow2_instance_padding(opcode_proof.num_instances) - opcode_proof.num_instances;
159-
dummy_table_item_multiplicity += num_padded_lks_per_instance
160-
* opcode_proof.num_instances
161-
+ num_lks.next_power_of_two() * num_padded_instance;
162-
163-
prod_r *= opcode_proof
164-
.record_r_out_evals
165-
.iter()
166-
.copied()
167-
.product::<E>();
168-
prod_w *= opcode_proof
169-
.record_w_out_evals
170-
.iter()
171-
.copied()
172-
.product::<E>();
173-
174-
logup_sum += opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.inverse();
175-
logup_sum += opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.inverse();
176-
}
177-
178-
for (name, (i, table_proof)) in vm_proof.table_proofs {
179-
let transcript = &mut transcripts[i];
180-
181-
let circuit_vk = self
182-
.vk
183-
.circuit_vks
184-
.get(&name)
185-
.ok_or(ZKVMError::VKNotFound(name.clone()))?;
186-
let _rand_point = self.verify_table_proof(
187-
&name,
188-
&self.vk.vp,
189-
circuit_vk,
190-
&table_proof,
191-
&vm_proof.raw_pi,
192-
&vm_proof.pi_evals,
193-
transcript,
194-
NUM_FANIN_LOGUP,
195-
&point_eval,
196-
&challenges,
197-
)?;
198-
tracing::info!("verified proof for table {}", name);
199-
200-
logup_sum = table_proof
201-
.lk_out_evals
202-
.iter()
203-
.fold(logup_sum, |acc, [p1, p2, q1, q2]| {
204-
acc - *p1 * q1.inverse() - *p2 * q2.inverse()
205-
});
135+
for (index, (circuit_name, circuit_vk)) in self.vk.circuit_vks.iter().enumerate() {
136+
if let Some((_, opcode_proof)) = vm_proof.opcode_proofs.get(circuit_name) {
137+
transcript.append_field_element(&E::BaseField::from_u64(index as u64));
138+
let name = circuit_name;
139+
let _rand_point = self.verify_opcode_proof(
140+
name,
141+
&self.vk.vp,
142+
circuit_vk,
143+
opcode_proof,
144+
pi_evals,
145+
&mut transcript,
146+
NUM_FANIN,
147+
&point_eval,
148+
&challenges,
149+
)?;
150+
tracing::info!("verified proof for opcode {}", name);
151+
152+
// getting the number of dummy padding item that we used in this opcode circuit
153+
let num_lks = circuit_vk.get_cs().lk_expressions.len();
154+
let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks;
155+
let num_padded_instance = next_pow2_instance_padding(opcode_proof.num_instances)
156+
- opcode_proof.num_instances;
157+
dummy_table_item_multiplicity += num_padded_lks_per_instance
158+
* opcode_proof.num_instances
159+
+ num_lks.next_power_of_two() * num_padded_instance;
160+
161+
prod_r *= opcode_proof
162+
.record_r_out_evals
163+
.iter()
164+
.copied()
165+
.product::<E>();
166+
prod_w *= opcode_proof
167+
.record_w_out_evals
168+
.iter()
169+
.copied()
170+
.product::<E>();
171+
172+
logup_sum += opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.inverse();
173+
logup_sum += opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.inverse();
174+
} else if let Some((_, table_proof)) = vm_proof.table_proofs.get(circuit_name) {
175+
transcript.append_field_element(&E::BaseField::from_u64(index as u64));
176+
let name = circuit_name;
177+
let _rand_point = self.verify_table_proof(
178+
name,
179+
&self.vk.vp,
180+
circuit_vk,
181+
table_proof,
182+
&vm_proof.raw_pi,
183+
&vm_proof.pi_evals,
184+
&mut transcript,
185+
NUM_FANIN_LOGUP,
186+
&point_eval,
187+
&challenges,
188+
)?;
189+
tracing::info!("verified proof for table {}", name);
190+
191+
logup_sum = table_proof
192+
.lk_out_evals
193+
.iter()
194+
.fold(logup_sum, |acc, [p1, p2, q1, q2]| {
195+
acc - *p1 * q1.inverse() - *p2 * q2.inverse()
196+
});
206197

207-
prod_w *= table_proof
208-
.w_out_evals
209-
.iter()
210-
.flatten()
211-
.copied()
212-
.product::<E>();
213-
prod_r *= table_proof
214-
.r_out_evals
215-
.iter()
216-
.flatten()
217-
.copied()
218-
.product::<E>();
198+
prod_w *= table_proof
199+
.w_out_evals
200+
.iter()
201+
.flatten()
202+
.copied()
203+
.product::<E>();
204+
prod_r *= table_proof
205+
.r_out_evals
206+
.iter()
207+
.flatten()
208+
.copied()
209+
.product::<E>();
210+
} else {
211+
// all proof are optional
212+
}
219213
}
220214
logup_sum -= E::from_u64(dummy_table_item_multiplicity as u64) * dummy_table_item.inverse();
221215

ceno_zkvm/src/structs.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
};
1010
use ceno_emul::{CENO_PLATFORM, Platform, StepRecord};
1111
use ff_ext::ExtensionField;
12-
use itertools::{Itertools, chain};
12+
use itertools::Itertools;
1313
use mpcs::PolynomialCommitmentScheme;
1414
use multilinear_extensions::{
1515
mle::DenseMultilinearExtension, virtual_poly::ArcMultilinearExtension,
@@ -350,14 +350,16 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
350350
pub fn into_iter_sorted(
351351
self,
352352
) -> impl Iterator<Item = (String, Vec<RowMajorMatrix<E::BaseField>>)> {
353-
chain(
354-
self.witnesses_opcodes
355-
.into_iter()
356-
.map(|(name, witnesses)| (name, vec![witnesses])),
357-
self.witnesses_tables
358-
.into_iter()
359-
.map(|(name, witnesses)| (name, witnesses.to_vec())),
360-
)
353+
self.witnesses_opcodes
354+
.into_iter()
355+
.map(|(name, witness)| (name, vec![witness]))
356+
.chain(
357+
self.witnesses_tables
358+
.into_iter()
359+
.map(|(name, witnesses)| (name, witnesses.into())),
360+
)
361+
.collect::<BTreeMap<_, _>>()
362+
.into_iter()
361363
}
362364
}
363365
pub struct ZKVMProvingKey<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {

0 commit comments

Comments
 (0)