diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index e7aca714..fdc1a409 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -51,10 +51,7 @@ use allow_origin::AllowOriginFuture; use bytes::{BufMut, BytesMut}; -use http::{ - header::{self, HeaderName}, - HeaderMap, HeaderValue, Method, Request, Response, -}; +use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response}; use pin_project_lite::pin_project; use std::{ future::Future, @@ -99,6 +96,7 @@ pub struct CorsLayer { expose_headers: ExposeHeaders, max_age: MaxAge, vary: Vary, + is_vary_custom: bool, } #[allow(clippy::declare_interior_mutable_const)] @@ -122,6 +120,7 @@ impl CorsLayer { expose_headers: Default::default(), max_age: Default::default(), vary: Default::default(), + is_vary_custom: false, } } @@ -445,6 +444,7 @@ impl CorsLayer { T: Into, { self.vary = headers.into(); + self.is_vary_custom = true; self } } @@ -493,10 +493,34 @@ impl Layer for CorsLayer { fn layer(&self, inner: S) -> Self::Service { ensure_usable_cors_rules(self); - Cors { - inner, - layer: self.clone(), + // Clone the layer to modify Vary header logic + let mut layer = self.clone(); + + // Only set Vary if not custom + if !layer.is_vary_custom { + // If all origins, methods, and headers are allowed, omit Vary + let all_origins = layer.allow_origin.is_wildcard(); + let all_methods = layer.allow_methods.is_wildcard(); + let all_headers = layer.allow_headers.is_wildcard(); + if all_origins && all_methods && all_headers { + layer.vary = Vary::list([]); + } else { + // Otherwise, set Vary to the appropriate headers + let mut vary_headers = Vec::new(); + if !all_origins { + vary_headers.push(header::ORIGIN); + } + if !all_methods { + vary_headers.push(header::ACCESS_CONTROL_REQUEST_METHOD); + } + if !all_headers { + vary_headers.push(header::ACCESS_CONTROL_REQUEST_HEADERS); + } + layer.vary = Vary::list(vary_headers); + } } + + Cors { inner, layer } } } @@ -641,6 +665,28 @@ impl Cors { F: FnOnce(CorsLayer) -> CorsLayer, { self.layer = f(self.layer); + + // Centralize Vary header logic here as well + if !self.layer.is_vary_custom { + let all_origins = self.layer.allow_origin.is_wildcard(); + let all_methods = self.layer.allow_methods.is_wildcard(); + let all_headers = self.layer.allow_headers.is_wildcard(); + if all_origins && all_methods && all_headers { + self.layer.vary = Vary::list([]); + } else { + let mut vary_headers = Vec::new(); + if !all_origins { + vary_headers.push(header::ORIGIN); + } + if !all_methods { + vary_headers.push(header::ACCESS_CONTROL_REQUEST_METHOD); + } + if !all_headers { + vary_headers.push(header::ACCESS_CONTROL_REQUEST_HEADERS); + } + self.layer.vary = Vary::list(vary_headers); + } + } self } } diff --git a/tower-http/src/cors/tests.rs b/tower-http/src/cors/tests.rs index 8f3f4acb..f20030e8 100644 --- a/tower-http/src/cors/tests.rs +++ b/tower-http/src/cors/tests.rs @@ -1,37 +1,106 @@ use std::convert::Infallible; -use crate::test_helpers::Body; -use http::{header, HeaderValue, Request, Response}; +use crate::{cors::Vary, test_helpers::Body}; +use http::{header, HeaderName, HeaderValue, Request, Response}; use tower::{service_fn, util::ServiceExt, Layer}; use crate::cors::{AllowOrigin, CorsLayer}; +const INITIAL_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding"); +const ADDITIONAL_VARY_HEADERS: [HeaderName; 3] = [ + header::ORIGIN, + header::ACCESS_CONTROL_REQUEST_METHOD, + header::ACCESS_CONTROL_REQUEST_HEADERS, +]; + +#[tokio::test] +#[allow( + clippy::declare_interior_mutable_const, + clippy::borrow_interior_mutable_const +)] +async fn permissive_vary_header_is_empty() { + let svc = CorsLayer::permissive().layer(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })); + + let req = Request::builder().body(Body::empty()).unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert!( + res.headers().get(header::VARY).is_none(), + "Vary header should be omitted for permissive config" + ); +} + #[tokio::test] #[allow( clippy::declare_interior_mutable_const, clippy::borrow_interior_mutable_const )] -async fn vary_set_by_inner_service() { - const CUSTOM_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding"); +async fn include_custom_permissive_to_vary_set_by_inner_service() { const PERMISSIVE_CORS_VARY_HEADERS: HeaderValue = HeaderValue::from_static( "origin, access-control-request-method, access-control-request-headers", ); async fn inner_svc(_: Request) -> Result, Infallible> { Ok(Response::builder() - .header(header::VARY, CUSTOM_VARY_HEADERS) + .header(header::VARY, INITIAL_VARY_HEADERS) .body(Body::empty()) .unwrap()) } - let svc = CorsLayer::permissive().layer(service_fn(inner_svc)); + let svc = CorsLayer::permissive() + .vary(Vary::list(ADDITIONAL_VARY_HEADERS)) + .layer(service_fn(inner_svc)); + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut vary_headers = res.headers().get_all(header::VARY).into_iter(); - assert_eq!(vary_headers.next(), Some(&CUSTOM_VARY_HEADERS)); + assert_eq!(vary_headers.next(), Some(&INITIAL_VARY_HEADERS)); assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS)); assert_eq!(vary_headers.next(), None); } +#[tokio::test] +async fn permissive_with_custom_vary_builder() { + let custom_vary = HeaderValue::from_static("x-foo"); + let svc = CorsLayer::permissive() + .vary(Vary::list([header::HeaderName::from_static("x-foo")])) + .layer(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.oneshot(req).await.unwrap(); + let vary = res.headers().get(header::VARY); + assert_eq!(vary, Some(&custom_vary)); +} + +#[tokio::test] +async fn permissive_with_inner_and_builder_vary() { + let custom_vary = HeaderValue::from_static("x-foo"); + let inner_vary = HeaderValue::from_static("accept-encoding"); + let svc = CorsLayer::permissive() + .vary(Vary::list([header::HeaderName::from_static("x-foo")])) + .layer(service_fn(|_: Request| { + let inner_vary = inner_vary.clone(); + async move { + Ok::<_, Infallible>( + Response::builder() + .header(header::VARY, inner_vary) + .body(Body::empty()) + .unwrap(), + ) + } + })); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.oneshot(req).await.unwrap(); + let mut vary_headers = res.headers().get_all(header::VARY).iter(); + assert_eq!(vary_headers.next(), Some(&inner_vary)); + assert_eq!(vary_headers.next(), Some(&custom_vary)); + assert_eq!(vary_headers.next(), None); +} + #[tokio::test] async fn test_allow_origin_async_predicate() { #[derive(Clone)] diff --git a/tower-http/src/cors/vary.rs b/tower-http/src/cors/vary.rs index 3ebe4a27..21a3486d 100644 --- a/tower-http/src/cors/vary.rs +++ b/tower-http/src/cors/vary.rs @@ -1,6 +1,6 @@ use http::header::{self, HeaderName, HeaderValue}; -use super::preflight_request_headers; +use crate::cors::preflight_request_headers; /// Holds configuration for how to set the [`Vary`][mdn] header. ///