Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions rust_http_proxy/src/forward_proxy_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,26 @@ where
}
}

#[allow(unused)]
pub async fn send_request_no_cache(
&self, req: Request<B>, access_label: &AccessLabel, ipv6_first: Option<bool>,
stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO<EitherTlsStream, LabelImpl<AccessLabel>>,
) -> Result<Response<body::Incoming>, std::io::Error> {
// Make a new connection
let mut c = match HttpConnection::connect(access_label, ipv6_first, stream_map_func).await {
Ok(c) => c,
Err(err) => {
error!("failed to connect to host: {}, error: {}", &access_label.target, err);
return Err(io::Error::new(io::ErrorKind::InvalidData, err));
}
};

trace!("HTTP making request to host: {access_label}, request: {req:?}");
let response = c.send_request(req).await.map_err(io::Error::other)?;
trace!("HTTP received response from host: {access_label}, response: {response:?}");
Ok(response)
}

/// Make HTTP requests
#[inline]
pub async fn send_request(
Expand Down Expand Up @@ -97,7 +117,7 @@ where
None
}

async fn send_request_conn(
pub(crate) async fn send_request_conn(
&self, access_label: &AccessLabel, mut c: HttpConnection<B>, req: Request<B>,
) -> hyper::Result<Response<body::Incoming>> {
trace!("HTTP making request to host: {access_label}, request: {req:?}");
Expand Down Expand Up @@ -164,7 +184,7 @@ fn get_keep_alive_val(values: header::GetAll<HeaderValue>) -> Option<bool> {
}

#[allow(dead_code)]
enum HttpConnection<B> {
pub(crate) enum HttpConnection<B> {
Http1(http1::SendRequest<B>),
}

Expand All @@ -174,7 +194,7 @@ where
B::Data: Send,
B::Error: Into<Box<dyn ::std::error::Error + Send + Sync>>,
{
async fn connect(
pub(crate) async fn connect(
access_label: &AccessLabel, ipv6_first: Option<bool>,
stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO<EitherTlsStream, LabelImpl<AccessLabel>>,
) -> io::Result<HttpConnection<B>> {
Expand Down
165 changes: 165 additions & 0 deletions rust_http_proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,131 @@ impl ProxyHandler {
// 默认为正向代理
}

/// 处理 WebSocket 升级请求(正向代理场景)
async fn handle_websocket_upgrade_forward(
&self, mut req: Request<Incoming>, traffic_label: AccessLabel,
) -> Result<Response<BoxBody<Bytes, io::Error>>, io::Error> {
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};

// 直接建立到上游的 TCP 连接
let upstream_stream = connect_with_preference(&traffic_label.target, self.config.ipv6_first).await?;
info!("[forward] WebSocket TCP connection established to {}", &traffic_label.target);

let mut upstream_io =
CounterIO::new(upstream_stream, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone()));

// 构建 HTTP 请求行和头部
let mut request_bytes = Vec::new();
request_bytes.extend_from_slice(
format!(
"{} {} {:?}\r\n",
req.method(),
req.uri().path_and_query().map(|p| p.as_str()).unwrap_or("/"),
req.version()
)
.as_bytes(),
);

// 添加所有请求头
for (name, value) in req.headers() {
request_bytes.extend_from_slice(name.as_str().as_bytes());
request_bytes.extend_from_slice(b": ");
request_bytes.extend_from_slice(value.as_bytes());
request_bytes.extend_from_slice(b"\r\n");
}
request_bytes.extend_from_slice(b"\r\n");

// 发送请求到上游
upstream_io.write_all(&request_bytes).await?;
upstream_io.flush().await?;

info!("[forward] WebSocket upgrade request sent to upstream");

// 读取上游响应
// 将 upstream_io 包装在 BufReader 中,以便按行高效读取 HTTP 状态行和头部。
// 在完成 HTTP 响应解析后,我们会通过 reader.into_inner() 取回底层的 upstream_io,
// 继续将同一个 TCP 连接用于 WebSocket 隧道的数据转发。
let mut reader = tokio::io::BufReader::new(upstream_io);
let mut response_line = String::new();
reader.read_line(&mut response_line).await?;

// 检查响应状态码
let status_code = response_line.split_whitespace().nth(1).unwrap_or("");
if status_code.is_empty() {
warn!("[forward] Failed to parse status code from upstream response: {}", response_line);
return Err(io::Error::other(format!(
"Failed to parse status code from upstream response: {}",
response_line
)));
}
if status_code != "101" {
warn!("[forward] WebSocket upgrade failed, upstream returned: {}", response_line);
return Err(io::Error::other(format!("WebSocket upgrade failed: {}", response_line)));
}

info!("[forward] WebSocket upgrade successful, status: {}", status_code);

// 读取并保存响应头
let mut response_headers = Vec::new();
loop {
let mut header_line = String::new();
reader.read_line(&mut header_line).await?;
if header_line == "\r\n" || header_line == "\n" {
break;
}
response_headers.push(header_line);
}

// 从BufReader中取回原始stream
let upstream_io = reader.into_inner();

// 构造 101 响应给客户端,并添加上游返回的响应头
let mut response_builder = Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS);

// 添加上游返回的所有响应头
for header_line in response_headers {
if let Some((name, value)) = header_line.trim_end().split_once(':') {
let name = name.trim();
let value = value.trim();
if let Ok(header_value) = HeaderValue::from_str(value) {
response_builder = response_builder.header(name, header_value);
}
}
}

let client_response = response_builder
.body(http_body_util::Empty::<Bytes>::new().map_err(|e| match e {}).boxed())
.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;

// 启动异步任务进行双向数据转发
tokio::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(client_upgraded) => {
if let Err(e) = Self::tunnel_websocket_forward(upstream_io, client_upgraded).await {
warn!("[forward] WebSocket tunnel error: {e:?}");
}
}
Err(e) => {
warn!("[forward] WebSocket client upgrade error: {e:?}");
}
}
});

Ok(client_response)
}

/// WebSocket 双向数据转发(正向代理场景)
async fn tunnel_websocket_forward(
mut upstream_io: CounterIO<TcpStream, LabelImpl<AccessLabel>>, client: Upgraded,
) -> io::Result<()> {
let mut client_io = TokioIo::new(client);

// 双向数据转发
let _ = tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await?;

Ok(())
}

/// 代理普通请求
/// HTTP/1.1 GET/POST/PUT/DELETE/HEAD
async fn simple_proxy(
Expand All @@ -309,7 +434,27 @@ impl ProxyHandler {
username,
relay_over_tls: None,
};

// 先检测是否是 WebSocket 升级请求(在 request 被消费之前)
let is_websocket = req
.headers()
.get(http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);

mod_http1_proxy_req(&mut req)?;
if is_websocket {
info!(
"[forward] WebSocket upgrade request: {:^35} ==> {} {:?}",
client_socket_addr.to_string(),
req.method(),
req.uri(),
);

return self.handle_websocket_upgrade_forward(req, access_label).await;
}

match self
.forward_proxy_client
.send_request(
Expand Down Expand Up @@ -344,6 +489,15 @@ impl ProxyHandler {
username,
relay_over_tls: Some(forward_bypass_config.is_https),
};

// 先检测是否是 WebSocket 升级请求(在 request 被消费之前)
let is_websocket = req
.headers()
.get(http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);

// 如果配置了 username 和 password,添加 Proxy-Authorization 头
if let (Some(username), Some(password)) = (&forward_bypass_config.username, &forward_bypass_config.password) {
let credentials = format!("{}:{}", username, password);
Expand All @@ -363,6 +517,17 @@ impl ProxyHandler {
info!("change host header: {origin:?} -> {host_header:?}");
}

if is_websocket {
info!(
"[forward_bypass] WebSocket upgrade request: {:^35} ==> {} {:?}",
client_socket_addr.to_string(),
req.method(),
req.uri(),
);

return self.handle_websocket_upgrade_forward(req, access_label).await;
}

warn!("bypass {:?} {} {}", req.version(), req.method(), req.uri());

match self
Expand Down