Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tower-http/src/timeout/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ pin_project! {
}
}

pin_project! {
/// Middleware that applies a total timeout to the entire body stream.
///
/// Unlike `TimeoutBody`, this does NOT reset when a frame is received.
pub struct TotalTimeoutBody<B> {
timeout: Duration,
#[pin]
sleep: Option<Sleep>,
#[pin]
body: B,
}
}

impl<B> TimeoutBody<B> {
/// Creates a new [`TimeoutBody`].
pub fn new(timeout: Duration, body: B) -> Self {
Expand Down Expand Up @@ -106,6 +119,53 @@ where
}
}

impl<B> TotalTimeoutBody<B> {
/// Creates a new [`TotalTimeoutBody`].
pub fn new(timeout: Duration, body: B) -> Self {
Self {
timeout,
sleep: None,
body,
}
}
}

impl<B> Body for TotalTimeoutBody<B>
where
B: Body,
B::Error: Into<BoxError>,
{
type Data = B::Data;
type Error = Box<dyn std::error::Error + Send + Sync>;

fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
let mut this = self.project();

// Initialize the timer once and keep it throughout the body's lifetime
let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() {
some
} else {
this.sleep.set(Some(sleep(*this.timeout)));
this.sleep.as_mut().as_pin_mut().unwrap()
};

// If the absolute deadline has passed, error out immediately
if let Poll::Ready(()) = sleep_pinned.poll(cx) {
return Poll::Ready(Some(Err(Box::new(TotalTimeoutError(())))));
}

// Poll the underlying body for the next frame
match ready!(this.body.poll_frame(cx)) {
Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
Some(Err(e)) => Poll::Ready(Some(Err(e.into()))),
None => Poll::Ready(None),
}
}
}

/// Error for [`TimeoutBody`].
#[derive(Debug)]
pub struct TimeoutError(());
Expand All @@ -117,6 +177,17 @@ impl std::fmt::Display for TimeoutError {
write!(f, "data was not received within the designated timeout")
}
}

/// Error for [`TotalTimeoutBody`].
#[derive(Debug)]
pub struct TotalTimeoutError(());
impl std::error::Error for TotalTimeoutError {}
impl std::fmt::Display for TotalTimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "total body transfer time exceeded the limit")
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion tower-http/src/timeout/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
mod body;
mod service;

pub use body::{TimeoutBody, TimeoutError};
pub use body::{TimeoutBody, TimeoutError, TotalTimeoutBody, TotalTimeoutError};
pub use service::{
RequestBodyTimeout, RequestBodyTimeoutLayer, ResponseBodyTimeout, ResponseBodyTimeoutLayer,
Timeout, TimeoutLayer,
Expand Down