diff --git a/tower-http/src/timeout/body.rs b/tower-http/src/timeout/body.rs index d44f35b8..db3aab94 100644 --- a/tower-http/src/timeout/body.rs +++ b/tower-http/src/timeout/body.rs @@ -59,6 +59,19 @@ pin_project! { } } +pin_project! { + /// Middleware that applies a total timeout to the entire body stream. + /// + /// Unlike `TimeoutBody`, this does NOT reset when a frame is received. + pub struct TotalTimeoutBody { + timeout: Duration, + #[pin] + sleep: Option, + #[pin] + body: B, + } +} + impl TimeoutBody { /// Creates a new [`TimeoutBody`]. pub fn new(timeout: Duration, body: B) -> Self { @@ -106,6 +119,53 @@ where } } +impl TotalTimeoutBody { + /// Creates a new [`TotalTimeoutBody`]. + pub fn new(timeout: Duration, body: B) -> Self { + Self { + timeout, + sleep: None, + body, + } + } +} + +impl Body for TotalTimeoutBody +where + B: Body, + B::Error: Into, +{ + type Data = B::Data; + type Error = Box; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let mut this = self.project(); + + // Initialize the timer once and keep it throughout the body's lifetime + let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() { + some + } else { + this.sleep.set(Some(sleep(*this.timeout))); + this.sleep.as_mut().as_pin_mut().unwrap() + }; + + // If the absolute deadline has passed, error out immediately + if let Poll::Ready(()) = sleep_pinned.poll(cx) { + return Poll::Ready(Some(Err(Box::new(TotalTimeoutError(()))))); + } + + // Poll the underlying body for the next frame + match ready!(this.body.poll_frame(cx)) { + Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))), + Some(Err(e)) => Poll::Ready(Some(Err(e.into()))), + None => Poll::Ready(None), + } + } +} + /// Error for [`TimeoutBody`]. #[derive(Debug)] pub struct TimeoutError(()); @@ -117,6 +177,17 @@ impl std::fmt::Display for TimeoutError { write!(f, "data was not received within the designated timeout") } } + +/// Error for [`TotalTimeoutBody`]. +#[derive(Debug)] +pub struct TotalTimeoutError(()); +impl std::error::Error for TotalTimeoutError {} +impl std::fmt::Display for TotalTimeoutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "total body transfer time exceeded the limit") + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/tower-http/src/timeout/mod.rs b/tower-http/src/timeout/mod.rs index e159b23c..91eb9b13 100644 --- a/tower-http/src/timeout/mod.rs +++ b/tower-http/src/timeout/mod.rs @@ -43,7 +43,7 @@ mod body; mod service; -pub use body::{TimeoutBody, TimeoutError}; +pub use body::{TimeoutBody, TimeoutError, TotalTimeoutBody, TotalTimeoutError}; pub use service::{ RequestBodyTimeout, RequestBodyTimeoutLayer, ResponseBodyTimeout, ResponseBodyTimeoutLayer, Timeout, TimeoutLayer,