diff --git a/ceno_recursion_v2/src/bus.rs b/ceno_recursion_v2/src/bus.rs index 21562d770..4672a3f7e 100644 --- a/ceno_recursion_v2/src/bus.rs +++ b/ceno_recursion_v2/src/bus.rs @@ -60,6 +60,9 @@ pub struct TowerModuleMessage { pub idx: T, pub tidx: T, pub n_logup: T, + pub num_read_count: T, + pub num_write_count: T, + pub num_logup_count: T, } define_typed_per_proof_permutation_bus!(TowerModuleBus, TowerModuleMessage); diff --git a/ceno_recursion_v2/src/circuit/inner/mod.rs b/ceno_recursion_v2/src/circuit/inner/mod.rs index 2d66e353d..a754f9b17 100644 --- a/ceno_recursion_v2/src/circuit/inner/mod.rs +++ b/ceno_recursion_v2/src/circuit/inner/mod.rs @@ -29,6 +29,8 @@ pub use trace::*; pub struct InnerCircuit { pub verifier_circuit: Arc, pub def_hook_commit: Option, + pub has_fixed_commit: bool, + pub has_fixed_no_omc_init_commit: bool, pub instance_public_value_indices: Arc>>, } @@ -67,6 +69,8 @@ impl, S: AggregationSubCircuit> Circuit for I lookup_challenge_bus, pvs_air_consistency_bus, deferral_enabled, + has_fixed_commit: self.has_fixed_commit, + has_fixed_no_omc_init_commit: self.has_fixed_no_omc_init_commit, instance_public_value_indices: self.instance_public_value_indices.clone(), }) as AirRef; diff --git a/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs b/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs index 63ae7727d..0c1cde41b 100644 --- a/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs +++ b/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs @@ -1,10 +1,13 @@ use std::{borrow::Borrow, sync::Arc}; use ceno_emul::{FullTracer as Tracer, WORD_SIZE}; -use ceno_zkvm::instructions::riscv::constants::{ - END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, EXIT_PC, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, - HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBIO_DIGEST_IDX, - SHARD_ID_IDX, SHARD_RW_SUM_IDX, +use ceno_zkvm::{ + instructions::riscv::constants::{ + END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, EXIT_PC, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, + HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBIO_DIGEST_IDX, + SHARD_ID_IDX, SHARD_RW_SUM_IDX, + }, + scheme::PublicValues, }; use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; use openvm_stark_backend::{ @@ -21,7 +24,11 @@ use stark_recursion_circuit_derive::AlignedBorrow; use crate::{ bus::{LookupChallengeBus, LookupChallengeKind, LookupChallengeMessage}, - circuit::inner::{bus::PvsAirConsistencyBus, vm_pvs::VmPvs}, + circuit::inner::{ + bus::{PvsAirConsistencyBus, PvsAirConsistencyMessage}, + vm_pvs::VmPvs, + }, + utils::TranscriptLabel, }; #[repr(C)] @@ -35,6 +42,9 @@ pub struct VmPvsCols { pub lookup_challenge_beta: [F; D_EF], pub lookup_challenge_alpha_lookup_count: F, pub lookup_challenge_beta_lookup_count: F, + pub fixed_commit_log2_max_codeword_size: F, + pub fixed_no_omc_init_commit_log2_max_codeword_size: F, + pub witness_commit_log2_max_codeword_size: F, pub child_pvs: VmPvs, } @@ -45,6 +55,8 @@ pub struct VmPvsAir { pub lookup_challenge_bus: LookupChallengeBus, pub pvs_air_consistency_bus: PvsAirConsistencyBus, pub deferral_enabled: bool, + pub has_fixed_commit: bool, + pub has_fixed_no_omc_init_commit: bool, pub instance_public_value_indices: Arc>>, } @@ -189,41 +201,60 @@ impl Air f // ); // Commitments are observed after transcript-visible public values in preflight. - let start_tidx_after_public_value = VmPvs::::width() - 3 * DIGEST_SIZE; - for (didx, value) in local.child_pvs.fixed_commit.iter().enumerate() { - self.transcript_bus.receive( + let mut commit_tidx = TranscriptLabel::Riscv.field_len() + PublicValues::flattened_len(); + if self.has_fixed_commit { + receive_commitment( + self, builder, local.proof_idx, - TranscriptBusMessage { - tidx: AB::Expr::from_usize(start_tidx_after_public_value + didx), - value: (*value).into(), - is_sample: AB::Expr::ZERO, - }, + commit_tidx, + local.child_pvs.fixed_commit, + local.fixed_commit_log2_max_codeword_size, local.is_valid, ); + commit_tidx += DIGEST_SIZE + 1; } - for (didx, value) in local.child_pvs.fixed_no_omc_init_commit.iter().enumerate() { + if self.has_fixed_no_omc_init_commit { + receive_commitment( + self, + builder, + local.proof_idx, + commit_tidx, + local.child_pvs.fixed_no_omc_init_commit, + local.fixed_no_omc_init_commit_log2_max_codeword_size, + local.is_valid, + ); + commit_tidx += DIGEST_SIZE + 1; + } + receive_commitment( + self, + builder, + local.proof_idx, + commit_tidx, + local.child_pvs.witness_commit, + local.witness_commit_log2_max_codeword_size, + local.is_valid, + ); + commit_tidx += DIGEST_SIZE + 1; + + for i in 0..D_EF { self.transcript_bus.receive( builder, local.proof_idx, TranscriptBusMessage { - tidx: AB::Expr::from_usize(start_tidx_after_public_value + DIGEST_SIZE + didx), - value: (*value).into(), - is_sample: AB::Expr::ZERO, + tidx: AB::Expr::from_usize(commit_tidx + i), + value: local.lookup_challenge_alpha[i].into(), + is_sample: AB::Expr::ONE, }, local.is_valid, ); - } - for (didx, value) in local.child_pvs.witness_commit.iter().enumerate() { self.transcript_bus.receive( builder, local.proof_idx, TranscriptBusMessage { - tidx: AB::Expr::from_usize( - start_tidx_after_public_value + 2 * DIGEST_SIZE + didx, - ), - value: (*value).into(), - is_sample: AB::Expr::ZERO, + tidx: AB::Expr::from_usize(commit_tidx + D_EF + i), + value: local.lookup_challenge_beta[i].into(), + is_sample: AB::Expr::ONE, }, local.is_valid, ); @@ -253,15 +284,15 @@ impl Air f } // We look up proof metadata from VerifierPvsAir here to ensure consistency on each row. - // self.pvs_air_consistency_bus.lookup_key( - // builder, - // local.proof_idx, - // PvsAirConsistencyMessage { - // deferral_flag, - // has_verifier_pvs: local.has_verifier_pvs.into(), - // }, - // local.is_valid, - // ); + self.pvs_air_consistency_bus.lookup_key( + builder, + local.proof_idx, + PvsAirConsistencyMessage { + deferral_flag, + has_verifier_pvs: local.has_verifier_pvs.into(), + }, + local.is_valid, + ); // Finally, constrain that this AIR's output public values are consistent with child_pvs. let &VmPvs::<_> { @@ -383,6 +414,41 @@ where } } +fn receive_commitment( + air: &VmPvsAir, + builder: &mut AB, + proof_idx: AB::Var, + start_tidx: usize, + commit: [AB::Var; DIGEST_SIZE], + log2_max_codeword_size: AB::Var, + mult: AB::Var, +) where + AB: AirBuilder + InteractionBuilder + AirBuilderWithPublicValues, +{ + for (didx, value) in commit.into_iter().enumerate() { + air.transcript_bus.receive( + builder, + proof_idx, + TranscriptBusMessage { + tidx: AB::Expr::from_usize(start_tidx + didx), + value: value.into(), + is_sample: AB::Expr::ZERO, + }, + mult, + ); + } + air.transcript_bus.receive( + builder, + proof_idx, + TranscriptBusMessage { + tidx: AB::Expr::from_usize(start_tidx + DIGEST_SIZE), + value: log2_max_codeword_size.into(), + is_sample: AB::Expr::ZERO, + }, + mult, + ); +} + impl VmPvsAir { fn eval_deferrals( &self, diff --git a/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs b/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs index 102876658..0f5723497 100644 --- a/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs +++ b/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs @@ -77,9 +77,9 @@ pub fn run_preflight( let alpha_ext = ts.sample_ext(); let beta_ext = ts.sample_ext(); - eprintln!("vm_pvs alpha {} beta {}", alpha_ext, beta_ext); preflight.vm_pvs.lookup_challenge_alpha = alpha_ext; preflight.vm_pvs.lookup_challenge_beta = beta_ext; - preflight.vm_pvs.lookup_challenge_alpha_lookup_count = 0; - preflight.vm_pvs.lookup_challenge_beta_lookup_count = 0; + let present_air_count = proof.chip_proofs.len(); + preflight.vm_pvs.lookup_challenge_alpha_lookup_count = present_air_count; + preflight.vm_pvs.lookup_challenge_beta_lookup_count = present_air_count; } diff --git a/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs b/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs index a1262f04b..790ff95fa 100644 --- a/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs +++ b/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs @@ -46,6 +46,16 @@ pub fn generate_proving_ctx( .as_ref() .map(|commitment| commitment.commit.clone()), ); + let fixed_commit_log2_max_codeword_size = child_vk + .fixed_commit + .as_ref() + .map(|commitment| F::from_u64(commitment.log2_max_codeword_size as u64)) + .unwrap_or(F::ZERO); + let fixed_no_omc_init_commit_log2_max_codeword_size = child_vk + .fixed_no_omc_init_commit + .as_ref() + .map(|commitment| F::from_u64(commitment.log2_max_codeword_size as u64)) + .unwrap_or(F::ZERO); for (row_idx, row) in trace.chunks_exact_mut(width).enumerate() { let (base_row, def_row) = row.split_at_mut(VmPvsCols::::width()); @@ -64,6 +74,11 @@ pub fn generate_proving_ctx( F::from_usize(preflight.vm_pvs.lookup_challenge_alpha_lookup_count); cols.lookup_challenge_beta_lookup_count = F::from_usize(preflight.vm_pvs.lookup_challenge_beta_lookup_count); + cols.fixed_commit_log2_max_codeword_size = fixed_commit_log2_max_codeword_size; + cols.fixed_no_omc_init_commit_log2_max_codeword_size = + fixed_no_omc_init_commit_log2_max_codeword_size; + cols.witness_commit_log2_max_codeword_size = + F::from_u64(proof.witin_commit.log2_max_codeword_size as u64); cols.child_pvs = build_vm_pvs(fixed_commit, fixed_no_omc_init_commit, proof); } diff --git a/ceno_recursion_v2/src/continuation/prover/inner/mod.rs b/ceno_recursion_v2/src/continuation/prover/inner/mod.rs index a546298bb..2977d01bd 100644 --- a/ceno_recursion_v2/src/continuation/prover/inner/mod.rs +++ b/ceno_recursion_v2/src/continuation/prover/inner/mod.rs @@ -82,6 +82,8 @@ impl< let circuit = Arc::new(InnerCircuit::new( Arc::new(verifier_circuit), def_hook_commit.map(|d| d.into()), + child_vk.fixed_commit.is_some(), + child_vk.fixed_no_omc_init_commit.is_some(), instance_public_value_indices, )); let (pk, vk) = engine.keygen(&circuit.airs()); @@ -127,6 +129,8 @@ impl< let circuit = Arc::new(InnerCircuit::new( Arc::new(verifier_circuit), def_hook_commit.map(|d| d.into()), + child_vk.fixed_commit.is_some(), + child_vk.fixed_no_omc_init_commit.is_some(), instance_public_value_indices, )); let vk = Arc::new(pk.get_vk()); diff --git a/ceno_recursion_v2/src/continuation/prover/mod.rs b/ceno_recursion_v2/src/continuation/prover/mod.rs index a11802970..8c56dd1bf 100644 --- a/ceno_recursion_v2/src/continuation/prover/mod.rs +++ b/ceno_recursion_v2/src/continuation/prover/mod.rs @@ -67,9 +67,8 @@ impl AggregationOptions { } type CenoProof = ZKVMProof>; -type Engine = BabyBearPoseidon2CpuEngine< - openvm_stark_sdk::config::baby_bear_poseidon2::DuplexSponge, ->; +type Engine = + BabyBearPoseidon2CpuEngine; /// Full recursion pipeline that aggregates N Ceno base-layer shard proofs /// into a single compact root proof. @@ -91,9 +90,7 @@ pub struct AggProver { options: AggregationOptions, } -impl - AggProver -{ +impl AggProver { /// Create a new aggregation prover from the base-layer verifying key. pub fn new(child_vk: Arc, options: AggregationOptions) -> Self { let leaf_prover = InnerCpuProver::::new::( diff --git a/ceno_recursion_v2/src/main/mod.rs b/ceno_recursion_v2/src/main/mod.rs index 5f0e1aa28..73f02a635 100644 --- a/ceno_recursion_v2/src/main/mod.rs +++ b/ceno_recursion_v2/src/main/mod.rs @@ -77,47 +77,71 @@ impl MainModule { for (proof_idx, (proof, preflight)) in proofs.iter().zip(preflights).enumerate() { let mut chip_pf_iter = preflight.main.chips.iter(); let mut saw_chip = false; - for (&chip_idx, chip_instances) in &proof.chip_proofs { - for (instance_idx, chip_proof) in chip_instances.iter().enumerate() { - saw_chip = true; - let pf_entry = chip_pf_iter - .next() - .ok_or_else(|| eyre!( - "missing main preflight entry for chip {chip_idx} instance {instance_idx}" - ))?; - if pf_entry.chip_idx != chip_idx || pf_entry.instance_idx != instance_idx { - bail!( - "main preflight chip mismatch: expected ({}, {}), got ({}, {})", - chip_idx, - instance_idx, - pf_entry.chip_idx, - pf_entry.instance_idx - ); - } - let claim = input_layer_claim(chip_proof); - // Access the fork log directly using fork_idx and fork-local tidx. - let fork_log = preflight.fork_log(pf_entry.fork_idx); - let mut ts = ReadOnlyTranscript::new(fork_log, pf_entry.tidx); - record_main_transcript(&mut ts, chip_idx, chip_proof); + let sorted_idx_by_chip: std::collections::BTreeMap = preflight + .proof_shape + .sorted_trace_vdata + .iter() + .enumerate() + .map(|(sorted_idx, (chip_idx, _))| (*chip_idx, sorted_idx)) + .collect(); + let mut chip_entries = proof + .chip_proofs + .iter() + .flat_map(|(&chip_idx, chip_instances)| { + chip_instances + .iter() + .enumerate() + .map(move |(instance_idx, chip_proof)| (chip_idx, instance_idx, chip_proof)) + }) + .collect::>(); + chip_entries.sort_by_key(|(chip_idx, instance_idx, _)| { + ( + sorted_idx_by_chip + .get(chip_idx) + .copied() + .unwrap_or(usize::MAX), + *instance_idx, + ) + }); - // Compute global tidx for trace column values. - let global_tidx = - preflight.fork_global_offset(pf_entry.fork_idx) + pf_entry.tidx; - let main_record = MainRecord { - proof_idx, - idx: chip_idx, - tidx: global_tidx, - claim, - }; - let sumcheck_record = build_sumcheck_record_from_chip( - proof_idx, + for (chip_idx, instance_idx, chip_proof) in chip_entries { + saw_chip = true; + let pf_entry = chip_pf_iter.next().ok_or_else(|| { + eyre!( + "missing main preflight entry for chip {chip_idx} instance {instance_idx}" + ) + })?; + if pf_entry.chip_idx != chip_idx || pf_entry.instance_idx != instance_idx { + bail!( + "main preflight chip mismatch: expected ({}, {}), got ({}, {})", chip_idx, - claim, - chip_proof, - global_tidx, + instance_idx, + pf_entry.chip_idx, + pf_entry.instance_idx ); - paired.push((main_record, sumcheck_record)); } + let claim = input_layer_claim(chip_proof); + // Access the fork log directly using fork_idx and fork-local tidx. + let fork_log = preflight.fork_log(pf_entry.fork_idx); + let mut ts = ReadOnlyTranscript::new(fork_log, pf_entry.tidx); + record_main_transcript(&mut ts, chip_idx, chip_proof); + + // Compute global tidx for trace column values. + let global_tidx = preflight.fork_global_offset(pf_entry.fork_idx) + pf_entry.tidx; + let main_record = MainRecord { + proof_idx, + idx: chip_idx, + tidx: global_tidx, + claim, + }; + let sumcheck_record = build_sumcheck_record_from_chip( + proof_idx, + chip_idx, + claim, + chip_proof, + global_tidx, + ); + paired.push((main_record, sumcheck_record)); } if !saw_chip { diff --git a/ceno_recursion_v2/src/proof_shape/mod.rs b/ceno_recursion_v2/src/proof_shape/mod.rs index cf931b2b4..b093d1019 100644 --- a/ceno_recursion_v2/src/proof_shape/mod.rs +++ b/ceno_recursion_v2/src/proof_shape/mod.rs @@ -7,7 +7,7 @@ use openvm_stark_backend::{ AirRef, FiatShamirTranscript, StarkProtocolConfig, TranscriptHistory, p3_maybe_rayon::prelude::*, prover::AirProvingContext, }; -use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, F}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, D_EF, F}; use p3_field::PrimeCharacteristicRing; use p3_matrix::dense::RowMajorMatrix; @@ -127,8 +127,21 @@ impl ProofShapeModule { // TODO remove l_skip preflight.proof_shape.l_skip = 0; + let num_airs = child_vk.circuit_vks.len(); + let mut num_present_at_air = vec![0usize; num_airs]; + let mut present_count = 0usize; + for (air_idx, _) in &preflight.proof_shape.sorted_trace_vdata { + present_count += 1; + num_present_at_air[*air_idx] = present_count; + } + for (air_idx, num_present) in num_present_at_air.iter_mut().enumerate() { + if !proof.chip_proofs.contains_key(&air_idx) { + *num_present = present_count; + } + } + let mut current_tidx = transcript_start_tidx; - let mut starting_tidx = vec![0usize; child_vk.circuit_vks.len()]; + let mut starting_tidx = vec![0usize; num_airs]; let n_max = preflight .proof_shape .sorted_trace_vdata @@ -137,21 +150,13 @@ impl ProofShapeModule { .max() .unwrap_or(0); - for air_idx in 0..child_vk.circuit_vks.len() { - let metadata = &self.per_air[air_idx]; + for air_idx in 0..num_airs { let is_present = proof.chip_proofs.contains_key(&air_idx); starting_tidx[air_idx] = current_tidx; - if !metadata.is_required { - current_tidx += 1; - } - + current_tidx += num_present_at_air[air_idx] * 2 * D_EF; if is_present { - current_tidx += 1; - - if metadata.num_public_values != 0 { - current_tidx += metadata.num_public_values; - } + current_tidx += 2 * D_EF; } } diff --git a/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs b/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs index 853848d02..9166b7667 100644 --- a/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs +++ b/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs @@ -24,7 +24,6 @@ use crate::{ NLiftBus, NLiftMessage, TowerModuleBus, TowerModuleMessage, TranscriptBus, TranscriptBusMessage, }, - circuit::inner::vm_pvs::VmPvs, primitives::bus::{RangeCheckerBus, RangeCheckerBusMessage}, proof_shape::{ AirMetadata, @@ -35,7 +34,7 @@ use crate::{ }, subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, tower::tower_transcript_len, - utils::TranscriptLabel, + utils::{TranscriptLabel, label_field_len, transcript_receive_label}, }; #[repr(C)] @@ -51,11 +50,14 @@ pub struct ProofShapeCols { /// /// Has a special use on summary row (when `is_last`). pub log_height: F, + /// Number of tower layers for this chip proof when `is_present`. + pub tower_layers: F, /// Whether this AIR needs rotation openings. pub need_rot: F, // First possible transcript index of the current AIR. pub starting_tidx: F, + pub fork_start_tidx: F, // Columns that may be read from the transcript. pub is_present: F, @@ -157,6 +159,7 @@ where let local: &ProofShapeCols = (*local)[..const_width].borrow(); let next: &ProofShapeCols = (*next)[..const_width].borrow(); let n = local.log_height.into(); + let n_logup = local.tower_layers.into(); self.idx_encoder.eval(builder, localv.idx_flags); @@ -240,7 +243,7 @@ where // transcript span model. tower_tidx_bump += is_current_air * per_air_tower_span::( - n.clone(), + n_logup.clone(), air_data.num_read_count, air_data.num_write_count, air_data.num_logup_count, @@ -327,13 +330,6 @@ where let is_first_idx = self.idx_encoder.get_flag_expr::(0, localv.idx_flags); - // The first AIR starts immediately after the fixed trunk transcript prefix. - builder.when(is_first_idx.clone()).assert_eq( - local.starting_tidx, - AB::Expr::from_usize(TranscriptLabel::Riscv.field_len() + VmPvs::::width()) - + AB::Expr::from_usize(2 * D_EF), - ); - self.starting_tidx_bus.receive( builder, local.proof_idx, @@ -403,8 +399,29 @@ where // Receive fork transcript words after the fork label prefix. let fork_tidx_base = TranscriptLabel::Fork.field_len(); let fork_id = local.num_present - AB::F::ONE; + let fork_enabled = local.is_present * local.is_valid; + transcript_receive_label( + &self.transcript_bus, + builder, + local.proof_idx, + local.fork_start_tidx, + TranscriptLabel::Fork.as_bytes(), + fork_enabled.clone(), + ); + let fork_global_base = local.fork_start_tidx.into() + + AB::Expr::from_usize(label_field_len(TranscriptLabel::Fork.as_bytes())); // observe lookup alpha/beta for i in 0..D_EF { + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: fork_global_base.clone() + AB::Expr::from_usize(i), + value: local.lookup_challenge_alpha[i].into(), + is_sample: AB::Expr::ZERO, + }, + fork_enabled.clone(), + ); self.forked_transcript_bus.receive( builder, local.proof_idx, @@ -414,7 +431,17 @@ where value: local.lookup_challenge_alpha[i].into(), is_sample: AB::Expr::ZERO, }, - local.is_present * local.is_valid, + fork_enabled.clone(), + ); + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: fork_global_base.clone() + AB::Expr::from_usize(D_EF + i), + value: local.lookup_challenge_beta[i].into(), + is_sample: AB::Expr::ZERO, + }, + fork_enabled.clone(), ); self.forked_transcript_bus.receive( builder, @@ -425,9 +452,19 @@ where value: local.lookup_challenge_beta[i].into(), is_sample: AB::Expr::ZERO, }, - local.is_present * local.is_valid, + fork_enabled.clone(), ); } + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: fork_global_base.clone() + AB::Expr::from_usize(2 * D_EF), + value: fork_id.clone(), + is_sample: AB::Expr::ZERO, + }, + fork_enabled.clone(), + ); self.forked_transcript_bus.receive( builder, local.proof_idx, @@ -437,9 +474,19 @@ where value: fork_id.clone(), is_sample: AB::Expr::ZERO, }, - local.is_present * local.is_valid, + fork_enabled.clone(), ); - // Fork transcript metadata order is fixed: num_present, air_idx, then log_height. + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: fork_global_base.clone() + AB::Expr::from_usize(2 * D_EF + 1), + value: air_idx.clone(), + is_sample: AB::Expr::ZERO, + }, + fork_enabled.clone(), + ); + // Fork transcript metadata order is fixed: fork_id, air_idx, height_1, height_2. self.forked_transcript_bus.receive( builder, local.proof_idx, @@ -449,7 +496,17 @@ where value: air_idx.clone(), is_sample: AB::Expr::ZERO, }, - local.is_present * local.is_valid, + fork_enabled.clone(), + ); + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: fork_global_base.clone() + AB::Expr::from_usize(2 * D_EF + 2), + value: local.height_1.into(), + is_sample: AB::Expr::ZERO, + }, + fork_enabled.clone(), ); self.forked_transcript_bus.receive( builder, @@ -457,17 +514,38 @@ where ForkedTranscriptBusMessage { fork_id: fork_id.clone().into(), tidx: AB::Expr::from_usize(fork_tidx_base + 2 * D_EF + 2), - value: local.log_height.into(), + value: local.height_1.into(), is_sample: AB::Expr::ZERO, }, - local.is_present * local.is_valid, + fork_enabled.clone(), + ); + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: fork_global_base.clone() + AB::Expr::from_usize(2 * D_EF + 3), + value: local.height_2.into(), + is_sample: AB::Expr::ZERO, + }, + fork_enabled.clone(), + ); + self.forked_transcript_bus.receive( + builder, + local.proof_idx, + ForkedTranscriptBusMessage { + fork_id: fork_id.clone().into(), + tidx: AB::Expr::from_usize(fork_tidx_base + 2 * D_EF + 3), + value: local.height_2.into(), + is_sample: AB::Expr::ZERO, + }, + fork_enabled.clone(), ); // Skip the full per-air tower transcript span (out-evals, alpha/beta, // and all GKR/sumcheck layer transcript activity) before binding the // post-fork sampled challenges. let forked_challenge_1_tidx = - AB::Expr::from_usize(fork_tidx_base + 2 * D_EF + 3) + tower_tidx_bump; + AB::Expr::from_usize(fork_tidx_base + 2 * D_EF + 4) + tower_tidx_bump; // Challenge 2 starts after challenge 1 plus the product_sum label span. let forked_challenge_2_tidx = forked_challenge_1_tidx.clone() + AB::Expr::from_usize(tower_transcript_len::BETA_LEN); @@ -559,7 +637,7 @@ where AirShapeBusMessage { sort_idx: local.sorted_idx.into(), property_idx: AirShapeProperty::NumLk.to_field(), - value: num_logup_count, + value: num_logup_count.clone(), }, local.is_present * n.clone(), ); @@ -689,15 +767,20 @@ where local.is_last, ); + let fork_tower_tidx = local.fork_start_tidx.into() + + AB::Expr::from_usize(TranscriptLabel::Fork.field_len() + 2 * D_EF + 4); self.tower_module_bus.send( builder, local.proof_idx, TowerModuleMessage { - idx: air_idx.clone(), - tidx: local.starting_tidx.into(), - n_logup: n, + idx: local.sorted_idx.into(), + tidx: fork_tower_tidx, + n_logup: local.tower_layers.into(), + num_read_count: num_read_count.clone(), + num_write_count: num_write_count.clone(), + num_logup_count: num_logup_count.clone(), }, - local.is_last, + local.is_present * local.is_valid, ); // Send n_max value to expression claim air @@ -740,7 +823,7 @@ fn per_air_tower_span( num_logup_count: usize, ) -> AB::Expr { use tower_transcript_len::{ - ALPHA_BETA_LEN, ALPHA_LEN, POST_SUMCHECK_LEN, ROUND_LEN, SUMCHECK_INIT_LEN, + ALPHA_BETA_LEN, ALPHA_LEN, MERGE_LEN, ROUND_LEN, SUMCHECK_INIT_LEN, }; // Derivation notes (matches tower transcript replay order used by verifier): @@ -758,11 +841,12 @@ fn per_air_tower_span( let gkr_span = if out_eval_words == 0 { AB::Expr::ZERO } else { - let gkr_inner = n_logup.clone() * AB::Expr::from_usize(ROUND_LEN / 2) - + AB::Expr::from_usize( - ALPHA_LEN + SUMCHECK_INIT_LEN + POST_SUMCHECK_LEN - ROUND_LEN / 2, - ); - n_logup * gkr_inner - AB::Expr::from_usize(ALPHA_LEN + SUMCHECK_INIT_LEN) + let post_sumcheck_len = AB::Expr::from_usize(out_eval_words * D_EF + MERGE_LEN); + n_logup.clone() * (AB::Expr::from_usize(SUMCHECK_INIT_LEN) + post_sumcheck_len) + + n_logup.clone() + * (n_logup.clone() + AB::Expr::ONE) + * AB::Expr::from_usize(ROUND_LEN / 2) + + (n_logup - AB::Expr::ONE) * AB::Expr::from_usize(ALPHA_LEN) }; out_eval_span + AB::Expr::from_usize(ALPHA_BETA_LEN) + gkr_span diff --git a/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs b/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs index 444e78e4d..d5e7d46a2 100644 --- a/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs +++ b/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs @@ -115,6 +115,12 @@ impl RowMajorChip .get(air_idx) .map(|instances| two_instance_heights_from_chip_instances(instances)) .unwrap_or((0, 0)); + let tower_layers = proof + .chip_proofs + .get(air_idx) + .and_then(|instances| instances.first()) + .map(|instance| instance.tower_proof.proofs.len()) + .unwrap_or(0); num_present += 1; cols.proof_idx = F::from_usize(proof_idx); @@ -123,8 +129,10 @@ impl RowMajorChip cols.is_last = F::ZERO; cols.sorted_idx = F::from_usize(sorted_idx); cols.log_height = F::from_usize(log_height); + cols.tower_layers = F::from_usize(tower_layers); cols.need_rot = F::ZERO; cols.starting_tidx = F::from_usize(preflight.proof_shape.starting_tidx[*air_idx]); + cols.fork_start_tidx = F::from_usize(preflight.fork_global_offset(num_present - 1)); cols.is_present = F::ONE; cols.height_1 = F::from_usize(height_1); cols.height_2 = F::from_usize(height_2); @@ -171,8 +179,10 @@ impl RowMajorChip cols.is_last = F::ZERO; cols.sorted_idx = F::from_usize(sorted_idx); cols.log_height = F::ZERO; + cols.tower_layers = F::ZERO; cols.need_rot = F::ZERO; cols.starting_tidx = F::from_usize(preflight.proof_shape.starting_tidx[air_idx]); + cols.fork_start_tidx = F::ZERO; cols.is_present = F::ZERO; cols.height_1 = F::ZERO; cols.height_2 = F::ZERO; @@ -211,8 +221,10 @@ impl RowMajorChip cols.is_last = F::ONE; cols.sorted_idx = F::ZERO; cols.log_height = F::from_usize(preflight.proof_shape.n_logup); + cols.tower_layers = F::from_usize(preflight.proof_shape.n_logup); cols.need_rot = F::ZERO; cols.starting_tidx = F::from_usize(preflight.proof_shape.post_tidx); + cols.fork_start_tidx = F::ZERO; cols.is_present = F::ZERO; cols.height_1 = F::ZERO; cols.height_2 = F::ZERO; diff --git a/ceno_recursion_v2/src/proof_shape/pvs/air.rs b/ceno_recursion_v2/src/proof_shape/pvs/air.rs index 7626001b4..0b7d713b8 100644 --- a/ceno_recursion_v2/src/proof_shape/pvs/air.rs +++ b/ceno_recursion_v2/src/proof_shape/pvs/air.rs @@ -90,12 +90,12 @@ where .assert_one(next.is_first_in_air); let is_same_air = local.is_valid * next.is_valid * not(next.is_first_in_air); - // TODO fix first tidx to be TranscriptLabel::Riscv.field_len() - // TODO fix comment as well - // first tidx happened here builder .when(local.is_valid * local.is_first_in_proof * local.is_first_in_air) - .assert_zero(local.tidx); + .assert_eq( + local.tidx, + AB::Expr::from_usize(TranscriptLabel::Riscv.field_len()), + ); // self.num_pvs_bus.receive( // builder, diff --git a/ceno_recursion_v2/src/proof_shape/pvs/trace.rs b/ceno_recursion_v2/src/proof_shape/pvs/trace.rs index 13e6b645b..6c176f434 100644 --- a/ceno_recursion_v2/src/proof_shape/pvs/trace.rs +++ b/ceno_recursion_v2/src/proof_shape/pvs/trace.rs @@ -8,6 +8,7 @@ use crate::{ proof_shape::pvs::PublicValuesCols, system::{Preflight, RecursionField, RecursionProof, RecursionVk}, tracegen::RowMajorChip, + utils::TranscriptLabel, }; pub struct PublicValuesTraceGenerator; @@ -43,10 +44,16 @@ impl RowMajorChip for PublicValuesTraceGenerator { for (proof_idx, proof) in proofs.iter().enumerate() { let mut is_first_in_proof = true; - // TODO first tidx start from TranscriptLabel::Riscv.field_len() - let mut tidx = 0usize; + let mut tidx = TranscriptLabel::Riscv.field_len(); - for (air_idx, (_, circuit_vk)) in child_vk.circuit_vks.iter().enumerate() { + for air_idx in 0..child_vk.circuit_vks.len() { + let Some(circuit_vk) = child_vk + .circuit_index_to_name + .get(&air_idx) + .and_then(|name| child_vk.circuit_vks.get(name)) + else { + continue; + }; let instance_openings = &circuit_vk.get_cs().zkvm_v1_css.instance; if instance_openings.is_empty() { continue; diff --git a/ceno_recursion_v2/src/system/mod.rs b/ceno_recursion_v2/src/system/mod.rs index 0012264cb..a68514299 100644 --- a/ceno_recursion_v2/src/system/mod.rs +++ b/ceno_recursion_v2/src/system/mod.rs @@ -351,6 +351,28 @@ impl VerifierSubCircuit { { let mut preflight = Preflight::default(); + // The pre-verifier stage has already run VmPvs preflight and passes + // this subcircuit the advanced transcript. Rehydrate the sampled + // lookup challenges into this module's fresh Preflight so proof-shape + // rows look up the same values VmPvsAir publishes. + let initial_log = sponge.clone().into_log(); + let values = initial_log.values(); + let samples = initial_log.samples(); + debug_assert_eq!(values.len(), samples.len()); + if values.len() >= 2 * D_EF { + let alpha_start = values.len() - 2 * D_EF; + let beta_start = values.len() - D_EF; + debug_assert!(samples[alpha_start..].iter().all(|is_sample| *is_sample)); + preflight.vm_pvs.lookup_challenge_alpha = + EF::from_basis_coefficients_slice(&values[alpha_start..beta_start]) + .unwrap_or(EF::ZERO); + preflight.vm_pvs.lookup_challenge_beta = + EF::from_basis_coefficients_slice(&values[beta_start..]).unwrap_or(EF::ZERO); + let present_air_count = proof.chip_proofs.len(); + preflight.vm_pvs.lookup_challenge_alpha_lookup_count = present_air_count; + preflight.vm_pvs.lookup_challenge_beta_lookup_count = present_air_count; + } + // Phase 1: Trunk operations. // Proof-shape metadata and alpha/beta sampling after pre-verifier transcript observes. self.proof_shape @@ -367,7 +389,23 @@ impl VerifierSubCircuit { let fork_offset = sponge.len(); // Phase 2: Fork — fresh transcript per chip proof instance. - let chip_proof_list = Self::build_chip_proof_list(proof); + let mut chip_proof_list = Self::build_chip_proof_list(proof); + let sorted_idx_by_chip: std::collections::BTreeMap = preflight + .proof_shape + .sorted_trace_vdata + .iter() + .enumerate() + .map(|(sorted_idx, (chip_idx, _))| (*chip_idx, sorted_idx)) + .collect(); + chip_proof_list.sort_by_key(|(chip_idx, instance_idx, _)| { + ( + sorted_idx_by_chip + .get(chip_idx) + .copied() + .unwrap_or(usize::MAX), + *instance_idx, + ) + }); // `TS::from(poseidon2_perm())` is the generic equivalent of // `default_duplex_sponge_recorder()` used by the inner prover. let mut fork_sponges: Vec = (0..chip_proof_list.len()) diff --git a/ceno_recursion_v2/src/tower/bus.rs b/ceno_recursion_v2/src/tower/bus.rs index ab41c3c30..5d4b4878b 100644 --- a/ceno_recursion_v2/src/tower/bus.rs +++ b/ceno_recursion_v2/src/tower/bus.rs @@ -18,6 +18,7 @@ define_typed_per_proof_permutation_bus!(TowerXiSamplerBus, TowerXiSamplerMessage pub struct TowerLayerInputMessage { pub idx: T, pub tidx: T, + pub beta_logup: [T; D_EF], pub r0_claim: [T; D_EF], pub w0_claim: [T; D_EF], pub q0_claim: [T; D_EF], @@ -48,6 +49,7 @@ pub struct TowerProdLayerChallengeMessage { pub lambda: [T; D_EF], pub lambda_prime: [T; D_EF], pub mu: [T; D_EF], + pub root_prime_claim: [T; D_EF], } define_typed_per_proof_permutation_bus!(TowerProdReadClaimInputBus, TowerProdLayerChallengeMessage); @@ -78,6 +80,7 @@ pub struct TowerLogupLayerChallengeMessage { pub lambda: [T; D_EF], pub lambda_prime: [T; D_EF], pub mu: [T; D_EF], + pub root_prime_claim: [T; D_EF], } define_typed_per_proof_permutation_bus!(TowerLogupClaimInputBus, TowerLogupLayerChallengeMessage); diff --git a/ceno_recursion_v2/src/tower/input/air.rs b/ceno_recursion_v2/src/tower/input/air.rs index d30ed4c7e..2731c8587 100644 --- a/ceno_recursion_v2/src/tower/input/air.rs +++ b/ceno_recursion_v2/src/tower/input/air.rs @@ -2,9 +2,11 @@ use core::borrow::Borrow; use crate::{ bus::{MainBus, MainMessage, TowerModuleBus, TowerModuleMessage, TranscriptBus}, + subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, tower::bus::{ TowerLayerInputBus, TowerLayerInputMessage, TowerLayerOutputBus, TowerLayerOutputMessage, }, + utils::{label_field_len, transcript_receive_label}, }; use openvm_circuit_primitives::{ SubAir, @@ -18,10 +20,7 @@ use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{Field, PrimeCharacteristicRing}; use p3_matrix::Matrix; -use recursion_circuit::{ - subairs::proof_idx::{ProofIdxIoCols, ProofIdxSubAir}, - utils::assert_zeros, -}; +use recursion_circuit::utils::assert_zeros; use stark_recursion_circuit_derive::AlignedBorrow; #[repr(C)] @@ -32,8 +31,13 @@ pub struct TowerInputCols { pub proof_idx: T, pub idx: T, + pub is_first_idx: T, + pub is_first: T, pub n_logup: T, + pub num_read_count: T, + pub num_write_count: T, + pub num_logup_count: T, /// Flag indicating whether there are any interactions /// n_logup = 0 <=> total_interactions = 0 @@ -49,6 +53,7 @@ pub struct TowerInputCols { pub q0_claim: [T; D_EF], pub alpha_logup: [T; D_EF], + pub beta_logup: [T; D_EF], pub input_layer_claim: [T; D_EF], pub layer_output_lambda: [T; D_EF], @@ -88,21 +93,20 @@ impl Air for TowerInputAir { // Proof Index Constraints /////////////////////////////////////////////////////////////////////// - // This subair has the following constraints: - // 1. Boolean enabled flag - // 2. Disabled rows are followed by disabled rows - // 3. Proof index increments by exactly one between enabled rows - ProofIdxSubAir.eval( + type LoopSubAir = NestedForLoopSubAir<2>; + LoopSubAir {}.eval( builder, ( - ProofIdxIoCols { + NestedForLoopIoCols { is_enabled: local.is_enabled, - proof_idx: local.proof_idx, + counter: [local.proof_idx, local.idx], + is_first: [local.is_first_idx, local.is_first], } .map_into(), - ProofIdxIoCols { + NestedForLoopIoCols { is_enabled: next.is_enabled, - proof_idx: next.proof_idx, + counter: [next.proof_idx, next.idx], + is_first: [next.is_first_idx, next.is_first], } .map_into(), ), @@ -152,20 +156,22 @@ impl Air for TowerInputAir { // Add PoW (if any) and alpha label+sample, beta label+sample use crate::tower::tower_transcript_len::{ - ALPHA_BETA_LEN, ALPHA_LEN, POST_SUMCHECK_LEN, ROUND_LEN, SUMCHECK_INIT_LEN, + ALPHA_BETA_LEN, ALPHA_LEN, MERGE_LEN, ROUND_LEN, SUMCHECK_INIT_LEN, }; let tidx_after_alpha_beta = local.tidx + AB::Expr::from_usize(ALPHA_BETA_LEN); // Add GKR layers + Sumcheck. - // Total GKR span: n*(10n+25) - 13 for n>0. - // layers_cumulative(n) = 10n² + 25n - 13. - let gkr_inner = num_layers.clone() * AB::Expr::from_usize(ROUND_LEN / 2) - + AB::Expr::from_usize( - ALPHA_LEN + SUMCHECK_INIT_LEN + POST_SUMCHECK_LEN - ROUND_LEN / 2, - ); - let tidx_after_gkr_layers = tidx_after_alpha_beta.clone() - + has_interactions.clone() - * (num_layers.clone() * gkr_inner - - AB::Expr::from_usize(ALPHA_LEN + SUMCHECK_INIT_LEN)); + let claim_span = (local.num_read_count.into() * AB::Expr::from_usize(2 * D_EF)) + + (local.num_write_count.into() * AB::Expr::from_usize(2 * D_EF)) + + (local.num_logup_count.into() * AB::Expr::from_usize(4 * D_EF)); + let post_sumcheck_len = claim_span + AB::Expr::from_usize(MERGE_LEN); + let gkr_span = num_layers.clone() + * (AB::Expr::from_usize(SUMCHECK_INIT_LEN) + post_sumcheck_len) + + num_layers.clone() + * (num_layers.clone() + AB::Expr::ONE) + * AB::Expr::from_usize(ROUND_LEN / 2) + + (num_layers.clone() - AB::Expr::ONE) * AB::Expr::from_usize(ALPHA_LEN); + let tidx_after_gkr_layers = + tidx_after_alpha_beta.clone() + has_interactions.clone() * gkr_span; // 1. TowerLayerInputBus // 1a. Send input to TowerLayerAir self.layer_input_bus.send( @@ -174,6 +180,7 @@ impl Air for TowerInputAir { TowerLayerInputMessage { idx: local.idx.into(), tidx: tidx_after_alpha_beta.clone() * has_interactions.clone(), + beta_logup: local.beta_logup.map(Into::into), r0_claim: local.r0_claim.map(Into::into), w0_claim: local.w0_claim.map(Into::into), q0_claim: local.q0_claim.map(Into::into), @@ -208,19 +215,49 @@ impl Air for TowerInputAir { idx: local.idx.into(), tidx: local.tidx.into(), n_logup: local.n_logup.into(), + num_read_count: local.num_read_count.into(), + num_write_count: local.num_write_count.into(), + num_logup_count: local.num_logup_count.into(), }, local.is_enabled, ); // 2. TranscriptBus - // 2a. Sample alpha_logup challenge - self.transcript_bus.sample_ext( + // 2a. Observe native labels and sample alpha/beta challenges. + let alpha_label = b"combine subset evals"; + transcript_receive_label( + &self.transcript_bus, builder, local.proof_idx, local.tidx, + alpha_label, + local.is_enabled, + ); + let beta_label_tidx = + local.tidx + AB::Expr::from_usize(label_field_len(alpha_label) + D_EF); + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + local.tidx + AB::Expr::from_usize(label_field_len(alpha_label)), local.alpha_logup.map(Into::into), local.is_enabled, ); + let beta_label = b"product_sum"; + transcript_receive_label( + &self.transcript_bus, + builder, + local.proof_idx, + beta_label_tidx.clone(), + beta_label, + local.is_enabled, + ); + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + beta_label_tidx + AB::Expr::from_usize(label_field_len(beta_label)), + local.beta_logup.map(Into::into), + local.is_enabled, + ); self.main_bus.send( builder, local.proof_idx, diff --git a/ceno_recursion_v2/src/tower/input/trace.rs b/ceno_recursion_v2/src/tower/input/trace.rs index 6b7340604..8c398b099 100644 --- a/ceno_recursion_v2/src/tower/input/trace.rs +++ b/ceno_recursion_v2/src/tower/input/trace.rs @@ -13,7 +13,11 @@ pub struct TowerInputRecord { pub idx: usize, pub tidx: usize, pub n_logup: usize, + pub num_read_count: usize, + pub num_write_count: usize, + pub num_logup_count: usize, pub alpha_logup: EF, + pub beta_logup: EF, pub input_layer_claim: EF, pub layer_output_lambda: EF, pub layer_output_mu: EF, @@ -50,19 +54,28 @@ impl RowMajorChip for TowerInputTraceGenerator { let (data_slice, _) = trace.split_at_mut(num_valid_rows * width); + let mut prev_proof_idx = usize::MAX; + let mut prev_idx = usize::MAX; for (row_data, (record, q0_claim)) in data_slice .chunks_exact_mut(width) .zip(gkr_input_records.iter().zip(q0_claims.iter())) { let cols: &mut TowerInputCols = row_data.borrow_mut(); + let is_new_proof_idx = prev_proof_idx != record.proof_idx; + let is_new_idx = is_new_proof_idx || prev_idx != record.idx; cols.is_enabled = F::ONE; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); + cols.is_first_idx = F::from_bool(is_new_proof_idx); + cols.is_first = F::from_bool(is_new_idx); cols.tidx = F::from_usize(record.tidx); cols.n_logup = F::from_usize(record.n_logup); + cols.num_read_count = F::from_usize(record.num_read_count); + cols.num_write_count = F::from_usize(record.num_write_count); + cols.num_logup_count = F::from_usize(record.num_logup_count); IsZeroSubAir.generate_subrow( cols.n_logup, (&mut cols.is_n_logup_zero_aux.inv, &mut cols.is_n_logup_zero), @@ -77,6 +90,11 @@ impl RowMajorChip for TowerInputTraceGenerator { .as_basis_coefficients_slice() .try_into() .unwrap(); + cols.beta_logup = record + .beta_logup + .as_basis_coefficients_slice() + .try_into() + .unwrap(); cols.input_layer_claim = record .input_layer_claim .as_basis_coefficients_slice() @@ -92,6 +110,9 @@ impl RowMajorChip for TowerInputTraceGenerator { .as_basis_coefficients_slice() .try_into() .unwrap(); + + prev_proof_idx = record.proof_idx; + prev_idx = record.idx; } Some(RowMajorMatrix::new(trace, width)) diff --git a/ceno_recursion_v2/src/tower/layer/air.rs b/ceno_recursion_v2/src/tower/layer/air.rs index ce786b746..e119179bc 100644 --- a/ceno_recursion_v2/src/tower/layer/air.rs +++ b/ceno_recursion_v2/src/tower/layer/air.rs @@ -25,12 +25,13 @@ use crate::{ TowerSumcheckOutputMessage, }, }, + utils::{label_field_len, transcript_receive_label}, }; use recursion_circuit::{ - bus::TranscriptBus, + bus::{TranscriptBus, TranscriptBusMessage}, subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, - utils::{assert_zeros, ext_field_add, ext_field_multiply}, + utils::ext_field_add, }; #[repr(C)] @@ -59,6 +60,8 @@ pub struct TowerLayerCols { pub lambda_prime: [T; D_EF], /// Reduction point pub mu: [T; D_EF], + /// Initial point sampled before the tower layer loop. + pub beta_logup: [T; D_EF], pub sumcheck_claim_in: [T; D_EF], @@ -71,6 +74,7 @@ pub struct TowerLayerCols { pub num_read_count: T, pub num_write_count: T, pub num_logup_count: T, + pub sumcheck_claim_out: [T; D_EF], /// Received from TowerLayerSumcheckAir pub eq_at_r_prime: [T; D_EF], @@ -183,15 +187,6 @@ where local.lambda, ); - /////////////////////////////////////////////////////////////////////// - // Root Layer Constraints - /////////////////////////////////////////////////////////////////////// - - assert_zeros( - &mut builder.when(local.is_first), - local.sumcheck_claim_in.map(Into::into), - ); - /////////////////////////////////////////////////////////////////////// // Inter-Layer Constraints /////////////////////////////////////////////////////////////////////// @@ -206,14 +201,26 @@ where // Transcript index increment use crate::tower::tower_transcript_len::{ - ALPHA_LEN, POST_SUMCHECK_LEN, ROUND_LEN, SUMCHECK_INIT_LEN, + ALPHA_LEN, MERGE_LEN, ROUND_LEN, SUMCHECK_INIT_LEN, }; + let non_root_layer = AB::Expr::ONE - local.is_first; + let sumcheck_init_tidx = + local.tidx + non_root_layer.clone() * AB::Expr::from_usize(ALPHA_LEN); let tidx_after_sumcheck = local.tidx - // Sample lambda label+sample on non-root layer - + (AB::Expr::ONE - local.is_first) - * AB::Expr::from_usize(ALPHA_LEN + SUMCHECK_INIT_LEN) - + local.layer_idx * AB::Expr::from_usize(ROUND_LEN); - let tidx_end = tidx_after_sumcheck.clone() + AB::Expr::from_usize(POST_SUMCHECK_LEN); + + non_root_layer.clone() * AB::Expr::from_usize(ALPHA_LEN) + + AB::Expr::from_usize(SUMCHECK_INIT_LEN) + + (local.layer_idx + AB::Expr::ONE) * AB::Expr::from_usize(ROUND_LEN); + let read_count: AB::Expr = local.num_read_count.into(); + let write_count: AB::Expr = local.num_write_count.into(); + let logup_count: AB::Expr = local.num_logup_count.into(); + let read_claim_span = read_count.clone() * AB::Expr::from_usize(2 * D_EF); + let write_claim_span = write_count.clone() * AB::Expr::from_usize(2 * D_EF); + let logup_claim_span = logup_count.clone() * AB::Expr::from_usize(4 * D_EF); + let read_claim_tidx = tidx_after_sumcheck.clone(); + let write_claim_tidx = read_claim_tidx.clone() + read_claim_span; + let logup_claim_tidx = write_claim_tidx.clone() + write_claim_span; + let merge_label_tidx = logup_claim_tidx.clone() + logup_claim_span; + let tidx_end = merge_label_tidx.clone() + AB::Expr::from_usize(MERGE_LEN); builder .when(is_transition.clone()) .assert_eq(next.tidx, tidx_end.clone()); @@ -257,19 +264,19 @@ where lookup_enable.clone(), ); - let tidx_for_claims = tidx_after_sumcheck.clone(); self.prod_read_claim_input_bus.send( builder, local.proof_idx, TowerProdLayerChallengeMessage { idx: local.idx.into(), layer_idx: local.layer_idx.into(), - tidx: tidx_for_claims.clone(), + tidx: read_claim_tidx, lambda: local.lambda.map(Into::into), lambda_prime: local.lambda_prime.map(Into::into), mu: local.mu.map(Into::into), + root_prime_claim: local.read_claim_prime.map(Into::into), }, - is_not_dummy.clone(), + is_not_dummy.clone() * read_count.clone(), ); // TODO separate lambda, lambda_prime for prod-write the relation should be local.lambda^(num_read) self.prod_write_claim_input_bus.send( @@ -278,12 +285,13 @@ where TowerProdLayerChallengeMessage { idx: local.idx.into(), layer_idx: local.layer_idx.into(), - tidx: tidx_for_claims.clone(), + tidx: write_claim_tidx, lambda: local.lambda.map(Into::into), lambda_prime: local.lambda_prime.map(Into::into), mu: local.mu.map(Into::into), + root_prime_claim: local.write_claim_prime.map(Into::into), }, - is_not_dummy.clone(), + is_not_dummy.clone() * write_count.clone(), ); // TODO separate lambda, lambda_prime for logup the relation should be local.lambda^(num_read + num_write) self.logup_claim_input_bus.send( @@ -292,12 +300,13 @@ where TowerLogupLayerChallengeMessage { idx: local.idx.into(), layer_idx: local.layer_idx.into(), - tidx: tidx_for_claims.clone(), + tidx: logup_claim_tidx, lambda: local.lambda.map(Into::into), lambda_prime: local.lambda_prime.map(Into::into), mu: local.mu.map(Into::into), + root_prime_claim: local.logup_claim_prime.map(Into::into), }, - is_not_dummy.clone(), + is_not_dummy.clone() * logup_count.clone(), ); self.prod_read_claim_bus.receive( builder, @@ -309,7 +318,7 @@ where lambda_prime_claim: local.read_claim_prime.map(Into::into), num_prod_count: local.num_read_count.into(), }, - is_not_dummy.clone(), + is_not_dummy.clone() * read_count, ); self.prod_write_claim_bus.receive( builder, @@ -321,7 +330,7 @@ where lambda_prime_claim: local.write_claim_prime.map(Into::into), num_prod_count: local.num_write_count.into(), }, - is_not_dummy.clone(), + is_not_dummy.clone() * write_count, ); self.logup_claim_bus.receive( builder, @@ -333,7 +342,7 @@ where lambda_prime_claim: local.logup_claim_prime.map(Into::into), num_logup_count: local.num_logup_count.into(), }, - is_not_dummy.clone(), + is_not_dummy.clone() * logup_count, ); let root_layer_mask = local.is_first * is_not_dummy.clone(); @@ -348,7 +357,7 @@ where local.w0_claim, ); assert_array_eq( - &mut builder.when(root_layer_mask), + &mut builder.when(root_layer_mask.clone()), local.logup_claim_prime, local.q0_claim, ); @@ -361,11 +370,12 @@ where TowerLayerInputMessage { idx: local.idx.into(), tidx: local.tidx.into(), + beta_logup: local.beta_logup.map(Into::into), r0_claim: local.r0_claim.map(Into::into), w0_claim: local.w0_claim.map(Into::into), q0_claim: local.q0_claim.map(Into::into), }, - local.is_first_air_idx * is_not_dummy.clone(), + local.is_first * is_not_dummy.clone(), ); // 2. TowerLayerOutputBus // 2a. Send GKR input layer claims back @@ -392,18 +402,13 @@ where idx: local.idx.into(), layer_idx: local.layer_idx.into(), is_last_layer: is_last.clone(), - tidx: local.tidx + AB::Expr::from_usize(ALPHA_LEN + SUMCHECK_INIT_LEN), + tidx: sumcheck_init_tidx.clone() + AB::Expr::from_usize(SUMCHECK_INIT_LEN), claim: local.sumcheck_claim_in.map(Into::into), }, - is_non_root_layer.clone() * is_not_dummy.clone(), + local.is_enabled * is_not_dummy.clone(), ); // 3. TowerSumcheckOutputBus // 3a. Receive sumcheck results - let prime_fold = ext_field_add::(local.read_claim_prime, local.write_claim_prime); - let sumcheck_claim_out = ext_field_multiply::( - ext_field_add::(prime_fold, local.logup_claim_prime), - local.eq_at_r_prime, - ); self.sumcheck_output_bus.receive( builder, local.proof_idx, @@ -411,20 +416,32 @@ where idx: local.idx.into(), layer_idx: local.layer_idx.into(), tidx: tidx_after_sumcheck.clone(), - claim_out: sumcheck_claim_out.map(Into::into), + claim_out: local.sumcheck_claim_out.map(Into::into), eq_at_r_prime: local.eq_at_r_prime.map(Into::into), }, - is_non_root_layer.clone() * is_not_dummy.clone(), + local.is_enabled * is_not_dummy.clone(), ); // 4. TowerSumcheckChallengeBus - // 4a. Send challenge mu + // 4a. Send the root sumcheck's initial point. self.sumcheck_challenge_bus.send( builder, local.proof_idx, TowerSumcheckChallengeMessage { idx: local.idx.into(), - layer_idx: local.layer_idx.into(), + layer_idx: AB::Expr::ZERO, sumcheck_round: AB::Expr::ZERO, + challenge: local.beta_logup.map(Into::into), + }, + root_layer_mask, + ); + // 4b. Send merge challenge to the final round slot of the next layer. + self.sumcheck_challenge_bus.send( + builder, + local.proof_idx, + TowerSumcheckChallengeMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx + AB::Expr::ONE, + sumcheck_round: local.layer_idx + AB::Expr::ONE, challenge: local.mu.map(Into::into), }, is_transition.clone() * is_not_dummy.clone(), @@ -440,20 +457,78 @@ where // in last layer: for send back to GKR input layer // 1a. Sample `lambda` — only on non-root layers. // Root layer uses alpha_logup (set in trace), not a transcript sample. - self.transcript_bus.sample_ext( + let combine_label = b"combine subset evals"; + transcript_receive_label( + &self.transcript_bus, builder, local.proof_idx, local.tidx, + combine_label, + is_non_root_layer.clone() * is_not_dummy.clone(), + ); + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + local.tidx + AB::Expr::from_usize(label_field_len(combine_label)), local.lambda, is_non_root_layer.clone() * is_not_dummy.clone(), ); - // 1b. Observe layer claims - let tidx = tidx_after_sumcheck; + let init_enabled = local.is_enabled * is_not_dummy.clone(); + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: sumcheck_init_tidx.clone(), + value: local.layer_idx + AB::Expr::ONE, + is_sample: AB::Expr::ZERO, + }, + init_enabled.clone(), + ); + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: sumcheck_init_tidx.clone() + AB::Expr::ONE, + value: AB::Expr::ZERO, + is_sample: AB::Expr::ZERO, + }, + init_enabled.clone(), + ); + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: sumcheck_init_tidx.clone() + AB::Expr::TWO, + value: AB::Expr::from_u32(3), + is_sample: AB::Expr::ZERO, + }, + init_enabled.clone(), + ); + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: sumcheck_init_tidx + AB::Expr::from_u32(3), + value: AB::Expr::ZERO, + is_sample: AB::Expr::ZERO, + }, + init_enabled, + ); + // 1b. Observe layer claims in the claim AIRs, then sample the merge challenge. + let merge_label = b"merge"; + transcript_receive_label( + &self.transcript_bus, + builder, + local.proof_idx, + merge_label_tidx.clone(), + merge_label, + local.is_enabled * is_not_dummy.clone(), + ); // 1c. Sample `mu` self.transcript_bus.sample_ext( builder, local.proof_idx, - tidx, + merge_label_tidx + AB::Expr::from_usize(label_field_len(merge_label)), local.mu, local.is_enabled * is_not_dummy.clone(), ); diff --git a/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs b/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs index 0c0fef221..6c4493612 100644 --- a/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs +++ b/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs @@ -27,6 +27,7 @@ pub struct TowerLogupSumCheckClaimCols { pub idx: T, pub is_first_layer: T, pub is_first: T, + pub is_root_layer: T, pub is_dummy: T, pub layer_idx: T, @@ -36,6 +37,7 @@ pub struct TowerLogupSumCheckClaimCols { pub lambda: [T; D_EF], pub lambda_prime: [T; D_EF], pub mu: [T; D_EF], + pub root_prime_claim: [T; D_EF], pub p_xi_0: [T; D_EF], pub p_xi_1: [T; D_EF], @@ -83,6 +85,10 @@ where builder.assert_bool(local.is_dummy); builder.assert_bool(local.is_first_layer); + builder.assert_bool(local.is_root_layer); + builder + .when(local.is_root_layer) + .assert_zero(local.layer_idx); /////////////////////////////////////////////////////////////////////// // Structural constraints (replaces NestedForLoopSubAir<2>) @@ -260,6 +266,12 @@ where q_cross_term, ); let acc_q_with_cur = ext_field_add::(local.acc_q_cross, scaled_q_term); + let is_root_layer: AB::Expr = local.is_root_layer.into(); + let is_non_root_layer = AB::Expr::ONE - is_root_layer.clone(); + let lambda_prime_claim = core::array::from_fn(|i| { + local.root_prime_claim[i].into() * is_root_layer.clone() + + acc_q_with_cur[i].clone() * is_non_root_layer.clone() + }); assert_array_eq( &mut builder.when(is_within_layer.clone()), next.acc_q_cross, @@ -273,6 +285,7 @@ where pow_lambda_prime_next, ); + let num_logup_count: AB::Expr = local.num_logup_count.into(); self.logup_claim_input_bus.receive( builder, local.proof_idx, @@ -283,8 +296,9 @@ where lambda: lambda.clone(), lambda_prime: lambda_prime.clone(), mu: local.mu.map(Into::into), + root_prime_claim: local.root_prime_claim.map(Into::into), }, - local.is_first.into(), + local.is_first * is_not_dummy.clone() * num_logup_count.clone(), ); self.logup_claim_bus.send( @@ -294,10 +308,10 @@ where idx: local.idx.into(), layer_idx: local.layer_idx.into(), lambda_claim: acc_sum_export.map(Into::into), - lambda_prime_claim: acc_q_with_cur.map(Into::into), + lambda_prime_claim: lambda_prime_claim.map(Into::into), num_logup_count: local.num_logup_count.into(), }, - is_layer_end, + is_layer_end * is_not_dummy.clone() * num_logup_count, ); let mut tidx = local.tidx.into(); diff --git a/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs b/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs index b1b0d92f0..986869a8b 100644 --- a/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs +++ b/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs @@ -76,6 +76,7 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { cols.is_enabled = F::ONE; cols.is_first_layer = F::from_bool(record.is_first_air_idx); cols.is_first = F::ONE; // single row = first of its (degenerate) layer + cols.is_root_layer = F::ONE; cols.is_dummy = F::ONE; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); @@ -87,6 +88,7 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { lambda_prime_one[0] = F::ONE; cols.lambda_prime = lambda_prime_one; cols.mu = [F::ZERO; D_EF]; + cols.root_prime_claim = [F::ZERO; D_EF]; cols.p_xi_0 = [F::ZERO; D_EF]; cols.p_xi_1 = [F::ZERO; D_EF]; cols.q_xi_0 = [F::ZERO; D_EF]; @@ -111,7 +113,8 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { .get(layer_idx) .map(|rows| rows.as_slice()) .unwrap_or(&[]); - let total_rows = record.logup_count_at(layer_idx).max(1); + let actual_rows = record.logup_count_at(layer_idx); + let total_rows = actual_rows.max(1); debug_assert!( total_rows == logup_rows.len().max(1), "unexpected logup count mismatch at layer {layer_idx}" @@ -127,7 +130,16 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { .try_into() .unwrap(); let mu_basis: [F; D_EF] = mu.as_basis_coefficients_slice().try_into().unwrap(); - let layer_tidx = record.claim_tidx(layer_idx); + let root_prime_claim = if layer_idx == 0 { + record.root_logup_prime_claim + } else { + record.logup_claim_at(layer_idx).1 + }; + let root_prime_basis: [F; D_EF] = root_prime_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let layer_tidx = record.logup_claim_tidx(layer_idx); let mut pow_lambda = EF::ONE; let mut pow_lambda_prime = EF::ONE; @@ -181,6 +193,7 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { cols.is_first_layer = F::from_bool(is_first_row_of_record && record.is_first_air_idx); cols.is_first = F::from_bool(is_first_row_of_layer); + cols.is_root_layer = F::from_bool(layer_idx == 0); cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); cols.layer_idx = F::from_usize(layer_idx); @@ -189,6 +202,7 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { cols.lambda = lambda_basis; cols.lambda_prime = lambda_prime_basis; cols.mu = mu_basis; + cols.root_prime_claim = root_prime_basis; cols.p_xi_0 = p_xi_0.as_basis_coefficients_slice().try_into().unwrap(); cols.p_xi_1 = p_xi_1.as_basis_coefficients_slice().try_into().unwrap(); cols.q_xi_0 = q_xi_0.as_basis_coefficients_slice().try_into().unwrap(); @@ -210,7 +224,7 @@ impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { .as_basis_coefficients_slice() .try_into() .unwrap(); - cols.num_logup_count = F::from_usize(total_rows); + cols.num_logup_count = F::from_usize(actual_rows); acc_sum += contribution; acc_p_cross += p_cross_contribution; diff --git a/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs b/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs index 27ad5a1de..98a7a0316 100644 --- a/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs +++ b/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs @@ -27,6 +27,7 @@ pub struct TowerProdSumCheckClaimCols { pub idx: T, pub is_first_layer: T, pub is_first: T, + pub is_root_layer: T, pub is_dummy: T, pub layer_idx: T, @@ -36,6 +37,7 @@ pub struct TowerProdSumCheckClaimCols { pub lambda: [T; D_EF], pub lambda_prime: [T; D_EF], pub mu: [T; D_EF], + pub root_prime_claim: [T; D_EF], pub p_xi_0: [T; D_EF], pub p_xi_1: [T; D_EF], pub p_xi: [T; D_EF], @@ -91,6 +93,10 @@ impl TowerProdSumCheckClaimAir { builder.assert_bool(local.is_dummy); builder.assert_bool(local.is_first_layer); + builder.assert_bool(local.is_root_layer); + builder + .when(local.is_root_layer) + .assert_zero(local.layer_idx); /////////////////////////////////////////////////////////////////////// // Structural constraints (replaces NestedForLoopSubAir<2>) @@ -240,7 +246,12 @@ impl TowerProdSumCheckClaimAir { ext_field_multiply::(pow_lambda_prime.clone(), prime_product); let acc_sum_prime_with_cur = ext_field_add::(local.acc_sum_prime, prime_contribution); - let acc_sum_prime_export = acc_sum_prime_with_cur.clone(); + let is_root_layer: AB::Expr = local.is_root_layer.into(); + let is_non_root_layer = AB::Expr::ONE - is_root_layer.clone(); + let acc_sum_prime_export = core::array::from_fn(|i| { + local.root_prime_claim[i].into() * is_root_layer.clone() + + acc_sum_prime_with_cur[i].clone() * is_non_root_layer.clone() + }); assert_array_eq( &mut builder.when(is_within_layer.clone()), @@ -269,6 +280,7 @@ impl TowerProdSumCheckClaimAir { pow_lambda_prime_next, ); + let num_prod_count: AB::Expr = local.num_prod_count.into(); recv_challenge( &self.prod_claim_input_bus, builder, @@ -280,8 +292,9 @@ impl TowerProdSumCheckClaimAir { lambda, lambda_prime: lambda_prime.clone(), mu: local.mu.map(Into::into), + root_prime_claim: local.root_prime_claim.map(Into::into), }, - local.is_first.into(), + local.is_first * is_not_dummy.clone() * num_prod_count.clone(), ); send_claim( @@ -295,7 +308,7 @@ impl TowerProdSumCheckClaimAir { lambda_prime_claim: acc_sum_prime_export.map(Into::into), num_prod_count: local.num_prod_count.into(), }, - is_layer_end, + is_layer_end * is_not_dummy.clone() * num_prod_count, ); let mut tidx = local.tidx.into(); diff --git a/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs b/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs index 8253341dd..8368bfbb7 100644 --- a/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs +++ b/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs @@ -84,6 +84,7 @@ fn generate_prod_trace( cols.is_enabled = F::ONE; cols.is_first_layer = F::from_bool(record.is_first_air_idx); cols.is_first = F::ONE; // single row = first of its (degenerate) layer + cols.is_root_layer = F::ONE; cols.is_dummy = F::ONE; cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); @@ -95,6 +96,7 @@ fn generate_prod_trace( lambda_prime_one[0] = F::ONE; cols.lambda_prime = lambda_prime_one; cols.mu = [F::ZERO; D_EF]; + cols.root_prime_claim = [F::ZERO; D_EF]; cols.p_xi_0 = [F::ZERO; D_EF]; cols.p_xi_1 = [F::ZERO; D_EF]; cols.p_xi = [F::ZERO; D_EF]; @@ -123,11 +125,12 @@ fn generate_prod_trace( .map(|rows| rows.as_slice()) .unwrap_or(&[]) }; - let total_rows = if is_write { - record.write_count_at(layer_idx).max(1) + let actual_rows = if is_write { + record.write_count_at(layer_idx) } else { - record.read_count_at(layer_idx).max(1) + record.read_count_at(layer_idx) }; + let total_rows = actual_rows.max(1); debug_assert!( total_rows == active_rows.len().max(1), "unexpected prod count mismatch at layer {layer_idx}" @@ -142,7 +145,26 @@ fn generate_prod_trace( .try_into() .unwrap(); let mu_basis: [F; D_EF] = mu.as_basis_coefficients_slice().try_into().unwrap(); - let layer_tidx = record.claim_tidx(layer_idx); + let root_prime_claim = if layer_idx == 0 { + if is_write { + record.root_write_prime_claim + } else { + record.root_read_prime_claim + } + } else if is_write { + record.write_claim_at(layer_idx).1 + } else { + record.read_claim_at(layer_idx).1 + }; + let root_prime_basis: [F; D_EF] = root_prime_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let layer_tidx = if is_write { + record.write_claim_tidx(layer_idx) + } else { + record.read_claim_tidx(layer_idx) + }; let mut pow_lambda = EF::ONE; let mut pow_lambda_prime = EF::ONE; @@ -178,6 +200,7 @@ fn generate_prod_trace( cols.is_first_layer = F::from_bool(is_first_row_of_record && record.is_first_air_idx); cols.is_first = F::from_bool(is_first_row_of_layer); + cols.is_root_layer = F::from_bool(layer_idx == 0); cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); cols.layer_idx = F::from_usize(layer_idx); @@ -186,6 +209,7 @@ fn generate_prod_trace( cols.lambda = lambda_basis; cols.lambda_prime = lambda_prime_basis; cols.mu = mu_basis; + cols.root_prime_claim = root_prime_basis; cols.p_xi_0 = p_xi_0.as_basis_coefficients_slice().try_into().unwrap(); cols.p_xi_1 = p_xi_1.as_basis_coefficients_slice().try_into().unwrap(); cols.p_xi = p_xi.as_basis_coefficients_slice().try_into().unwrap(); @@ -199,7 +223,7 @@ fn generate_prod_trace( .as_basis_coefficients_slice() .try_into() .unwrap(); - cols.num_prod_count = F::from_usize(total_rows); + cols.num_prod_count = F::from_usize(actual_rows); acc_sum += contribution; acc_sum_prime += prime_contribution; diff --git a/ceno_recursion_v2/src/tower/layer/trace.rs b/ceno_recursion_v2/src/tower/layer/trace.rs index 1da4f96fe..ee492c76a 100644 --- a/ceno_recursion_v2/src/tower/layer/trace.rs +++ b/ceno_recursion_v2/src/tower/layer/trace.rs @@ -27,7 +27,13 @@ pub struct TowerLayerRecord { pub write_prime_claims: Vec, pub logup_claims: Vec, pub logup_prime_claims: Vec, + pub beta_logup: EF, pub sumcheck_claims: Vec, + pub sumcheck_claim_outs: Vec, + pub sumcheck_eq_outs: Vec, + pub root_read_prime_claim: EF, + pub root_write_prime_claim: EF, + pub root_logup_prime_claim: EF, } impl TowerLayerRecord { @@ -111,28 +117,57 @@ impl TowerLayerRecord { #[inline] pub(crate) fn layer_tidx(&self, layer_idx: usize) -> usize { - self.tidx + tower_transcript_len::layers_cumulative(layer_idx) + let mut span = 0; + for i in 0..layer_idx { + span += self.layer_span_at(i); + } + self.tidx + span } #[inline] pub(crate) fn read_count_at(&self, layer_idx: usize) -> usize { - self.read_counts.get(layer_idx).copied().unwrap_or(1) + self.read_counts.get(layer_idx).copied().unwrap_or(0) } #[inline] pub(crate) fn write_count_at(&self, layer_idx: usize) -> usize { - self.write_counts.get(layer_idx).copied().unwrap_or(1) + self.write_counts.get(layer_idx).copied().unwrap_or(0) } #[inline] pub(crate) fn logup_count_at(&self, layer_idx: usize) -> usize { - self.logup_counts.get(layer_idx).copied().unwrap_or(1) + self.logup_counts.get(layer_idx).copied().unwrap_or(0) + } + + #[inline] + pub(crate) fn layer_span_at(&self, layer_idx: usize) -> usize { + tower_transcript_len::layer_span( + layer_idx, + self.read_count_at(layer_idx), + self.write_count_at(layer_idx), + self.logup_count_at(layer_idx), + ) } #[inline] pub(crate) fn claim_tidx(&self, layer_idx: usize) -> usize { self.layer_tidx(layer_idx) + tower_transcript_len::claim_offset_in_layer(layer_idx) } + + #[inline] + pub(crate) fn read_claim_tidx(&self, layer_idx: usize) -> usize { + self.claim_tidx(layer_idx) + } + + #[inline] + pub(crate) fn write_claim_tidx(&self, layer_idx: usize) -> usize { + self.read_claim_tidx(layer_idx) + 2 * D_EF * self.read_count_at(layer_idx) + } + + #[inline] + pub(crate) fn logup_claim_tidx(&self, layer_idx: usize) -> usize { + self.write_claim_tidx(layer_idx) + 2 * D_EF * self.write_count_at(layer_idx) + } } pub struct TowerLayerTraceGenerator; @@ -207,6 +242,7 @@ impl RowMajorChip for TowerLayerTraceGenerator { lambda_prime_one[0] = F::ONE; cols.lambda_prime = lambda_prime_one; cols.mu = [F::ZERO; D_EF]; + cols.beta_logup = [F::ZERO; D_EF]; cols.sumcheck_claim_in = [F::ZERO; D_EF]; cols.read_claim = [F::ZERO; D_EF]; cols.read_claim_prime = [F::ZERO; D_EF]; @@ -217,6 +253,7 @@ impl RowMajorChip for TowerLayerTraceGenerator { cols.num_read_count = F::ZERO; cols.num_write_count = F::ZERO; cols.num_logup_count = F::ZERO; + cols.sumcheck_claim_out = [F::ZERO; D_EF]; cols.eq_at_r_prime = [F::ZERO; D_EF]; cols.r0_claim.copy_from_slice(q0_basis); cols.w0_claim.copy_from_slice(q0_basis); @@ -251,11 +288,21 @@ impl RowMajorChip for TowerLayerTraceGenerator { .unwrap(); let mu = mus_for_proof.get(layer_idx).copied().unwrap_or(EF::ZERO); cols.mu = mu.as_basis_coefficients_slice().try_into().unwrap(); - let sumcheck_claim = if layer_idx == 0 { + cols.beta_logup = record + .beta_logup + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let fallback_sumcheck_claim = if layer_idx == 0 { EF::ZERO } else { prev_folded_claim.unwrap_or(EF::ZERO) }; + let sumcheck_claim = record + .sumcheck_claims + .get(layer_idx) + .copied() + .unwrap_or(fallback_sumcheck_claim); cols.sumcheck_claim_in = sumcheck_claim .as_basis_coefficients_slice() .try_into() @@ -272,11 +319,22 @@ impl RowMajorChip for TowerLayerTraceGenerator { .as_basis_coefficients_slice() .try_into() .unwrap(); - cols.num_read_count = F::from_usize(record.read_count_at(layer_idx).max(1)); - cols.num_write_count = F::from_usize(record.write_count_at(layer_idx).max(1)); - cols.num_logup_count = F::from_usize(record.logup_count_at(layer_idx).max(1)); + cols.num_read_count = F::from_usize(record.read_count_at(layer_idx)); + cols.num_write_count = F::from_usize(record.write_count_at(layer_idx)); + cols.num_logup_count = F::from_usize(record.logup_count_at(layer_idx)); + cols.sumcheck_claim_out = record + .sumcheck_claim_outs + .get(layer_idx) + .copied() + .unwrap_or(EF::ZERO) + .as_basis_coefficients_slice() + .try_into() + .unwrap(); cols.eq_at_r_prime = record - .eq_at(layer_idx) + .sumcheck_eq_outs + .get(layer_idx) + .copied() + .unwrap_or_else(|| record.eq_at(layer_idx)) .as_basis_coefficients_slice() .try_into() .unwrap(); diff --git a/ceno_recursion_v2/src/tower/mod.rs b/ceno_recursion_v2/src/tower/mod.rs index 15f48dbb8..eac4e6cd3 100644 --- a/ceno_recursion_v2/src/tower/mod.rs +++ b/ceno_recursion_v2/src/tower/mod.rs @@ -54,9 +54,9 @@ use openvm_cpu_backend::CpuBackend; use openvm_poseidon2_air::POSEIDON2_WIDTH; use openvm_stark_backend::{ AirRef, FiatShamirTranscript, ReadOnlyTranscript, StarkProtocolConfig, TranscriptHistory, - p3_maybe_rayon::prelude::*, prover::AirProvingContext, + p3_maybe_rayon::prelude::*, poly_common::interpolate_cubic_at_0123, prover::AirProvingContext, }; -use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, D_EF, EF, F}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, EF, F}; use p3_field::PrimeCharacteristicRing; use p3_matrix::dense::RowMajorMatrix; use recursion_circuit::primitives::exp_bits_len::ExpBitsLenTraceGenerator; @@ -130,47 +130,67 @@ pub mod tower_transcript_len { /// Merge: label "merge" (2) + sample mu (D_EF) pub const MERGE_LEN: usize = LABEL_MERGE + D_EF; - /// Post-sumcheck tidx span: claim observation slots (4*D_EF) + MERGE_LEN. - /// This is the tidx gap from `tidx_after_sumcheck` to `tidx_end`. - pub const POST_SUMCHECK_LEN: usize = 4 * D_EF + MERGE_LEN; + /// Tidx span used by active claim observations after sumcheck. + pub const fn claim_span( + num_read_count: usize, + num_write_count: usize, + num_logup_count: usize, + ) -> usize { + (2 * num_read_count + 2 * num_write_count + 4 * num_logup_count) * D_EF + } - /// Gap between consecutive sumcheck blocks across GKR layers: - /// post-sumcheck of previous layer + pre-sumcheck of next layer. - pub const LAYER_GAP_LEN: usize = POST_SUMCHECK_LEN + ALPHA_LEN + SUMCHECK_INIT_LEN; + /// Tidx gap from `tidx_after_sumcheck` to `tidx_end`. + pub const fn post_sumcheck_len( + num_read_count: usize, + num_write_count: usize, + num_logup_count: usize, + ) -> usize { + claim_span(num_read_count, num_write_count, num_logup_count) + MERGE_LEN + } - /// Tidx span of layer `layer_idx` (includes claim slots + transcript ops). - /// Layer 0 (root): POST_SUMCHECK_LEN (no lambda sample — uses alpha_logup). - /// Layer j>0: ALPHA_LEN + SUMCHECK_INIT_LEN + j*ROUND_LEN + POST_SUMCHECK_LEN. - pub const fn layer_span(layer_idx: usize) -> usize { + /// Offset from the start of layer `layer_idx` to `tidx_after_sumcheck` + /// (where claims start). + /// Layer 0: 0. + /// Layer j>0: ALPHA_LEN + SUMCHECK_INIT_LEN + j*ROUND_LEN. + pub const fn claim_offset_in_layer(layer_idx: usize) -> usize { + let lambda_len = if layer_idx == 0 { 0 } else { ALPHA_LEN }; + lambda_len + SUMCHECK_INIT_LEN + (layer_idx + 1) * ROUND_LEN + } + + pub const fn sumcheck_round_offset_in_layer(layer_idx: usize) -> usize { if layer_idx == 0 { - POST_SUMCHECK_LEN + SUMCHECK_INIT_LEN } else { - ALPHA_LEN + SUMCHECK_INIT_LEN + layer_idx * ROUND_LEN + POST_SUMCHECK_LEN + ALPHA_LEN + SUMCHECK_INIT_LEN } } + /// Tidx span of layer `layer_idx` (includes claim slots + transcript ops). + pub const fn layer_span( + layer_idx: usize, + num_read_count: usize, + num_write_count: usize, + num_logup_count: usize, + ) -> usize { + claim_offset_in_layer(layer_idx) + + post_sumcheck_len(num_read_count, num_write_count, num_logup_count) + } + /// Cumulative tidx span for layers 0..layer_idx (exclusive). - pub const fn layers_cumulative(layer_idx: usize) -> usize { + pub const fn layers_cumulative( + layer_idx: usize, + num_read_count: usize, + num_write_count: usize, + num_logup_count: usize, + ) -> usize { let mut total = 0; let mut i = 0; while i < layer_idx { - total += layer_span(i); + total += layer_span(i, num_read_count, num_write_count, num_logup_count); i += 1; } total } - - /// Offset from the start of layer `layer_idx` to `tidx_after_sumcheck` - /// (where claims start). - /// Layer 0: 0. - /// Layer j>0: ALPHA_LEN + SUMCHECK_INIT_LEN + j*ROUND_LEN. - pub const fn claim_offset_in_layer(layer_idx: usize) -> usize { - if layer_idx == 0 { - 0 - } else { - ALPHA_LEN + SUMCHECK_INIT_LEN + layer_idx * ROUND_LEN - } - } } // Sub-modules for different AIRs @@ -461,16 +481,22 @@ fn build_chip_records( layer_claims: Vec::with_capacity(layer_count), lambdas: vec![EF::ZERO; layer_count], eq_at_r_primes: vec![EF::ZERO; layer_count], - read_counts: vec![1; layer_count], - write_counts: vec![1; layer_count], - logup_counts: vec![1; layer_count], + read_counts: vec![0; layer_count], + write_counts: vec![0; layer_count], + logup_counts: vec![0; layer_count], read_claims: vec![EF::ZERO; layer_count], read_prime_claims: vec![EF::ZERO; layer_count], write_claims: vec![EF::ZERO; layer_count], write_prime_claims: vec![EF::ZERO; layer_count], logup_claims: vec![EF::ZERO; layer_count], logup_prime_claims: vec![EF::ZERO; layer_count], + beta_logup: schedule.beta, sumcheck_claims: vec![EF::ZERO; layer_count], + sumcheck_claim_outs: Vec::new(), + sumcheck_eq_outs: Vec::new(), + root_read_prime_claim: EF::ZERO, + root_write_prime_claim: EF::ZERO, + root_logup_prime_claim: EF::ZERO, }; for layer_idx in 0..layer_count { @@ -494,9 +520,9 @@ fn build_chip_records( // read_len == write_len, // "read/write prod spec count mismatch at layer {layer_idx}: read={read_len}, write={write_len}" // ); - layer_record.read_counts[layer_idx] = read_len.max(1); - layer_record.write_counts[layer_idx] = write_len.max(1); - layer_record.logup_counts[layer_idx] = logup_len.max(1); + layer_record.read_counts[layer_idx] = read_len; + layer_record.write_counts[layer_idx] = write_len; + layer_record.logup_counts[layer_idx] = logup_len; } for layer_idx in 0..layer_count { @@ -505,39 +531,29 @@ fn build_chip_records( .push(convert_logup_claim(chip_proof, layer_idx)); } - let input_layer_claim = layer_record - .layer_claims - .last() - .map(|claim| claim[0]) - .unwrap_or(EF::ZERO); - + let num_sumcheck_layers = layer_count; let mut sumcheck_record = TowerSumcheckRecord { proof_idx, idx, is_first_air_idx, - // First sumcheck transcript row starts at layer_tidx(1) + ALPHA_LEN + SUMCHECK_INIT_LEN. - tidx: tidx - + tower_transcript_len::ALPHA_BETA_LEN - + tower_transcript_len::POST_SUMCHECK_LEN - + tower_transcript_len::ALPHA_LEN - + tower_transcript_len::SUMCHECK_INIT_LEN, + tidx: 0, + sumcheck_tidxs: (0..num_sumcheck_layers) + .map(|layer_idx| { + layer_record.layer_tidx(layer_idx) + + tower_transcript_len::sumcheck_round_offset_in_layer(layer_idx) + }) + .collect(), + beta: schedule.beta, evals: Vec::new(), ris: Vec::new(), - claims: vec![EF::ZERO; layer_count.saturating_sub(1)], + claims: vec![EF::ZERO; layer_count], }; - // The sumcheck trace processes num_sumcheck_layers = layer_count - 1 layers. - // Layer k (0-indexed) has layer_rounds(k) = k+1 sumcheck rounds. - // Total rounds = num_sumcheck_layers*(num_sumcheck_layers+1)/2. - // record_gkr_transcript produces ris/evals for ALL layer_count layers, - // but the last layer is not processed by the sumcheck AIR (it corresponds - // to the final input layer claim, not a sumcheck). Truncate to - // total_rounds. - let num_sumcheck_layers = layer_count.saturating_sub(1); + // The tower proof carries one sumcheck for every GKR layer, including + // layer 0. Layer k has k + 1 sumcheck rounds. let total_sumcheck_rounds = num_sumcheck_layers * (num_sumcheck_layers + 1) / 2; for (k, round_msgs) in chip_proof.tower_proof.proofs.iter().enumerate() { - // Only include sumcheck evals for the first num_sumcheck_layers layers if k >= num_sumcheck_layers { break; } @@ -553,19 +569,12 @@ fn build_chip_records( .and_then(|evals| evals.get(2)) .copied() .unwrap_or(EF::ZERO); + layer_record.root_read_prime_claim = q0_claim; + layer_record.root_write_prime_claim = q0_claim; + layer_record.root_logup_prime_claim = q0_claim; let layer_output_lambda = schedule.lambdas.last().copied().unwrap_or(EF::ZERO); let layer_output_mu = schedule.mus.last().copied().unwrap_or(EF::ZERO); - let input_record = TowerInputRecord { - proof_idx, - idx, - tidx, - n_logup: layer_count, - alpha_logup: schedule.alpha_logup, - input_layer_claim, - layer_output_lambda, - layer_output_mu, - }; // Truncate ris to match the sumcheck trace's expected total_rounds. sumcheck_record.ris = schedule.ris[..total_sumcheck_rounds.min(schedule.ris.len())].to_vec(); if !replay.layers.is_empty() && total_sumcheck_rounds > 0 { @@ -583,13 +592,11 @@ fn build_chip_records( schedule.lambdas.get(layer_idx).copied().unwrap_or(EF::ZERO); mus_record[layer_idx] = schedule.mus.get(layer_idx).copied().unwrap_or(EF::ZERO); } - if layer_idx + 1 < layer_count { - if layer_idx < sumcheck_record.claims.len() { - sumcheck_record.claims[layer_idx] = data.claim_in; - } - if layer_idx < layer_record.sumcheck_claims.len() { - layer_record.sumcheck_claims[layer_idx] = data.claim_in; - } + if layer_idx < sumcheck_record.claims.len() { + sumcheck_record.claims[layer_idx] = data.claim_in; + } + if layer_idx < layer_record.sumcheck_claims.len() { + layer_record.sumcheck_claims[layer_idx] = data.claim_in; } } @@ -619,29 +626,80 @@ fn build_chip_records( } } - // Sync sumcheck claims with accumulated values so that the sumcheck trace - // uses the same claim_in that TowerLayerAir sends on the sumcheck_input_bus. - // TowerLayerAir layer j (j >= 1) sends: sumcheck_claim_in = read[j-1] + write[j-1] + logup[j-1] - // Sumcheck internal layer k uses: claims[k], where k = j - 1. - for k in 0..layer_count.saturating_sub(1) { - let folded = layer_record.read_claims[k] - + layer_record.write_claims[k] - + layer_record.logup_claims[k]; + // The layer AIR enforces that each non-root sumcheck claim is the folded + // claim emitted by the previous layer. + for k in 1..layer_count { + let prev_layer = k - 1; + let folded = layer_record.read_claims[prev_layer] + + layer_record.write_claims[prev_layer] + + layer_record.logup_claims[prev_layer]; sumcheck_record.claims[k] = folded; layer_record.sumcheck_claims[k] = folded; } // Compute eq_at_r_primes from ris and mus so that TowerLayerAir's eq values // match the sumcheck trace's eq_out on the sumcheck_output_bus. - // Sumcheck internal layer k (0-indexed) → TowerLayerAir layer k+1. - let num_sumcheck_layers = layer_count.saturating_sub(1); for k in 0..num_sumcheck_layers { - let eq = TowerSumcheckRecord::compute_eq_for_layer(k, &mus_record, &sumcheck_record.ris); - if k + 1 < layer_record.eq_at_r_primes.len() { - layer_record.eq_at_r_primes[k + 1] = eq; + let eq = TowerSumcheckRecord::compute_eq_for_layer( + k, + schedule.beta, + &mus_record, + &sumcheck_record.ris, + ); + if k < layer_record.eq_at_r_primes.len() { + layer_record.eq_at_r_primes[k] = eq; + } + } + + layer_record.sumcheck_claim_outs.clear(); + layer_record.sumcheck_eq_outs.clear(); + let mut global_round_idx = 0usize; + for layer_idx in 0..num_sumcheck_layers { + let mut claim = sumcheck_record.claims[layer_idx]; + let mut eq = EF::ONE; + for round_in_layer in 0..TowerSumcheckRecord::layer_rounds(layer_idx) { + let challenge = sumcheck_record.ris[global_round_idx]; + let evals = sumcheck_record.evals[global_round_idx]; + let prev_challenge = TowerSumcheckRecord::prev_challenge( + layer_idx, + round_in_layer, + schedule.beta, + &mus_record, + &sumcheck_record.ris, + ); + let ev0 = claim - evals[0]; + let evals_full = [ev0, evals[0], evals[1], evals[2]]; + claim = interpolate_cubic_at_0123(&evals_full, challenge); + eq *= prev_challenge * challenge + (EF::ONE - prev_challenge) * (EF::ONE - challenge); + global_round_idx += 1; } + layer_record.sumcheck_claim_outs.push(claim); + layer_record.sumcheck_eq_outs.push(eq); } + let input_layer_claim = layer_count + .checked_sub(1) + .map(|last_layer| { + layer_record.read_claims[last_layer] + + layer_record.write_claims[last_layer] + + layer_record.logup_claims[last_layer] + }) + .unwrap_or(EF::ZERO); + let input_record = TowerInputRecord { + proof_idx, + idx, + tidx, + n_logup: layer_count, + num_read_count: read_count, + num_write_count: write_count, + num_logup_count: logup_count, + alpha_logup: schedule.alpha_logup, + beta_logup: schedule.beta, + input_layer_claim, + layer_output_lambda, + layer_output_mu, + }; + Ok(( input_record, layer_record, @@ -753,11 +811,6 @@ pub(crate) fn build_gkr_blob( for (proof_idx, (proof, preflight)) in proofs.iter().zip(preflights).enumerate() { let mut has_chip = false; - let mut first_chip_alpha = EF::ZERO; - let mut first_chip_q0 = EF::ZERO; - let mut last_input_layer_claim = EF::ZERO; - let mut last_layer_output_lambda = EF::ZERO; - let mut last_layer_output_mu = EF::ZERO; let sorted_idx_by_chip: std::collections::BTreeMap = preflight .proof_shape @@ -827,17 +880,8 @@ pub(crate) fn build_gkr_blob( global_tidx, )?; - // Capture first chip's alpha and q0 for the proof-level record - if entry_idx == 0 { - first_chip_alpha = chip_input_record.alpha_logup; - first_chip_q0 = q0_claim; - } - // Always update to latest chip for combined values - last_input_layer_claim = chip_input_record.input_layer_claim; - last_layer_output_lambda = chip_input_record.layer_output_lambda; - last_layer_output_mu = chip_input_record.layer_output_mu; - - // Per-chip records (not input_records) + input_records.push(chip_input_record); + proof_q0_claims.push(q0_claim); layer_records.push(layer_record); tower_records.push(tower_record); sumcheck_records.push(sumcheck_record); @@ -845,20 +889,13 @@ pub(crate) fn build_gkr_blob( q0_claims.push(q0_claim); } - // ONE input record per proof (matching ProofIdxSubAir constraint) - input_records.push(TowerInputRecord { - proof_idx, - idx: 0, - tidx: preflight.proof_shape.post_tidx, - n_logup: preflight.proof_shape.n_logup, - alpha_logup: first_chip_alpha, - input_layer_claim: last_input_layer_claim, - layer_output_lambda: last_layer_output_lambda, - layer_output_mu: last_layer_output_mu, - }); - proof_q0_claims.push(first_chip_q0); - if !has_chip { + input_records.push(TowerInputRecord { + proof_idx, + idx: 0, + ..Default::default() + }); + proof_q0_claims.push(EF::ZERO); layer_records.push(TowerLayerRecord { idx: 0, proof_idx, @@ -935,6 +972,7 @@ where // This keeps preflight transcript history aligned with TowerLayer/Sumcheck/ // ProdClaim/LogupClaim transcript bus interactions. let read_count = chip_proof.r_out_evals.len(); + let write_count = chip_proof.w_out_evals.len(); let layer_count = chip_proof .tower_proof .logup_specs_eval @@ -984,13 +1022,42 @@ where } } + for rounds in chip_proof + .tower_proof + .prod_specs_eval + .iter() + .take(read_count) + { + let values = rounds.get(layer_idx).map(Vec::as_slice).unwrap_or(&[]); + for i in 0..2 { + ts.observe_ext(values.get(i).copied().unwrap_or(EF::ZERO)); + } + } + for rounds in chip_proof + .tower_proof + .prod_specs_eval + .iter() + .skip(read_count) + .take(write_count) + { + let values = rounds.get(layer_idx).map(Vec::as_slice).unwrap_or(&[]); + for i in 0..2 { + ts.observe_ext(values.get(i).copied().unwrap_or(EF::ZERO)); + } + } + for rounds in &chip_proof.tower_proof.logup_specs_eval { + let values = rounds.get(layer_idx).map(Vec::as_slice).unwrap_or(&[]); + for i in 0..4 { + ts.observe_ext(values.get(i).copied().unwrap_or(EF::ZERO)); + } + } + // Mirror native: sample_and_append_vec(b"merge", log2_num_fanin) transcript_observe_label(ts, b"merge"); let mu = FiatShamirTranscript::::sample_ext(ts); mus.push(mu); } - let _ = read_count; TowerTranscriptSchedule { alpha_logup, beta, diff --git a/ceno_recursion_v2/src/tower/sumcheck/air.rs b/ceno_recursion_v2/src/tower/sumcheck/air.rs index 39c0bdec6..2c4fef053 100644 --- a/ceno_recursion_v2/src/tower/sumcheck/air.rs +++ b/ceno_recursion_v2/src/tower/sumcheck/air.rs @@ -10,9 +10,12 @@ use p3_field::{Field, PrimeCharacteristicRing, extension::BinomiallyExtendable}; use p3_matrix::Matrix; use stark_recursion_circuit_derive::AlignedBorrow; -use crate::tower::bus::{ - TowerSumcheckChallengeBus, TowerSumcheckChallengeMessage, TowerSumcheckInputBus, - TowerSumcheckInputMessage, TowerSumcheckOutputBus, TowerSumcheckOutputMessage, +use crate::{ + tower::bus::{ + TowerSumcheckChallengeBus, TowerSumcheckChallengeMessage, TowerSumcheckInputBus, + TowerSumcheckInputMessage, TowerSumcheckOutputBus, TowerSumcheckOutputMessage, + }, + utils::{label_field_len, transcript_receive_label}, }; use recursion_circuit::{ bus::{TranscriptBus, XiRandomnessBus, XiRandomnessMessage}, @@ -252,7 +255,7 @@ where // Sumcheck round flag end builder .when(is_last_round.clone()) - .assert_eq(local.round, local.layer_idx - AB::Expr::ONE); + .assert_eq(local.round, local.layer_idx); /////////////////////////////////////////////////////////////////////// // Round Constraints @@ -337,7 +340,7 @@ where local.proof_idx, TowerSumcheckChallengeMessage { idx: local.idx.into(), - layer_idx: local.layer_idx - AB::Expr::ONE, + layer_idx: local.layer_idx.into(), sumcheck_round: local.round.into(), challenge: local.prev_challenge.map(Into::into), }, @@ -349,8 +352,8 @@ where local.proof_idx, TowerSumcheckChallengeMessage { idx: local.idx.into(), - layer_idx: local.layer_idx.into(), - sumcheck_round: local.round.into() + AB::Expr::ONE, + layer_idx: local.layer_idx + AB::Expr::ONE, + sumcheck_round: local.round.into(), challenge: local.challenge.map(Into::into), }, local.is_enabled * (AB::Expr::ONE - local.is_last_layer) * is_not_dummy.clone(), @@ -374,10 +377,19 @@ where tidx += AB::Expr::from_usize(D_EF); } // 1b. Sample challenge `ri` + let round_label = b"Internal round"; + transcript_receive_label( + &self.transcript_bus, + builder, + local.proof_idx, + tidx.clone(), + round_label, + local.is_enabled * is_not_dummy.clone(), + ); self.transcript_bus.sample_ext( builder, local.proof_idx, - tidx, + tidx + AB::Expr::from_usize(label_field_len(round_label)), local.challenge, local.is_enabled * is_not_dummy.clone(), ); diff --git a/ceno_recursion_v2/src/tower/sumcheck/trace.rs b/ceno_recursion_v2/src/tower/sumcheck/trace.rs index 9d84a04a3..15badd845 100644 --- a/ceno_recursion_v2/src/tower/sumcheck/trace.rs +++ b/ceno_recursion_v2/src/tower/sumcheck/trace.rs @@ -14,6 +14,8 @@ pub struct TowerSumcheckRecord { pub idx: usize, pub is_first_air_idx: bool, pub tidx: usize, + pub sumcheck_tidxs: Vec, + pub beta: EF, pub evals: Vec<[EF; 3]>, pub ris: Vec, pub claims: Vec, @@ -43,33 +45,39 @@ impl TowerSumcheckRecord { #[inline] fn derive_tidx(&self, layer_idx: usize, round_in_layer: usize) -> usize { - let rounds_before_layer = Self::layer_start_index(layer_idx); - self.tidx - + tower_transcript_len::ROUND_LEN * (rounds_before_layer + round_in_layer) - + tower_transcript_len::LAYER_GAP_LEN * layer_idx + self.sumcheck_tidxs + .get(layer_idx) + .copied() + .unwrap_or(self.tidx) + + tower_transcript_len::ROUND_LEN * round_in_layer } #[inline] - pub fn prev_challenge(layer_idx: usize, round_in_layer: usize, mus: &[EF], ris: &[EF]) -> EF { - if round_in_layer == 0 { - mus[layer_idx] - } else { - let prev_layer = layer_idx - .checked_sub(1) - .expect("round_in_layer > 0 only occurs for non-root layers"); - let offset = Self::layer_start_index(prev_layer) + (round_in_layer - 1); + pub fn prev_challenge( + layer_idx: usize, + round_in_layer: usize, + beta: EF, + mus: &[EF], + ris: &[EF], + ) -> EF { + if layer_idx == 0 { + beta + } else if round_in_layer < layer_idx { + let offset = Self::layer_start_index(layer_idx - 1) + round_in_layer; ris[offset] + } else { + mus[layer_idx - 1] } } /// Compute the eq evaluation for a given sumcheck layer from ris and mus. /// This produces the same eq_out value that the sumcheck trace generates. - pub fn compute_eq_for_layer(layer_idx: usize, mus: &[EF], ris: &[EF]) -> EF { + pub fn compute_eq_for_layer(layer_idx: usize, beta: EF, mus: &[EF], ris: &[EF]) -> EF { let rounds = Self::layer_rounds(layer_idx); let start = Self::layer_start_index(layer_idx); let mut eq = EF::ONE; for round in 0..rounds { - let prev = Self::prev_challenge(layer_idx, round, mus, ris); + let prev = Self::prev_challenge(layer_idx, round, beta, mus, ris); let challenge = ris[start + round]; eq *= prev * challenge + (EF::ONE - prev) * (EF::ONE - challenge); } @@ -145,7 +153,7 @@ impl RowMajorChip for TowerSumcheckTraceGenerator { cols.tidx = F::from_usize(D_EF); cols.proof_idx = F::from_usize(record.proof_idx); cols.idx = F::from_usize(record.idx); - cols.layer_idx = F::ONE; + cols.layer_idx = F::ZERO; cols.is_first_round = F::ONE; cols.is_first_idx = F::from_bool(record.is_first_air_idx); cols.is_first_layer = F::ONE; @@ -163,7 +171,7 @@ impl RowMajorChip for TowerSumcheckTraceGenerator { for layer_idx in 0..num_layers { let layer_rounds = TowerSumcheckRecord::layer_rounds(layer_idx); - let layer_idx_value = layer_idx + 1; + let layer_idx_value = layer_idx; let is_last_layer = layer_idx == num_layers.saturating_sub(1); let mut claim = record.claims[layer_idx]; @@ -175,6 +183,7 @@ impl RowMajorChip for TowerSumcheckTraceGenerator { let prev_challenge = TowerSumcheckRecord::prev_challenge( layer_idx, round_in_layer, + record.beta, mus_for_proof, &record.ris, ); diff --git a/ceno_recursion_v2/src/tower/tower.rs b/ceno_recursion_v2/src/tower/tower.rs index 0d288adf9..4743a8aa2 100644 --- a/ceno_recursion_v2/src/tower/tower.rs +++ b/ceno_recursion_v2/src/tower/tower.rs @@ -76,7 +76,13 @@ impl PrecomputedTranscript { for k in 0..num_layers { let n_ris = k + 1; for i in 0..n_ris { - challenges.push_back(schedule.ris[ri_offset + i]); + challenges.push_back( + schedule + .ris + .get(ri_offset + i) + .copied() + .unwrap_or(RecursionField::ZERO), + ); } ri_offset += n_ris; challenges.push_back(schedule.mus[k]); diff --git a/ceno_recursion_v2/src/transcript/mod.rs b/ceno_recursion_v2/src/transcript/mod.rs index 1588acd78..0fe90751a 100644 --- a/ceno_recursion_v2/src/transcript/mod.rs +++ b/ceno_recursion_v2/src/transcript/mod.rs @@ -118,6 +118,7 @@ impl TranscriptModule { let is_sample = log.samples()[tidx]; cols.is_sample = F::from_bool(is_sample); cols.tidx = F::from_usize(tidx + tidx_offset); + cols.local_tidx = F::from_usize(tidx); cols.mask[0] = F::ONE; cols.prev_state = prev_poseidon_state; @@ -234,7 +235,8 @@ impl TranscriptModule { ); offset = trunk_end; - // Fill fork rows with fork-local tidx offsets. + // Fill fork rows with global transcript offsets for TranscriptBus. + // ForkedTranscriptBus uses the separate local_tidx column. for (fi, fork_log) in preflight.fork_transcripts.iter().enumerate() { let fork_rows = info.fork_rows[fi]; let fork_end = offset + fork_rows; @@ -251,7 +253,7 @@ impl TranscriptModule { false, // is_proof_start true, // is_fork_start [F::ZERO; POSEIDON2_WIDTH], - 0, + preflight.fork_global_offset(fi), &mut poseidon2_perm_inputs, ); offset = fork_end; diff --git a/ceno_recursion_v2/src/transcript/transcript_air.rs b/ceno_recursion_v2/src/transcript/transcript_air.rs index ca12a2c28..9592b6814 100644 --- a/ceno_recursion_v2/src/transcript/transcript_air.rs +++ b/ceno_recursion_v2/src/transcript/transcript_air.rs @@ -47,6 +47,7 @@ pub struct ForkedTranscriptCols { pub is_proof_start: T, pub tidx: T, + pub local_tidx: T, /// Indicator for sample/observe. pub is_sample: T, /// 0/1 indicators for positions being absorbed/squeezed. @@ -65,12 +66,12 @@ pub struct ForkedTranscriptCols { impl ForkedTranscriptCols { pub const fn width() -> usize { - // proof_idx, is_proof_start, tidx, is_sample = 4 + // proof_idx, is_proof_start, tidx, local_tidx, is_sample = 5 // mask = CHUNK // prev_state = POSEIDON2_WIDTH // post_state = POSEIDON2_WIDTH // is_fork_start, fork_id = 2 - 4 + CHUNK + 2 * POSEIDON2_WIDTH + 2 + 5 + CHUNK + 2 * POSEIDON2_WIDTH + 2 } } @@ -157,15 +158,21 @@ impl Air for ForkedTranscriptAir { // When is_proof_start: tidx = 0, sponge state = 0 (trunk start) builder.when(local.is_proof_start).assert_zero(local.tidx); + builder + .when(local.is_proof_start) + .assert_zero(local.local_tidx); builder.when(local.is_proof_start).assert_one(is_valid); builder .when(local.is_proof_start) .assert_zero(local.fork_id); builder.assert_bool(local.is_sample); - // When is_fork_start: fork chain begins (tidx is NOT zero; it's the - // fork's global tidx offset). Only constrain validity. + // When is_fork_start: fork chain begins. `tidx` is global; `local_tidx` + // is fork-local for the forked transcript bus. builder.when(local.is_fork_start).assert_one(is_valid); + builder + .when(local.is_fork_start) + .assert_zero(local.local_tidx); // Initial state for proof start (trunk): all-zero sponge for i in 0..CHUNK { @@ -212,6 +219,9 @@ impl Air for ForkedTranscriptAir { builder .when(local_next_same_chain.clone()) .assert_eq(next.tidx, local.tidx + count.clone()); + builder + .when(local_next_same_chain.clone()) + .assert_eq(next.local_tidx, local.local_tidx + count.clone()); // If local.is_sample == next.is_sample within the same chain, // there must be exactly CHUNK operations. @@ -263,7 +273,7 @@ impl Air for ForkedTranscriptAir { local.proof_idx, ForkedTranscriptBusMessage { fork_id: local.fork_id.into(), - tidx: local.tidx + AB::Expr::from_usize(i), + tidx: local.local_tidx + AB::Expr::from_usize(i), value: local.prev_state[i].into(), is_sample: AB::Expr::ZERO, }, @@ -274,7 +284,7 @@ impl Air for ForkedTranscriptAir { local.proof_idx, ForkedTranscriptBusMessage { fork_id: local.fork_id.into(), - tidx: local.tidx + AB::Expr::from_usize(D_EF + i), + tidx: local.local_tidx + AB::Expr::from_usize(D_EF + i), value: local.prev_state[i].into(), is_sample: AB::Expr::ZERO, }, @@ -285,7 +295,7 @@ impl Air for ForkedTranscriptAir { local.proof_idx, ForkedTranscriptBusMessage { fork_id: local.fork_id.into(), - tidx: local.tidx + AB::Expr::from_usize(2 * D_EF + i), + tidx: local.local_tidx + AB::Expr::from_usize(2 * D_EF + i), value: local.prev_state[i].into(), is_sample: AB::Expr::ONE, }, @@ -296,7 +306,7 @@ impl Air for ForkedTranscriptAir { local.proof_idx, ForkedTranscriptBusMessage { fork_id: local.fork_id.into(), - tidx: local.tidx + AB::Expr::from_usize(3 * D_EF + i), + tidx: local.local_tidx + AB::Expr::from_usize(3 * D_EF + i), value: local.prev_state[i].into(), is_sample: AB::Expr::ONE, }, diff --git a/ceno_recursion_v2/src/utils.rs b/ceno_recursion_v2/src/utils.rs index e486ab544..60f3344f0 100644 --- a/ceno_recursion_v2/src/utils.rs +++ b/ceno_recursion_v2/src/utils.rs @@ -9,6 +9,7 @@ use openvm_stark_sdk::config::baby_bear_poseidon2::{ use p3_air::AirBuilder; use p3_field::{PrimeCharacteristicRing, extension::BinomiallyExtendable}; use p3_symmetric::Permutation; +use recursion_circuit::bus::{TranscriptBus, TranscriptBusMessage}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TranscriptLabel { @@ -45,12 +46,54 @@ pub fn transcript_observe_label(transcript: &mut TS, label: &[u8]) where TS: FiatShamirTranscript, { - let label_f = ::BaseField::bytes_to_field_elements(label); - for elem in label_f { + for elem in label_field_values(label) { transcript.observe(elem); } } +pub fn label_field_values(label: &[u8]) -> Vec +where + FA: PrimeCharacteristicRing, +{ + label + .chunks(4) + .map(|chunk| { + let mut bytes = [0u8; 4]; + bytes[..chunk.len()].copy_from_slice(chunk); + FA::from_u32(u32::from_le_bytes(bytes)) + }) + .collect() +} + +pub fn transcript_receive_label( + transcript_bus: &TranscriptBus, + builder: &mut AB, + proof_idx: AB::Var, + tidx: impl Into, + label: &[u8], + is_enabled: impl Into, +) where + AB: openvm_stark_backend::interaction::InteractionBuilder, +{ + let tidx = tidx.into(); + let is_enabled = is_enabled.into(); + for (i, value) in label_field_values::(label) + .into_iter() + .enumerate() + { + transcript_bus.receive( + builder, + proof_idx, + TranscriptBusMessage { + tidx: tidx.clone() + AB::Expr::from_usize(i), + value, + is_sample: AB::Expr::ZERO, + }, + is_enabled.clone(), + ); + } +} + pub fn base_to_ext(x: impl Into) -> [FA; D_EF] where FA: PrimeCharacteristicRing,