diff --git a/crates/backend/fiat-shamir/src/challenger.rs b/crates/backend/fiat-shamir/src/challenger.rs index 34fcd94ab..5f526615f 100644 --- a/crates/backend/fiat-shamir/src/challenger.rs +++ b/crates/backend/fiat-shamir/src/challenger.rs @@ -1,64 +1,65 @@ use field::PrimeField64; -use symetric::Compression; +use koala_bear::symmetric::Permutation; pub(crate) const RATE: usize = 8; pub(crate) const WIDTH: usize = RATE * 2; +pub(crate) const CAPACITY: usize = WIDTH - RATE; #[derive(Clone, Debug)] pub struct Challenger { - pub compressor: P, - pub state: [F; RATE], + pub permutation: P, + pub state: [F; WIDTH], + rate_fresh: bool, } -impl> Challenger { - pub fn new(compressor: P) -> Self +impl> Challenger { + pub fn new(permutation: P) -> Self where F: Default, { Self { - compressor, - state: [F::ZERO; RATE], + permutation, + state: [F::ZERO; WIDTH], + rate_fresh: false, } } pub fn observe(&mut self, value: [F; RATE]) { - self.state = self.compressor.compress({ - let mut concat = [F::ZERO; WIDTH]; - concat[..RATE].copy_from_slice(&self.state); - concat[RATE..].copy_from_slice(&value); - concat - })[..RATE] - .try_into() - .unwrap(); + self.state[CAPACITY..].copy_from_slice(&value); + self.permutation.permute_mut(&mut self.state); + self.rate_fresh = true; } - pub fn observe_scalars(&mut self, scalars: &[F]) { + pub fn observe_many(&mut self, scalars: &[F]) { for chunk in scalars.chunks(RATE) { let mut buffer = [F::ZERO; RATE]; - for (i, val) in chunk.iter().enumerate() { - buffer[i] = *val; - } + buffer[..chunk.len()].copy_from_slice(chunk); self.observe(buffer); } } + pub fn duplex(&mut self) { + self.observe([F::ZERO; RATE]); + } + + pub fn sample(&mut self) -> [F; RATE] { + assert!(self.rate_fresh, "stale rate. insert a duplex() before."); + let out: [F; RATE] = self.state[CAPACITY..].try_into().unwrap(); + self.rate_fresh = false; + out + } + pub fn sample_many(&mut self, n: usize) -> Vec<[F; RATE]> { - let mut sampled = Vec::with_capacity(n + 1); - for i in 0..n + 1 { - let mut domain_sep = [F::ZERO; RATE]; - domain_sep[0] = F::from_usize(i); - let hashed = self.compressor.compress({ - let mut concat = [F::ZERO; WIDTH]; - concat[..RATE].copy_from_slice(&domain_sep); - concat[RATE..].copy_from_slice(&self.state); - concat - })[..RATE] - .try_into() - .unwrap(); - sampled.push(hashed); + if n == 0 { + return Vec::new(); + } + let mut out = Vec::with_capacity(n); + out.push(self.sample()); + for _ in 1..n { + self.duplex(); + out.push(self.sample()); } - self.state = sampled.pop().unwrap(); - sampled + out } /// Warning: not perfectly uniform diff --git a/crates/backend/fiat-shamir/src/prover.rs b/crates/backend/fiat-shamir/src/prover.rs index 2ea95580d..ee7e87be0 100644 --- a/crates/backend/fiat-shamir/src/prover.rs +++ b/crates/backend/fiat-shamir/src/prover.rs @@ -1,6 +1,6 @@ use crate::{ MerklePaths, PrunedMerklePaths, - challenger::{Challenger, RATE, WIDTH}, + challenger::{CAPACITY, Challenger, RATE, WIDTH}, *, }; use field::Field; @@ -8,11 +8,11 @@ use field::PackedValue; use field::PrimeCharacteristicRing; use field::integers::QuotientMap; use field::{ExtensionField, PrimeField64}; +use koala_bear::symmetric::Permutation; use rayon::prelude::*; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; use std::{fmt::Debug, sync::Mutex, time::Instant}; -use symetric::Compression; static POW_GRINDING_NANOS: AtomicU64 = AtomicU64::new(0); @@ -31,15 +31,15 @@ pub struct ProverState>, P> { merkle_paths: Vec, PF>>, } -impl>, P: Compression<[PF; WIDTH]>> ProverState +impl>, P: Permutation<[PF; WIDTH]>> ProverState where PF: PrimeField64, { #[must_use] - pub fn new(compressor: P) -> Self { + pub fn new(permutation: P) -> Self { assert!(EF::DIMENSION <= RATE); Self { - challenger: Challenger::new(compressor), + challenger: Challenger::new(permutation), transcript: Vec::new(), merkle_paths: Vec::new(), } @@ -53,7 +53,7 @@ where } } -impl>, P: Compression<[PF; WIDTH]>> ChallengeSampler for ProverState +impl>, P: Permutation<[PF; WIDTH]>> ChallengeSampler for ProverState where PF: PrimeField64, { @@ -66,18 +66,22 @@ where } } -impl>, P: Compression<[PF; WIDTH]> + Compression<[ as Field>::Packing; WIDTH]>> +impl>, P: Permutation<[PF; WIDTH]> + Permutation<[ as Field>::Packing; WIDTH]>> FSProver for ProverState where PF: PrimeField64, { fn add_base_scalars(&mut self, scalars: &[PF]) { - self.challenger.observe_scalars(scalars); + self.challenger.observe_many(scalars); self.transcript.extend_from_slice(scalars); } fn observe_scalars(&mut self, scalars: &[PF]) { - self.challenger.observe_scalars(scalars); + self.challenger.observe_many(scalars); + } + + fn duplex(&mut self) { + self.challenger.duplex(); } fn state(&self) -> String { @@ -97,13 +101,13 @@ where match eq_alpha { None => { let scalars = flatten_scalars_to_base(coeffs); - self.challenger.observe_scalars(&scalars); + self.challenger.observe_many(&scalars); self.transcript.extend_from_slice(&scalars[EF::DIMENSION..]); // c0 reconstructed by verifier from claimed_sum } Some(alpha) => { let bare_scalars = flatten_scalars_to_base(coeffs); let full_scalars = flatten_scalars_to_base(&expand_bare_to_full(coeffs, alpha)); - self.challenger.observe_scalars(&full_scalars); + self.challenger.observe_many(&full_scalars); self.transcript.extend_from_slice(&bare_scalars[EF::DIMENSION..]); // h0 reconstructed by verifier from claimed_sum } } @@ -140,15 +144,17 @@ where }); let mut packed_state = [Packed::::ZERO; WIDTH]; - packed_state[..RATE] + for (slot, val) in packed_state[..CAPACITY] .iter_mut() - .zip(&self.challenger.state) - .for_each(|(val, state)| *val = Packed::::from(*state)); - packed_state[RATE] = packed_witnesses; + .zip(&self.challenger.state[..CAPACITY]) + { + *slot = Packed::::from(*val); + } + packed_state[CAPACITY] = packed_witnesses; - self.challenger.compressor.compress_mut(&mut packed_state); + self.challenger.permutation.permute_mut(&mut packed_state); - let samples = packed_state[0].as_slice(); + let samples = packed_state[CAPACITY].as_slice(); for (sample, witness) in samples.iter().zip(packed_witnesses.as_slice()) { let rand_usize = sample.as_canonical_u64() as usize; if (rand_usize & ((1 << bits) - 1)) == 0 { @@ -162,8 +168,8 @@ where let witness = witness_found.lock().unwrap().unwrap(); - self.challenger.observe_scalars(&[witness]); - assert!(self.challenger.state[0].as_canonical_u64() & ((1 << bits) - 1) == 0); + self.challenger.observe_many(&[witness]); + assert!(self.challenger.state[CAPACITY].as_canonical_u64() & ((1 << bits) - 1) == 0); self.transcript.push(witness); let elapsed = time.elapsed(); diff --git a/crates/backend/fiat-shamir/src/traits.rs b/crates/backend/fiat-shamir/src/traits.rs index 5aba9f667..e2e0f13d7 100644 --- a/crates/backend/fiat-shamir/src/traits.rs +++ b/crates/backend/fiat-shamir/src/traits.rs @@ -16,6 +16,7 @@ pub trait FSProver>>: ChallengeSampler { fn state(&self) -> String; fn add_base_scalars(&mut self, scalars: &[PF]); fn observe_scalars(&mut self, scalars: &[PF]); + fn duplex(&mut self); fn pow_grinding(&mut self, bits: usize); fn hint_merkle_paths_base(&mut self, paths: Vec, PF>>); fn add_sumcheck_polynomial(&mut self, coeffs: &[EF], eq_alpha: Option); @@ -46,6 +47,7 @@ pub trait FSVerifier>>: ChallengeSampler { fn state(&self) -> String; fn next_base_scalars_vec(&mut self, n: usize) -> Result>, ProofError>; fn observe_scalars(&mut self, scalars: &[PF]); + fn duplex(&mut self); fn next_merkle_opening(&mut self) -> Result>, ProofError>; fn check_pow_grinding(&mut self, bits: usize) -> Result<(), ProofError>; fn next_sumcheck_polynomial( diff --git a/crates/backend/fiat-shamir/src/utils.rs b/crates/backend/fiat-shamir/src/utils.rs index 6ffcb63c2..b4e743aef 100644 --- a/crates/backend/fiat-shamir/src/utils.rs +++ b/crates/backend/fiat-shamir/src/utils.rs @@ -1,5 +1,5 @@ use field::{BasedVectorSpace, ExtensionField, Field, PrimeCharacteristicRing, PrimeField64}; -use symetric::Compression; +use koala_bear::symmetric::Permutation; use crate::challenger::{Challenger, RATE, WIDTH}; @@ -40,7 +40,7 @@ pub fn expand_bare_to_full(bare: &[EF], alpha: EF) -> Vec { full } -pub(crate) fn sample_vec, P: Compression<[F; WIDTH]>>( +pub(crate) fn sample_vec, P: Permutation<[F; WIDTH]>>( challenger: &mut Challenger, len: usize, ) -> Vec { diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 9bbc26bd7..a6a8a45bb 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -3,14 +3,14 @@ use std::iter::repeat_n; use crate::{ MerkleOpening, MerklePaths, PrunedMerklePaths, RawProof, - challenger::{Challenger, RATE, WIDTH}, + challenger::{CAPACITY, Challenger, RATE, WIDTH}, transcript::{DIGEST_LEN_FE, Proof}, *, }; use field::PrimeCharacteristicRing; use field::{ExtensionField, PrimeField64}; +use koala_bear::symmetric::Permutation; use koala_bear::{KoalaBear, default_koalabear_poseidon1_16}; -use symetric::Compression; pub struct VerifierState>, P> { challenger: Challenger, P>, @@ -21,11 +21,11 @@ pub struct VerifierState>, P> { raw_transcript: Vec>, // reconstructed during the proof verification, it's the format that the zkVM recursion program expects (no Merkle pruning, no sumcheck optimization to send less data, etc) } -impl>, C: Compression<[PF; WIDTH]>> VerifierState +impl>, P: Permutation<[PF; WIDTH]>> VerifierState where PF: PrimeField64, { - pub fn new(proof: Proof>, compressor: C) -> Result { + pub fn new(proof: Proof>, permutation: P) -> Result { let mut merkle_openings = Vec::new(); for paths in proof.merkle_paths { let restored = Self::restore_merkle_paths(paths).ok_or(ProofError::InvalidProof)?; @@ -33,7 +33,7 @@ where } Ok(Self { - challenger: Challenger::new(compressor), + challenger: Challenger::new(permutation), transcript: proof.transcript, transcript_offset: 0, merkle_openings, @@ -50,7 +50,7 @@ where } fn absorb_and_record(&mut self, scalars: &[PF]) { - self.challenger.observe_scalars(scalars); + self.challenger.observe_many(scalars); let total_padded = scalars.len().next_multiple_of(RATE); self.raw_transcript.extend_from_slice(scalars); self.raw_transcript @@ -90,7 +90,7 @@ where } } -impl>, C: Compression<[PF; WIDTH]>> ChallengeSampler for VerifierState +impl>, P: Permutation<[PF; WIDTH]>> ChallengeSampler for VerifierState where PF: PrimeField64, { @@ -102,7 +102,7 @@ where } } -impl>, C: Compression<[PF; WIDTH]>> FSVerifier for VerifierState +impl>, P: Permutation<[PF; WIDTH]>> FSVerifier for VerifierState where PF: PrimeField64, { @@ -121,7 +121,11 @@ where } fn observe_scalars(&mut self, scalars: &[PF]) { - self.challenger.observe_scalars(scalars); + self.challenger.observe_many(scalars); + } + + fn duplex(&mut self) { + self.challenger.duplex(); } fn next_base_scalars_vec(&mut self, n: usize) -> Result>, ProofError> { @@ -144,8 +148,8 @@ where return Ok(()); } let witness = self.read_transcript(1)?[0]; - self.challenger.observe_scalars(&[witness]); - if self.challenger.state[0].as_canonical_u64() & ((1 << bits) - 1) != 0 { + self.challenger.observe_many(&[witness]); + if self.challenger.state[CAPACITY].as_canonical_u64() & ((1 << bits) - 1) != 0 { return Err(ProofError::InvalidGrindingWitness); } self.raw_transcript.push(witness); diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index e700b31e5..5c9621779 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -89,6 +89,12 @@ def poseidon16_compress_half_hardcoded_left(left, right, output, offset): _ = left, right, output, offset +def poseidon16_permute(left, right, output): + """Raw Poseidon1 permutation (no feed-forward). Writes the 16-cell result in natural order: + m[output .. output + 16] = poseidon(left || right)""" + _ = left, right, output + + def add_be(a, b, result, length=None): _ = a, b, result, length diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index e784d27f8..1a7786396 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -7,8 +7,8 @@ use crate::{ use backend::PrimeCharacteristicRing; use lean_vm::{ ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, - POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, PrecompileArgs, - PrecompileCompTimeArgs, SourceLocation, + POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_PERMUTE_NAME, + PrecompileArgs, PrecompileCompTimeArgs, SourceLocation, }; use std::{ collections::{BTreeMap, BTreeSet}, @@ -2259,13 +2259,14 @@ fn simplify_lines( continue; } - // Special handling for poseidon16 precompile (4 variants) + // Special handling for poseidon16 precompile (5 variants). if ALL_POSEIDON16_NAMES.contains(&function_name.as_str()) { if !targets.is_empty() { return Err(format!( "Precompile {function_name} should not return values, at {location}" )); } + let permute = function_name.as_str() == POSEIDON16_PERMUTE_NAME; let half_output = [POSEIDON16_HALF_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME] .contains(&function_name.as_str()); let is_hardcoded_left = @@ -2303,6 +2304,7 @@ fn simplify_lines( data: PrecompileCompTimeArgs::Poseidon16 { half_output, hardcoded_offset_left, + permute, }, })); continue; diff --git a/crates/lean_compiler/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs index 1060e3be4..b9bb1ea08 100644 --- a/crates/lean_compiler/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -51,10 +51,12 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { PrecompileCompTimeArgs::Poseidon16 { half_output, hardcoded_offset_left, + permute, } => { let flag_left = hardcoded_offset_left.is_some() as usize; let hardcoded_offset_left_val = hardcoded_offset_left.unwrap_or(0); POSEIDON_PRECOMPILE_DATA + + POSEIDON_PERMUTE_SHIFT * (*permute as usize) + POSEIDON_HALF_OUTPUT_SHIFT * (*half_output as usize) + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * flag_left + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index fa86a3ae2..3e948f3a6 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -121,6 +121,7 @@ pub fn prove_execution( // logup (GKR) let logup_c = prover_state.sample(); + prover_state.duplex(); let logup_alphas = prover_state.sample_vec(log2_ceil_usize(max_bus_width_including_domainsep())); let logup_alphas_eq_poly = eval_eq(&logup_alphas); @@ -149,8 +150,10 @@ pub fn prove_execution( } let bus_beta = prover_state.sample(); + prover_state.duplex(); let air_alpha = prover_state.sample(); let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); + prover_state.duplex(); let air_eta: EF = prover_state.sample(); let tables_log_heights: BTreeMap = diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index f4e87c947..ceeff205e 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -3,7 +3,7 @@ use backend::*; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; -use utils::{init_tracing, poseidon16_compress}; +use utils::{init_tracing, poseidon16_compress, poseidon16_permute}; #[test] fn test_zk_vm_all_precompiles() { @@ -49,6 +49,11 @@ def main(): for i in unroll(0, HALF_DIGEST_LEN): assert hardcoded_full_out[i] == hardcoded_half_out[i] + # poseidon16_permute: full 16-element permutation (no feed-forward), written in natural order: + # m[res .. res + 16] = poseidon(left || right) + permute_out = pub_start + 1600 + poseidon16_permute(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, permute_out) + base_ptr = pub_start + 88 ext_a_ptr = pub_start + 88 + N ext_b_ptr = pub_start + 88 + N * (DIM + 1) @@ -123,6 +128,10 @@ def main(): F::from_usize(888), ]); + // poseidon16_permute output at 1600..1616: raw permutation result. + let permute_output = poseidon16_permute(poseidon_16_compress_input); + public_input[1600..1616].copy_from_slice(&permute_output); + // Extension op operands: base[N], ext_a[N], ext_b[N] let base_slice: [F; N] = rng.random(); let ext_a_slice: [EF; N] = rng.random(); diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 1801a5b62..915434caa 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -108,23 +108,31 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul let poseidon_trace = traces.get_mut(&Table::poseidon16()).unwrap(); fill_trace_poseidon_16(&mut poseidon_trace.columns); - // For half_output rows, override last 4 output columns with actual memory values - // (the AIR doesn't constrain them, but the lookup checks against memory). + // For permute=0 rows, override unconstrained output columns with memory values + // so the lookup matches. Same when half_output=1. { - let split = POSEIDON_16_COL_OUTPUT_START + HALF_DIGEST_LEN; + let split = POSEIDON_16_COL_OUTPUT_LEFT + HALF_DIGEST_LEN; let (left, right) = poseidon_trace.columns.split_at_mut(split); let half_output_col = &left[POSEIDON_16_COL_FLAG_HALF_OUTPUT]; + let permute_col = &left[POSEIDON_16_COL_FLAG_PERMUTE]; let res_col = &left[POSEIDON_16_COL_INDEX_INPUT_RES]; - let output_cols: &mut [Vec; HALF_DIGEST_LEN] = (&mut right[..HALF_DIGEST_LEN]).try_into().unwrap(); + const N: usize = HALF_DIGEST_LEN + DIGEST_LEN; + let cols: &mut [Vec; N] = (&mut right[..N]).try_into().unwrap(); - transposed_par_iter_mut(output_cols) + transposed_par_iter_mut(cols) .zip(half_output_col) + .zip(permute_col) .zip(res_col) - .for_each(|((row, &half), &res)| { - if half == F::ONE { - let base = res.to_usize() + HALF_DIGEST_LEN; - for j in 0..HALF_DIGEST_LEN { - *row[j] = memory_padded[base + j]; + .for_each(|(((row, &half), &permute), &res)| { + if permute == F::ZERO { + let base = res.to_usize(); + if half == F::ONE { + for j in 0..HALF_DIGEST_LEN { + *row[j] = memory_padded[base + HALF_DIGEST_LEN + j]; + } + } + for j in 0..DIGEST_LEN { + *row[HALF_DIGEST_LEN + j] = memory_padded[base + DIGEST_LEN + j]; } } }); diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 0909691d7..f3e05d0b5 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -71,6 +71,7 @@ pub fn verify_execution( )?; let logup_c = verifier_state.sample(); + verifier_state.duplex(); let logup_alphas = verifier_state.sample_vec(log2_ceil_usize(max_bus_width_including_domainsep())); let logup_alphas_eq_poly = eval_eq(&logup_alphas); @@ -98,8 +99,10 @@ pub fn verify_execution( } let bus_beta = verifier_state.sample(); + verifier_state.duplex(); let air_alpha = verifier_state.sample(); let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); + verifier_state.duplex(); let eta: EF = verifier_state.sample(); // batching the sumchecks proving validity of AIR tables let tables_sorted = sort_tables_by_height(&table_n_vars); diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index f0b7ef212..d52f7513e 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -2,12 +2,12 @@ use super::Operation; use super::operands::{MemOrConstant, MemOrFpOrConstant}; -use crate::POSEIDON16_NAME; use crate::core::{F, Label}; use crate::diagnostics::RunnerError; use crate::execution::memory::MemoryAccess; use crate::tables::TableT; use crate::{ExtensionOpMode, Table, TableTrace}; +use crate::{POSEIDON16_NAME, POSEIDON16_PERMUTE_NAME}; use backend::*; use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; @@ -68,6 +68,8 @@ pub enum PrecompileCompTimeArgs { // hardcoded_offset_left = None: left_input = m[arg_a..arg_a+8] // hardcoded_offset_left = Some(offset_left): left_input = m[offset_left..offset_left+4] | m[arg_a..arg_a+4] (arg_a is the first runtime parameter) hardcoded_offset_left: Option, + // Mutually exclusive with `half_output`. + permute: bool, }, ExtensionOp { size: S, @@ -88,9 +90,11 @@ impl PrecompileCompTimeArgs { Self::Poseidon16 { half_output, hardcoded_offset_left: hardcoded_left_4, + permute, } => PrecompileCompTimeArgs::Poseidon16 { half_output, hardcoded_offset_left: hardcoded_left_4.map(&mut f), + permute, }, Self::ExtensionOp { size, mode } => PrecompileCompTimeArgs::ExtensionOp { size: f(size), mode }, } @@ -253,15 +257,24 @@ impl Display for PrecompileArgs { PrecompileCompTimeArgs::Poseidon16 { half_output, hardcoded_offset_left: hardcoded_left_4, - } => match (*half_output, hardcoded_left_4) { - (false, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})"), - (true, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half)"), - (false, Some(off)) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, hardcoded_left_4={off})"), - (true, Some(off)) => write!( - f, - "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half, hardcoded_left_4={off})" - ), - }, + permute, + } => { + if *permute { + write!(f, "{POSEIDON16_PERMUTE_NAME}({arg_0}, {arg_1}, {res})") + } else { + match (*half_output, hardcoded_left_4) { + (false, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})"), + (true, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half)"), + (false, Some(off)) => { + write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, hardcoded_left_4={off})") + } + (true, Some(off)) => write!( + f, + "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half, hardcoded_left_4={off})" + ), + } + } + } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { write!(f, "{}({arg_0}, {arg_1}, {res}, {size})", mode.name()) } diff --git a/crates/lean_vm/src/tables/mod.rs b/crates/lean_vm/src/tables/mod.rs index bf9523291..cc652a185 100644 --- a/crates/lean_vm/src/tables/mod.rs +++ b/crates/lean_vm/src/tables/mod.rs @@ -19,7 +19,7 @@ pub(crate) use utils::*; // `PRECOMPILE_DATA` is the bus discriminator separating the two precompile // tables. Disjointness is by parity of bit 0: // -// Poseidon16 (odd): 1 + 2·flag_half + 4·flag_left + 8·flag_left·offset_left +// Poseidon16 (odd): 1 + 2·flag_permute + 4·flag_half + 8·flag_left + 16·flag_left·offset_left // ExtensionOp (even): 4·is_be + 8·flag_add + 16·flag_mul + 32·flag_poly_eq + 64·len // // Multiplying `offset_left` by `flag_left` is needed for soundness: see 3.4.1 in minimal_zkVM.pdf diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 9d0994d12..bd4666fb2 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -3,7 +3,7 @@ use std::any::TypeId; use crate::*; use crate::{execution::memory::MemoryAccess, tables::poseidon_16::trace_gen::generate_trace_rows_for_perm}; use backend::*; -use utils::{ToUsize, poseidon16_compress}; +use utils::{ToUsize, poseidon16_compress, poseidon16_permute}; /// Dispatch `mds_fft_16` through concrete types. /// For `SymbolicExpression` we use the dense form so the zkDSL generator can @@ -91,9 +91,10 @@ const HALF_FINAL_FULL_ROUNDS: usize = POSEIDON1_HALF_FULL_ROUNDS / 2; // `PRECOMPILE_DATA` encoding: see `tables/mod.rs`. pub const POSEIDON_PRECOMPILE_DATA: usize = 1; -pub const POSEIDON_HALF_OUTPUT_SHIFT: usize = 1 << 1; -pub const POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT: usize = 1 << 2; -pub const POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT: usize = 1 << 3; +pub const POSEIDON_PERMUTE_SHIFT: usize = 1 << 1; +pub const POSEIDON_HALF_OUTPUT_SHIFT: usize = 1 << 2; +pub const POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT: usize = 1 << 3; +pub const POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT: usize = 1 << 4; pub const POSEIDON_16_COL_FLAG: ColIndex = 0; pub const POSEIDON_16_COL_INDEX_INPUT_RIGHT: ColIndex = 1; @@ -103,8 +104,10 @@ pub const POSEIDON_16_COL_FLAG_HARDCODED_LEFT: ColIndex = 4; pub const POSEIDON_16_COL_OFFSET_LEFT_HARDCODED: ColIndex = 5; pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST: ColIndex = 6; pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND: ColIndex = 7; -pub const POSEIDON_16_COL_INPUT_START: ColIndex = 8; -pub const POSEIDON_16_COL_OUTPUT_START: ColIndex = num_cols_poseidon_16() - 8; +pub const POSEIDON_16_COL_FLAG_PERMUTE: ColIndex = 8; +pub const POSEIDON_16_COL_INPUT_START: ColIndex = 9; +pub const POSEIDON_16_COL_OUTPUT_LEFT: ColIndex = num_cols_poseidon_16() - 16; +pub const POSEIDON_16_COL_OUTPUT_RIGHT: ColIndex = num_cols_poseidon_16() - 8; /// Non-committed columns ("virtual"): pub const POSEIDON_16_COL_INDEX_INPUT_LEFT: ColIndex = num_cols_poseidon_16(); pub const POSEIDON_16_COL_PRECOMPILE_DATA: ColIndex = num_cols_poseidon_16() + 1; @@ -113,11 +116,13 @@ pub const POSEIDON16_NAME: &str = "poseidon16_compress"; pub const POSEIDON16_HALF_NAME: &str = "poseidon16_compress_half"; pub const POSEIDON16_HARDCODED_LEFT_NAME: &str = "poseidon16_compress_hardcoded_left"; pub const POSEIDON16_HALF_HARDCODED_LEFT_NAME: &str = "poseidon16_compress_half_hardcoded_left"; -pub const ALL_POSEIDON16_NAMES: [&str; 4] = [ +pub const POSEIDON16_PERMUTE_NAME: &str = "poseidon16_permute"; +pub const ALL_POSEIDON16_NAMES: [&str; 5] = [ POSEIDON16_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME, + POSEIDON16_PERMUTE_NAME, ]; pub const HALF_DIGEST_LEN: usize = DIGEST_LEN / 2; @@ -151,7 +156,7 @@ impl TableT for Poseidon16Precompile { }, LookupIntoMemory { index: POSEIDON_16_COL_INDEX_INPUT_RES, - values: (POSEIDON_16_COL_OUTPUT_START..POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN).collect(), + values: (POSEIDON_16_COL_OUTPUT_LEFT..POSEIDON_16_COL_OUTPUT_LEFT + DIGEST_LEN * 2).collect(), }, ] } @@ -190,7 +195,8 @@ impl TableT for Poseidon16Precompile { *perm.offset_hardcoded_left = F::ZERO; *perm.effective_index_left_first = F::from_usize(zero_vec_ptr); *perm.effective_index_left_second = F::from_usize(zero_vec_ptr + HALF_DIGEST_LEN); - // Non-committed columns + *perm.flag_permute = F::ZERO; + perm.outputs_right.iter_mut().for_each(|x| **x = F::ZERO); row[POSEIDON_16_COL_INDEX_INPUT_LEFT] = F::from_usize(zero_vec_ptr); row[POSEIDON_16_COL_PRECOMPILE_DATA] = F::from_usize(POSEIDON_PRECOMPILE_DATA); @@ -210,10 +216,15 @@ impl TableT for Poseidon16Precompile { let PrecompileCompTimeArgs::Poseidon16 { half_output, hardcoded_offset_left, + permute, } = args else { unreachable!("Poseidon16 table called with non-Poseidon16 args"); }; + assert!( + !(permute && (half_output || hardcoded_offset_left.is_some())), + "Poseidon16 permute is mutually exclusive with half_output and hardcoded_left" + ); let trace = ctx.traces.get_mut(&self.table()).unwrap(); let arg_a_usize = arg_a.to_usize(); @@ -238,13 +249,17 @@ impl TableT for Poseidon16Precompile { input[HALF_DIGEST_LEN..DIGEST_LEN].copy_from_slice(&arg0_second); input[DIGEST_LEN..].copy_from_slice(&arg1); - let output = poseidon16_compress(input); - - if half_output { - ctx.memory - .set_slice(index_res_a.to_usize(), &output[..HALF_DIGEST_LEN])?; + let res_addr = index_res_a.to_usize(); + if permute { + let permuted = poseidon16_permute(input); + ctx.memory.set_slice(res_addr, &permuted)?; } else { - ctx.memory.set_slice(index_res_a.to_usize(), &output)?; + let output = poseidon16_compress(input); + if half_output { + ctx.memory.set_slice(res_addr, &output[..HALF_DIGEST_LEN])?; + } else { + ctx.memory.set_slice(res_addr, &output)?; + } } let hardcoded_offset_left_val = hardcoded_offset_left.unwrap_or(0); @@ -257,12 +272,14 @@ impl TableT for Poseidon16Precompile { trace.columns[POSEIDON_16_COL_OFFSET_LEFT_HARDCODED].push(F::from_usize(hardcoded_offset_left_val)); trace.columns[POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST].push(F::from_usize(left_first_addr)); trace.columns[POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND].push(F::from_usize(left_second_addr)); + trace.columns[POSEIDON_16_COL_FLAG_PERMUTE].push(F::from_bool(permute)); for (i, value) in input.iter().enumerate() { trace.columns[POSEIDON_16_COL_INPUT_START + i].push(*value); } // Non-committed columns trace.columns[POSEIDON_16_COL_INDEX_INPUT_LEFT].push(arg_a); let precompile_data = POSEIDON_PRECOMPILE_DATA + + POSEIDON_PERMUTE_SHIFT * (permute as usize) + POSEIDON_HALF_OUTPUT_SHIFT * (half_output as usize) + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * (flag_hardcoded as usize) + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val; @@ -280,7 +297,11 @@ impl Air for Poseidon16Precompile { num_cols_poseidon_16() } fn degree_air(&self) -> usize { - 10 // Last 4 output constraints are gated by (1 - half_output), raising degree from 9 to 10 + // Last 4 output constraints (i in 4..8) are gated by the single linear factor + // `(1 - flag_permute - flag_half_output)`, which is boolean thanks to the mutex + // `flag_permute * flag_half_output = 0`. The permutation expression has degree 9, so + // the gated constraint stays at degree 10. + 10 } fn low_degree_air(&self) -> Option<(usize, usize)> { // Each partial round contributes one `assert_eq_low` per round (1 S-box / round), of degree 3 (= the "low" degree part) @@ -290,7 +311,7 @@ impl Air for Poseidon16Precompile { vec![] } fn n_constraints(&self) -> usize { - BUS as usize + 80 + BUS as usize + 99 } fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { let cols: Poseidon1Cols16 = { @@ -307,7 +328,8 @@ impl Air for Poseidon16Precompile { + cols.flag_hardcoded_left * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT) + cols.flag_hardcoded_left * cols.offset_hardcoded_left - * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT); + * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT) + + cols.flag_permute * AB::F::from_usize(POSEIDON_PERMUTE_SHIFT); // effective_index_left_first = index_a * (1 - flag_hardcoded_left_4) + offset * flag_hardcoded_left_4 let one_minus_flag_hardcoded_left = AB::IF::ONE - cols.flag_hardcoded_left; @@ -329,6 +351,8 @@ impl Air for Poseidon16Precompile { builder.assert_bool(cols.flag_active); builder.assert_bool(cols.flag_half_output); builder.assert_bool(cols.flag_hardcoded_left); + builder.assert_bool(cols.flag_permute); + builder.assert_zero(cols.flag_permute * (cols.flag_half_output + cols.flag_hardcoded_left)); builder.assert_zero(cols.flag_hardcoded_left * (cols.offset_hardcoded_left - cols.effective_index_left_first)); builder.assert_zero(one_minus_flag_hardcoded_left * (index_a - cols.effective_index_left_first)); @@ -348,12 +372,14 @@ pub(super) struct Poseidon1Cols16 { pub offset_hardcoded_left: T, pub effective_index_left_first: T, pub effective_index_left_second: T, + pub flag_permute: T, pub inputs: [T; WIDTH], pub beginning_full_rounds: [[T; WIDTH]; HALF_INITIAL_FULL_ROUNDS], pub partial_rounds: [T; PARTIAL_ROUNDS], pub ending_full_rounds: [[T; WIDTH]; HALF_FINAL_FULL_ROUNDS - 1], - pub outputs: [T; WIDTH / 2], + pub outputs_left: [T; WIDTH / 2], + pub outputs_right: [T; WIDTH / 2], } fn eval_poseidon1_16(builder: &mut AB, local: &Poseidon1Cols16) { @@ -412,10 +438,12 @@ fn eval_poseidon1_16(builder: &mut AB, local: &Poseidon1Cols16( } #[inline] +#[allow(clippy::too_many_arguments)] fn eval_last_2_full_rounds_16( initial_state: &[AB::IF; WIDTH], state: &mut [AB::IF; WIDTH], - outputs: &[AB::IF; WIDTH / 2], + outputs_left: &[AB::IF; WIDTH / 2], + outputs_right: &[AB::IF; WIDTH / 2], round_constants_1: &[F; WIDTH], round_constants_2: &[F; WIDTH], flag_half_output: AB::IF, + flag_permute: AB::IF, builder: &mut AB, ) { for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { @@ -473,20 +504,17 @@ fn eval_last_2_full_rounds_16( *s = s.cube(); } mds_air_16(state); - // add inputs to outputs (for compression) - for (state_i, init_state_i) in state.iter_mut().zip(initial_state) { - *state_i += *init_state_i; - } - let one_minus_flag_half_output = AB::IF::ONE - flag_half_output; - for (idx, (state_i, output_i)) in state.iter_mut().zip(outputs).enumerate() { - if idx < HALF_DIGEST_LEN { - // First 4 outputs: always constrained - builder.assert_eq(*state_i, *output_i); + let not_permute = AB::IF::ONE - flag_permute; + let compression_last4 = not_permute - flag_half_output; + for i in 0..(WIDTH / 2) { + let compression_gate = if i < HALF_DIGEST_LEN { + not_permute } else { - // Last 4 outputs: constrained only when half_output = 0 - builder.assert_zero(one_minus_flag_half_output * (*state_i - *output_i)); - } - *state_i = *output_i; + compression_last4 + }; + builder.assert_zero(compression_gate * (state[i] + initial_state[i] - outputs_left[i])); + builder.assert_zero(flag_permute * (state[i] - outputs_left[i])); + builder.assert_zero(flag_permute * (state[i + WIDTH / 2] - outputs_right[i])); } } diff --git a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs index fca712257..c7b93cf56 100644 --- a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs @@ -99,11 +99,13 @@ pub(super) fn generate_trace_rows_for_perm + Copy>(perm: & generate_2_full_round(&mut state, full_round, &constants[0], &constants[1]); } - // Last 2 full rounds with compression (add inputs to outputs) + let flag_permute = *perm.flag_permute; generate_last_2_full_rounds( &mut state, &inputs, - &mut perm.outputs, + &mut perm.outputs_left, + &mut perm.outputs_right, + flag_permute, &poseidon1_final_constants()[2 * n_ending_full_rounds], &poseidon1_final_constants()[2 * n_ending_full_rounds + 1], ); @@ -137,7 +139,9 @@ fn generate_2_full_round + Copy>( fn generate_last_2_full_rounds + Copy>( state: &mut [F; WIDTH], inputs: &[F; WIDTH], - outputs: &mut [&mut F; WIDTH / 2], + outputs_left: &mut [&mut F; WIDTH / 2], + outputs_right: &mut [&mut F; WIDTH / 2], + flag_permute: F, round_constants_1: &[KoalaBear; WIDTH], round_constants_2: &[KoalaBear; WIDTH], ) { @@ -153,8 +157,9 @@ fn generate_last_2_full_rounds + Copy>( } mds_circ_16(state); - // Add inputs to outputs (compression) - for ((output, state_i), &input_i) in outputs.iter_mut().zip(state).zip(inputs) { - **output = *state_i + input_i; + for i in 0..(WIDTH / 2) { + let compression_value = state[i] + inputs[i]; + *outputs_left[i] = (F::ONE - flag_permute) * compression_value + flag_permute * state[i]; + *outputs_right[i] = flag_permute * state[i + WIDTH / 2]; } } diff --git a/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py b/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py index 5eb3c1316..3cee15d76 100644 --- a/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py +++ b/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py @@ -1,28 +1,35 @@ from snark_lib import * -# FIAT SHAMIR layout: 17 field elements -# 0..8 -> first half of sponge state -# 8 -> transcript pointer - from utils import * +# fs layout (17 cells): +# fs[0..8] = capacity +# fs[8..16] = rate +# fs[16] = transcript pointer +# This matches the normal-ordering poseidon precompile output [cap | rate]. + + def fs_new(transcript_ptr): - fs_state = Array(9) - set_to_8_zeros(fs_state) - fs_state[8] = transcript_ptr - return fs_state + fs = Array(17) + set_to_16_zeros(fs) + fs[16] = transcript_ptr + return fs @inline -def fs_observe_chunks(fs, data, n_chunks): - result: Mut = Array(9) - poseidon16_compress(fs, data, result) +def _absorb_chunks(fs, data, n_chunks, new_transcript_ptr): + assert n_chunks != 0 + chain = Array(n_chunks * 16 + 1) + poseidon16_permute(fs, data, chain) for i in unroll(1, n_chunks): - new_result = Array(9) - poseidon16_compress(result, data + i * DIGEST_LEN, new_result) - result = new_result - result[8] = fs[8] # preserve transcript pointer - return result + poseidon16_permute(chain + (i - 1) * 16, data + i * DIGEST_LEN, chain + i * 16) + chain[n_chunks * 16] = new_transcript_ptr + return chain + (n_chunks - 1) * 16 + + +@inline +def fs_observe_chunks(fs, data, n_chunks): + return _absorb_chunks(fs, data, n_chunks, fs[16]) def fs_observe(fs, data, length: Const): @@ -36,28 +43,23 @@ def fs_observe(fs, data, length: Const): padded[j] = data[n_full_chunks * DIGEST_LEN + j] for j in unroll(remainder, DIGEST_LEN): padded[j] = 0 - final_result = Array(9) - poseidon16_compress(intermediate, padded, final_result) - final_result[8] = fs[8] # preserve transcript pointer - return final_result + return fs_observe_chunks(intermediate, padded, 1) def fs_grinding(fs, bits): if bits == 0: return fs # no grinding - transcript_ptr = fs[8] - set_to_7_zeros(transcript_ptr + 1) - - new_fs = Array(9) - poseidon16_compress(fs, transcript_ptr, new_fs) - new_fs[8] = transcript_ptr + 8 + transcript_ptr = fs[16] + new_fs = _absorb_chunks(fs, transcript_ptr, 1, transcript_ptr + DIGEST_LEN) - sampled = new_fs[0] + # Rate is at new_fs[8..16]; sample the first cell of it for the grinding check. + sampled = new_fs[8] debug_assert(bits <= 24) match_range(bits, range(0, 25), lambda b: assert_trailing_bits_are_zeros(sampled, b)) return new_fs + def assert_trailing_bits_are_zeros(value, bits: Const): debug_assert(bits != 0) @@ -68,7 +70,7 @@ def assert_trailing_bits_are_zeros(value, bits: Const): hint_decompose_bits_merkle_whir(chunks, value, chunk_size) for i in unroll(0, num_chunks): assert chunks[i] < 2**chunk_size - + partial_sums = Array(num_chunks) partial_sums[0] = chunks[0] for i in unroll(1, num_chunks): @@ -90,72 +92,76 @@ def assert_trailing_bits_are_zeros(value, bits: Const): debug_assert(bits == 24) assert chunks[0] == 0 assert chunks[1] == 0 - + return - + + +@inline +def fs_duplex(fs): + # (equivalent to absorbing 8 zeros) + # Refreshes the rate so a subsequent sample doesn't repeat the previous one. + new_fs = Array(17) + poseidon16_permute(fs, ZERO_VEC_PTR, new_fs) + new_fs[16] = fs[16] + return new_fs def fs_sample_chunks(fs, n_chunks: Const): - # return the updated fiat-shamir, and a pointer to n_chunks chunks of 8 field elements - - sampled = Array((n_chunks + 1) * 8 + 1) - for i in unroll(0, (n_chunks + 1)): - domain_sep = Array(8) - domain_sep[0] = i - set_to_7_zeros(domain_sep + 1) - poseidon16_compress( - domain_sep, - fs, - sampled + i * 8, - ) - sampled[(n_chunks + 1) * 8] = fs[8] # same transcript pointer - new_fs = sampled + n_chunks * 8 - return new_fs, sampled + # Returns (new_fs, samples_ptr) where samples_ptr points to a contiguous + # n_chunks * 8-cell buffer holding the squeezed chunks. Assumes the rate at + # fs[8..16] is "fresh" (just-permuted, not yet emitted); caller must duplex + # (or observe) between independent sample sequences. + if n_chunks == 0: + return fs, ZERO_VEC_PTR + if n_chunks == 1: + # Chunk 0 is the current fs itself: its rate is fs[8..16], no permute needed. + return fs, fs + 8 + samples = Array(n_chunks * 8) + copy_8(samples, fs + 8) + chain = Array((n_chunks - 1) * 16 + 1) + poseidon16_permute(fs, ZERO_VEC_PTR, chain) + copy_8(samples + 8, chain + 8) + for i in unroll(2, n_chunks): + poseidon16_permute(chain + (i - 2) * 16, ZERO_VEC_PTR, chain + (i - 1) * 16) + copy_8(samples + i * 8, chain + (i - 1) * 16 + 8) + chain[(n_chunks - 1) * 16] = fs[16] + new_fs = chain + (n_chunks - 2) * 16 + return new_fs, samples @inline def fs_sample_ef(fs): - sampled = Array(8) - poseidon16_compress(ZERO_VEC_PTR, fs, sampled) - new_fs = Array(9) - poseidon16_compress(SAMPLING_DOMAIN_SEPARATOR_PTR, fs, new_fs) - new_fs[8] = fs[8] # same transcript pointer - return new_fs, sampled + # Single-chunk sample: read the fresh rate at fs[8..16]; the new fs is unchanged. + return fs, fs + 8 +@inline def fs_sample_many_ef(fs, n): # return the updated fiat-shamir, and a pointer to n (continuous) extension field elements - n_chunks = div_ceil_dynamic(n * DIM, 8) + n_chunks = div_ceil(n * DIM, 8) debug_assert(n_chunks <= 31) debug_assert(1 <= n_chunks) - new_fs, sampled = match_range(n_chunks, range(1, 32), lambda nc: fs_sample_chunks(fs, nc)) + new_fs, sampled = fs_sample_chunks(fs, n_chunks) return new_fs, sampled @inline def fs_hint(fs, n): - # return the updated fiat-shamir, and a pointer to n field elements from the transcript - transcript_ptr = fs[8] - new_fs = Array(9) - copy_8(fs, new_fs) - new_fs[8] = fs[8] + n # advance transcript pointer - return new_fs, transcript_ptr + # Hint = read `n` cells from the transcript without absorbing them. Just advance the + # transcript pointer; the sponge state is unchanged. + new_fs = Array(17) + copy_8(new_fs, fs) + copy_8(new_fs + 8, fs + 8) + new_fs[16] = fs[16] + n + return new_fs, fs[16] def fs_receive_chunks(fs, n_chunks: Const): - # each chunk = 8 field elements - new_fs = Array(1 + 8 * n_chunks) - transcript_ptr = fs[8] - new_fs[8 * n_chunks] = transcript_ptr + 8 * n_chunks # advance transcript pointer - - poseidon16_compress(fs, transcript_ptr, new_fs) - for i in unroll(1, n_chunks): - poseidon16_compress( - new_fs + ((i - 1) * 8), - transcript_ptr + i * 8, - new_fs + i * 8, - ) - return new_fs + 8 * (n_chunks - 1), transcript_ptr + # Read n_chunks * 8 cells from the transcript and absorb them. Returns the new fs + # and a pointer to the just-consumed transcript region. + transcript_ptr = fs[16] + new_fs = _absorb_chunks(fs, transcript_ptr, n_chunks, transcript_ptr + n_chunks * DIGEST_LEN) + return new_fs, transcript_ptr @inline @@ -183,37 +189,15 @@ def fs_receive_ef(fs, n: Const): def fs_print_state(fs_state): - for i in unroll(0, 9): + for i in unroll(0, 17): print(i, fs_state[i]) return -def fs_sample_data_with_offset(fs, n_chunks: Const, offset): - # Like fs_sample_chunks but uses domain separators [offset..offset+n_chunks-1]. - # Only returns the sampled data, does NOT update fs. - sampled = Array(n_chunks * 8) - for i in unroll(0, n_chunks): - domain_sep = Array(8) - domain_sep[0] = offset + i - set_to_7_zeros(domain_sep + 1) - poseidon16_compress(domain_sep, fs, sampled + i * 8) - return sampled - - -def fs_finalize_sample(fs, total_n_chunks): - # Compute new fs state using domain_sep = total_n_chunks - # (same as the last poseidon call in fs_sample_chunks). - new_fs = Array(9) - domain_sep = Array(8) - domain_sep[0] = total_n_chunks - set_to_7_zeros(domain_sep + 1) - poseidon16_compress(domain_sep, fs, new_fs) - new_fs[8] = fs[8] # same transcript pointer - return new_fs - - @inline def fs_sample_queries(fs, n_samples): + # Sample `n_samples` query bit-strings. Each chunk yields 8 base field elements that + # can be downsampled to query indices. We squeeze `ceil(n_samples / 8)` chunks. debug_assert(n_samples < 512) # Compute total_chunks = ceil(n_samples / 8) via bit decomposition. # Big-endian: nb[0]=bit8 (MSB), nb[8]=bit0 (LSB). @@ -221,7 +205,5 @@ def fs_sample_queries(fs, n_samples): floor_div = nb[0] * 32 + nb[1] * 16 + nb[2] * 8 + nb[3] * 4 + nb[4] * 2 + nb[5] has_remainder = 1 - (1 - nb[6]) * (1 - nb[7]) * (1 - nb[8]) total_chunks = floor_div + has_remainder - # Sample exactly the needed chunks (dispatch via match_range to keep n_chunks const) - sampled = match_range(total_chunks, range(0, 65), lambda nc: fs_sample_data_with_offset(fs, nc, 0)) - new_fs = fs_finalize_sample(fs, total_chunks) + new_fs, sampled = match_range(total_chunks, range(0, 65), lambda nc: fs_sample_chunks(fs, nc)) return sampled, new_fs diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index 29cb66860..4433da0f0 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -95,6 +95,7 @@ def recursion(inner_public_memory, bytecode_hash_domsep): fs, logup_c = fs_sample_ef(fs) + fs = fs_duplex(fs) fs, logup_alphas = fs_sample_many_ef(fs, log2_ceil(MAX_BUS_WIDTH)) logup_alphas_eq_poly = compute_eq_mle_extension(logup_alphas, log2_ceil(MAX_BUS_WIDTH)) @@ -382,8 +383,10 @@ def continue_recursion_ordered( # VERIFY BUS AND AIR — back-loaded batched sumcheck (see https://hackmd.io/s/HyxaupAAA) fs, bus_beta = fs_sample_ef(fs) + fs = fs_duplex(fs) fs, air_alpha = fs_sample_ef(fs) air_alpha_powers = powers_const(air_alpha, MAX_NUM_AIR_CONSTRAINTS + 1) + fs = fs_duplex(fs) fs, eta = fs_sample_ef(fs) eta_powers = powers_const(eta, N_TABLES) @@ -464,6 +467,7 @@ def continue_recursion_ordered( dot_product_be(inner_public_memory, poly_eq_public_mem, public_memory_eval, 2**INNER_PUBLIC_MEMORY_LOG_SIZE) # WHIR BASE + fs = fs_duplex(fs) combination_randomness_gen: Mut fs, combination_randomness_gen = fs_sample_ef(fs) combination_randomness_powers: Mut = powers(combination_randomness_gen, num_ood_at_commitment + TOTAL_WHIR_STATEMENTS) @@ -717,6 +721,7 @@ def verify_gkr_quotient(fs: Mut, n_vars): def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den): + fs = fs_duplex(fs) fs, alpha = fs_sample_ef(fs) alpha_mul_claim_den = mul_extension_ret(alpha, claim_den) num_plus_alpha_mul_claim_den = add_extension_ret(claim_num, alpha_mul_claim_den) diff --git a/crates/rec_aggregation/zkdsl_implem/utils.py b/crates/rec_aggregation/zkdsl_implem/utils.py index 2526bc09a..7196c06f1 100644 --- a/crates/rec_aggregation/zkdsl_implem/utils.py +++ b/crates/rec_aggregation/zkdsl_implem/utils.py @@ -396,14 +396,15 @@ def set_to_8_zeros(a): dot_product_ee(a + (8 - DIM), ONE_EF_PTR, zero_ptr) return - @inline -def copy_8(a, b): - dot_product_ee(a, ONE_EF_PTR, b) - dot_product_ee(a + (8 - DIM), ONE_EF_PTR, b + (8 - DIM)) +def set_to_16_zeros(a): + zero_ptr = ZERO_VEC_PTR + dot_product_ee(a, ONE_EF_PTR, zero_ptr) + dot_product_ee(a + 5, ONE_EF_PTR, zero_ptr) + dot_product_ee(a + 10, ONE_EF_PTR, zero_ptr) + a[15] = 0 return - @inline def copy_16(a, b): dot_product_ee(a, ONE_EF_PTR, b) @@ -412,6 +413,13 @@ def copy_16(a, b): a[15] = b[15] return +@inline +def copy_8(a, b): + dot_product_ee(a, ONE_EF_PTR, b) + dot_product_ee(a + (8 - DIM), ONE_EF_PTR, b + (8 - DIM)) + return + + @inline def copy_32(a, b): chunks = div_floor(32, DIM) diff --git a/crates/rec_aggregation/zkdsl_implem/whir.py b/crates/rec_aggregation/zkdsl_implem/whir.py index 23099e91d..423acf335 100644 --- a/crates/rec_aggregation/zkdsl_implem/whir.py +++ b/crates/rec_aggregation/zkdsl_implem/whir.py @@ -342,6 +342,7 @@ def whir_round( query_grinding_bits, ) + fs = fs_duplex(fs) fs, combination_randomness_gen = fs_sample_ef(fs) combination_randomness_powers = powers(combination_randomness_gen, num_queries + num_ood) diff --git a/crates/sub_protocols/src/quotient_gkr/mod.rs b/crates/sub_protocols/src/quotient_gkr/mod.rs index ae0e264c3..ba798f010 100644 --- a/crates/sub_protocols/src/quotient_gkr/mod.rs +++ b/crates/sub_protocols/src/quotient_gkr/mod.rs @@ -84,6 +84,7 @@ fn prove_gkr_layer>>( claim_num: EF, claim_den: EF, ) -> (MultilinearPoint, EF, EF) { + prover_state.duplex(); let alpha = prover_state.sample(); let expected_sum = claim_num + alpha * claim_den; @@ -168,6 +169,7 @@ fn verify_gkr_quotient_step>>( claims_num: EF, claims_den: EF, ) -> Result<(MultilinearPoint, EF, EF), ProofError> { + verifier_state.duplex(); let alpha = verifier_state.sample(); let expected_sum = claims_num + alpha * claims_den; let eq_alphas_rev: Vec = point.0.iter().rev().copied().collect(); diff --git a/crates/sub_protocols/tests/prove_poseidon_16.rs b/crates/sub_protocols/tests/prove_poseidon_16.rs index 8371aadf5..941993662 100644 --- a/crates/sub_protocols/tests/prove_poseidon_16.rs +++ b/crates/sub_protocols/tests/prove_poseidon_16.rs @@ -68,6 +68,7 @@ fn prove_air_poseidon_16(log_n_rows: usize) { let air_alpha_powers: Vec = alpha.powers().collect_n(n_constraints + 1); // BUS=false => `logup_alphas_eq_poly` and `bus_beta` are unused; only `alpha_powers` matter. let extra_data = ExtraDataForBuses::new(Vec::new(), EF::ZERO, air_alpha_powers); + prover_state.duplex(); let eq_factor: Vec = prover_state.sample_vec(log_n_rows); let column_refs: Vec<&[F]> = trace.iter().map(Vec::as_slice).collect(); let packed = MleGroupRef::::Base(column_refs).pack(); @@ -110,6 +111,7 @@ fn prove_air_poseidon_16(log_n_rows: usize) { let air_alpha_powers: Vec = alpha.powers().collect_n(n_constraints + 1); let extra_data = ExtraDataForBuses::new(Vec::new(), EF::ZERO, air_alpha_powers); + verifier_state.duplex(); let eq_factor_v: Vec = verifier_state.sample_vec(log_n_rows); let Evaluation { diff --git a/crates/utils/src/poseidon.rs b/crates/utils/src/poseidon.rs index 3eb2557ba..f8c09e226 100644 --- a/crates/utils/src/poseidon.rs +++ b/crates/utils/src/poseidon.rs @@ -1,3 +1,4 @@ +use backend::symmetric::Permutation; use backend::*; use std::sync::OnceLock; @@ -24,6 +25,11 @@ pub fn poseidon16_compress(input: [KoalaBear; 16]) -> [KoalaBear; 8] { get_poseidon16().compress(input)[0..8].try_into().unwrap() } +#[inline(always)] +pub fn poseidon16_permute(input: [KoalaBear; 16]) -> [KoalaBear; 16] { + get_poseidon16().permute(input) +} + pub fn poseidon16_compress_pair(left: &[KoalaBear; 8], right: &[KoalaBear; 8]) -> [KoalaBear; 8] { let mut input = [KoalaBear::default(); 16]; input[..8].copy_from_slice(left); diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 8b8b4031c..f9634c918 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -142,6 +142,7 @@ where }; // Randomness for combination + prover_state.duplex(); let combination_randomness_gen: EF = prover_state.sample(); let ood_combination_randomness: Vec<_> = combination_randomness_gen.powers().collect_n(ood_challenges.len()); round_state @@ -484,6 +485,7 @@ where statement.splice(0..0, ood_statements); + prover_state.duplex(); let combination_randomness_gen: EF = prover_state.sample(); let (sumcheck_prover, folding_randomness) = SumcheckSingle::run_initial_sumcheck_rounds( diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index 18925b287..4ab38d1b3 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -101,7 +101,8 @@ where let mut claimed_sum = EF::ZERO; let mut prev_commitment = parsed_commitment.clone(); - // Combine OODS and statement constraints to claimed_sum + // Combine OODS and statement constraints to claimed_sum. + verifier_state.duplex(); let constraints: Vec<_> = prev_commitment .oods_constraints() .into_iter() @@ -143,6 +144,7 @@ where .chain(stir_constraints) .collect(); + verifier_state.duplex(); let combination_randomness = self.combine_constraints(verifier_state, &mut claimed_sum, &constraints)?; round_constraints.push((combination_randomness.clone(), constraints)); diff --git a/misc/minimal_zkVM.tex b/misc/minimal_zkVM.tex index 1567d3a23..211ee2863 100644 --- a/misc/minimal_zkVM.tex +++ b/misc/minimal_zkVM.tex @@ -270,15 +270,22 @@ \subsection{Precompiles} \subsubsection{POSEIDON} -Compression (feed-forward) of 16 field elements (two blocks of 8) into 8 field elements. +Two sub-modes are supported, selected by a compile-time \texttt{flag\_permute} bit: + +\textbf{Compression mode} (default, \texttt{flag\_permute} $= 0$): feed-forward compression of 16 field elements (two blocks of 8) into 8 field elements. $$ \textbf{m}[\nu_C..\nu_C + 8] = \text{Poseidon}(\textbf{m}[\nu_A..\nu_A + 8] \;\|\; \textbf{m}[\nu_B..\nu_B + 8]) + \textbf{m}[\nu_A..\nu_A + 8] $$ +\textbf{Permute mode} (\texttt{flag\_permute} $= 1$): the raw Poseidon permutation (no feed-forward), with the 16-element output written into memory in natural order: +$$ +\textbf{m}[\nu_C..\nu_C + 16] = \text{Poseidon}(\textbf{m}[\nu_A..\nu_A + 8] \;\|\; \textbf{m}[\nu_B..\nu_B + 8]) +$$ + + \vspace{2mm} -\texttt{PRECOMPILE\_DATA} $= 1$ Recently some additonal paramters were introduced (see \ref{efficiently_verrifying_hash-based_signatures} for details), allowing more granular input / output. @@ -557,13 +564,33 @@ \subsubsection{AIR transition constraints} \subsection{Poseidon table} -We use poseidon \cite{poseidon1} over 16 field elements, in compression mode, i.e. for $\texttt{input} \in \Fp^{16}$ and interpreting the addition as coordinate-wise in $\Fp^8$: - $$\text{poseidon\_compress}(\texttt{input}) = \text{poseidon}(\texttt{input})[..8] + \texttt{input}[..8]$$. +Poseidon \cite{poseidon1} is a permutation of $\Fp^{16}$. The table runs this permutation once per row, and the final part of the row, which is exposed to memory via lookup, is a function of a compile-time bit $\texttt{flag\_permute}$: we perform either a $16 \to 8$ \emph{compression} or a $16 \to 16$ \emph{permutation}. + +The execution table hands the precompile three pointers $\nu_A$, $\nu_B$, $\nu_C$ over the precompile bus, and the precompile fetches the two input halves +$$\texttt{left} = \textbf{m}[\nu_A..\nu_A + 8], \qquad \texttt{right} = \textbf{m}[\nu_B..\nu_B + 8].$$ +Write $s = \text{poseidon}(\texttt{left} \;\|\; \texttt{right}) \in \Fp^{16}$ for the raw permutation output. + +\vspace{2mm} + +\paragraph{Compression mode} ($\texttt{flag\_permute} = 0$) Using pairwise addition in $\Fp$: -The Poseidon precompile receives 3 runtime arguments: $\nu_A$, $\nu_B$, $\nu_C$ interpreted as memory pointers. 3 lookups (each of size 8) into the memory are used to fetch $\texttt{left}$ = $\textbf{m}[\nu_A..\nu_A + 8] \in \Fp^8$, $\texttt{right}$ = $\textbf{m}[\nu_B..\nu_B + 8] \in \Fp^8$, and $\texttt{res}$ = $\textbf{m}[\nu_C..\nu_C + 8] \in \Fp^8$. AIR constraints, of degree 9, assert that $\texttt{res} = \text{poseidon\_compress}(\texttt{left} \;\|\; \texttt{right})$. Degree 3 is also an alternative, at the cost of more committed columns ($\approx 160$ vs. $\approx 100$). +$$\textbf{m}[\nu_C..\nu_C + 8] = s[0..8] + \texttt{left}$$ + +\vspace{2mm} + +\paragraph{Permute mode} ($\texttt{flag\_permute} = 1$). + +$$\textbf{m}[\nu_C..\nu_C + 16] = s.$$ + + +\vspace{2mm} + +The AIR has degree 10 and $\approx 100$ committed columns; a degree-3 variant exists at the cost of $\approx 160$ columns. \subsubsection{Efficiently verrifying hash-based signatures}\label{efficiently_verrifying_hash-based_signatures} +The following optimizes the $\emph{compression}$ mode, and is incompatible with the $\emph{permutation}$ mode. + Hash-based signatures often rely on tweaks and public parameters (see \cite{ethereum_signatures}). We present two independent (and composable) optimizations of the poseidon precompile, when both the hash digest and the public parameters are composed of $n = 4$ field elements. We also assume each tweak occupies less than $n$ field elements (in practice 2 is enough). @@ -586,7 +613,7 @@ \subsubsection{Efficiently verrifying hash-based signatures}\label{efficiently_v Both optimizations can be implemented (together), at the cost of 4 additional columns, and incrementing the degree of the constraints by 1. Both flags ($\texttt{flag\_short}$ and $\texttt{flag\_left}$) and the offset ($\texttt{offset\_left}$) can be encoded in a single (compile-time) parameter, as follows: -$$\texttt{AUX} = 1 + 2 \cdot \texttt{flag\_short} + 4 \cdot \texttt{flag\_left} + 8 \cdot \texttt{offset\_left} \cdot \texttt{flag\_left}$$ +$$\texttt{AUX} = 1 + 2 \cdot \texttt{flag\_permute} + 4 \cdot \texttt{flag\_short} + 8 \cdot \texttt{flag\_left} + 16 \cdot \texttt{offset\_left} \cdot \texttt{flag\_left}$$ Soundness of this encoding (encoding multiple data into a single field element requires care, to avoid overflows that would break injectivity): \begin{itemize}