diff --git a/Cargo.lock b/Cargo.lock index d938586b8..61403de2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -608,6 +608,7 @@ name = "mt-air" version = "0.1.0" dependencies = [ "mt-field", + "mt-koala-bear", "mt-poly", ] diff --git a/crates/backend/air/Cargo.toml b/crates/backend/air/Cargo.toml index 3868ecf2b..8f92951aa 100644 --- a/crates/backend/air/Cargo.toml +++ b/crates/backend/air/Cargo.toml @@ -6,3 +6,6 @@ edition.workspace = true [dependencies] field = { path = "../field", package = "mt-field" } poly = { path = "../poly", package = "mt-poly" } + +[dev-dependencies] +koala_bear = { path = "../koala-bear", package = "mt-koala-bear" } diff --git a/crates/backend/air/src/symbolic.rs b/crates/backend/air/src/symbolic.rs index 0e1fcf06e..cb53e13bc 100644 --- a/crates/backend/air/src/symbolic.rs +++ b/crates/backend/air/src/symbolic.rs @@ -84,6 +84,10 @@ fn alloc_node(node: SymbolicNode) -> u32 { let mut bytes = arena.borrow_mut(); let node_size = std::mem::size_of::>(); let idx = bytes.len(); + assert!( + idx <= u32::MAX as usize, + "symbolic arena overflow: index {idx} exceeds u32::MAX" + ); bytes.resize(idx + node_size, 0); unsafe { std::ptr::write_unaligned(bytes.as_mut_ptr().add(idx) as *mut SymbolicNode, node); @@ -95,7 +99,14 @@ fn alloc_node(node: SymbolicNode) -> u32 { pub fn get_node(idx: u32) -> SymbolicNode { ARENA.with(|arena| { let bytes = arena.borrow(); - unsafe { std::ptr::read_unaligned(bytes.as_ptr().add(idx as usize) as *const SymbolicNode) } + let offset = idx as usize; + let node_size = std::mem::size_of::>(); + assert!( + offset.checked_add(node_size).is_some_and(|end| end <= bytes.len()), + "arena out-of-bounds: index {idx} (need {node_size} bytes, arena has {} bytes)", + bytes.len() + ); + unsafe { std::ptr::read_unaligned(bytes.as_ptr().add(offset) as *const SymbolicNode) } }) } diff --git a/crates/backend/air/tests/arena_safety.rs b/crates/backend/air/tests/arena_safety.rs new file mode 100644 index 000000000..87829941f --- /dev/null +++ b/crates/backend/air/tests/arena_safety.rs @@ -0,0 +1,75 @@ +// Tests for arena bounds checking in symbolic expression handling. +// Addresses issue #170: unchecked arena index enables out-of-bounds reads. + +use field::PrimeCharacteristicRing; +use koala_bear::KoalaBear; +use mt_air::{SymbolicExpression, get_node}; + +type F = KoalaBear; + +#[test] +#[should_panic(expected = "arena out-of-bounds")] +fn get_node_oob_panics_on_empty_arena() { + // Constructing Operation(0) on an empty arena must panic with a bounds + // check, not trigger undefined behavior via read_unaligned. + let _: mt_air::SymbolicNode = get_node(0); +} + +#[test] +#[should_panic(expected = "arena out-of-bounds")] +fn get_node_oob_panics_on_large_index() { + // A fabricated large index must be caught by the bounds check. + let _: mt_air::SymbolicNode = get_node(u32::MAX); +} + +#[test] +fn alloc_and_get_node_roundtrip() { + // Building a real expression through arithmetic triggers alloc_node + // internally. The resulting Operation index must be readable. + let a = SymbolicExpression::::Variable(mt_air::SymbolicVariable::new(0)); + let b = SymbolicExpression::::Variable(mt_air::SymbolicVariable::new(1)); + let sum = a + b; + + if let SymbolicExpression::Operation(idx) = sum { + let node = get_node::(idx); + assert_eq!(node.op, mt_air::SymbolicOperation::Add); + } else { + panic!("expected Operation variant from variable addition"); + } +} + +#[test] +fn nested_expression_indices_are_valid() { + // Multiple levels of nesting must all produce valid arena indices. + let a = SymbolicExpression::::Variable(mt_air::SymbolicVariable::new(0)); + let b = SymbolicExpression::::Constant(F::TWO); + let c = SymbolicExpression::::Variable(mt_air::SymbolicVariable::new(1)); + + // (a * b) + c — two alloc_node calls + let product = a * b; + let sum = product + c; + + if let SymbolicExpression::Operation(idx) = sum { + let node = get_node::(idx); + assert_eq!(node.op, mt_air::SymbolicOperation::Add); + // The lhs should be the product Operation + if let SymbolicExpression::Operation(mul_idx) = node.lhs { + let mul_node = get_node::(mul_idx); + assert_eq!(mul_node.op, mt_air::SymbolicOperation::Mul); + } else { + panic!("expected Operation for nested lhs"); + } + } else { + panic!("expected Operation variant from nested expression"); + } +} + +#[test] +fn constant_folding_bypasses_arena() { + // Constant + Constant should fold without arena allocation. + let a = SymbolicExpression::::Constant(F::ONE); + let b = SymbolicExpression::::Constant(F::TWO); + let sum = a + b; + + assert!(matches!(sum, SymbolicExpression::Constant(_))); +} diff --git a/crates/lean_vm/src/diagnostics/error.rs b/crates/lean_vm/src/diagnostics/error.rs index 16492d708..cc1235343 100644 --- a/crates/lean_vm/src/diagnostics/error.rs +++ b/crates/lean_vm/src/diagnostics/error.rs @@ -23,6 +23,7 @@ pub enum RunnerError { range: usize, }, InvalidExtensionOp, + AddressOverflow, ParallelSegmentFailed(usize, Box), } @@ -55,6 +56,7 @@ impl Display for RunnerError { ) } Self::InvalidExtensionOp => write!(f, "invalid extension op"), + Self::AddressOverflow => write!(f, "address overflow: fp + offset exceeds usize"), Self::ParallelSegmentFailed(id, err) => { write!(f, "parallel segment {id} failed: {err}") } diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 0b812dea8..793115250 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -283,7 +283,8 @@ impl Hint { let size = size.read_value(ctx.memory, ctx.fp)?.to_usize(); let allocation_start_addr = *ctx.ap; - ctx.memory.set(ctx.fp + *offset, F::from_usize(allocation_start_addr))?; + let addr = ctx.fp.checked_add(*offset).ok_or(RunnerError::AddressOverflow)?; + ctx.memory.set(addr, F::from_usize(allocation_start_addr))?; *ctx.ap += size; } Self::Custom(hint, args) => { @@ -292,7 +293,8 @@ impl Hint { Self::Inverse { arg, res_offset } => { let value = arg.read_value(ctx.memory, ctx.fp)?; let result = value.try_inverse().unwrap_or(F::ZERO); - ctx.memory.set(ctx.fp + *res_offset, result)?; + let addr = ctx.fp.checked_add(*res_offset).ok_or(RunnerError::AddressOverflow)?; + ctx.memory.set(addr, result)?; } Self::Print { line_info, content } => { if let Some(diag) = &mut ctx.hints.diagnostics { @@ -364,8 +366,8 @@ impl Hint { offset_target, } => { // Record a deref constraint: memory[target_addr] = memory[memory[src_addr]] - let src_addr = ctx.fp + offset_src; - let target_addr = ctx.fp + offset_target; + let src_addr = ctx.fp.checked_add(*offset_src).ok_or(RunnerError::AddressOverflow)?; + let target_addr = ctx.fp.checked_add(*offset_target).ok_or(RunnerError::AddressOverflow)?; ctx.pending_deref_hints.push((target_addr, src_addr)); } Self::Panic { message } => { @@ -380,8 +382,13 @@ impl Hint { Self::HintWitness { name, destination } => { let data = consume_next_hint_entry(ctx.hints.named_hints, name); let dest_addr = match destination { - HintWitnessDestination::Inline { offset } => ctx.fp + *offset, - HintWitnessDestination::Indirect { ptr_offset } => ctx.memory.get(ctx.fp + *ptr_offset)?.to_usize(), + HintWitnessDestination::Inline { offset } => { + ctx.fp.checked_add(*offset).ok_or(RunnerError::AddressOverflow)? + } + HintWitnessDestination::Indirect { ptr_offset } => { + let addr = ctx.fp.checked_add(*ptr_offset).ok_or(RunnerError::AddressOverflow)?; + ctx.memory.get(addr)?.to_usize() + } }; ctx.memory.set_slice(dest_addr, data)?; } diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index f0b7ef212..dd5106f4d 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -188,18 +188,27 @@ impl Instruction { Ok(()) } Self::Deref { shift_0, shift_1, res } => { + let base_addr = ctx.fp.checked_add(*shift_0).ok_or(RunnerError::AddressOverflow)?; if res.is_value_unknown(ctx.memory, *ctx.fp) { let memory_address_res = res.memory_address(*ctx.fp)?; - let ptr = ctx.memory.get(*ctx.fp + shift_0)?; - if let Ok(value) = ctx.memory.get(ptr.to_usize() + shift_1) { + let ptr = ctx.memory.get(base_addr)?; + let deref_addr = ptr + .to_usize() + .checked_add(*shift_1) + .ok_or(RunnerError::AddressOverflow)?; + if let Ok(value) = ctx.memory.get(deref_addr) { ctx.memory.set(memory_address_res, value)?; } else { // Do nothing, we are probably in a range check, will be resolved later } } else { let value = res.read_value(ctx.memory, *ctx.fp).unwrap(); - let ptr = ctx.memory.get(*ctx.fp + shift_0)?; - ctx.memory.set(ptr.to_usize() + shift_1, value)?; + let ptr = ctx.memory.get(base_addr)?; + let deref_addr = ptr + .to_usize() + .checked_add(*shift_1) + .ok_or(RunnerError::AddressOverflow)?; + ctx.memory.set(deref_addr, value)?; } ctx.counts.deref += 1; diff --git a/crates/lean_vm/src/isa/operands/mem_or_constant.rs b/crates/lean_vm/src/isa/operands/mem_or_constant.rs index 1082ba33a..e49dc95e8 100644 --- a/crates/lean_vm/src/isa/operands/mem_or_constant.rs +++ b/crates/lean_vm/src/isa/operands/mem_or_constant.rs @@ -29,7 +29,10 @@ impl MemOrConstant { pub fn read_value(&self, memory: &impl MemoryAccess, fp: usize) -> Result { match self { Self::Constant(c) => Ok(*c), - Self::MemoryAfterFp { offset } => memory.get(fp + *offset), + Self::MemoryAfterFp { offset } => { + let addr = fp.checked_add(*offset).ok_or(RunnerError::AddressOverflow)?; + memory.get(addr) + } } } @@ -42,7 +45,10 @@ impl MemOrConstant { pub const fn memory_address(&self, fp: usize) -> Result { match self { Self::Constant(_) => Err(RunnerError::NotAPointer), - Self::MemoryAfterFp { offset } => Ok(fp + *offset), + Self::MemoryAfterFp { offset } => match fp.checked_add(*offset) { + Some(addr) => Ok(addr), + None => Err(RunnerError::AddressOverflow), + }, } } } diff --git a/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs b/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs index 28c984c3d..027378585 100644 --- a/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs +++ b/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs @@ -20,8 +20,14 @@ impl MemOrFpOrConstant { /// Read the value from memory, return fp, or return the constant pub fn read_value(&self, memory: &impl MemoryAccess, fp: usize) -> Result { match self { - Self::MemoryAfterFp { offset } => memory.get(fp + *offset), - Self::FpRelative { offset } => Ok(F::from_usize(fp + *offset)), + Self::MemoryAfterFp { offset } => { + let addr = fp.checked_add(*offset).ok_or(RunnerError::AddressOverflow)?; + memory.get(addr) + } + Self::FpRelative { offset } => { + let addr = fp.checked_add(*offset).ok_or(RunnerError::AddressOverflow)?; + Ok(F::from_usize(addr)) + } Self::Constant(c) => Ok(*c), } } @@ -34,7 +40,10 @@ impl MemOrFpOrConstant { /// Get the memory address (returns error for Fp and constants) pub const fn memory_address(&self, fp: usize) -> Result { match self { - Self::MemoryAfterFp { offset } => Ok(fp + *offset), + Self::MemoryAfterFp { offset } => match fp.checked_add(*offset) { + Some(addr) => Ok(addr), + None => Err(RunnerError::AddressOverflow), + }, Self::FpRelative { .. } => Err(RunnerError::NotAPointer), Self::Constant(_) => Err(RunnerError::NotAPointer), } diff --git a/crates/lean_vm/tests/address_safety.rs b/crates/lean_vm/tests/address_safety.rs new file mode 100644 index 000000000..73b7a2f44 --- /dev/null +++ b/crates/lean_vm/tests/address_safety.rs @@ -0,0 +1,97 @@ +// Tests for checked arithmetic in memory address computation. +// Addresses issue #176: fp+offset overflow enables wrong memory access. + +use backend::PrimeCharacteristicRing; +use lean_vm::{F, MemOrConstant, MemOrFpOrConstant, Memory, RunnerError}; + +#[test] +fn read_value_overflow_returns_error() { + // fp + offset that overflows usize must return AddressOverflow, not wrap. + let mem = Memory::new(vec![F::ONE; 16]); + let operand = MemOrConstant::MemoryAfterFp { offset: usize::MAX }; + let fp = 1usize; + + let result = operand.read_value(&mem, fp); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), RunnerError::AddressOverflow); +} + +#[test] +fn memory_address_overflow_returns_error() { + let operand = MemOrConstant::MemoryAfterFp { offset: usize::MAX }; + let fp = 1usize; + + let result = operand.memory_address(fp); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), RunnerError::AddressOverflow); +} + +#[test] +fn read_value_no_overflow_works() { + // Normal operation: fp=4, offset=2 → address 6 → reads mem[6]. + let mut data = vec![F::ZERO; 8]; + data[6] = F::from_u32(42); + let mem = Memory::new(data); + let operand = MemOrConstant::MemoryAfterFp { offset: 2 }; + + let result = operand.read_value(&mem, 4); + assert_eq!(result.unwrap(), F::from_u32(42)); +} + +#[test] +fn memory_address_no_overflow_works() { + let operand = MemOrConstant::MemoryAfterFp { offset: 10 }; + assert_eq!(operand.memory_address(5).unwrap(), 15); +} + +#[test] +fn constant_read_ignores_fp() { + // Constants don't compute addresses, so overflow is irrelevant. + let mem = Memory::new(vec![F::ZERO; 4]); + let operand = MemOrConstant::Constant(F::from_u32(99)); + + let result = operand.read_value(&mem, usize::MAX); + assert_eq!(result.unwrap(), F::from_u32(99)); +} + +#[test] +fn is_value_unknown_with_overflow() { + // Overflow makes the value "unknown" (returns error → true). + let mem = Memory::new(vec![F::ONE; 4]); + let operand = MemOrConstant::MemoryAfterFp { offset: usize::MAX }; + assert!(operand.is_value_unknown(&mem, 1)); +} + +// --- MemOrFpOrConstant: same overflow pattern --- + +#[test] +fn fp_or_constant_memory_read_overflow() { + let mem = Memory::new(vec![F::ONE; 16]); + let operand = MemOrFpOrConstant::MemoryAfterFp { offset: usize::MAX }; + assert_eq!(operand.read_value(&mem, 1).unwrap_err(), RunnerError::AddressOverflow); +} + +#[test] +fn fp_or_constant_fp_relative_overflow() { + // FpRelative computes fp + offset as a field element value, not a memory address. + // Overflow must still be caught before the conversion. + let mem = Memory::new(vec![F::ONE; 4]); + let operand = MemOrFpOrConstant::FpRelative { offset: usize::MAX }; + assert_eq!(operand.read_value(&mem, 1).unwrap_err(), RunnerError::AddressOverflow); +} + +#[test] +fn fp_or_constant_memory_address_overflow() { + let operand = MemOrFpOrConstant::MemoryAfterFp { offset: usize::MAX }; + assert_eq!(operand.memory_address(1).unwrap_err(), RunnerError::AddressOverflow); +} + +#[test] +fn fp_or_constant_normal_operation() { + let mut data = vec![F::ZERO; 8]; + data[5] = F::from_u32(77); + let mem = Memory::new(data); + let operand = MemOrFpOrConstant::MemoryAfterFp { offset: 2 }; + assert_eq!(operand.read_value(&mem, 3).unwrap(), F::from_u32(77)); + assert_eq!(operand.memory_address(3).unwrap(), 5); +} diff --git a/crates/whir/tests/verifier_protocol_audit.rs b/crates/whir/tests/verifier_protocol_audit.rs new file mode 100644 index 000000000..88ad84c30 --- /dev/null +++ b/crates/whir/tests/verifier_protocol_audit.rs @@ -0,0 +1,469 @@ +// WHIR verifier protocol audit — tests pinned to ACFY24 (eprint 2024/1586). +// +// Each test cites the specific Construction/Theorem/Definition from the paper +// that it validates. Convention differences between the implementation and the +// paper are documented inline. + +use fiat_shamir::{ProverState, VerifierState}; +use field::{Field, PrimeCharacteristicRing}; +use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; +use mt_whir::*; +use poly::*; +use rand::{RngExt, SeedableRng, rngs::StdRng}; + +type F = KoalaBear; +type EF = QuinticExtensionFieldKB; + +/// Small-parameter WHIR instance for fast tests. +/// num_variables=12 keeps prove+verify under 1s on M4. +fn small_whir_config() -> (WhirConfig, usize) { + let num_variables = 12; + let builder = WhirConfigBuilder { + security_level: 80, + max_num_variables_to_send_coeffs: 4, + pow_bits: 0, + folding_factor: FoldingFactor::constant(4), + soundness_type: SecurityAssumption::JohnsonBound, + starting_log_inv_rate: 2, + rs_domain_initial_reduction_factor: 1, + }; + (WhirConfig::new(&builder, num_variables), num_variables) +} + +/// Produce a valid proof for the given polynomial and statement. +fn prove_and_get_proof( + config: &WhirConfig, + polynomial: &[F], + statement: &[SparseStatement], +) -> (fiat_shamir::Proof, Witness) { + let poseidon = default_koalabear_poseidon1_16(); + let mut prover_state = ProverState::new(poseidon); + precompute_dft_twiddles::(1 << 16); + + let poly_owned: MleOwned = MleOwned::Base(polynomial.to_vec()); + let witness = config.commit(&mut prover_state, &poly_owned, polynomial.len()); + let witness_clone = witness.clone(); + config.prove( + &mut prover_state, + statement.to_vec(), + witness_clone, + &poly_owned.by_ref(), + ); + (prover_state.into_proof(), witness) +} + +/// Verify a proof. Returns Ok(folding_randomness) or Err. +fn verify_proof( + config: &WhirConfig, + proof: fiat_shamir::Proof, + statement: &[SparseStatement], +) -> Result, fiat_shamir::ProofError> { + let poseidon = default_koalabear_poseidon1_16(); + let mut verifier_state = VerifierState::::new(proof, poseidon).unwrap(); + let parsed = config.parse_commitment::(&mut verifier_state)?; + config.verify::(&mut verifier_state, &parsed, statement.to_vec()) +} + +// --------------------------------------------------------------------------- +// §5, Construction 5.1: End-to-end completeness +// --------------------------------------------------------------------------- + +#[test] +fn completeness_simple_evaluation_constraint() { + // Construction 5.1 completeness: honest prover's proof is accepted. + // Weight polynomial ŵ = Z · eq(z, X) for a single evaluation constraint. + let (config, num_variables) = small_whir_config(); + let mut rng = StdRng::seed_from_u64(42); + let polynomial: Vec = (0..1 << num_variables).map(|_| rng.random()).collect(); + + // Evaluate at a random point to create the statement + let point: Vec = (0..num_variables).map(|_| rng.random()).collect(); + let eval = polynomial.evaluate(&MultilinearPoint(point.clone())); + let statement = vec![SparseStatement::dense(MultilinearPoint(point), eval)]; + + let (proof, _) = prove_and_get_proof(&config, &polynomial, &statement); + verify_proof(&config, proof, &statement).expect("honest proof must verify"); +} + +#[test] +fn completeness_multiple_evaluation_constraints() { + // §5.2 Construction 5.5: batching multiple constraints. + // ŵ(Z,X) = Σ γ^{i-1} · ŵ_i(Z,X), σ = Σ γ^{i-1} · σ_i. + let (config, num_variables) = small_whir_config(); + let mut rng = StdRng::seed_from_u64(99); + let polynomial: Vec = (0..1 << num_variables).map(|_| rng.random()).collect(); + + let mut statement = Vec::new(); + for _ in 0..4 { + let point: Vec = (0..num_variables).map(|_| rng.random()).collect(); + let eval = polynomial.evaluate(&MultilinearPoint(point.clone())); + statement.push(SparseStatement::dense(MultilinearPoint(point), eval)); + } + + let (proof, _) = prove_and_get_proof(&config, &polynomial, &statement); + verify_proof(&config, proof, &statement).expect("batched proof must verify"); +} + +// --------------------------------------------------------------------------- +// §5.1, Theorem 5.2: Soundness — verifier rejects invalid proofs +// --------------------------------------------------------------------------- + +#[test] +fn soundness_wrong_evaluation_rejected() { + // If the statement claims f(z) = v but v is wrong, the verifier must reject. + // This tests the full verification chain: sumcheck consistency (Decision §1), + // STIR queries (Decision §2c), and the final polynomial check (Decision §3). + let (config, num_variables) = small_whir_config(); + let mut rng = StdRng::seed_from_u64(7); + let polynomial: Vec = (0..1 << num_variables).map(|_| rng.random()).collect(); + + let point: Vec = (0..num_variables).map(|_| rng.random()).collect(); + let correct_eval = polynomial.evaluate(&MultilinearPoint(point.clone())); + // Deliberately wrong evaluation + let wrong_eval = correct_eval + EF::ONE; + + let correct_statement = vec![SparseStatement::dense(MultilinearPoint(point.clone()), correct_eval)]; + let wrong_statement = vec![SparseStatement::dense(MultilinearPoint(point), wrong_eval)]; + + // Prove with correct statement, verify with wrong statement + let (proof, _) = prove_and_get_proof(&config, &polynomial, &correct_statement); + let result = verify_proof(&config, proof, &wrong_statement); + assert!(result.is_err(), "verifier must reject proof for wrong evaluation"); +} + +// --------------------------------------------------------------------------- +// §6.2: Parameter choices — error budget validation +// --------------------------------------------------------------------------- + +#[test] +fn error_budget_johnson_bound_parameters() { + // Theorem 5.2 error bounds. For JohnsonBound at security_level λ: + // - ε^fold ≤ d*·ℓ/|F| + err*(C, 2, δ) [folding error] + // - ε^out ≤ 2^m · ℓ² / (2·|F|) [OOD error, Lemma 4.25] + // - ε^shift ≤ (1-δ)^t + ℓ·(t+1)/|F| [query error] + // + // The config must allocate PoW grinding bits so that + // ε^fold + ε^out + ε^shift ≤ 2^{-λ} per round. + let security_level = 100; + let num_variables = 20; + let builder = WhirConfigBuilder { + security_level, + max_num_variables_to_send_coeffs: 6, + pow_bits: 16, + folding_factor: FoldingFactor::constant(4), + soundness_type: SecurityAssumption::JohnsonBound, + starting_log_inv_rate: 2, + rs_domain_initial_reduction_factor: 1, + }; + let config = WhirConfig::::new(&builder, num_variables); + + // Verify that the PoW budget doesn't exceed the configured maximum + // (per-round pow_bits ≤ builder.pow_bits). + assert!( + config.starting_folding_pow_bits <= builder.pow_bits, + "initial folding PoW {} exceeds budget {}", + config.starting_folding_pow_bits, + builder.pow_bits + ); + for (i, round) in config.round_parameters.iter().enumerate() { + assert!( + round.folding_pow_bits <= builder.pow_bits, + "round {} folding PoW {} exceeds budget {}", + i, + round.folding_pow_bits, + builder.pow_bits + ); + assert!( + round.query_pow_bits <= builder.pow_bits, + "round {} query PoW {} exceeds budget {}", + i, + round.query_pow_bits, + builder.pow_bits + ); + } + + // Verify round structure: Σ kᵢ + final_sumcheck_rounds = num_variables + // (Construction 5.1 parameter constraint) + let total_folding = config.folding_factor.total_number(config.n_rounds()); + assert_eq!( + total_folding + config.final_sumcheck_rounds, + num_variables, + "folding + final sumcheck must cover all variables (Construction 5.1)" + ); +} + +// --------------------------------------------------------------------------- +// §3.3, Definition 3.4: eq polynomial and multilinear evaluation +// --------------------------------------------------------------------------- + +#[test] +fn expand_from_univariate_matches_pow_definition() { + // §3 notation: pow(x, m) := (x^{2^0}, ..., x^{2^{m-1}}). + // MultilinearPoint::expand_from_univariate must produce this. + let alpha = EF::from(KoalaBear::from_u32(17)); + let m = 5; + let point = MultilinearPoint::::expand_from_univariate(alpha, m); + + let mut expected = Vec::with_capacity(m); + let mut current = alpha; + for _ in 0..m { + expected.push(current); + current = current * current; // x^{2^i} → x^{2^{i+1}} + } + + assert_eq!(point.0, expected, "expand_from_univariate must equal pow(x, m) per §3"); +} + +// --------------------------------------------------------------------------- +// §4.3, Definition 4.14: Folding operation +// --------------------------------------------------------------------------- + +#[test] +fn folding_definition_consistency() { + // Definition 4.14: Fold(f, α)(x²) = (f(x) + f(-x))/2 + α · (f(x) - f(-x))/(2x). + // + // For a multilinear polynomial, this means: + // f̂(α, X₂, ..., Xₘ) should equal the evaluation of the folded function. + // + // We verify this by constructing a polynomial, folding it at a random point, + // and checking that the result matches the multilinear extension evaluation. + let m = 4; + let mut rng = StdRng::seed_from_u64(314); + let coeffs: Vec = (0..1 << m).map(|_| EF::from(rng.random::())).collect(); + + // Evaluate f̂(α, r₂, r₃, r₄) using eval_multilinear_coeffs + let alpha: EF = rng.random(); + let r: Vec = (1..m).map(|_| rng.random()).collect(); + + let mut full_point = vec![alpha]; + full_point.extend_from_slice(&r); + let direct_eval = eval_multilinear_coeffs(&coeffs, &full_point); + + // Now fold the coefficients by fixing X₁ = α, then evaluate at (r₂, r₃, r₄) + let half = coeffs.len() / 2; + let folded: Vec = (0..half).map(|i| coeffs[i] + alpha * coeffs[i + half]).collect(); + let folded_eval = eval_multilinear_coeffs(&folded, &r); + + assert_eq!( + direct_eval, folded_eval, + "folding X₁=α then evaluating must equal direct evaluation (Def 4.14)" + ); +} + +// --------------------------------------------------------------------------- +// verify.rs:385-398: verify_constraint_coeffs — univariate ↔ multilinear +// --------------------------------------------------------------------------- + +#[test] +fn univariate_horner_matches_multilinear_on_evals_to_coeffs_output() { + // The final-round check uses two evaluation methods on the SAME coefficient array + // (output of evals_to_coeffs): + // + // (a) verify.rs:199 — eval_multilinear_coeffs(coeffs, reversed_sumcheck_point) + // for the final sumcheck consistency check + // (b) verify.rs:396 — Horner evaluation coeffs[0] + coeffs[1]·α + ... for STIR + // constraint checks at domain points pow(z, m) + // + // evals_to_coeffs applies a bit-reverse permutation (evals.rs:55) after the + // butterfly, which reorders coefficients so that Horner at z gives the correct + // univariate evaluation f(z) = f̂(pow(z, m)). We verify this by starting from + // evaluations, converting to coefficients, and checking both methods agree. + let n = 4; + let mut rng = StdRng::seed_from_u64(271); + + // Start from evaluation form: evals[k] = f(b₀, b₁, ...) where k = b₀ + 2b₁ + ... + let mut evals: Vec = (0..1 << n).map(|_| EF::from(rng.random::())).collect(); + + // Evaluate the polynomial at a random point BEFORE converting to coefficients + // (using evaluation-form multilinear evaluation) + let alpha: EF = rng.random(); + let pow_point = MultilinearPoint::::expand_from_univariate(alpha, n); + let eval_form_result = evals.evaluate(&pow_point); + + // Convert evaluations to coefficients (includes bit-reverse permutation) + evals_to_coeffs(&mut evals); + let coeffs = evals; + + // Horner evaluation on the coefficients + let horner_result = coeffs.iter().rfold(EF::ZERO, |acc, &c| acc * alpha + c); + + assert_eq!( + eval_form_result, horner_result, + "Horner on evals_to_coeffs output must match evaluation-form result at pow(α,n)" + ); +} + +// --------------------------------------------------------------------------- +// §5.1, Theorem 5.2: domain structure validation +// --------------------------------------------------------------------------- + +#[test] +fn folded_domain_generator_order() { + // Construction 5.1 step 2d: STIR queries sample from L^{(2^k)}_{i-1}. + // The folded_domain_gen must be a generator of the correct-order subgroup. + let (config, _) = small_whir_config(); + + for (i, round) in config.round_parameters.iter().enumerate() { + let expected_order = round.domain_size >> round.folding_factor; + // gen^{expected_order} should equal 1 (it generates a group of that order) + let gen_to_order = round.folded_domain_gen.exp_u64(expected_order as u64); + assert_eq!( + gen_to_order, + as PrimeCharacteristicRing>::ONE, + "round {} folded_domain_gen must have order {} (domain_size/2^folding_factor)", + i, + expected_order + ); + // gen^{expected_order/2} should NOT equal 1 (primitive generator) + if expected_order > 1 { + let gen_to_half = round.folded_domain_gen.exp_u64((expected_order / 2) as u64); + assert_ne!( + gen_to_half, + as PrimeCharacteristicRing>::ONE, + "round {} folded_domain_gen must be primitive", + i + ); + } + } +} + +// --------------------------------------------------------------------------- +// Convention documentation: combination weight shift +// --------------------------------------------------------------------------- + +#[test] +fn combination_weights_start_from_one() { + // Convention difference from ACFY24 Construction 5.1: + // + // Paper (step 2e, page 32): new constraints get weights γ^{j+1} (j=0,...,t-1). + // Implementation (verify.rs:214): weights are [γ^0, γ^1, ...] = [1, γ, γ², ...]. + // + // The shift doesn't affect soundness: the polynomial identity lemma + // argument in Claim 5.4 works with degree t instead of t+1, giving + // a tighter bound. Both prover (open.rs:146) and verifier agree on + // this convention. + // + // This test documents the convention by verifying that the first + // constraint weight is 1 (not γ). We do this indirectly by checking + // that a single-constraint proof works — if the weight were γ instead + // of 1, the claimed_sum would be γ·v instead of v, and the sumcheck + // would need to account for that. + let (config, num_variables) = small_whir_config(); + let mut rng = StdRng::seed_from_u64(111); + let polynomial: Vec = (0..1 << num_variables).map(|_| rng.random()).collect(); + + // Single constraint: weight = 1 (first in the combination) + let point: Vec = (0..num_variables).map(|_| rng.random()).collect(); + let eval = polynomial.evaluate(&MultilinearPoint(point.clone())); + let statement = vec![SparseStatement::dense(MultilinearPoint(point), eval)]; + + let (proof, _) = prove_and_get_proof(&config, &polynomial, &statement); + verify_proof(&config, proof, &statement).expect("single-constraint proof verifies with weight=1"); +} + +// --------------------------------------------------------------------------- +// Sumcheck degree: d = max{d*, 3} where d* = 1 + deg_Z(ŵ) + max_i deg_{X_i}(ŵ) +// --------------------------------------------------------------------------- + +#[test] +fn sumcheck_degree_correct_for_eq_weight() { + // Construction 5.1: d* = 1 + deg_Z(ŵ₀) + max_{i∈[m₀]} deg_{X_i}(ŵ₀). + // For ŵ = Z · eq(z, X): + // deg_Z = 1, deg_{X_i} = 1, so d* = 1 + 1 + 1 = 3. + // d = max{d*, 3} = 3. + // + // The sumcheck polynomial has degree < d = 3, i.e., at most degree 2, + // requiring 3 coefficients. This matches the hardcoded value in + // verify_sumcheck_rounds (verify.rs:417). + // + // NOTE: If future weight polynomials have higher degree (d* > 3), + // the hardcoded 3 must be updated. This test will catch such a + // regression if the config exposes the degree. + let (config, num_variables) = small_whir_config(); + let mut rng = StdRng::seed_from_u64(55); + let polynomial: Vec = (0..1 << num_variables).map(|_| rng.random()).collect(); + + // If the degree were wrong (too low), the sumcheck would produce + // incorrect evaluations and the final check would fail. + let point: Vec = (0..num_variables).map(|_| rng.random()).collect(); + let eval = polynomial.evaluate(&MultilinearPoint(point.clone())); + let statement = vec![SparseStatement::dense(MultilinearPoint(point), eval)]; + + let (proof, _) = prove_and_get_proof(&config, &polynomial, &statement); + verify_proof(&config, proof, &statement).expect("degree-3 sumcheck is correct for eq weights"); +} + +// --------------------------------------------------------------------------- +// §6.2: SecurityAssumption distance parameters +// --------------------------------------------------------------------------- + +#[test] +fn security_assumptions_distance_parameters() { + // §6.2 parameter choices: + // - UD: δ = (1-ρ)/2 + // - JB: δ = 1 - √ρ - η, where η = √ρ/c + // - CB: δ = 1 - ρ - η, where η = ρ/c + // + // Verify log(1-δ) computation for each assumption at rate ρ = 1/4. + let log_inv_rate = 2; // ρ = 1/4 + let log_c = 3.0; // c = 8 + + // UD: δ = (1 - 1/4)/2 = 3/8, so 1-δ = 5/8, log(5/8) ≈ -0.678 + let ud_log = SecurityAssumption::UniqueDecoding.log_1_delta(log_inv_rate, log_c); + assert!(ud_log < -0.67 && ud_log > -0.69, "UD log(1-δ) = {}", ud_log); + + // JB: η = √(1/4)/8 = 1/16, δ = 1 - 1/2 - 1/16 = 7/16, 1-δ = 9/16 + let jb_log = SecurityAssumption::JohnsonBound.log_1_delta(log_inv_rate, log_c); + let expected_jb = (9.0_f64 / 16.0).log2(); + assert!( + (jb_log - expected_jb).abs() < 0.01, + "JB log(1-δ) = {}, expected {}", + jb_log, + expected_jb + ); + + // CB: η = (1/4)/8 = 1/32, δ = 1 - 1/4 - 1/32 = 23/32, 1-δ = 9/32 + let cb_log = SecurityAssumption::CapacityBound.log_1_delta(log_inv_rate, log_c); + let expected_cb = (9.0_f64 / 32.0).log2(); + assert!( + (cb_log - expected_cb).abs() < 0.01, + "CB log(1-δ) = {}, expected {}", + cb_log, + expected_cb + ); +} + +// --------------------------------------------------------------------------- +// Lemma 4.25: OOD error formula +// --------------------------------------------------------------------------- + +#[test] +fn ood_error_formula_matches_lemma_4_25() { + // Lemma 4.25 (adapted from [ACFY24] Lemma 4.5): + // Pr[|Λ(CRS[...], f, δ)| > 1] ≤ (ℓ²/2) · (2^m / |F|)^s + // + // In log scale: -log₂(error) = s·field_size + 1 - 2·list_size_bits - s·log_degree + // + // The implementation's ood_error returns this in bits of security. + let sa = SecurityAssumption::JohnsonBound; + let log_degree = 20; + let log_inv_rate = 2; + let field_size_bits = EF::bits(); + let log_c = 4.0; + let ood_samples = 1; + + let security_bits = sa.ood_error(log_degree, log_inv_rate, field_size_bits, ood_samples, log_c); + + // Manually compute: s·field_size + 1 - 2·list_size_bits - s·log_degree + let list_size_bits = sa.list_size_bits(log_degree, log_inv_rate, log_c); + let expected = + (ood_samples * field_size_bits) as f64 + 1.0 - 2.0 * list_size_bits - (ood_samples * log_degree) as f64; + + assert!( + (security_bits - expected).abs() < 0.001, + "ood_error({} samples) = {}, expected {} (Lemma 4.25)", + ood_samples, + security_bits, + expected + ); +}