Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions rust_http_proxy/src/forward_proxy_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,6 @@ where
}
}

#[allow(unused)]
pub async fn send_request_no_cache(
&self, req: Request<B>, access_label: &AccessLabel, ipv6_first: Option<bool>,
stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO<EitherTlsStream, LabelImpl<AccessLabel>>,
) -> Result<Response<body::Incoming>, std::io::Error> {
// Make a new connection
let mut c = match HttpConnection::connect(access_label, ipv6_first, stream_map_func).await {
Ok(c) => c,
Err(err) => {
error!("failed to connect to host: {}, error: {}", &access_label.target, err);
return Err(io::Error::new(io::ErrorKind::InvalidData, err));
}
};

trace!("HTTP making request to host: {access_label}, request: {req:?}");
let response = c.send_request(req).await.map_err(io::Error::other)?;
trace!("HTTP received response from host: {access_label}, response: {response:?}");
Ok(response)
}

/// Make HTTP requests
#[inline]
pub async fn send_request(
Expand Down
91 changes: 68 additions & 23 deletions rust_http_proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,27 +296,77 @@ impl ProxyHandler {
// 默认为正向代理
}

/// 检测是否是 WebSocket 升级请求
fn is_websocket_upgrade(req: &Request<Incoming>) -> bool {
req.headers()
.get(http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
}

/// 处理 WebSocket 升级请求(正向代理场景)
async fn handle_websocket_upgrade_forward(
&self, mut req: Request<Incoming>, traffic_label: AccessLabel,
) -> Result<Response<BoxBody<Bytes, io::Error>>, io::Error> {
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};

// 检查是否需要 TLS (wss://)
let is_secure = is_schema_secure(req.uri());

// 直接建立到上游的 TCP 连接
let upstream_stream = connect_with_preference(&traffic_label.target, self.config.ipv6_first).await?;
info!("[forward] WebSocket TCP connection established to {}", &traffic_label.target);
let tcp_stream = connect_with_preference(&traffic_label.target, self.config.ipv6_first).await?;
info!(
"[forward] WebSocket TCP connection established to {} (secure: {})",
&traffic_label.target, is_secure
);

// 根据是否安全连接决定是否需要 TLS
let stream = if is_secure {
// 建立 TLS 连接
let connector = build_tls_connector();
// 从 URI 中提取主机名用于 TLS SNI
let host = req
.uri()
.host()
.ok_or_else(|| io::Error::new(ErrorKind::InvalidInput, "Missing host in URI"))?;
let server_name = pki_types::ServerName::try_from(host)
.map_err(|e| io::Error::new(ErrorKind::InvalidInput, format!("Invalid DNS name: {}", e)))?
.to_owned();

match connector.connect(server_name, tcp_stream).await {
Ok(tls_stream) => {
info!("[forward] WebSocket TLS handshake successful");
EitherTlsStream::Tls { stream: tls_stream }
}
Err(e) => {
warn!("[forward] WebSocket TLS handshake failed: {}", e);
return Err(io::Error::new(ErrorKind::ConnectionAborted, format!("TLS handshake failed: {}", e)));
}
}
} else {
EitherTlsStream::Tcp { stream: tcp_stream }
};

let mut upstream_io =
CounterIO::new(upstream_stream, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone()));
CounterIO::new(stream, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone()));

// 构建 HTTP 请求行和头部
let mut request_bytes = Vec::new();
let version_str = match req.version() {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2",
Version::HTTP_3 => "HTTP/3",
_ => "HTTP/1.1", // fallback to HTTP/1.1
};
request_bytes.extend_from_slice(
format!(
"{} {} {:?}\r\n",
"{} {} {}\r\n",
req.method(),
req.uri().path_and_query().map(|p| p.as_str()).unwrap_or("/"),
req.version()
version_str
)
.as_bytes(),
);
Expand All @@ -342,7 +392,10 @@ impl ProxyHandler {
reader.read_line(&mut response_line).await?;

// 检查响应状态码
let status_code = response_line.split_whitespace().nth(1).unwrap_or("");
let status_code = response_line
.split_whitespace()
.nth(1)
.ok_or_else(|| io::Error::other(format!("Invalid HTTP response line: {}", response_line)))?;
if status_code != "101" {
warn!("[forward] WebSocket upgrade failed, upstream returned: {}", response_line);
return Err(io::Error::other(format!("WebSocket upgrade failed: {}", response_line)));
Expand All @@ -355,6 +408,7 @@ impl ProxyHandler {
loop {
let mut header_line = String::new();
reader.read_line(&mut header_line).await?;
// Check if we've reached the end of headers (CRLF or LF)
if header_line == "\r\n" || header_line == "\n" {
break;
}
Expand All @@ -369,9 +423,10 @@ impl ProxyHandler {

// 添加上游返回的所有响应头
for header_line in response_headers {
if let Some((name, value)) = header_line.trim_end().split_once(':') {
let name = name.trim();
let value = value.trim();
// Split on first colon only to handle values with colons
if let Some(colon_pos) = header_line.find(':') {
let name = header_line[..colon_pos].trim();
let value = header_line[colon_pos + 1..].trim();
if let Ok(header_value) = HeaderValue::from_str(value) {
response_builder = response_builder.header(name, header_value);
}
Expand Down Expand Up @@ -401,12 +456,12 @@ impl ProxyHandler {

/// WebSocket 双向数据转发(正向代理场景)
async fn tunnel_websocket_forward(
mut upstream_io: CounterIO<TcpStream, LabelImpl<AccessLabel>>, client: Upgraded,
mut upstream_io: CounterIO<EitherTlsStream, LabelImpl<AccessLabel>>, client: Upgraded,
) -> io::Result<()> {
let mut client_io = TokioIo::new(client);

// 双向数据转发
let _ = tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await?;
tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await?;

Ok(())
}
Expand All @@ -426,12 +481,7 @@ impl ProxyHandler {
};

// 先检测是否是 WebSocket 升级请求(在 request 被消费之前)
let is_websocket = req
.headers()
.get(http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
let is_websocket = Self::is_websocket_upgrade(&req);

mod_http1_proxy_req(&mut req)?;
if is_websocket {
Expand Down Expand Up @@ -481,12 +531,7 @@ impl ProxyHandler {
};

// 先检测是否是 WebSocket 升级请求(在 request 被消费之前)
let is_websocket = req
.headers()
.get(http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
let is_websocket = Self::is_websocket_upgrade(&req);

// 如果配置了 username 和 password,添加 Proxy-Authorization 头
if let (Some(username), Some(password)) = (&forward_bypass_config.username, &forward_bypass_config.password) {
Expand Down