diff --git a/Cargo.lock b/Cargo.lock index 236039f0ff..ccbfa4787d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1377,9 +1377,9 @@ dependencies = [ [[package]] name = "flume" -version = "0.11.1" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" dependencies = [ "futures-core", "futures-sink", @@ -3938,6 +3938,7 @@ dependencies = [ name = "sqlx-postgres" version = "0.9.0-alpha.1" dependencies = [ + "async-lock", "atoi", "base64 0.22.1", "bigdecimal", diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 2943049f0b..85b6bdc70e 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -52,6 +52,7 @@ time = { workspace = true, optional = true } uuid = { workspace = true, optional = true } # Misc +async-lock = "3.4" atoi = "2.0" base64 = { version = "0.22.0", default-features = false, features = ["std"] } bitflags = { version = "2", default-features = false } diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index d5db20ad05..00069e14f3 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -23,6 +23,8 @@ use sqlx_core::sql_str::SqlSafeStr; pub use self::stream::PgStream; +pub use sasl::ClientKeyCache; + #[cfg(feature = "offline")] mod describe; mod establish; diff --git a/sqlx-postgres/src/connection/sasl.rs b/sqlx-postgres/src/connection/sasl.rs index 94fdfc689f..d7e3e0d921 100644 --- a/sqlx-postgres/src/connection/sasl.rs +++ b/sqlx-postgres/src/connection/sasl.rs @@ -1,6 +1,13 @@ +use std::sync::Arc; + +use async_lock::{RwLock, RwLockUpgradableReadGuard}; + use crate::connection::stream::PgStream; use crate::error::Error; -use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse}; +use crate::message::{ + Authentication, AuthenticationSasl, AuthenticationSaslContinue, SaslInitialResponse, + SaslResponse, +}; use crate::rt; use crate::PgConnectOptions; use hmac::{Hmac, Mac}; @@ -16,6 +23,89 @@ const USERNAME_ATTR: &str = "n"; const CLIENT_PROOF_ATTR: &str = "p"; const NONCE_ATTR: &str = "r"; +/// A single-entry cache for the client key derived from the password. +/// +/// Salting the password and deriving the client key can be expensive, so this cache stores the most +/// recently used client key along with the parameters used to derive it. +/// +/// The cache is keyed on the salt and iteration count, which are the server-provided parameters +/// that affect the HMAC result. The password is not included in the cache key because it can only +/// change via `&mut self` on `PgConnectOptions`, which replaces the cache instance. +/// +/// An async `RwLock` is used so that only one caller computes the key at a time; subsequent callers +/// wait and then read the cached result. +/// +/// According to [RFC-7677](https://datatracker.ietf.org/doc/html/rfc7677): +/// +/// > This computational cost can be avoided by caching the ClientKey (assuming the Salt and hash +/// > iteration-count is stable). +#[derive(Debug, Clone)] +pub struct ClientKeyCache { + inner: Arc>>, +} + +#[derive(Debug)] +struct CacheEntry { + // Keys + salt: Vec, + iterations: u32, + + // Values + salted_password: [u8; 32], + client_key: Hmac, +} + +impl CacheEntry { + fn matches(&self, cont: &AuthenticationSaslContinue) -> bool { + self.salt == cont.salt && self.iterations == cont.iterations + } +} + +impl ClientKeyCache { + pub fn new() -> Self { + ClientKeyCache { + inner: Arc::new(RwLock::new(None)), + } + } + + /// Returns the cached salted password and client key HMAC if the cache matches the given + /// salt and iteration count. Otherwise, computes and caches them. + async fn get_or_compute( + &self, + password: &str, + cont: &AuthenticationSaslContinue, + ) -> Result<([u8; 32], Hmac), Error> { + let guard = self.inner.upgradable_read().await; + + if let Some(entry) = guard.as_ref().filter(|e| e.matches(cont)) { + return Ok((entry.salted_password, entry.client_key.clone())); + } + + let mut guard = RwLockUpgradableReadGuard::upgrade(guard).await; + + // Re-check after acquiring the write lock, in case another caller populated the cache. + if let Some(entry) = guard.as_ref().filter(|e| e.matches(cont)) { + return Ok((entry.salted_password, entry.client_key.clone())); + } + + // SaltedPassword := Hi(Normalize(password), salt, i) + let salted_password = hi(password, &cont.salt, cont.iterations).await?; + + // ClientKey := HMAC(SaltedPassword, "Client Key") + let client_key = + Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; + + *guard = Some(CacheEntry { + salt: cont.salt.clone(), + iterations: cont.iterations, + salted_password, + client_key: client_key.clone(), + }); + + Ok((salted_password, client_key)) + } +} + pub(crate) async fn authenticate( stream: &mut PgStream, options: &PgConnectOptions, @@ -86,16 +176,11 @@ pub(crate) async fn authenticate( } }; - // SaltedPassword := Hi(Normalize(password), salt, i) - let salted_password = hi( - options.password.as_deref().unwrap_or_default(), - &cont.salt, - cont.iterations, - ) - .await?; + let (salted_password, mut mac) = options + .sasl_client_key_cache + .get_or_compute(options.password.as_deref().unwrap_or_default(), &cont) + .await?; - // ClientKey := HMAC(SaltedPassword, "Client Key") - let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; mac.update(b"Client Key"); let client_key = mac.finalize().into_bytes(); diff --git a/sqlx-postgres/src/message/mod.rs b/sqlx-postgres/src/message/mod.rs index e62f9bebb3..e7648f4419 100644 --- a/sqlx-postgres/src/message/mod.rs +++ b/sqlx-postgres/src/message/mod.rs @@ -30,7 +30,7 @@ mod startup; mod sync; mod terminate; -pub use authentication::{Authentication, AuthenticationSasl}; +pub use authentication::{Authentication, AuthenticationSasl, AuthenticationSaslContinue}; pub use backend_key_data::BackendKeyData; pub use bind::Bind; pub use close::Close; diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index 21e6628cae..0efd2e032c 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -5,7 +5,10 @@ use std::path::{Path, PathBuf}; pub use ssl_mode::PgSslMode; -use crate::{connection::LogSettings, net::tls::CertificateInput}; +use crate::{ + connection::{ClientKeyCache, LogSettings}, + net::tls::CertificateInput, +}; mod connect; mod parse; @@ -30,6 +33,7 @@ pub struct PgConnectOptions { pub(crate) log_settings: LogSettings, pub(crate) extra_float_digits: Option>, pub(crate) options: Option, + pub(crate) sasl_client_key_cache: ClientKeyCache, } impl Default for PgConnectOptions { @@ -97,6 +101,7 @@ impl PgConnectOptions { extra_float_digits: Some("2".into()), log_settings: Default::default(), options: var("PGOPTIONS").ok(), + sasl_client_key_cache: ClientKeyCache::new(), } } @@ -188,6 +193,8 @@ impl PgConnectOptions { /// ``` pub fn password(mut self, password: &str) -> Self { self.password = Some(password.to_owned()); + // Invalidate the cached SASL client key, since it was derived from the old password. + self.sasl_client_key_cache = ClientKeyCache::new(); self } @@ -274,7 +281,7 @@ impl PgConnectOptions { /// -----BEGIN CERTIFICATE----- /// /// -----END CERTIFICATE-----"; - /// + /// /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa)