diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 9bbc26bd7..c9053d905 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -72,7 +72,20 @@ where // SAFETY: We've confirmed PF == KoalaBear let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; let perm = default_koalabear_poseidon1_16(); - let hash_fn = |data: &[KoalaBear]| symetric::hash_slice::<_, _, 16, 8, DIGEST_LEN_FE>(&perm, data); + let hash_fn = |data: &[KoalaBear]| { + // Pad data up to the smallest sponge-aligned length so that + // (padded - WIDTH) is a multiple of RATE. The prover's + // build_merkle_tree_koalabear pads identically before hashing. + const W: usize = 16; + const R: usize = 12; + let mut padded_len = data.len().max(W); + while !(padded_len - W).is_multiple_of(R) { + padded_len += 1; + } + let mut buf: Vec = data.to_vec(); + buf.resize(padded_len, KoalaBear::default()); + symetric::mmo_hash_slice::<_, _, 16, 12, DIGEST_LEN_FE>(&perm, &buf) + }; let combine_fn = |left: &[KoalaBear; DIGEST_LEN_FE], right: &[KoalaBear; DIGEST_LEN_FE]| { symetric::compress(&perm, [*left, *right]) }; diff --git a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs index 41be54ff5..1072c09d2 100644 --- a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs +++ b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs @@ -25,86 +25,47 @@ const MDS_CIRC_COL: [KoalaBear; 16] = KoalaBear::new_array([1, 3, 13, 22, 67, 2, // Forward twiddles for 16-point FFT: W_k = omega^k // ========================================================================= -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W1: KoalaBear = KoalaBear::new(0x08dbd69c); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W2: KoalaBear = KoalaBear::new(0x6832fe4a); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W3: KoalaBear = KoalaBear::new(0x27ae21e2); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W4: KoalaBear = KoalaBear::new(0x7e010002); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W5: KoalaBear = KoalaBear::new(0x3a89a025); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W6: KoalaBear = KoalaBear::new(0x174e3650); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W7: KoalaBear = KoalaBear::new(0x27dfce22); // ========================================================================= // 16-point FFT / IFFT (radix-2, fully unrolled, in-place) // ========================================================================= -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn bt>(v: &mut [R; 16], lo: usize, hi: usize) { +fn bt>(v: &mut [R; 16], lo: usize, hi: usize) { let (a, b) = (v[lo], v[hi]); v[lo] = a + b; v[hi] = a - b; } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn dit>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { +fn dit>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { let a = v[lo]; let tb = v[hi] * t; v[lo] = a + tb; v[hi] = a - tb; } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn neg_dif>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { +fn neg_dif>( + v: &mut [R; 16], + lo: usize, + hi: usize, + t: KoalaBear, +) { let (a, b) = (v[lo], v[hi]); v[lo] = a + b; v[hi] = (b - a) * t; } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn dif_ifft_16_mut>(f: &mut [R; 16]) { +fn dif_ifft_16_mut>(f: &mut [R; 16]) { bt(f, 0, 8); neg_dif(f, 1, 9, W7); neg_dif(f, 2, 10, W6); @@ -139,12 +100,8 @@ fn dif_ifft_16_mut>(f: &mut [R; 16]) { bt(f, 14, 15); } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn dit_fft_16_mut>(f: &mut [R; 16]) { +fn dit_fft_16_mut>(f: &mut [R; 16]) { bt(f, 0, 1); bt(f, 2, 3); bt(f, 4, 5); @@ -543,6 +500,11 @@ struct Precomputed { /// Length = POSEIDON1_PARTIAL_ROUNDS - 1. sparse_round_constants: Vec, + // --- FFT MDS eigenvalues (unpacked, for AIR / generic types) --- + /// `lambda_over_16[i]` = (DIF_IFFT(MDS_CIRC_COL))[i] * 16^{-1}. + /// Used by `mds_fft_16` to compute the circulant MDS via FFT. + lambda_over_16: [KoalaBear; 16], + // --- SIMD pre-packed constants (NEON / AVX2 / AVX512) --- #[cfg(any( all(target_arch = "aarch64", target_feature = "neon"), @@ -634,6 +596,16 @@ fn precomputed() -> &'static Precomputed { .map(|w| core::array::from_fn(|i| if i == 0 { mds_0_0 } else { w[i - 1] })) .collect(); + // --- FFT MDS eigenvalues (unpacked) --- + // C * x = DIT_FFT((lambda/16) ⊙ DIF_IFFT(x)) — same identity used in + // `permute_simd`, factored out for the AIR / generic mds_fft_16 path. + let lambda_over_16: [KoalaBear; 16] = { + let mut lambda_br = MDS_CIRC_COL; + dif_ifft_16_mut(&mut lambda_br); + let inv16 = KoalaBear::new(1997537281); // 16^{-1} mod p + lambda_br.map(|l| l * inv16) + }; + // --- SIMD pre-packed constants (same layout for NEON / AVX2 / AVX512) --- #[cfg(any( all(target_arch = "aarch64", target_feature = "neon"), @@ -675,10 +647,8 @@ fn precomputed() -> &'static Precomputed { let packed_fused_bias: [PackedKB; 16] = fused_bias.map(pack); // Pre-packed eigenvalues * INV16 (absorbs /16 into eigenvalues). - let mut lambda_br = MDS_CIRC_COL; - dif_ifft_16_mut(&mut lambda_br); - let inv16 = KoalaBear::new(1997537281); // 16^{-1} mod p - let packed_lambda_over_16: [PackedKB; 16] = core::array::from_fn(|i| pack(lambda_br[i] * inv16)); + // Reuse the unpacked lambda computed above. + let packed_lambda_over_16: [PackedKB; 16] = lambda_over_16.map(pack); SimdPrecomputed { packed_initial_rc, @@ -699,6 +669,7 @@ fn precomputed() -> &'static Precomputed { sparse_first_row, sparse_v, sparse_round_constants: scalar_round_constants, + lambda_over_16, #[cfg(any( all(target_arch = "aarch64", target_feature = "neon"), all(target_arch = "x86_64", target_feature = "avx2") @@ -708,6 +679,29 @@ fn precomputed() -> &'static Precomputed { }) } +/// Eigenvalues of the circulant MDS matrix, divided by 16 (the unnormalized +/// FFT round-trip scaling). Used by [`mds_fft_16`]. +#[inline(always)] +pub fn poseidon1_lambda_over_16() -> &'static [KoalaBear; 16] { + &precomputed().lambda_over_16 +} + +/// Circulant MDS multiply via 16-point FFT (50 mults vs 72 for Karatsuba). +/// +/// Computes `state = C * state = DIT_FFT((lambda/16) o DIF_IFFT(state))`. +/// Bitwise-identical to `mds_circ_16` but with fewer multiplications. +/// Used by the AIR constraint folder where MDS is evaluated per row over +/// (packed) field types. +#[inline(always)] +pub fn mds_fft_16>(state: &mut [R; 16]) { + let lambda = poseidon1_lambda_over_16(); + dif_ifft_16_mut(state); + for i in 0..16 { + state[i] = state[i] * lambda[i]; + } + dit_fft_16_mut(state); +} + // ========================================================================= // Round constants (Grain LFSR, matching Plonky3) // ========================================================================= @@ -1049,6 +1043,7 @@ impl Poseidon1KoalaBear16 { impl + InjectiveMonomial<3> + Send + Sync + 'static> Permutation<[R; 16]> for Poseidon1KoalaBear16 { + #[inline] fn permute_mut(&self, input: &mut [R; 16]) { // On targets with a SIMD fast path, dispatch to it when R is the arch-specific packed type. #[cfg(any( diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 676e83f3e..f0834ae8d 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -98,14 +98,14 @@ pub fn merkle_verify bool where - F: Default + Copy + PartialEq, + F: field::PrimeCharacteristicRing + PartialEq, Comp: Compression<[F; WIDTH]>, { if opening_proof.len() != log_height { return false; } - let mut root = crate::hash_slice::<_, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values); + let mut root = crate::mmo_hash_slice::<_, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values); for &sibling in opening_proof.iter() { let (left, right) = if index & 1 == 0 { diff --git a/crates/backend/symetric/src/permutation.rs b/crates/backend/symetric/src/permutation.rs index c129a1dc4..4068f2af3 100644 --- a/crates/backend/symetric/src/permutation.rs +++ b/crates/backend/symetric/src/permutation.rs @@ -16,6 +16,7 @@ pub trait Compression: Clone + Sync { impl + InjectiveMonomial<3> + Send + Sync + 'static> Compression<[R; 16]> for Poseidon1KoalaBear16 { + #[inline] fn compress_mut(&self, input: &mut [R; 16]) { self.compress_in_place(input); } diff --git a/crates/backend/symetric/src/sponge.rs b/crates/backend/symetric/src/sponge.rs index ebea80a9e..e195d2a06 100644 --- a/crates/backend/symetric/src/sponge.rs +++ b/crates/backend/symetric/src/sponge.rs @@ -1,22 +1,25 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). +use field::PrimeCharacteristicRing; + use crate::Compression; -// IV should have been added to data when necessary (typically: when the length of the data beeing hashed is not constant). Maybe we should re-add IV all the time for simplicity? -// assumes data length is a multiple of RATE (= 8 in practice). +// IV should have been added to data when necessary (typically: when the length of the data beeing hashed is not constant). +// Sponge construction with capacity = WIDTH - RATE. +// Constraint: data.len() >= WIDTH and (data.len() - WIDTH) is a multiple of RATE. pub fn hash_slice(comp: &Comp, data: &[T]) -> [T; OUT] where T: Default + Copy, Comp: Compression<[T; WIDTH]>, { - debug_assert!(RATE == OUT); - debug_assert!(WIDTH == OUT + RATE); - debug_assert!(data.len().is_multiple_of(RATE)); - let n_chunks = data.len() / RATE; - debug_assert!(n_chunks >= 2); + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); + debug_assert!(data.len() >= WIDTH); + debug_assert!((data.len() - WIDTH).is_multiple_of(RATE)); let mut state: [T; WIDTH] = data[data.len() - WIDTH..].try_into().unwrap(); comp.compress_mut(&mut state); - for chunk_idx in (0..n_chunks - 2).rev() { + let n_remaining_chunks = (data.len() - WIDTH) / RATE; + for chunk_idx in (0..n_remaining_chunks).rev() { let offset = chunk_idx * RATE; state[WIDTH - RATE..].copy_from_slice(&data[offset..offset + RATE]); comp.compress_mut(&mut state); @@ -24,7 +27,8 @@ where state[..OUT].try_into().unwrap() } -/// Precompute sponge state after absorbing `n_zero_chunks` all-zero RATE-chunks. +/// Precompute sponge state after `n_zero_chunks - 1` zero compresses +/// (1 for initial WIDTH zeros + (n-2) RATE-zero absorbs). pub fn precompute_zero_suffix_state( comp: &Comp, n_zero_chunks: usize, @@ -33,8 +37,8 @@ where T: Default + Copy, Comp: Compression<[T; WIDTH]>, { - debug_assert!(RATE == OUT); - debug_assert!(WIDTH == OUT + RATE); + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); debug_assert!(n_zero_chunks >= 2); let mut state = [T::default(); WIDTH]; comp.compress_mut(&mut state); @@ -58,8 +62,8 @@ where Comp: Compression<[T; WIDTH]>, I: IntoIterator, { - debug_assert!(RATE == OUT); - debug_assert!(WIDTH == OUT + RATE); + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); let mut state = [T::default(); WIDTH]; let mut iter = rtl_iter.into_iter(); for pos in (0..WIDTH).rev() { @@ -106,3 +110,231 @@ where } state[..OUT].try_into().unwrap() } + +// ============================================================================= +// MMO-mode (Davies-Meyer / Matyas-Meyer-Oseas) feedforward sponge +// ============================================================================= +// +// Standard PaddingFreeSponge ("oSponge") collision security is c·log2(p)/2 bits +// because of the inner-state birthday attack on the capacity portion. With +// (WIDTH=16, RATE=12, capacity=4) over KoalaBear (p ~= 2^31), that bound is +// 4*31/2 = 62 bits — short of the 124-bit target. +// +// This MMO variant treats each absorb step as the Matyas-Meyer-Oseas +// compression F(state, M) = state + perm(state + (M, 0_cap)), i.e. message is +// ADDED into the rate positions (not overwritten) and the full pre-perm state +// is fed forward. The chaining variable is then the FULL 16-element state +// (496 bits), not the 4-element capacity, so generic compression collision is +// 2^{b/2} = 2^248 in the random-permutation model, and after truncation to +// OUT=8 elements the digest birthday gives 2^{output_bits/2} = 2^124. +// +// IMPORTANT: `Compression::compress_mut` for Poseidon-16 in this codebase ALREADY +// computes `output = perm(input) + input` (matching the AIR's `eval_last_2_full_rounds_16` +// which adds initial state to post-perm state — see lean_vm/src/tables/poseidon_16/mod.rs). +// So a single `compress_mut` call IS one MMO step; we must NOT add `prev` again +// after, or we'd double-feedforward and disagree with the zk-DSL precompile. +// +// Convention matches the existing hash_slice: the first 16 elements of data +// are loaded directly into the state and the precompile is invoked once +// (zero IV implicit). Subsequent RATE-sized blocks are absorbed with ADD into +// rate positions, then a single compression invocation gives the next state. + +/// MMO-mode (feedforward) variant of `hash_slice`. Same input format and +/// alignment requirements; collision security is bounded by the digest size +/// rather than the capacity. +#[inline] +pub fn mmo_hash_slice( + comp: &Comp, + data: &[T], +) -> [T; OUT] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, +{ + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); + debug_assert!(data.len() >= WIDTH); + debug_assert!((data.len() - WIDTH).is_multiple_of(RATE)); + let mut state: [T; WIDTH] = data[data.len() - WIDTH..].try_into().unwrap(); + // First MMO compression: state ← perm(state) + state (compress_mut already does this). + comp.compress_mut(&mut state); + let n_remaining_chunks = (data.len() - WIDTH) / RATE; + for chunk_idx in (0..n_remaining_chunks).rev() { + let offset = chunk_idx * RATE; + // ADD message into rate positions (not overwrite). + for i in 0..RATE { + state[WIDTH - RATE + i] += data[offset + i]; + } + // One MMO compression: state ← perm(state) + state. compress_mut already + // performs the full-state feedforward. + comp.compress_mut(&mut state); + } + state[..OUT].try_into().unwrap() +} + +/// MMO-mode variant of `precompute_zero_suffix_state`. Same number of perm +/// calls as the standard variant (n_zero_chunks - 1 total). +#[inline] +pub fn mmo_precompute_zero_suffix_state( + comp: &Comp, + n_zero_chunks: usize, +) -> [T; WIDTH] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, +{ + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); + debug_assert!(n_zero_chunks >= 2); + let mut state = [T::ZERO; WIDTH]; + // First absorb (16 zeros). compress_mut applies one MMO compression. + comp.compress_mut(&mut state); + // Subsequent (n_zero_chunks - 2) absorbs of zero RATE-chunks. ADD 0 is a + // no-op, so each iteration is just one MMO compression. + for _ in 0..n_zero_chunks - 2 { + comp.compress_mut(&mut state); + } + state +} + +/// RTL = Right-to-left. MMO-mode counterpart of `hash_rtl_iter`. +#[inline(always)] +pub fn mmo_hash_rtl_iter( + comp: &Comp, + rtl_iter: I, +) -> [T; OUT] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, + I: IntoIterator, +{ + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); + let mut state = [T::ZERO; WIDTH]; + let mut iter = rtl_iter.into_iter(); + for pos in (0..WIDTH).rev() { + state[pos] = iter.next().unwrap(); + } + comp.compress_mut(&mut state); + mmo_absorb_rtl_chunks::(comp, &mut state, &mut iter) +} + +/// RTL = Right-to-left. MMO-mode counterpart of `hash_rtl_iter_with_initial_state`. +#[inline(always)] +pub fn mmo_hash_rtl_iter_with_initial_state( + comp: &Comp, + mut iter: I, + initial_state: &[T; WIDTH], +) -> [T; OUT] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, + I: Iterator, +{ + let mut state = *initial_state; + mmo_absorb_rtl_chunks::(comp, &mut state, &mut iter) +} + +/// RTL = Right-to-left. MMO-mode chunk absorption: ADD message into rate, then +/// one MMO compression (compress_mut already does perm + input feedforward). +#[inline(always)] +fn mmo_absorb_rtl_chunks( + comp: &Comp, + state: &mut [T; WIDTH], + iter: &mut I, +) -> [T; OUT] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, + I: Iterator, +{ + while let Some(elem) = iter.next() { + // ADD into rate positions (last RATE elements), reading the iterator + // from right to left. + state[WIDTH - 1] += elem; + for pos in (WIDTH - RATE..WIDTH - 1).rev() { + state[pos] += iter.next().unwrap(); + } + comp.compress_mut(state); + } + state[..OUT].try_into().unwrap() +} + +#[cfg(test)] +mod tests { + use super::*; + use field::PrimeCharacteristicRing; + use koala_bear::{KoalaBear, default_koalabear_poseidon1_16}; + + /// Verify hash_slice(D) == hash_rtl_iter(D.iter().rev()) for arbitrary D with valid length. + #[test] + fn hash_slice_matches_rtl_iter_rate12() { + let perm = default_koalabear_poseidon1_16(); + // 100 = 16 + 12*7, compatible with WIDTH=16, RATE=12 + let data: Vec = (0..100u32).map(KoalaBear::from_u32).collect(); + let h_slice = hash_slice::(&perm, &data); + let h_rtl = hash_rtl_iter::(&perm, data.iter().rev().copied()); + assert_eq!( + h_slice, h_rtl, + "hash_slice and hash_rtl_iter must agree on equivalent inputs" + ); + } + + /// Same as above but for the existing RATE=8 case. + #[test] + fn hash_slice_matches_rtl_iter_rate8() { + let perm = default_koalabear_poseidon1_16(); + let data: Vec = (0..64u32).map(KoalaBear::from_u32).collect(); + let h_slice = hash_slice::(&perm, &data); + let h_rtl = hash_rtl_iter::(&perm, data.iter().rev().copied()); + assert_eq!( + h_slice, h_rtl, + "hash_slice and hash_rtl_iter must agree on equivalent inputs (RATE=8)" + ); + } + + /// MMO-mode counterpart of hash_slice_matches_rtl_iter_rate12. + #[test] + fn mmo_hash_slice_matches_rtl_iter_rate12() { + let perm = default_koalabear_poseidon1_16(); + let data: Vec = (0..100u32).map(KoalaBear::from_u32).collect(); + let h_slice = mmo_hash_slice::(&perm, &data); + let h_rtl = mmo_hash_rtl_iter::(&perm, data.iter().rev().copied()); + assert_eq!( + h_slice, h_rtl, + "mmo_hash_slice and mmo_hash_rtl_iter must agree on equivalent inputs" + ); + } + + /// MMO-mode is structurally distinct from oSponge — verify they produce + /// different digests on the same input (sanity check that we are not + /// accidentally falling back to the standard sponge). + #[test] + fn mmo_differs_from_standard_sponge() { + let perm = default_koalabear_poseidon1_16(); + let data: Vec = (0..28u32).map(KoalaBear::from_u32).collect(); // 16 + 12, two-block input + let h_std = hash_slice::(&perm, &data); + let h_mmo = mmo_hash_slice::(&perm, &data); + assert_ne!( + h_std, h_mmo, + "MMO must differ from standard sponge for multi-block inputs" + ); + } + + /// Verify the MMO precompute is consistent with directly hashing zeros. + #[test] + fn mmo_precompute_zero_suffix_matches_full_zero_hash() { + let perm = default_koalabear_poseidon1_16(); + let n_zero_chunks: usize = 4; // WIDTH absorb + 3 RATE absorbs of zero + let zeros: Vec = std::iter::repeat_n(KoalaBear::ZERO, 16 + 12 * (n_zero_chunks - 1)).collect(); + let direct = mmo_hash_slice::(&perm, &zeros); + let pre = mmo_precompute_zero_suffix_state::(&perm, n_zero_chunks); + // The precompute does (n_zero_chunks - 1) MMO compressions; mmo_hash_slice + // does n_zero_chunks total. To finalize we need ONE MORE compression + // (ADDing zero rate is a no-op). + let mut state = pre; + perm.compress_mut(&mut state); + let advanced: [KoalaBear; 8] = state[..8].try_into().unwrap(); + assert_eq!(advanced, direct); + } +} diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index 5d13b761f..e15e474f4 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -89,6 +89,20 @@ def poseidon16_compress_half_hardcoded_left(left, right, output, offset): _ = left, right, output, offset +def poseidon16_permute(left, right, output): + """Apply Poseidon-16 with input feedforward (MMO compression) and write all + 16 output elements to memory[output..output+16]. + + output[0..8] = perm(left || right)[0..8] + left + output[8..16] = perm(left || right)[8..16] + right + + Used for MMO sponge leaf hashing — the FULL 16-element state must be + chained between rounds to achieve `output_bits/2 = 124`-bit collision + security. Allocate Array(16) (NOT Array(8)) for the result. + """ + _ = left, right, output + + def add_be(a, b, result, length=None): _ = a, b, result, length diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index b3c121d88..cba454bf2 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -7,8 +7,8 @@ use crate::{ use backend::PrimeCharacteristicRing; use lean_vm::{ ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, - POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, PrecompileArgs, - PrecompileCompTimeArgs, SourceLocation, + POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_PERMUTE_NAME, + PrecompileArgs, PrecompileCompTimeArgs, SourceLocation, }; use std::{ collections::{BTreeMap, BTreeSet}, @@ -2259,7 +2259,7 @@ fn simplify_lines( continue; } - // Special handling for poseidon16 precompile (4 variants) + // Special handling for poseidon16 precompile (5 variants) if ALL_POSEIDON16_NAMES.contains(&function_name.as_str()) { if !targets.is_empty() { return Err(format!( @@ -2268,6 +2268,7 @@ fn simplify_lines( } let half_output = [POSEIDON16_HALF_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME] .contains(&function_name.as_str()); + let full_output = function_name.as_str() == POSEIDON16_PERMUTE_NAME; let is_hardcoded_left = [POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME] .contains(&function_name.as_str()); @@ -2302,6 +2303,7 @@ fn simplify_lines( res: simplified_args[2].clone(), data: PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left, }, })); diff --git a/crates/lean_compiler/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs index 1060e3be4..42b284624 100644 --- a/crates/lean_compiler/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -50,6 +50,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { let precompile_data = match &precompile.data { PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left, } => { let flag_left = hardcoded_offset_left.is_some() as usize; @@ -58,6 +59,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { + POSEIDON_HALF_OUTPUT_SHIFT * (*half_output as usize) + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * flag_left + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val + + POSEIDON_FULL_OUTPUT_SHIFT * (*full_output as usize) } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { assert!(*size >= 1, "invalid extension_op size={size}"); diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 1801a5b62..f3b0be56b 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -130,6 +130,23 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul }); } + // For non-full-output rows, zero outputs_high (AIR constrains them to zero) and point + // index_input_res_high at the zero-vec region so the high-half memory lookup is a no-op + // (m[zero_vec_ptr + i] = 0 = outputs_high[i]). + { + // Snapshot flag column (immutable copy) before taking mutable references to the trace. + let full_output_flags: Vec = poseidon_trace.columns[POSEIDON_16_COL_FLAG_FULL_OUTPUT].clone(); + let zero_ptr = F::from_usize(padding_zero_vec_ptr); + for (row_idx, flag) in full_output_flags.iter().enumerate() { + if *flag != F::ONE { + poseidon_trace.columns[POSEIDON_16_COL_INDEX_INPUT_RES_HIGH][row_idx] = zero_ptr; + for j in 0..DIGEST_LEN { + poseidon_trace.columns[POSEIDON_16_COL_OUTPUTS_HIGH_START + j][row_idx] = F::ZERO; + } + } + } + } + let extension_op_trace = traces.get_mut(&Table::extension_op()).unwrap(); fill_trace_extension_op(extension_op_trace, &memory_padded); diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index f0b7ef212..c3e42fd4d 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -65,6 +65,9 @@ pub struct PrecompileArgs { pub enum PrecompileCompTimeArgs { Poseidon16 { half_output: bool, + /// Permute mode: write all 16 elements of perm(input)+input to memory at `res`. + /// Mutually exclusive with `half_output`. Used by the MMO sponge leaf hash. + full_output: bool, // hardcoded_offset_left = None: left_input = m[arg_a..arg_a+8] // hardcoded_offset_left = Some(offset_left): left_input = m[offset_left..offset_left+4] | m[arg_a..arg_a+4] (arg_a is the first runtime parameter) hardcoded_offset_left: Option, @@ -87,9 +90,11 @@ impl PrecompileCompTimeArgs { match self { Self::Poseidon16 { half_output, + full_output, hardcoded_offset_left: hardcoded_left_4, } => PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left: hardcoded_left_4.map(&mut f), }, Self::ExtensionOp { size, mode } => PrecompileCompTimeArgs::ExtensionOp { size: f(size), mode }, @@ -252,12 +257,16 @@ impl Display for PrecompileArgs { match data { PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left: hardcoded_left_4, - } => match (*half_output, hardcoded_left_4) { - (false, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})"), - (true, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half)"), - (false, Some(off)) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, hardcoded_left_4={off})"), - (true, Some(off)) => write!( + } => match (*full_output, *half_output, hardcoded_left_4) { + (true, _, _) => write!(f, "poseidon16_permute({arg_0}, {arg_1}, {res})"), + (false, false, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})"), + (false, true, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half)"), + (false, false, Some(off)) => { + write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, hardcoded_left_4={off})") + } + (false, true, Some(off)) => write!( f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half, hardcoded_left_4={off})" ), diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 5cffe5194..af1b8b9bc 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -5,9 +5,13 @@ use crate::{execution::memory::MemoryAccess, tables::poseidon_16::trace_gen::gen use backend::*; use utils::{ToUsize, poseidon16_compress}; -/// Dispatch `mds_circ_16` through concrete types. -/// For `SymbolicExpression` we use the dense form so the zkDSL generator can -/// emit `dot_product_be` precompile calls instead of Karatsuba arithmetic. +/// Dispatch the circulant MDS multiply through concrete types. +/// +/// - `SymbolicExpression`: dense matrix-vector form so the zkDSL generator can +/// emit `dot_product_be` precompile calls instead of Karatsuba arithmetic. +/// - Runtime field types (F, EF, FPacking, EFPacking): FFT-based MDS +/// (`mds_fft_16`, 50 mults) instead of Karatsuba (`mds_circ_16`, 72 mults). +/// Same algebraic result; ~30% fewer mults per call. #[inline(always)] fn mds_air_16(state: &mut [A; WIDTH]) { if TypeId::of::() == TypeId::of::>() { @@ -17,7 +21,7 @@ fn mds_air_16(state: &mut [A; WIDTH]) { macro_rules! dispatch { ($t:ty) => { if TypeId::of::() == TypeId::of::<$t>() { - mds_circ_16::<$t>(unsafe { &mut *(state as *mut [A; WIDTH] as *mut [$t; WIDTH]) }); + mds_fft_16::<$t>(unsafe { &mut *(state as *mut [A; WIDTH] as *mut [$t; WIDTH]) }); return; } }; @@ -94,6 +98,9 @@ pub const POSEIDON_PRECOMPILE_DATA: usize = 1; pub const POSEIDON_HALF_OUTPUT_SHIFT: usize = 1 << 1; pub const POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT: usize = 1 << 2; pub const POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT: usize = 1 << 3; +// Bit 30 is safely beyond `8 * MAX_LOG_MEMORY_SIZE = 2^29` so it cannot +// alias with `POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * offset`. +pub const POSEIDON_FULL_OUTPUT_SHIFT: usize = 1 << 30; pub const POSEIDON_16_COL_FLAG: ColIndex = 0; pub const POSEIDON_16_COL_INDEX_INPUT_RIGHT: ColIndex = 1; @@ -104,7 +111,14 @@ pub const POSEIDON_16_COL_OFFSET_LEFT_HARDCODED: ColIndex = 5; pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST: ColIndex = 6; pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND: ColIndex = 7; pub const POSEIDON_16_COL_INPUT_START: ColIndex = 8; -pub const POSEIDON_16_COL_OUTPUT_START: ColIndex = num_cols_poseidon_16() - 8; +// Layout at end of struct (in field-declaration order): +// ... outputs (DIGEST_LEN cols) ... flag_full_output (1) ... index_input_res_high (1) ... outputs_high (DIGEST_LEN) +// So OUTPUTS_HIGH_START = num_cols - DIGEST_LEN, INDEX_INPUT_RES_HIGH = num_cols - DIGEST_LEN - 1, +// FLAG_FULL_OUTPUT = num_cols - DIGEST_LEN - 2, OUTPUT_START = num_cols - DIGEST_LEN - 2 - DIGEST_LEN. +pub const POSEIDON_16_COL_OUTPUTS_HIGH_START: ColIndex = num_cols_poseidon_16() - DIGEST_LEN; +pub const POSEIDON_16_COL_INDEX_INPUT_RES_HIGH: ColIndex = POSEIDON_16_COL_OUTPUTS_HIGH_START - 1; +pub const POSEIDON_16_COL_FLAG_FULL_OUTPUT: ColIndex = POSEIDON_16_COL_INDEX_INPUT_RES_HIGH - 1; +pub const POSEIDON_16_COL_OUTPUT_START: ColIndex = POSEIDON_16_COL_FLAG_FULL_OUTPUT - DIGEST_LEN; /// Non-committed columns ("virtual"): pub const POSEIDON_16_COL_INDEX_INPUT_LEFT: ColIndex = num_cols_poseidon_16(); pub const POSEIDON_16_COL_PRECOMPILE_DATA: ColIndex = num_cols_poseidon_16() + 1; @@ -113,11 +127,16 @@ pub const POSEIDON16_NAME: &str = "poseidon16_compress"; pub const POSEIDON16_HALF_NAME: &str = "poseidon16_compress_half"; pub const POSEIDON16_HARDCODED_LEFT_NAME: &str = "poseidon16_compress_hardcoded_left"; pub const POSEIDON16_HALF_HARDCODED_LEFT_NAME: &str = "poseidon16_compress_half_hardcoded_left"; -pub const ALL_POSEIDON16_NAMES: [&str; 4] = [ +/// Permute mode: writes ALL 16 perm-output elements (with input feedforward) to memory. +/// Used for MMO sponge leaf hashing where the FULL 16-element state must be chained +/// between rounds to achieve `output_bits/2 = 124`-bit collision security at any rate. +pub const POSEIDON16_PERMUTE_NAME: &str = "poseidon16_permute"; +pub const ALL_POSEIDON16_NAMES: [&str; 5] = [ POSEIDON16_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME, + POSEIDON16_PERMUTE_NAME, ]; pub const HALF_DIGEST_LEN: usize = DIGEST_LEN / 2; @@ -153,6 +172,13 @@ impl TableT for Poseidon16Precompile { index: POSEIDON_16_COL_INDEX_INPUT_RES, values: (POSEIDON_16_COL_OUTPUT_START..POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN).collect(), }, + // High-half output lookup (only meaningful in permute mode, but always active). + // For non-permute rows the trace_gen sets index_input_res_high = zero_vec_ptr and + // outputs_high = 0, so this lookup checks `m[zero_vec_ptr+i] == 0` (trivially true). + LookupIntoMemory { + index: POSEIDON_16_COL_INDEX_INPUT_RES_HIGH, + values: (POSEIDON_16_COL_OUTPUTS_HIGH_START..POSEIDON_16_COL_OUTPUTS_HIGH_START + DIGEST_LEN).collect(), + }, ] } @@ -190,11 +216,22 @@ impl TableT for Poseidon16Precompile { *perm.offset_hardcoded_left = F::ZERO; *perm.effective_index_left_first = F::from_usize(zero_vec_ptr); *perm.effective_index_left_second = F::from_usize(zero_vec_ptr + HALF_DIGEST_LEN); + *perm.flag_full_output = F::ZERO; + // Padding rows are non-permute → high-output index points at zero_vec_ptr (a 16-cell zero region). + *perm.index_input_res_high = F::from_usize(zero_vec_ptr); + // outputs_high is zeroed by the constraint `(1 - flag_full_output) * outputs_high[i] = 0`; + // the trace generator below leaves them at F::ZERO via the Vec::new() default. // Non-committed columns row[POSEIDON_16_COL_INDEX_INPUT_LEFT] = F::from_usize(zero_vec_ptr); row[POSEIDON_16_COL_PRECOMPILE_DATA] = F::from_usize(POSEIDON_PRECOMPILE_DATA); generate_trace_rows_for_perm(perm); + // generate_trace_rows_for_perm fills outputs[0..8] with state[i] + inputs[i]; for padding + // rows inputs are all zero so outputs ≡ Poseidon-16(0) (8 elements). outputs_high however + // must be zero per the AIR constraint, so explicitly clear it after the perm trace fill. + for output_high in perm.outputs_high.iter_mut() { + **output_high = F::ZERO; + } row } @@ -209,11 +246,16 @@ impl TableT for Poseidon16Precompile { ) -> Result<(), RunnerError> { let PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left, } = args else { unreachable!("Poseidon16 table called with non-Poseidon16 args"); }; + debug_assert!( + !(half_output && full_output), + "half_output and full_output are mutually exclusive" + ); let trace = ctx.traces.get_mut(&self.table()).unwrap(); let arg_a_usize = arg_a.to_usize(); @@ -238,13 +280,17 @@ impl TableT for Poseidon16Precompile { input[HALF_DIGEST_LEN..DIGEST_LEN].copy_from_slice(&arg0_second); input[DIGEST_LEN..].copy_from_slice(&arg1); - let output = poseidon16_compress(input); - - if half_output { - ctx.memory - .set_slice(index_res_a.to_usize(), &output[..HALF_DIGEST_LEN])?; + let res_addr = index_res_a.to_usize(); + if full_output { + // Write all 16 elements (perm output + input feedforward) to memory. + let full = utils::poseidon16_permute_full(input); + ctx.memory.set_slice(res_addr, &full)?; + } else if half_output { + let output = poseidon16_compress(input); + ctx.memory.set_slice(res_addr, &output[..HALF_DIGEST_LEN])?; } else { - ctx.memory.set_slice(index_res_a.to_usize(), &output)?; + let output = poseidon16_compress(input); + ctx.memory.set_slice(res_addr, &output)?; } let hardcoded_offset_left_val = hardcoded_offset_left.unwrap_or(0); @@ -260,12 +306,23 @@ impl TableT for Poseidon16Precompile { for (i, value) in input.iter().enumerate() { trace.columns[POSEIDON_16_COL_INPUT_START + i].push(*value); } + trace.columns[POSEIDON_16_COL_FLAG_FULL_OUTPUT].push(if full_output { F::ONE } else { F::ZERO }); + // index_input_res_high: real address (res+8) when in permute mode; otherwise a placeholder + // that will be rewritten in lean_prover post-processing to a zero region. Use 0 for now; + // soundness is maintained because the AIR constraint + // `flag_full_output * (index_input_res_high - index_input_res - 8) = 0` + // only forces the value when permute mode is on. + let index_high = if full_output { res_addr + DIGEST_LEN } else { 0 }; + trace.columns[POSEIDON_16_COL_INDEX_INPUT_RES_HIGH].push(F::from_usize(index_high)); + // outputs_high columns are filled by trace_gen (perm output high half) for permute rows, + // and overwritten to zero in lean_prover post-processing for non-permute rows. // Non-committed columns trace.columns[POSEIDON_16_COL_INDEX_INPUT_LEFT].push(arg_a); let precompile_data = POSEIDON_PRECOMPILE_DATA + POSEIDON_HALF_OUTPUT_SHIFT * (half_output as usize) + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * (flag_hardcoded as usize) - + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val; + + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val + + POSEIDON_FULL_OUTPUT_SHIFT * (full_output as usize); trace.columns[POSEIDON_16_COL_PRECOMPILE_DATA].push(F::from_usize(precompile_data)); // the rest of the trace is filled at the end of the execution (to get parallelism + SIMD) @@ -290,7 +347,9 @@ impl Air for Poseidon16Precompile { vec![] } fn n_constraints(&self) -> usize { - BUS as usize + 80 + // 80 (existing) + 1 (full_output bool) + 1 (full*half mutex) + 1 (high index) + // + 8 (full * (outputs_high[i] - state - input)) + 8 ((1-full) * outputs_high[i]) + BUS as usize + 80 + 19 } fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { let cols: Poseidon1Cols16 = { @@ -307,7 +366,8 @@ impl Air for Poseidon16Precompile { + cols.flag_hardcoded_left * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT) + cols.flag_hardcoded_left * cols.offset_hardcoded_left - * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT); + * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT) + + cols.flag_full_output * AB::F::from_usize(POSEIDON_FULL_OUTPUT_SHIFT); // effective_index_left_first = index_a * (1 - flag_hardcoded_left_4) + offset * flag_hardcoded_left_4 let one_minus_flag_hardcoded_left = AB::IF::ONE - cols.flag_hardcoded_left; @@ -329,6 +389,16 @@ impl Air for Poseidon16Precompile { builder.assert_bool(cols.flag_active); builder.assert_bool(cols.flag_half_output); builder.assert_bool(cols.flag_hardcoded_left); + builder.assert_bool(cols.flag_full_output); + // Mutually exclusive: a row cannot be both half-output and full-output. + builder.assert_zero(cols.flag_full_output * cols.flag_half_output); + // When full_output is set, index_input_res_high MUST equal index_res + DIGEST_LEN so that + // outputs_high lands at m[res+8..res+16]. When full_output is unset, the trace generator + // is free to set index_input_res_high to any zero-page address; the lookup will check + // outputs_high (= 0) against m[that_address+i] which is zero by construction. + builder.assert_zero( + cols.flag_full_output * (cols.index_input_res_high - cols.index_res - AB::F::from_usize(DIGEST_LEN)), + ); builder.assert_zero(cols.flag_hardcoded_left * (cols.offset_hardcoded_left - cols.effective_index_left_first)); builder.assert_zero(one_minus_flag_hardcoded_left * (index_a - cols.effective_index_left_first)); @@ -354,6 +424,15 @@ pub(super) struct Poseidon1Cols16 { pub partial_rounds: [T; PARTIAL_ROUNDS], pub ending_full_rounds: [[T; WIDTH]; HALF_FINAL_FULL_ROUNDS - 1], pub outputs: [T; WIDTH / 2], + /// 1 = expose all 16 perm-output elements (writes outputs_high to m[res+8..res+16]). + /// Mutually exclusive with flag_half_output. + pub flag_full_output: T, + /// Memory address for the high-half outputs. = index_res + DIGEST_LEN when flag_full_output; + /// otherwise points at zero_vec_ptr (a region pre-filled with zeros) so the lookup is a no-op. + pub index_input_res_high: T, + /// High-half perm output (state[8..16] + inputs[8..16]). Constrained when flag_full_output; + /// forced to zero when not, so the lookup against zero_vec_ptr is trivially satisfied. + pub outputs_high: [T; WIDTH / 2], } fn eval_poseidon1_16(builder: &mut AB, local: &Poseidon1Cols16) { @@ -413,9 +492,11 @@ fn eval_poseidon1_16(builder: &mut AB, local: &Poseidon1Cols16( } #[inline] +#[allow(clippy::too_many_arguments)] fn eval_last_2_full_rounds_16( initial_state: &[AB::IF; WIDTH], state: &mut [AB::IF; WIDTH], outputs: &[AB::IF; WIDTH / 2], + outputs_high: &[AB::IF; WIDTH / 2], round_constants_1: &[F; WIDTH], round_constants_2: &[F; WIDTH], flag_half_output: AB::IF, + flag_full_output: AB::IF, builder: &mut AB, ) { for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { @@ -473,20 +557,28 @@ fn eval_last_2_full_rounds_16( *s = s.cube(); } mds_air_16(state); - // add inputs to outputs (for compression) + // add inputs to outputs (for compression / MMO feedforward) for (state_i, init_state_i) in state.iter_mut().zip(initial_state) { *state_i += *init_state_i; } let one_minus_flag_half_output = AB::IF::ONE - flag_half_output; - for (idx, (state_i, output_i)) in state.iter_mut().zip(outputs).enumerate() { + let one_minus_flag_full_output = AB::IF::ONE - flag_full_output; + // First 8 outputs: existing behavior (always 0..4, gated by half on 4..8). + for (idx, (state_i, output_i)) in state.iter().take(WIDTH / 2).zip(outputs).enumerate() { if idx < HALF_DIGEST_LEN { - // First 4 outputs: always constrained builder.assert_eq(*state_i, *output_i); } else { - // Last 4 outputs: constrained only when half_output = 0 builder.assert_zero(one_minus_flag_half_output * (*state_i - *output_i)); } - *state_i = *output_i; + } + // Outputs_high: constrained to state[8..16] when full_output, else forced to zero. + for (state_i, output_high_i) in state.iter().skip(WIDTH / 2).zip(outputs_high) { + builder.assert_zero(flag_full_output * (*state_i - *output_high_i)); + builder.assert_zero(one_minus_flag_full_output * *output_high_i); + } + // Mirror the original "advance state to output" so any downstream code sees the canonical state. + for (idx, output_i) in outputs.iter().enumerate() { + state[idx] = *output_i; } } diff --git a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs index fca712257..c1350e02c 100644 --- a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs @@ -104,6 +104,7 @@ pub(super) fn generate_trace_rows_for_perm + Copy>(perm: & &mut state, &inputs, &mut perm.outputs, + &mut perm.outputs_high, &poseidon1_final_constants()[2 * n_ending_full_rounds], &poseidon1_final_constants()[2 * n_ending_full_rounds + 1], ); @@ -138,6 +139,7 @@ fn generate_last_2_full_rounds + Copy>( state: &mut [F; WIDTH], inputs: &[F; WIDTH], outputs: &mut [&mut F; WIDTH / 2], + outputs_high: &mut [&mut F; WIDTH / 2], round_constants_1: &[KoalaBear; WIDTH], round_constants_2: &[KoalaBear; WIDTH], ) { @@ -153,8 +155,16 @@ fn generate_last_2_full_rounds + Copy>( } mds_circ_16(state); - // Add inputs to outputs (compression) - for ((output, state_i), &input_i) in outputs.iter_mut().zip(state).zip(inputs) { - **output = *state_i + input_i; + // Add inputs to outputs (compression / MMO feedforward). + // First half of state goes into `outputs`; second half into `outputs_high`. + // Note: the AIR forces outputs_high to zero when flag_full_output = 0; the + // lean_prover post-processing pass overwrites these columns to zero for + // non-full-output rows. For full-output rows the values written here are + // exactly what the AIR + lookup expect (state[i+8] + inputs[i+8]). + for (idx, (output, &input_i)) in outputs.iter_mut().zip(inputs.iter().take(WIDTH / 2)).enumerate() { + **output = state[idx] + input_i; + } + for (idx, (output_high, &input_i)) in outputs_high.iter_mut().zip(inputs.iter().skip(WIDTH / 2)).enumerate() { + **output_high = state[idx + WIDTH / 2] + input_i; } } diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index ef501732a..0459c9c88 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -54,14 +54,101 @@ def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hash return -@inline def slice_hash_rtl(data, num_chunks): - states = Array((num_chunks - 1) * DIGEST_LEN) + """RATE=12 sponge over data of length num_chunks * 8 base elements. + Pads internally so that the absorbed length is 16 + 12*k (sponge-aligned), + matching the native prover's padded_full_base_width helper. + + `num_chunks` is `Const`, so all arithmetic and the if-branches below + resolve at compile time. + + Algorithm (mirrors Rust hash_rtl_iter for RATE=12, WIDTH=16): + state = padded_data[L-16..L] # initial state from last 16 elements + compress(state) + for chunk_idx descending from k-1 to 0: + state[0..4] persists (capacity) + state[4..16] = padded_data[chunk_idx*12..(chunk_idx+1)*12] + compress(state) + return state[0..8] + """ + if num_chunks == 1: + # data_len = 8 ; pad to 16 ; one permute. + buf = Array(16) + for i in unroll(0, 8): + buf[i] = data[i] + for i in unroll(8, 16): + buf[i] = 0 + result = Array(DIGEST_LEN) + poseidon16_compress(buf, buf + DIGEST_LEN, result) + return result + if num_chunks == 4: + return slice_hash_rtl_rate12(data, 32, 40, 2) + if num_chunks == 5: + return slice_hash_rtl_rate12(data, 40, 40, 2) + if num_chunks == 8: + return slice_hash_rtl_rate12(data, 64, 64, 4) + if num_chunks == 10: + return slice_hash_rtl_rate12(data, 80, 88, 6) + if num_chunks == 16: + return slice_hash_rtl_rate12(data, 128, 136, 10) + if num_chunks == 20: + return slice_hash_rtl_rate12(data, 160, 160, 12) + print(num_chunks) + assert False, "slice_hash_rtl called with unsupported num_chunks" + + +def slice_hash_rtl_rate12(data, data_len: Const, padded_len: Const, n_chunks_12: Const): + """Internal helper for RATE=12 sponge with explicit padding params. + Pre: padded_len = 16 + n_chunks_12 * 12 ; padded_len >= data_len. + """ + if padded_len == data_len: + # No padding needed; absorb directly from data. + return slice_hash_rtl_rate12_no_pad(data, padded_len, n_chunks_12) + # Build a local padded buffer once, then absorb from it. + padded_data = Array(padded_len) + for i in unroll(0, data_len): + padded_data[i] = data[i] + for i in unroll(data_len, padded_len): + padded_data[i] = 0 + return slice_hash_rtl_rate12_no_pad(padded_data, padded_len, n_chunks_12) + + +def slice_hash_rtl_rate12_no_pad(padded_data, padded_len: Const, n_chunks_12: Const): + """MMO sponge: ADD message into rate, full-state feedforward. + + The chaining variable is the FULL 16-element state across rounds, giving + output_bits/2 = 124-bit collision security regardless of capacity. + The output is the first 8 elements of the final 16-element state. + + states[k*16..(k+1)*16] holds the full 16-element state after round k. + """ + states = Array((n_chunks_12 + 1) * 16) + + # Round 0: states[0..16] = padded_data[len-16..len] + perm(padded_data[len-16..len]) + # (zero IV implicit; first absorb feeds the last 16 elements of input as the initial state). + poseidon16_permute( + padded_data + padded_len - 16, + padded_data + padded_len - 8, + states, + ) - poseidon16_compress(data + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, states) - for j in unroll(1, num_chunks - 1): - poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + (num_chunks - 2 - j) * DIGEST_LEN, states + j * DIGEST_LEN) - return states + (num_chunks - 2) * DIGEST_LEN + # Subsequent rounds: absorb 12-element chunks RTL using MMO compression. + # pre[0..4] = states[j*16..j*16+4] (capacity unchanged) + # pre[4..16] = states[j*16+4..j*16+16] + chunk (rate gets ADDED with chunk) + # states[(j+1)*16..(j+2)*16] = pre + perm(pre) (full-state feedforward) + for j in unroll(0, n_chunks_12): + chunk_idx = n_chunks_12 - 1 - j + + pre = Array(16) + for k in unroll(0, 4): + pre[k] = states[j * 16 + k] + for k in unroll(0, 12): + pre[4 + k] = states[j * 16 + 4 + k] + padded_data[chunk_idx * 12 + k] + + poseidon16_permute(pre, pre + 8, states + (j + 1) * 16) + + # Output the first 8 elements of the final state. + return states + n_chunks_12 * 16 @inline diff --git a/crates/utils/src/poseidon.rs b/crates/utils/src/poseidon.rs index 3eb2557ba..b8d60d5e5 100644 --- a/crates/utils/src/poseidon.rs +++ b/crates/utils/src/poseidon.rs @@ -24,6 +24,14 @@ pub fn poseidon16_compress(input: [KoalaBear; 16]) -> [KoalaBear; 8] { get_poseidon16().compress(input)[0..8].try_into().unwrap() } +/// Like `poseidon16_compress` but exposes the FULL 16-element output (with +/// input feedforward = MMO compression). Used by the `poseidon16_permute` +/// precompile to support MMO sponge leaf hashing. +#[inline(always)] +pub fn poseidon16_permute_full(input: [KoalaBear; 16]) -> [KoalaBear; 16] { + get_poseidon16().compress(input) +} + pub fn poseidon16_compress_pair(left: &[KoalaBear; 8], right: &[KoalaBear; 8]) -> [KoalaBear; 8] { let mut input = [KoalaBear::default(); 16]; input[..8].copy_from_slice(left); diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index b5517cd09..6b1758a44 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -8,6 +8,7 @@ use field::BasedVectorSpace; use field::ExtensionField; use field::Field; use field::PackedValue; +use field::PrimeCharacteristicRing; use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; use poly::*; @@ -55,6 +56,23 @@ pub(crate) fn merkle_commit>( } } +// Sponge rate for Merkle leaf hashing. WIDTH=16 (Poseidon1KoalaBear16) gives +// capacity = WIDTH - RATE = 4 with RATE=12. See iter 29 WIP analysis for the +// security tradeoff. +const SPONGE_RATE: usize = 12; +const SPONGE_WIDTH: usize = 16; + +/// Pad up so that (padded - WIDTH) is divisible by RATE. Used only for sponge +/// alignment; the protocol-visible leaf width stays unpadded. +#[inline] +fn padded_full_base_width(full_base_width: usize) -> usize { + let mut padded = full_base_width.max(SPONGE_WIDTH); + while !(padded - SPONGE_WIDTH).is_multiple_of(SPONGE_RATE) { + padded += 1; + } + padded +} + #[instrument(name = "build merkle tree", skip_all)] fn build_merkle_tree_koalabear( leaf: DenseMatrix, @@ -62,24 +80,45 @@ fn build_merkle_tree_koalabear( effective_base_width: usize, ) -> RoundMerkleTree { let perm = default_koalabear_poseidon1_16(); - let n_zero_suffix_rate_chunks = (full_base_width - effective_base_width) / 8; + // Internal padding for sponge alignment. NOT exposed to the protocol layer. + let padded_full_width = padded_full_base_width(full_base_width); + // n_zero_suffix_rate_chunks = number of "zero RATE-chunks" the precompute + // must cover so that the remaining iter (effective + n_pad elements, + // where n_pad rounds effective up to a multiple of RATE) takes the sponge + // exactly to padded_full_width. precompute(n) does n-1 compresses, + // absorbing 16 zeros initially + (n-2)*RATE more = WIDTH + (n-2)*RATE. + // Total: WIDTH + (n-2)*RATE + (effective + n_pad) = padded. + // Solving: n = 2 + (padded - WIDTH - effective - n_pad) / RATE. + let n_pad = (SPONGE_RATE - effective_base_width % SPONGE_RATE) % SPONGE_RATE; + let n_zero_suffix_rate_chunks = if padded_full_width >= SPONGE_WIDTH + effective_base_width + n_pad { + 2 + (padded_full_width - SPONGE_WIDTH - effective_base_width - n_pad) / SPONGE_RATE + } else { + 0 + }; let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( - &perm, - n_zero_suffix_rate_chunks, - ); - let packed_state: [PFPacking; 16] = + let scalar_state = + symetric::mmo_precompute_zero_suffix_state::( + &perm, + n_zero_suffix_rate_chunks, + ); + let packed_state: [PFPacking; SPONGE_WIDTH] = std::array::from_fn(|i| PFPacking::::from_fn(|_| scalar_state[i])); - first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, 16, 8>( + first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, SPONGE_WIDTH, SPONGE_RATE>( &perm, &leaf, &packed_state, effective_base_width, ) } else { - first_digest_layer::, _, _, DIGEST_ELEMS, 16, 8>(&perm, &leaf, full_base_width) + first_digest_layer::, _, _, DIGEST_ELEMS, SPONGE_WIDTH, SPONGE_RATE>( + &perm, + &leaf, + padded_full_width, + ) }; - let tree = symetric::merkle::MerkleTree::from_first_layer::, _, 16>(&perm, first_layer); + let tree = + symetric::merkle::MerkleTree::from_first_layer::, _, SPONGE_WIDTH>(&perm, first_layer); + // Expose UNPADDED width to the protocol; padding is purely a sponge detail. WhirMerkleTree { leaf, tree, @@ -125,8 +164,11 @@ pub(crate) fn merkle_verify>( let merkle_root = unsafe { std::mem::transmute_copy::<_, [KoalaBear; DIGEST_ELEMS]>(&merkle_root) }; let data = unsafe { std::mem::transmute::<_, Vec>(data) }; let proof = unsafe { std::mem::transmute::<_, &Vec<[KoalaBear; DIGEST_ELEMS]>>(proof) }; - let base_data = QuinticExtensionFieldKB::flatten_to_base(data); - symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, 16, 8>( + let mut base_data = QuinticExtensionFieldKB::flatten_to_base(data); + // Pad to the sponge-aligned width (matches the prover's internal padding). + let padded = padded_full_base_width(base_data.len()); + base_data.resize(padded, KoalaBear::ZERO); + symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, SPONGE_WIDTH, SPONGE_RATE>( &perm, &merkle_root, log_max_height, @@ -138,8 +180,10 @@ pub(crate) fn merkle_verify>( let merkle_root = unsafe { std::mem::transmute_copy::<_, [KoalaBear; DIGEST_ELEMS]>(&merkle_root) }; let data = unsafe { std::mem::transmute::<_, Vec>(data) }; let proof = unsafe { std::mem::transmute::<_, &Vec<[KoalaBear; DIGEST_ELEMS]>>(proof) }; - let base_data = KoalaBear::flatten_to_base(data); - symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, 16, 8>( + let mut base_data = KoalaBear::flatten_to_base(data); + let padded = padded_full_base_width(base_data.len()); + base_data.resize(padded, KoalaBear::ZERO); + symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, SPONGE_WIDTH, SPONGE_RATE>( &perm, &merkle_root, log_max_height, @@ -170,12 +214,13 @@ impl, const DIGEST_ELEMS: effective_base_width: usize, ) -> Self where - P: PackedValue + Default, + F: field::PrimeCharacteristicRing, + P: PackedValue + Default + field::PrimeCharacteristicRing, Perm: Compression<[F; WIDTH]> + Compression<[P; WIDTH]>, { let n_zero_suffix_rate_chunks = (full_leaf_base_width - effective_base_width) / RATE; let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( + let scalar_state = symetric::mmo_precompute_zero_suffix_state::( perm, n_zero_suffix_rate_chunks, ); @@ -218,7 +263,7 @@ fn first_digest_layer Vec<[P::Value; DIGEST_ELEMS]> where - P: PackedValue + Default, + P: PackedValue + Default + field::PrimeCharacteristicRing, P::Value: Default + Copy, Perm: Compression<[P::Value; WIDTH]> + Compression<[P; WIDTH]>, M: Matrix, @@ -238,7 +283,7 @@ where let first_row = i * width; let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, matrix_width, n_trailing_zeros); let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(perm, rtl_iter); + symetric::mmo_hash_rtl_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(perm, rtl_iter); for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { *dst = src; } @@ -255,7 +300,7 @@ fn first_digest_layer_with_initial_state Vec<[P::Value; DIGEST_ELEMS]> where - P: PackedValue + Default, + P: PackedValue + Default + field::PrimeCharacteristicRing, P::Value: Default + Copy, Perm: Compression<[P::Value; WIDTH]> + Compression<[P; WIDTH]>, M: Matrix, @@ -274,7 +319,7 @@ where let first_row = i * width; let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( + symetric::mmo_hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( perm, rtl_iter, packed_initial_state,