diff --git a/ml-dsa/src/sampling.rs b/ml-dsa/src/sampling.rs index 85ed73a1..1e92600e 100644 --- a/ml-dsa/src/sampling.rs +++ b/ml-dsa/src/sampling.rs @@ -5,6 +5,8 @@ use crate::{ }; use hybrid_array::Array; use module_lattice::{ArraySize, Field, Truncate}; +#[cfg(feature = "zeroize")] +use zeroize::Zeroize; // Algorithm 13 BytesToBits fn bit_set(z: &[u8], i: usize) -> bool { @@ -95,17 +97,38 @@ pub(crate) fn sample_in_ball(rho: &[u8], tau: usize) -> Polynomial { fn rej_ntt_poly(rho: &[u8], r: u8, s: u8) -> NttPolynomial { let mut j = 0; let mut ctx = G::default().absorb(rho).absorb(&[s]).absorb(&[r]); - let mut a = NttPolynomial::default(); - let mut s = [0u8; 3]; - while j < 256 { - ctx.squeeze(&mut s); - if let Some(x) = coeff_from_three_bytes(s) { + + // Squeeze 840 bytes (5 SHAKE128 blocks) in a single call rather than 3 bytes + // at a time. The rejection probability is ~0.098%, so 280 candidates are + // almost always sufficient while requiring the same 5 Keccak-f permutations. + let mut buf = [0u8; 840]; + ctx.squeeze(&mut buf); + + for chunk in buf.chunks_exact(3) { + if let Some(x) = coeff_from_three_bytes([chunk[0], chunk[1], chunk[2]]) { a.0[j] = x; j += 1; + if j == 256 { + break; + } } } + // Fallback: astronomically unlikely (~10^-44), but required for correctness. + let mut tmp = [0u8; 3]; + while j < 256 { + ctx.squeeze(&mut tmp); + if let Some(x) = coeff_from_three_bytes(tmp) { + a.0[j] = x; + j += 1; + } + } + #[cfg(feature = "zeroize")] + { + buf.zeroize(); + tmp.zeroize(); + } a } @@ -113,28 +136,51 @@ fn rej_ntt_poly(rho: &[u8], r: u8, s: u8) -> NttPolynomial { fn rej_bounded_poly(rho: &[u8], eta: Eta, r: u16) -> Polynomial { let mut j = 0; let mut ctx = H::default().absorb(rho).absorb(&r.to_le_bytes()); - let mut a = Polynomial::default(); - let mut z = [0u8]; - while j < 256 { - ctx.squeeze(&mut z); - let (z0, z1) = coeffs_from_byte(z[0], eta); - if let Some(z) = z0 { - a.0[j] = z; + // The reference implementation uses 136 bytes (1 SHAKE256 block) for eta=2 and 272 bytes (2 blocks) for eta=4. + let mut buf = [0u8; 272]; + ctx.squeeze(&mut buf); + + for &byte in &buf { + let (z0, z1) = coeffs_from_byte(byte, eta); + if let Some(x) = z0 { + a.0[j] = x; j += 1; + if j == 256 { + break; + } } - - if j == 256 { - break; + if let Some(x) = z1 { + a.0[j] = x; + j += 1; + if j == 256 { + break; + } } + } - if let Some(z) = z1 { - a.0[j] = z; + // Fallback: astronomically unlikely, but required for correctness. + let mut tmp = [0u8; 1]; + while j < 256 { + ctx.squeeze(&mut tmp); + let (z0, z1) = coeffs_from_byte(tmp[0], eta); + if let Some(x) = z0 { + a.0[j] = x; j += 1; } + if j < 256 { + if let Some(x) = z1 { + a.0[j] = x; + j += 1; + } + } + } + #[cfg(feature = "zeroize")] + { + buf.zeroize(); + tmp.zeroize(); } - a }