diff --git a/rust_http_proxy/src/forward_proxy_client.rs b/rust_http_proxy/src/forward_proxy_client.rs index b1946c4..b63d5f0 100644 --- a/rust_http_proxy/src/forward_proxy_client.rs +++ b/rust_http_proxy/src/forward_proxy_client.rs @@ -46,26 +46,6 @@ where } } - #[allow(unused)] - pub async fn send_request_no_cache( - &self, req: Request, access_label: &AccessLabel, ipv6_first: Option, - stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO>, - ) -> Result, std::io::Error> { - // Make a new connection - let mut c = match HttpConnection::connect(access_label, ipv6_first, stream_map_func).await { - Ok(c) => c, - Err(err) => { - error!("failed to connect to host: {}, error: {}", &access_label.target, err); - return Err(io::Error::new(io::ErrorKind::InvalidData, err)); - } - }; - - trace!("HTTP making request to host: {access_label}, request: {req:?}"); - let response = c.send_request(req).await.map_err(io::Error::other)?; - trace!("HTTP received response from host: {access_label}, response: {response:?}"); - Ok(response) - } - /// Make HTTP requests #[inline] pub async fn send_request( diff --git a/rust_http_proxy/src/proxy.rs b/rust_http_proxy/src/proxy.rs index 6a69176..87493ea 100644 --- a/rust_http_proxy/src/proxy.rs +++ b/rust_http_proxy/src/proxy.rs @@ -296,27 +296,77 @@ impl ProxyHandler { // 默认为正向代理 } + /// 检测是否是 WebSocket 升级请求 + fn is_websocket_upgrade(req: &Request) -> bool { + req.headers() + .get(http::header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) + } + /// 处理 WebSocket 升级请求(正向代理场景) async fn handle_websocket_upgrade_forward( &self, mut req: Request, traffic_label: AccessLabel, ) -> Result>, io::Error> { use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; + // 检查是否需要 TLS (wss://) + let is_secure = is_schema_secure(req.uri()); + // 直接建立到上游的 TCP 连接 - let upstream_stream = connect_with_preference(&traffic_label.target, self.config.ipv6_first).await?; - info!("[forward] WebSocket TCP connection established to {}", &traffic_label.target); + let tcp_stream = connect_with_preference(&traffic_label.target, self.config.ipv6_first).await?; + info!( + "[forward] WebSocket TCP connection established to {} (secure: {})", + &traffic_label.target, is_secure + ); + + // 根据是否安全连接决定是否需要 TLS + let stream = if is_secure { + // 建立 TLS 连接 + let connector = build_tls_connector(); + // 从 URI 中提取主机名用于 TLS SNI + let host = req + .uri() + .host() + .ok_or_else(|| io::Error::new(ErrorKind::InvalidInput, "Missing host in URI"))?; + let server_name = pki_types::ServerName::try_from(host) + .map_err(|e| io::Error::new(ErrorKind::InvalidInput, format!("Invalid DNS name: {}", e)))? + .to_owned(); + + match connector.connect(server_name, tcp_stream).await { + Ok(tls_stream) => { + info!("[forward] WebSocket TLS handshake successful"); + EitherTlsStream::Tls { stream: tls_stream } + } + Err(e) => { + warn!("[forward] WebSocket TLS handshake failed: {}", e); + return Err(io::Error::new(ErrorKind::ConnectionAborted, format!("TLS handshake failed: {}", e))); + } + } + } else { + EitherTlsStream::Tcp { stream: tcp_stream } + }; let mut upstream_io = - CounterIO::new(upstream_stream, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone())); + CounterIO::new(stream, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone())); // 构建 HTTP 请求行和头部 let mut request_bytes = Vec::new(); + let version_str = match req.version() { + Version::HTTP_09 => "HTTP/0.9", + Version::HTTP_10 => "HTTP/1.0", + Version::HTTP_11 => "HTTP/1.1", + Version::HTTP_2 => "HTTP/2", + Version::HTTP_3 => "HTTP/3", + _ => "HTTP/1.1", // fallback to HTTP/1.1 + }; request_bytes.extend_from_slice( format!( - "{} {} {:?}\r\n", + "{} {} {}\r\n", req.method(), req.uri().path_and_query().map(|p| p.as_str()).unwrap_or("/"), - req.version() + version_str ) .as_bytes(), ); @@ -342,7 +392,10 @@ impl ProxyHandler { reader.read_line(&mut response_line).await?; // 检查响应状态码 - let status_code = response_line.split_whitespace().nth(1).unwrap_or(""); + let status_code = response_line + .split_whitespace() + .nth(1) + .ok_or_else(|| io::Error::other(format!("Invalid HTTP response line: {}", response_line)))?; if status_code != "101" { warn!("[forward] WebSocket upgrade failed, upstream returned: {}", response_line); return Err(io::Error::other(format!("WebSocket upgrade failed: {}", response_line))); @@ -355,6 +408,7 @@ impl ProxyHandler { loop { let mut header_line = String::new(); reader.read_line(&mut header_line).await?; + // Check if we've reached the end of headers (CRLF or LF) if header_line == "\r\n" || header_line == "\n" { break; } @@ -369,9 +423,10 @@ impl ProxyHandler { // 添加上游返回的所有响应头 for header_line in response_headers { - if let Some((name, value)) = header_line.trim_end().split_once(':') { - let name = name.trim(); - let value = value.trim(); + // Split on first colon only to handle values with colons + if let Some(colon_pos) = header_line.find(':') { + let name = header_line[..colon_pos].trim(); + let value = header_line[colon_pos + 1..].trim(); if let Ok(header_value) = HeaderValue::from_str(value) { response_builder = response_builder.header(name, header_value); } @@ -401,12 +456,12 @@ impl ProxyHandler { /// WebSocket 双向数据转发(正向代理场景) async fn tunnel_websocket_forward( - mut upstream_io: CounterIO>, client: Upgraded, + mut upstream_io: CounterIO>, client: Upgraded, ) -> io::Result<()> { let mut client_io = TokioIo::new(client); // 双向数据转发 - let _ = tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await?; + tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await?; Ok(()) } @@ -426,12 +481,7 @@ impl ProxyHandler { }; // 先检测是否是 WebSocket 升级请求(在 request 被消费之前) - let is_websocket = req - .headers() - .get(http::header::UPGRADE) - .and_then(|v| v.to_str().ok()) - .map(|v| v.eq_ignore_ascii_case("websocket")) - .unwrap_or(false); + let is_websocket = Self::is_websocket_upgrade(&req); mod_http1_proxy_req(&mut req)?; if is_websocket { @@ -481,12 +531,7 @@ impl ProxyHandler { }; // 先检测是否是 WebSocket 升级请求(在 request 被消费之前) - let is_websocket = req - .headers() - .get(http::header::UPGRADE) - .and_then(|v| v.to_str().ok()) - .map(|v| v.eq_ignore_ascii_case("websocket")) - .unwrap_or(false); + let is_websocket = Self::is_websocket_upgrade(&req); // 如果配置了 username 和 password,添加 Proxy-Authorization 头 if let (Some(username), Some(password)) = (&forward_bypass_config.username, &forward_bypass_config.password) {