Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions sqlx-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ time = { workspace = true, optional = true }
uuid = { workspace = true, optional = true }

# Misc
async-lock = "3.4"
atoi = "2.0"
base64 = { version = "0.22.0", default-features = false, features = ["std"] }
bitflags = { version = "2", default-features = false }
Expand Down
2 changes: 2 additions & 0 deletions sqlx-postgres/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use sqlx_core::sql_str::SqlSafeStr;

pub use self::stream::PgStream;

pub use sasl::ClientKeyCache;

#[cfg(feature = "offline")]
mod describe;
mod establish;
Expand Down
105 changes: 95 additions & 10 deletions sqlx-postgres/src/connection/sasl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
use std::sync::Arc;

use async_lock::{RwLock, RwLockUpgradableReadGuard};

use crate::connection::stream::PgStream;
use crate::error::Error;
use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
use crate::message::{
Authentication, AuthenticationSasl, AuthenticationSaslContinue, SaslInitialResponse,
SaslResponse,
};
use crate::rt;
use crate::PgConnectOptions;
use hmac::{Hmac, Mac};
Expand All @@ -16,6 +23,89 @@ const USERNAME_ATTR: &str = "n";
const CLIENT_PROOF_ATTR: &str = "p";
const NONCE_ATTR: &str = "r";

/// A single-entry cache for the client key derived from the password.
///
/// Salting the password and deriving the client key can be expensive, so this cache stores the most
/// recently used client key along with the parameters used to derive it.
///
/// The cache is keyed on the salt and iteration count, which are the server-provided parameters
/// that affect the HMAC result. The password is not included in the cache key because it can only
/// change via `&mut self` on `PgConnectOptions`, which replaces the cache instance.
///
/// An async `RwLock` is used so that only one caller computes the key at a time; subsequent callers
/// wait and then read the cached result.
///
/// According to [RFC-7677](https://datatracker.ietf.org/doc/html/rfc7677):
///
/// > This computational cost can be avoided by caching the ClientKey (assuming the Salt and hash
/// > iteration-count is stable).
#[derive(Debug, Clone)]
pub struct ClientKeyCache {
inner: Arc<RwLock<Option<CacheEntry>>>,
}

#[derive(Debug)]
struct CacheEntry {
// Keys
salt: Vec<u8>,
iterations: u32,

// Values
salted_password: [u8; 32],
client_key: Hmac<Sha256>,
}

impl CacheEntry {
fn matches(&self, cont: &AuthenticationSaslContinue) -> bool {
self.salt == cont.salt && self.iterations == cont.iterations
}
}

impl ClientKeyCache {
pub fn new() -> Self {
ClientKeyCache {
inner: Arc::new(RwLock::new(None)),
}
}

/// Returns the cached salted password and client key HMAC if the cache matches the given
/// salt and iteration count. Otherwise, computes and caches them.
async fn get_or_compute(
&self,
password: &str,
cont: &AuthenticationSaslContinue,
) -> Result<([u8; 32], Hmac<Sha256>), Error> {
let guard = self.inner.upgradable_read().await;

if let Some(entry) = guard.as_ref().filter(|e| e.matches(cont)) {
return Ok((entry.salted_password, entry.client_key.clone()));
}

let mut guard = RwLockUpgradableReadGuard::upgrade(guard).await;

// Re-check after acquiring the write lock, in case another caller populated the cache.
if let Some(entry) = guard.as_ref().filter(|e| e.matches(cont)) {
return Ok((entry.salted_password, entry.client_key.clone()));
}

// SaltedPassword := Hi(Normalize(password), salt, i)
let salted_password = hi(password, &cont.salt, cont.iterations).await?;

// ClientKey := HMAC(SaltedPassword, "Client Key")
let client_key =
Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;

*guard = Some(CacheEntry {
salt: cont.salt.clone(),
iterations: cont.iterations,
salted_password,
client_key: client_key.clone(),
});

Ok((salted_password, client_key))
}
}

pub(crate) async fn authenticate(
stream: &mut PgStream,
options: &PgConnectOptions,
Expand Down Expand Up @@ -86,16 +176,11 @@ pub(crate) async fn authenticate(
}
};

// SaltedPassword := Hi(Normalize(password), salt, i)
let salted_password = hi(
options.password.as_deref().unwrap_or_default(),
&cont.salt,
cont.iterations,
)
.await?;
let (salted_password, mut mac) = options
.sasl_client_key_cache
.get_or_compute(options.password.as_deref().unwrap_or_default(), &cont)
.await?;

// ClientKey := HMAC(SaltedPassword, "Client Key")
let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
mac.update(b"Client Key");

let client_key = mac.finalize().into_bytes();
Expand Down
2 changes: 1 addition & 1 deletion sqlx-postgres/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ mod startup;
mod sync;
mod terminate;

pub use authentication::{Authentication, AuthenticationSasl};
pub use authentication::{Authentication, AuthenticationSasl, AuthenticationSaslContinue};
pub use backend_key_data::BackendKeyData;
pub use bind::Bind;
pub use close::Close;
Expand Down
11 changes: 9 additions & 2 deletions sqlx-postgres/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use std::path::{Path, PathBuf};

pub use ssl_mode::PgSslMode;

use crate::{connection::LogSettings, net::tls::CertificateInput};
use crate::{
connection::{ClientKeyCache, LogSettings},
net::tls::CertificateInput,
};

mod connect;
mod parse;
Expand All @@ -30,6 +33,7 @@ pub struct PgConnectOptions {
pub(crate) log_settings: LogSettings,
pub(crate) extra_float_digits: Option<Cow<'static, str>>,
pub(crate) options: Option<String>,
pub(crate) sasl_client_key_cache: ClientKeyCache,
}

impl Default for PgConnectOptions {
Expand Down Expand Up @@ -97,6 +101,7 @@ impl PgConnectOptions {
extra_float_digits: Some("2".into()),
log_settings: Default::default(),
options: var("PGOPTIONS").ok(),
sasl_client_key_cache: ClientKeyCache::new(),
}
}

Expand Down Expand Up @@ -188,6 +193,8 @@ impl PgConnectOptions {
/// ```
pub fn password(mut self, password: &str) -> Self {
self.password = Some(password.to_owned());
// Invalidate the cached SASL client key, since it was derived from the old password.
self.sasl_client_key_cache = ClientKeyCache::new();
self
}

Expand Down Expand Up @@ -274,7 +281,7 @@ impl PgConnectOptions {
/// -----BEGIN CERTIFICATE-----
/// <Certificate data here.>
/// -----END CERTIFICATE-----";
///
///
/// let options = PgConnectOptions::new()
/// // Providing a CA certificate with less than VerifyCa is pointless
/// .ssl_mode(PgSslMode::VerifyCa)
Expand Down
Loading