diff --git a/services/api/Cargo.toml b/services/api/Cargo.toml index bab4d7ee..a1c7a0fa 100644 --- a/services/api/Cargo.toml +++ b/services/api/Cargo.toml @@ -52,6 +52,7 @@ ipnet = "2" criterion = { version = "0.5", features = ["html_reports"] } testcontainers = { version = "0.23", features = ["tokio"] } testcontainers-modules = { version = "0.11", features = ["redis", "tokio"] } +wiremock = "0.6" [[bench]] name = "api_key_auth" diff --git a/services/api/src/email/service.rs b/services/api/src/email/service.rs index cbe3421f..253fd090 100644 --- a/services/api/src/email/service.rs +++ b/services/api/src/email/service.rs @@ -8,6 +8,7 @@ use validator::ValidateEmail; use crate::cache::RedisCache; use crate::config::Config; use crate::email::templates::EmailTemplateEngine; +use crate::metrics::Metrics; /// Configuration for email idempotency deduplication. #[derive(Clone, Debug)] @@ -87,6 +88,8 @@ pub struct EmailService { client: reqwest::Client, cache: Option, pub idempotency: IdempotencyConfig, + metrics: Option, + sendgrid_base_url: String, } impl EmailService { @@ -98,6 +101,15 @@ impl EmailService { config: Config, cache: Option, idempotency: IdempotencyConfig, + ) -> Result { + Self::with_cache_and_metrics(config, cache, idempotency, None) + } + + pub fn with_cache_and_metrics( + config: Config, + cache: Option, + idempotency: IdempotencyConfig, + metrics: Option, ) -> Result { let template_engine = EmailTemplateEngine::new()?; let client = reqwest::Client::builder() @@ -110,9 +122,17 @@ impl EmailService { client, cache, idempotency, + metrics, + sendgrid_base_url: "https://api.sendgrid.com".to_string(), }) } + #[cfg(test)] + pub fn with_base_url(mut self, base_url: String) -> Self { + self.sendgrid_base_url = base_url; + self + } + /// Send an email using SendGrid pub async fn send_email( &self, @@ -218,38 +238,78 @@ impl EmailService { } }); - // Send via SendGrid - let response = self - .client - .post("https://api.sendgrid.com/v3/mail/send") - .bearer_auth(api_key) - .json(&payload) - .send() - .await - .context("Failed to send email via SendGrid")?; + // Send via SendGrid with retry (max 3 attempts, exp backoff + jitter) + const MAX_ATTEMPTS: u32 = 3; + let mut last_error = String::new(); + + for attempt in 0..MAX_ATTEMPTS { + let response = self + .client + .post(format!("{}/v3/mail/send", self.sendgrid_base_url)) + .bearer_auth(api_key) + .json(&payload) + .send() + .await + .context("Failed to send email via SendGrid")?; - if !response.status().is_success() { let status = response.status(); - let body = response.text().await.unwrap_or_default(); - anyhow::bail!("SendGrid API error {}: {}", status, body); - } - // Extract message ID from response headers - let message_id = response - .headers() - .get("x-message-id") - .and_then(|v| v.to_str().ok()) - .unwrap_or("unknown") - .to_string(); - - tracing::info!( - "Email sent successfully to {} using template {} (message_id: {})", - recipient, - template_name, - message_id - ); + if status.is_success() { + let message_id = response + .headers() + .get("x-message-id") + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown") + .to_string(); - Ok(message_id) + tracing::info!( + "Email sent successfully to {} using template {} (message_id: {})", + recipient, + template_name, + message_id + ); + return Ok(message_id); + } + + let should_retry = status.as_u16() == 429 || status.is_server_error(); + let retry_after_header = response + .headers() + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + + if !should_retry || attempt + 1 == MAX_ATTEMPTS { + last_error = format!("SendGrid API error {}", status); + break; + } + + let reason = if status.as_u16() == 429 { "rate_limited" } else { "server_error" }; + if let Some(m) = &self.metrics { + m.observe_sendgrid_retry(reason); + } + + // Respect Retry-After (seconds) if present, else exp backoff + jitter + let delay_ms: u64 = if let Some(secs) = retry_after_header { + secs * 1_000 + } else { + let jitter = (std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .subsec_millis() % 100) as u64; + (1u64 << attempt) * 100 + jitter + }; + + tracing::warn!( + attempt = attempt + 1, + delay_ms, + reason, + "SendGrid transient error {}, retrying", + status + ); + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } + + anyhow::bail!(last_error); } /// Preview email without sending (for testing/development) @@ -399,6 +459,61 @@ mod tests { ); } + /// Two 429s followed by a 202: the service should succeed on the third attempt. + #[tokio::test] + async fn retry_succeeds_after_two_429s() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let mock_server = MockServer::start().await; + + // First two calls return 429, third returns 202. + Mock::given(method("POST")) + .and(path("/v3/mail/send")) + .respond_with(ResponseTemplate::new(429)) + .up_to_n_times(2) + .mount(&mock_server) + .await; + + Mock::given(method("POST")) + .and(path("/v3/mail/send")) + .respond_with( + ResponseTemplate::new(202) + .insert_header("x-message-id", "test-msg-id"), + ) + .mount(&mock_server) + .await; + + let mut config = Config::from_env(); + config.sendgrid_api_key = Some("test-key".to_string()); + config.from_email = Some("from@example.com".to_string()); + + let metrics = crate::metrics::Metrics::new().unwrap(); + let service = EmailService::with_cache_and_metrics( + config, + None, + IdempotencyConfig::default(), + Some(metrics.clone()), + ) + .unwrap() + .with_base_url(mock_server.uri()); + + let data = serde_json::json!({"confirm_url": "https://example.com/confirm?token=abc"}); + let result = service + .send_email("user@example.com", "newsletter_confirmation", &data) + .await; + + assert!(result.is_ok(), "expected success after retries, got: {:?}", result); + assert_eq!(result.unwrap(), "test-msg-id"); + + // Verify the retry counter was incremented twice (one per 429) + let rendered = metrics.render().unwrap(); + assert!( + rendered.contains(r#"sendgrid_retries_total{reason="rate_limited"} 2"#), + "expected 2 rate_limited retries in metrics:\n{rendered}" + ); + } + #[test] fn valid_address_passes() { assert!(sanitize_email("user@example.com").is_ok()); diff --git a/services/api/src/metrics.rs b/services/api/src/metrics.rs index e9c8b542..4c4e369e 100644 --- a/services/api/src/metrics.rs +++ b/services/api/src/metrics.rs @@ -18,6 +18,7 @@ pub struct Metrics { db_pool_connections_idle: IntGaugeVec, db_pool_acquire_duration: HistogramVec, rate_limit_rejections: IntCounterVec, + sendgrid_retries: IntCounterVec, } impl Metrics { @@ -120,6 +121,12 @@ impl Metrics { ) .context("rate_limit_rejections metric")?; + let sendgrid_retries = IntCounterVec::new( + prometheus::Opts::new("sendgrid_retries_total", "SendGrid send retries by reason"), + &["reason"], + ) + .context("sendgrid_retries metric")?; + registry.register(Box::new(cache_hits.clone()))?; registry.register(Box::new(cache_misses.clone()))?; registry.register(Box::new(invalidations.clone()))?; @@ -132,6 +139,7 @@ impl Metrics { registry.register(Box::new(db_pool_connections_idle.clone()))?; registry.register(Box::new(db_pool_acquire_duration.clone()))?; registry.register(Box::new(rate_limit_rejections.clone()))?; + registry.register(Box::new(sendgrid_retries.clone()))?; Ok(Self { registry, @@ -147,6 +155,7 @@ impl Metrics { db_pool_connections_idle, db_pool_acquire_duration, rate_limit_rejections, + sendgrid_retries, }) } @@ -226,6 +235,12 @@ impl Metrics { .inc(); } + /// Increment the SendGrid retry counter. + /// `reason` should be "rate_limited" (429) or "server_error" (5xx). + pub fn observe_sendgrid_retry(&self, reason: &str) { + self.sendgrid_retries.with_label_values(&[reason]).inc(); + } + pub fn render(&self) -> anyhow::Result { let mut buffer = vec![]; let encoder = TextEncoder::new(); diff --git a/services/api/src/validation.rs b/services/api/src/validation.rs index 6eac517d..cf84778d 100644 --- a/services/api/src/validation.rs +++ b/services/api/src/validation.rs @@ -8,11 +8,174 @@ //! //! This is a defence-in-depth layer; the frontend MUST also escape output. -use axum::http::StatusCode; +use axum::body::Body; +use axum::extract::Request; +use axum::http::{Method, StatusCode}; +use axum::middleware::Next; use axum::response::{IntoResponse, Response}; use axum::Json; use serde::Serialize; +// ── Request body size limit ────────────────────────────────────────────────── + +/// Default request body size limit: 1 MiB. +pub const DEFAULT_REQUEST_BODY_MAX_BYTES: usize = 1_048_576; + +/// Parse `REQUEST_BODY_MAX_BYTES` from an optional env-var string. +/// Returns the default on missing, zero, or unparseable values. +pub fn parse_request_body_max_bytes(val: Option<&str>) -> usize { + val.and_then(|s| s.trim().parse::().ok()) + .filter(|&n| n > 0) + .unwrap_or(DEFAULT_REQUEST_BODY_MAX_BYTES) +} + +fn body_limit() -> usize { + parse_request_body_max_bytes(std::env::var("REQUEST_BODY_MAX_BYTES").ok().as_deref()) +} + +#[derive(Serialize)] +struct PayloadTooLargeError { + error: &'static str, + message: String, + limit_bytes: usize, +} + +/// Tower middleware that enforces a request body size limit. +/// +/// Fast-path: rejects immediately when `Content-Length` exceeds the limit. +/// Slow-path: buffers the stream and rejects once accumulated bytes exceed limit. +pub async fn request_size_validation_middleware( + req: Request, + next: Next, +) -> Response { + let limit = body_limit(); + + // Fast path: Content-Length header present + if let Some(cl) = req.headers().get("content-length") { + if let Ok(s) = cl.to_str() { + if let Ok(n) = s.parse::() { + if n > limit { + return payload_too_large(limit); + } + } + } + } + + // Slow path: buffer stream up to limit+1 bytes. + // axum::body::to_bytes returns Err when body exceeds the cap — treat that as 413. + let (parts, body) = req.into_parts(); + let bytes = match axum::body::to_bytes(body, limit + 1).await { + Ok(b) => b, + Err(_) => return payload_too_large(limit), + }; + if bytes.len() > limit { + return payload_too_large(limit); + } + + let req = Request::from_parts(parts, Body::from(bytes)); + next.run(req).await +} + +fn payload_too_large(limit: usize) -> Response { + ( + StatusCode::PAYLOAD_TOO_LARGE, + Json(PayloadTooLargeError { + error: "payload_too_large", + message: format!( + "Request body exceeds the maximum allowed size of {} bytes.", + limit + ), + limit_bytes: limit, + }), + ) + .into_response() +} + +// ── Content-Type validation ─────────────────────────────────────────────────── + +const JSON_REQUIRED_METHODS: &[Method] = &[Method::POST, Method::PUT, Method::PATCH]; + +#[derive(Serialize)] +struct UnsupportedMediaTypeError { + error: &'static str, + message: String, + required: &'static str, + received: String, +} + +/// Reject POST/PUT/PATCH requests whose `Content-Type` is not `application/json`. +pub async fn content_type_validation_middleware(req: Request, next: Next) -> Response { + if JSON_REQUIRED_METHODS.contains(req.method()) { + let ct = req + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + if !ct.starts_with("application/json") { + return ( + StatusCode::UNSUPPORTED_MEDIA_TYPE, + Json(UnsupportedMediaTypeError { + error: "unsupported_media_type", + message: "Content-Type must be application/json for POST, PUT, and PATCH \ + requests." + .to_string(), + required: "application/json", + received: if ct.is_empty() { + "not set".to_string() + } else { + ct.to_string() + }, + }), + ) + .into_response(); + } + } + next.run(req).await +} + +// ── Query / path validation ─────────────────────────────────────────────────── + +static SUSPICIOUS_QUERY_PATTERNS: &[&str] = &[ + "' or", "\" or", "1=1", "or 1=1", "drop table", "select ", "insert ", + "delete ", "update ", "union ", "--", "/*", "*/", "xp_", "exec(", +]; + +static SUSPICIOUS_PATH_PATTERNS: &[&str] = &["//", "../", "..\\", "%2e%2e"]; + +/// Reject requests with SQL-injection or path-traversal patterns in query / path. +pub async fn request_validation_middleware(req: Request, next: Next) -> Response { + let uri = req.uri(); + + if let Some(query) = uri.query() { + let lower = query.to_lowercase(); + if SUSPICIOUS_QUERY_PATTERNS.iter().any(|p| lower.contains(p)) { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "message": "Request contains disallowed query patterns." + })), + ) + .into_response(); + } + } + + let path = uri.path(); + let lower_path = path.to_lowercase(); + if SUSPICIOUS_PATH_PATTERNS.iter().any(|p| lower_path.contains(p)) { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "message": "Request path contains disallowed patterns." + })), + ) + .into_response(); + } + + next.run(req).await +} + #[derive(Debug, Serialize)] pub struct ValidationError { pub error: &'static str, diff --git a/services/api/tests/request_validation_middleware_tests.rs b/services/api/tests/request_validation_middleware_tests.rs index 068c998b..f3b1a86f 100644 --- a/services/api/tests/request_validation_middleware_tests.rs +++ b/services/api/tests/request_validation_middleware_tests.rs @@ -460,6 +460,24 @@ mod tests { assert_eq!(DEFAULT_REQUEST_BODY_MAX_BYTES, 1_048_576); } + #[tokio::test] + async fn two_mb_body_returns_413() { + let two_mb = vec![b'a'; 2 * 1_048_576]; + let response = app() + .oneshot( + Request::builder() + .method("POST") + .uri("/safe-post") + .header("content-type", "application/json") + .body(Body::from(two_mb)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); + } + // ── Admin route auth ────────────────────────────────────────────────────── #[tokio::test]