diff --git a/ceno_recursion/src/transcript/mod.rs b/ceno_recursion/src/transcript/mod.rs index 9ea0e1ed4..b98ae4243 100644 --- a/ceno_recursion/src/transcript/mod.rs +++ b/ceno_recursion/src/transcript/mod.rs @@ -54,32 +54,3 @@ pub fn transcript_check_pow_witness( builder.assert_eq::>(bit, Usize::from(0)); }); } - -pub fn clone_challenger_state( - builder: &mut Builder, - src: &DuplexChallengerVariable, -) -> DuplexChallengerVariable { - let dst = DuplexChallengerVariable::new(builder); - builder - .range(0, dst.sponge_state.len()) - .for_each(|idx_vec, builder| { - let value = builder.get(&src.sponge_state, idx_vec[0]); - builder.set(&dst.sponge_state, idx_vec[0], value); - }); - - let input_offset = src.input_ptr - src.io_empty_ptr; - builder.assign(&dst.input_ptr, input_offset + dst.io_empty_ptr); - - let output_offset = src.output_ptr - src.io_empty_ptr; - builder.assign(&dst.output_ptr, output_offset + dst.io_empty_ptr); - dst -} - -pub fn challenger_add_forked_index( - builder: &mut Builder, - challenger: &mut DuplexChallengerVariable, - index: &Usize, -) { - let felt = builder.unsafe_cast_var_to_felt(index.get_var()); - challenger.observe(builder, felt); -} diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index fe29ea462..8911b290e 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -32,7 +32,6 @@ use crate::{ use ceno_zkvm::structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}; use ff_ext::BabyBearExt4; -use crate::transcript::{challenger_add_forked_index, clone_challenger_state}; use gkr_iop::{ evaluation::EvalExpression, gkr::{GKRCircuit, booleanhypercube::BooleanHypercube, layer::Layer}, @@ -182,18 +181,6 @@ pub fn verify_zkvm_proof>( challenger.observe(builder, log2_max_codeword_size_felt); } - iter_zip!(builder, zkvm_proof_input.chip_proofs).for_each(|ptr_vec, builder| { - let chip_proof = builder.iter_ptr_get(&zkvm_proof_input.chip_proofs, ptr_vec[0]); - let chip_idx = chip_proof.idx_felt; - challenger.observe(builder, chip_idx); - - iter_zip!(builder, chip_proof.num_instances).for_each(|ptr_vec, builder| { - let num_instance = builder.iter_ptr_get(&chip_proof.num_instances, ptr_vec[0]); - let num_instance = builder.unsafe_cast_var_to_felt(num_instance); - challenger.observe(builder, num_instance); - }); - }); - challenger_multi_observe( builder, &mut challenger, @@ -284,9 +271,15 @@ pub fn verify_zkvm_proof>( let chip_proof = builder.get(&zkvm_proof_input.chip_proofs, num_chips_verified.get_var()); - // fork transcript to support chip concurrently proved - let mut chip_challenger = clone_challenger_state(builder, &challenger); - challenger_add_forked_index(builder, &mut chip_challenger, &forked_sample_index); + // Fork chip transcript independently and bind challenges/metadata in verifier order. + let mut chip_challenger = DuplexChallengerVariable::new(builder); + transcript_observe_label(builder, &mut chip_challenger, b"fork"); + let alpha_felts = builder.ext2felt(alpha); + chip_challenger.observe_slice(builder, alpha_felts); + let beta_felts = builder.ext2felt(beta); + chip_challenger.observe_slice(builder, beta_felts); + let fork_id_felt = builder.unsafe_cast_var_to_felt(forked_sample_index.get_var()); + chip_challenger.observe(builder, fork_id_felt); builder.assert_usize_eq( chip_proof.rw_out_evals.length.clone(), Usize::from( @@ -298,6 +291,11 @@ pub fn verify_zkvm_proof>( Usize::from(circuit_vk.get_cs().num_lks() * 4), ); chip_challenger.observe(builder, chip_proof.idx_felt); + iter_zip!(builder, chip_proof.num_instances).for_each(|ptr_vec, builder| { + let num_instance = builder.iter_ptr_get(&chip_proof.num_instances, ptr_vec[0]); + let num_instance = builder.unsafe_cast_var_to_felt(num_instance); + chip_challenger.observe(builder, num_instance); + }); // getting the number of dummy padding item that we used in this opcode circuit let num_lks: Var = diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 22cea0f4d..259e868d5 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -24,7 +24,7 @@ use sumcheck::{ structs::IOPProverMessage, }; use tracing::info_span; -use transcript::{ForkableTranscript, Transcript}; +use transcript::{BasicTranscript, ForkableTranscript, Transcript}; use super::{PublicValues, ZKVMChipProof, ZKVMProof, hal::ProverDevice}; use crate::{ @@ -357,45 +357,14 @@ impl< } exit_span!(span); - // only keep track of circuits that have non-zero instances - for (name, chip_inputs) in &witnesses.witnesses { - let pk = self.pk.circuit_pks.get(name).ok_or(ZKVMError::VKNotFound( - format!("proving key for circuit {} not found", name).into(), - ))?; - - // include omc init tables iff it's in first shard - if !shard_ctx.is_first_shard() && pk.get_cs().with_omc_init_only() { - continue; - } - - // num_instance from witness might include rotation - let num_instances = chip_inputs - .iter() - .flat_map(|chip_input| chip_input.num_instances) - .collect_vec(); - - if num_instances.iter().sum::() == 0 { - continue; - } - - let circuit_idx = self.pk.circuit_name_to_index.get(name).unwrap(); - // write (circuit_idx, num_var) to transcript - transcript.append_field_element(&E::BaseField::from_canonical_usize(*circuit_idx)); - for num_instance in num_instances { - transcript - .append_field_element(&E::BaseField::from_canonical_usize(num_instance)); - } - } - - // extract chip meta info before consuming witnesses - // (circuit_name, num_instances) - let name_and_instances = witnesses.get_witnesses_name_instance(); - let commit_to_traces_span = entered_span!("batch commit to traces", profiling_1 = true); let mut wits_rmms = BTreeMap::new(); #[cfg(feature = "gpu")] let mut gpu_witness_traces = BTreeMap::new(); + // Extract chip metadata before consuming witnesses; task closures bind it + // into their per-chip forked transcripts. + let name_and_instances = witnesses.get_witnesses_name_instance(); let mut structural_rmms = Vec::with_capacity(name_and_instances.len()); #[cfg(feature = "gpu")] let mut witness_trace_rows = Vec::with_capacity(name_and_instances.len()); @@ -551,7 +520,6 @@ impl< transcript.read_challenge().elements, ]; tracing::debug!("global challenges in prover: {:?}", challenges); - let main_proofs_span = entered_span!("main_proofs", profiling_1 = true); // Phase 1: Build all ChipTasks @@ -575,8 +543,9 @@ impl< // GPU concurrent: memory-aware backfilling with standalone impl. // Sequential (GPU + CPU): unified path via self.create_chip_proof. let execute_tasks_span = entered_span!("execute_chip_tasks", profiling_1 = true); + let fork_transcript = BasicTranscript::::new(b"fork"); let (results, forked_samples) = - self.run_chip_proofs(tasks, &transcript, &witness_data)?; + self.run_chip_proofs(tasks, &fork_transcript, &witness_data)?; exit_span!(execute_tasks_span); // Phase 3: Collect results @@ -664,9 +633,20 @@ impl< let mut task = cast_gpu_chip_task::(task); + // Bind global challenges and metadata in the same order as verifier. + transcript.append_field_element_ext(&task.challenges[0]); + transcript.append_field_element_ext(&task.challenges[1]); + transcript + .append_field_element(&E::BaseField::from_canonical_usize(task.task_id)); + // Append circuit_idx to per-task forked transcript (matching verifier) transcript.append_field_element(&E::BaseField::from_canonical_u64( task.circuit_idx as u64, )); + for num_instance in task.input.num_instances { + transcript.append_field_element(&E::BaseField::from_canonical_usize( + num_instance, + )); + } prepare_gpu_chip_input::(&mut task, gpu_witness_data); let (proof, main_constraint_job) = @@ -718,9 +698,16 @@ impl< // Sequential path (CPU and non-GPU fallback): // Uses execute_sequentially directly to avoid Send+Sync requirement on the closure. scheduler.execute_sequentially(tasks, transcript, |mut task, transcript| { + // Bind global challenges and metadata in the same order as verifier. + transcript.append_field_element_ext(&task.challenges[0]); + transcript.append_field_element_ext(&task.challenges[1]); + transcript.append_field_element(&E::BaseField::from_canonical_usize(task.task_id)); // Append circuit_idx to per-task forked transcript (matching verifier) transcript .append_field_element(&E::BaseField::from_canonical_u64(task.circuit_idx as u64)); + for num_instance in task.input.num_instances { + transcript.append_field_element(&E::BaseField::from_canonical_usize(num_instance)); + } // Prepare: deferred extraction for GPU, no-op for CPU self.device.prepare_chip_input(&mut task, witness_data); diff --git a/ceno_zkvm/src/scheme/scheduler.rs b/ceno_zkvm/src/scheme/scheduler.rs index 3fc9c30ef..e3c4c4411 100644 --- a/ceno_zkvm/src/scheme/scheduler.rs +++ b/ceno_zkvm/src/scheme/scheduler.rs @@ -21,7 +21,6 @@ use crate::{ use ff_ext::ExtensionField; use gkr_iop::hal::ProverBackend; use mpcs::Point; -use p3::field::FieldAlgebra; use std::sync::OnceLock; use transcript::Transcript; static CHIP_PROVING_MODE: OnceLock = OnceLock::new(); @@ -159,8 +158,8 @@ impl ChipScheduler { /// Execute tasks sequentially with automatic transcript forking and sampling. /// - /// Each task gets a transcript cloned from `parent_transcript` with `task_id` - /// appended (identical to `ForkableTranscript::fork` default impl). + /// Each task gets a transcript cloned from `parent_transcript`. + /// Task-specific transcript appends are performed by the task closure. /// Returns `(results, forked_samples)` both sorted by task_id. #[allow(clippy::type_complexity)] pub(crate) fn execute_sequentially<'a, PB, T, F>( @@ -193,12 +192,8 @@ impl ChipScheduler { for task in tasks { let task_id = task.task_id; - // Fork: clone parent + append task_id - // (identical to ForkableTranscript::fork default impl) + // Fork: clone parent transcript template. let mut forked = parent_transcript.clone(); - forked.append_field_element(&::BaseField::from_canonical_u64( - task_id as u64, - )); let result = execute_task(task, &mut forked)?; results.push(result); @@ -220,8 +215,7 @@ impl ChipScheduler { /// Tasks are sorted by memory requirement (descending) and scheduled to /// maximize GPU utilization while respecting memory constraints. /// - /// Each worker thread clones the parent `transcript` and appends its task_id - /// (reproducing `ForkableTranscript::fork` locally). After proving, the worker + /// Each worker thread clones the parent `transcript`. After proving, the worker /// samples one extension-field element from its local transcript and returns it. /// This avoids sending non-`Send` transcript objects across threads. /// @@ -257,9 +251,6 @@ impl ChipScheduler { if tasks.len() == 1 { let task = tasks.remove(0); let mut fork = transcript.clone(); - fork.append_field_element(&::BaseField::from_canonical_u64( - task.task_id as u64, - )); let result = execute_task(task, &mut fork)?; let sample = fork.sample_vec(1)[0]; return Ok((vec![result], vec![sample])); @@ -371,14 +362,8 @@ impl ChipScheduler { // waiting for a CompletionMessage that never arrives). let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - // Fork locally: clone parent transcript + append task_id - // (identical to ForkableTranscript::fork default impl) + // Fork locally: clone parent transcript template. let mut local_transcript = tr.0.clone(); - local_transcript.append_field_element( - &::BaseField::from_canonical_u64( - task_id as u64, - ), - ); let result = execute_fn(task, &mut local_transcript); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 76ed4ccb0..b22c74d77 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -59,7 +59,7 @@ use sumcheck::{ structs::{IOPProof, IOPVerifierState, SumCheckSubClaim}, util::get_challenge_pows, }; -use transcript::{ForkableTranscript, Transcript}; +use transcript::{BasicTranscript, ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; pub use crate::structs::RV32imMemStateConfig; @@ -525,15 +525,6 @@ impl> PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; } - // write (circuit_idx, num_instance) to transcript - for (circuit_idx, proof) in vm_proof.chip_proofs.iter() { - transcript.append_field_element(&E::BaseField::from_canonical_u32(*circuit_idx as u32)); - // length of proof.num_instances will be constrained in verify_chip_proof - for num_instance in &proof.num_instances { - transcript.append_field_element(&E::BaseField::from_canonical_usize(*num_instance)); - } - } - // write witin commitment to transcript PCS::write_commitment(&vm_proof.witin_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; @@ -587,11 +578,12 @@ impl> // fork transcript to support chip concurrently proved let mut pending_main_constraints = Vec::with_capacity(num_proofs); - let mut forked_transcripts = transcript.fork(num_proofs); - for ((index, proof), transcript) in vm_proof + let mut forked_transcripts = vec![BasicTranscript::new(b"fork"); num_proofs]; + for (index, ((circuit_index, proof), transcript)) in vm_proof .chip_proofs .iter() .zip_eq(forked_transcripts.iter_mut()) + .enumerate() { let num_instance: usize = proof.num_instances.iter().sum(); if num_instance == 0 { @@ -599,12 +591,18 @@ impl> format!("{shard_id}th shard chip {index} has zero instances").into(), )); } - let circuit_name = self.vk.circuit_index_to_name.get(index).ok_or_else(|| { - ZKVMError::VKNotFound( - format!("{shard_id}th shard circuit index {index} missing from vk index map") - .into(), + let circuit_name = self + .vk + .circuit_index_to_name + .get(circuit_index) + .ok_or_else(|| { + ZKVMError::VKNotFound( + format!( + "{shard_id}th shard circuit index {circuit_index} missing from vk index map" + ) + .into(), ) - })?; + })?; let circuit_vk = self.vk.circuit_vks.get(circuit_name).ok_or_else(|| { ZKVMError::VKNotFound( format!("{shard_id}th shard circuit name {circuit_name} missing from vk") @@ -687,7 +685,14 @@ impl> }) .sum::>()?; - transcript.append_field_element(&E::BaseField::from_canonical_u64(*index as u64)); + transcript.append_field_element_ext(&challenges[0]); + transcript.append_field_element_ext(&challenges[1]); + transcript.append_field_element(&E::BaseField::from_canonical_usize(index)); + transcript + .append_field_element(&E::BaseField::from_canonical_u64(*circuit_index as u64)); + for num_instance in &proof.num_instances { + transcript.append_field_element(&E::BaseField::from_canonical_usize(*num_instance)); + } // compute logup_sum padding // getting the number of dummy padding item that we used in this opcode circuit