Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 35 additions & 34 deletions crates/backend/fiat-shamir/src/challenger.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,65 @@
use field::PrimeField64;
use symetric::Compression;
use koala_bear::symmetric::Permutation;

pub(crate) const RATE: usize = 8;
pub(crate) const WIDTH: usize = RATE * 2;
pub(crate) const CAPACITY: usize = WIDTH - RATE;

#[derive(Clone, Debug)]
pub struct Challenger<F, P> {
pub compressor: P,
pub state: [F; RATE],
pub permutation: P,
pub state: [F; WIDTH],
rate_fresh: bool,
}

impl<F: PrimeField64, P: Compression<[F; WIDTH]>> Challenger<F, P> {
pub fn new(compressor: P) -> Self
impl<F: PrimeField64, P: Permutation<[F; WIDTH]>> Challenger<F, P> {
pub fn new(permutation: P) -> Self
where
F: Default,
{
Self {
compressor,
state: [F::ZERO; RATE],
permutation,
state: [F::ZERO; WIDTH],
rate_fresh: false,
}
}

pub fn observe(&mut self, value: [F; RATE]) {
self.state = self.compressor.compress({
let mut concat = [F::ZERO; WIDTH];
concat[..RATE].copy_from_slice(&self.state);
concat[RATE..].copy_from_slice(&value);
concat
})[..RATE]
.try_into()
.unwrap();
self.state[CAPACITY..].copy_from_slice(&value);
self.permutation.permute_mut(&mut self.state);
self.rate_fresh = true;
}

pub fn observe_scalars(&mut self, scalars: &[F]) {
pub fn observe_many(&mut self, scalars: &[F]) {
for chunk in scalars.chunks(RATE) {
let mut buffer = [F::ZERO; RATE];
for (i, val) in chunk.iter().enumerate() {
buffer[i] = *val;
}
buffer[..chunk.len()].copy_from_slice(chunk);
self.observe(buffer);
}
}

pub fn duplex(&mut self) {
self.observe([F::ZERO; RATE]);
}

pub fn sample(&mut self) -> [F; RATE] {
assert!(self.rate_fresh, "stale rate. insert a duplex() before.");
let out: [F; RATE] = self.state[CAPACITY..].try_into().unwrap();
self.rate_fresh = false;
out
}

pub fn sample_many(&mut self, n: usize) -> Vec<[F; RATE]> {
let mut sampled = Vec::with_capacity(n + 1);
for i in 0..n + 1 {
let mut domain_sep = [F::ZERO; RATE];
domain_sep[0] = F::from_usize(i);
let hashed = self.compressor.compress({
let mut concat = [F::ZERO; WIDTH];
concat[..RATE].copy_from_slice(&domain_sep);
concat[RATE..].copy_from_slice(&self.state);
concat
})[..RATE]
.try_into()
.unwrap();
sampled.push(hashed);
if n == 0 {
return Vec::new();
}
let mut out = Vec::with_capacity(n);
out.push(self.sample());
for _ in 1..n {
self.duplex();
out.push(self.sample());
}
self.state = sampled.pop().unwrap();
sampled
out
}

/// Warning: not perfectly uniform
Expand Down
44 changes: 25 additions & 19 deletions crates/backend/fiat-shamir/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
use crate::{
MerklePaths, PrunedMerklePaths,
challenger::{Challenger, RATE, WIDTH},
challenger::{CAPACITY, Challenger, RATE, WIDTH},
*,
};
use field::Field;
use field::PackedValue;
use field::PrimeCharacteristicRing;
use field::integers::QuotientMap;
use field::{ExtensionField, PrimeField64};
use koala_bear::symmetric::Permutation;
use rayon::prelude::*;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use std::{fmt::Debug, sync::Mutex, time::Instant};
use symetric::Compression;

static POW_GRINDING_NANOS: AtomicU64 = AtomicU64::new(0);

Expand All @@ -31,15 +31,15 @@ pub struct ProverState<EF: ExtensionField<PF<EF>>, P> {
merkle_paths: Vec<PrunedMerklePaths<PF<EF>, PF<EF>>>,
}

impl<EF: ExtensionField<PF<EF>>, P: Compression<[PF<EF>; WIDTH]>> ProverState<EF, P>
impl<EF: ExtensionField<PF<EF>>, P: Permutation<[PF<EF>; WIDTH]>> ProverState<EF, P>
where
PF<EF>: PrimeField64,
{
#[must_use]
pub fn new(compressor: P) -> Self {
pub fn new(permutation: P) -> Self {
assert!(EF::DIMENSION <= RATE);
Self {
challenger: Challenger::new(compressor),
challenger: Challenger::new(permutation),
transcript: Vec::new(),
merkle_paths: Vec::new(),
}
Expand All @@ -53,7 +53,7 @@ where
}
}

impl<EF: ExtensionField<PF<EF>>, P: Compression<[PF<EF>; WIDTH]>> ChallengeSampler<EF> for ProverState<EF, P>
impl<EF: ExtensionField<PF<EF>>, P: Permutation<[PF<EF>; WIDTH]>> ChallengeSampler<EF> for ProverState<EF, P>
where
PF<EF>: PrimeField64,
{
Expand All @@ -66,18 +66,22 @@ where
}
}

impl<EF: ExtensionField<PF<EF>>, P: Compression<[PF<EF>; WIDTH]> + Compression<[<PF<EF> as Field>::Packing; WIDTH]>>
impl<EF: ExtensionField<PF<EF>>, P: Permutation<[PF<EF>; WIDTH]> + Permutation<[<PF<EF> as Field>::Packing; WIDTH]>>
FSProver<EF> for ProverState<EF, P>
where
PF<EF>: PrimeField64,
{
fn add_base_scalars(&mut self, scalars: &[PF<EF>]) {
self.challenger.observe_scalars(scalars);
self.challenger.observe_many(scalars);
self.transcript.extend_from_slice(scalars);
}

fn observe_scalars(&mut self, scalars: &[PF<EF>]) {
self.challenger.observe_scalars(scalars);
self.challenger.observe_many(scalars);
}

fn duplex(&mut self) {
self.challenger.duplex();
}

fn state(&self) -> String {
Expand All @@ -97,13 +101,13 @@ where
match eq_alpha {
None => {
let scalars = flatten_scalars_to_base(coeffs);
self.challenger.observe_scalars(&scalars);
self.challenger.observe_many(&scalars);
self.transcript.extend_from_slice(&scalars[EF::DIMENSION..]); // c0 reconstructed by verifier from claimed_sum
}
Some(alpha) => {
let bare_scalars = flatten_scalars_to_base(coeffs);
let full_scalars = flatten_scalars_to_base(&expand_bare_to_full(coeffs, alpha));
self.challenger.observe_scalars(&full_scalars);
self.challenger.observe_many(&full_scalars);
self.transcript.extend_from_slice(&bare_scalars[EF::DIMENSION..]); // h0 reconstructed by verifier from claimed_sum
}
}
Expand Down Expand Up @@ -140,15 +144,17 @@ where
});

let mut packed_state = [Packed::<EF>::ZERO; WIDTH];
packed_state[..RATE]
for (slot, val) in packed_state[..CAPACITY]
.iter_mut()
.zip(&self.challenger.state)
.for_each(|(val, state)| *val = Packed::<EF>::from(*state));
packed_state[RATE] = packed_witnesses;
.zip(&self.challenger.state[..CAPACITY])
{
*slot = Packed::<EF>::from(*val);
}
packed_state[CAPACITY] = packed_witnesses;

self.challenger.compressor.compress_mut(&mut packed_state);
self.challenger.permutation.permute_mut(&mut packed_state);

let samples = packed_state[0].as_slice();
let samples = packed_state[CAPACITY].as_slice();
for (sample, witness) in samples.iter().zip(packed_witnesses.as_slice()) {
let rand_usize = sample.as_canonical_u64() as usize;
if (rand_usize & ((1 << bits) - 1)) == 0 {
Expand All @@ -162,8 +168,8 @@ where

let witness = witness_found.lock().unwrap().unwrap();

self.challenger.observe_scalars(&[witness]);
assert!(self.challenger.state[0].as_canonical_u64() & ((1 << bits) - 1) == 0);
self.challenger.observe_many(&[witness]);
assert!(self.challenger.state[CAPACITY].as_canonical_u64() & ((1 << bits) - 1) == 0);
self.transcript.push(witness);

let elapsed = time.elapsed();
Expand Down
2 changes: 2 additions & 0 deletions crates/backend/fiat-shamir/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub trait FSProver<EF: ExtensionField<PF<EF>>>: ChallengeSampler<EF> {
fn state(&self) -> String;
fn add_base_scalars(&mut self, scalars: &[PF<EF>]);
fn observe_scalars(&mut self, scalars: &[PF<EF>]);
fn duplex(&mut self);
fn pow_grinding(&mut self, bits: usize);
fn hint_merkle_paths_base(&mut self, paths: Vec<MerklePath<PF<EF>, PF<EF>>>);
fn add_sumcheck_polynomial(&mut self, coeffs: &[EF], eq_alpha: Option<EF>);
Expand Down Expand Up @@ -46,6 +47,7 @@ pub trait FSVerifier<EF: ExtensionField<PF<EF>>>: ChallengeSampler<EF> {
fn state(&self) -> String;
fn next_base_scalars_vec(&mut self, n: usize) -> Result<Vec<PF<EF>>, ProofError>;
fn observe_scalars(&mut self, scalars: &[PF<EF>]);
fn duplex(&mut self);
fn next_merkle_opening(&mut self) -> Result<MerkleOpening<PF<EF>>, ProofError>;
fn check_pow_grinding(&mut self, bits: usize) -> Result<(), ProofError>;
fn next_sumcheck_polynomial(
Expand Down
4 changes: 2 additions & 2 deletions crates/backend/fiat-shamir/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use field::{BasedVectorSpace, ExtensionField, Field, PrimeCharacteristicRing, PrimeField64};
use symetric::Compression;
use koala_bear::symmetric::Permutation;

use crate::challenger::{Challenger, RATE, WIDTH};

Expand Down Expand Up @@ -40,7 +40,7 @@ pub fn expand_bare_to_full<EF: Field>(bare: &[EF], alpha: EF) -> Vec<EF> {
full
}

pub(crate) fn sample_vec<F: PrimeField64, EF: ExtensionField<F>, P: Compression<[F; WIDTH]>>(
pub(crate) fn sample_vec<F: PrimeField64, EF: ExtensionField<F>, P: Permutation<[F; WIDTH]>>(
challenger: &mut Challenger<F, P>,
len: usize,
) -> Vec<EF> {
Expand Down
26 changes: 15 additions & 11 deletions crates/backend/fiat-shamir/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ use std::iter::repeat_n;

use crate::{
MerkleOpening, MerklePaths, PrunedMerklePaths, RawProof,
challenger::{Challenger, RATE, WIDTH},
challenger::{CAPACITY, Challenger, RATE, WIDTH},
transcript::{DIGEST_LEN_FE, Proof},
*,
};
use field::PrimeCharacteristicRing;
use field::{ExtensionField, PrimeField64};
use koala_bear::symmetric::Permutation;
use koala_bear::{KoalaBear, default_koalabear_poseidon1_16};
use symetric::Compression;

pub struct VerifierState<EF: ExtensionField<PF<EF>>, P> {
challenger: Challenger<PF<EF>, P>,
Expand All @@ -21,19 +21,19 @@ pub struct VerifierState<EF: ExtensionField<PF<EF>>, P> {
raw_transcript: Vec<PF<EF>>, // reconstructed during the proof verification, it's the format that the zkVM recursion program expects (no Merkle pruning, no sumcheck optimization to send less data, etc)
}

impl<EF: ExtensionField<PF<EF>>, C: Compression<[PF<EF>; WIDTH]>> VerifierState<EF, C>
impl<EF: ExtensionField<PF<EF>>, P: Permutation<[PF<EF>; WIDTH]>> VerifierState<EF, P>
where
PF<EF>: PrimeField64,
{
pub fn new(proof: Proof<PF<EF>>, compressor: C) -> Result<Self, ProofError> {
pub fn new(proof: Proof<PF<EF>>, permutation: P) -> Result<Self, ProofError> {
let mut merkle_openings = Vec::new();
for paths in proof.merkle_paths {
let restored = Self::restore_merkle_paths(paths).ok_or(ProofError::InvalidProof)?;
merkle_openings.extend(restored);
}

Ok(Self {
challenger: Challenger::new(compressor),
challenger: Challenger::new(permutation),
transcript: proof.transcript,
transcript_offset: 0,
merkle_openings,
Expand All @@ -50,7 +50,7 @@ where
}

fn absorb_and_record(&mut self, scalars: &[PF<EF>]) {
self.challenger.observe_scalars(scalars);
self.challenger.observe_many(scalars);
let total_padded = scalars.len().next_multiple_of(RATE);
self.raw_transcript.extend_from_slice(scalars);
self.raw_transcript
Expand Down Expand Up @@ -90,7 +90,7 @@ where
}
}

impl<EF: ExtensionField<PF<EF>>, C: Compression<[PF<EF>; WIDTH]>> ChallengeSampler<EF> for VerifierState<EF, C>
impl<EF: ExtensionField<PF<EF>>, P: Permutation<[PF<EF>; WIDTH]>> ChallengeSampler<EF> for VerifierState<EF, P>
where
PF<EF>: PrimeField64,
{
Expand All @@ -102,7 +102,7 @@ where
}
}

impl<EF: ExtensionField<PF<EF>>, C: Compression<[PF<EF>; WIDTH]>> FSVerifier<EF> for VerifierState<EF, C>
impl<EF: ExtensionField<PF<EF>>, P: Permutation<[PF<EF>; WIDTH]>> FSVerifier<EF> for VerifierState<EF, P>
where
PF<EF>: PrimeField64,
{
Expand All @@ -121,7 +121,11 @@ where
}

fn observe_scalars(&mut self, scalars: &[PF<EF>]) {
self.challenger.observe_scalars(scalars);
self.challenger.observe_many(scalars);
}

fn duplex(&mut self) {
self.challenger.duplex();
}

fn next_base_scalars_vec(&mut self, n: usize) -> Result<Vec<PF<EF>>, ProofError> {
Expand All @@ -144,8 +148,8 @@ where
return Ok(());
}
let witness = self.read_transcript(1)?[0];
self.challenger.observe_scalars(&[witness]);
if self.challenger.state[0].as_canonical_u64() & ((1 << bits) - 1) != 0 {
self.challenger.observe_many(&[witness]);
if self.challenger.state[CAPACITY].as_canonical_u64() & ((1 << bits) - 1) != 0 {
return Err(ProofError::InvalidGrindingWitness);
}
self.raw_transcript.push(witness);
Expand Down
6 changes: 6 additions & 0 deletions crates/lean_compiler/snark_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def poseidon16_compress_half_hardcoded_left(left, right, output, offset):
_ = left, right, output, offset


def poseidon16_permute(left, right, output):
"""Raw Poseidon1 permutation (no feed-forward). Writes the 16-cell result in natural order:
m[output .. output + 16] = poseidon(left || right)"""
_ = left, right, output


def add_be(a, b, result, length=None):
_ = a, b, result, length

Expand Down
8 changes: 5 additions & 3 deletions crates/lean_compiler/src/a_simplify_lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use crate::{
use backend::PrimeCharacteristicRing;
use lean_vm::{
ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName,
POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, PrecompileArgs,
PrecompileCompTimeArgs, SourceLocation,
POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_PERMUTE_NAME,
PrecompileArgs, PrecompileCompTimeArgs, SourceLocation,
};
use std::{
collections::{BTreeMap, BTreeSet},
Expand Down Expand Up @@ -2259,13 +2259,14 @@ fn simplify_lines(
continue;
}

// Special handling for poseidon16 precompile (4 variants)
// Special handling for poseidon16 precompile (5 variants).
if ALL_POSEIDON16_NAMES.contains(&function_name.as_str()) {
if !targets.is_empty() {
return Err(format!(
"Precompile {function_name} should not return values, at {location}"
));
}
let permute = function_name.as_str() == POSEIDON16_PERMUTE_NAME;
let half_output = [POSEIDON16_HALF_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME]
.contains(&function_name.as_str());
let is_hardcoded_left =
Expand Down Expand Up @@ -2303,6 +2304,7 @@ fn simplify_lines(
data: PrecompileCompTimeArgs::Poseidon16 {
half_output,
hardcoded_offset_left,
permute,
},
}));
continue;
Expand Down
Loading
Loading