diff --git a/rust_http_proxy/src/forward_proxy_client.rs b/rust_http_proxy/src/forward_proxy_client.rs index 723638c..4c91d6b 100644 --- a/rust_http_proxy/src/forward_proxy_client.rs +++ b/rust_http_proxy/src/forward_proxy_client.rs @@ -46,6 +46,26 @@ where } } + #[allow(unused)] + pub async fn send_request_no_cache( + &self, req: Request, access_label: &AccessLabel, ipv6_first: Option, + stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO>, + ) -> Result, std::io::Error> { + // Make a new connection + let mut c = match HttpConnection::connect(access_label, ipv6_first, stream_map_func).await { + Ok(c) => c, + Err(err) => { + error!("failed to connect to host: {}, error: {}", &access_label.target, err); + return Err(io::Error::new(io::ErrorKind::InvalidData, err)); + } + }; + + trace!("HTTP making request to host: {access_label}, request: {req:?}"); + let response = c.send_request(req).await.map_err(io::Error::other)?; + trace!("HTTP received response from host: {access_label}, response: {response:?}"); + Ok(response) + } + /// Make HTTP requests #[inline] pub async fn send_request( @@ -97,7 +117,7 @@ where None } - async fn send_request_conn( + pub(crate) async fn send_request_conn( &self, access_label: &AccessLabel, mut c: HttpConnection, req: Request, ) -> hyper::Result> { trace!("HTTP making request to host: {access_label}, request: {req:?}"); @@ -164,7 +184,7 @@ fn get_keep_alive_val(values: header::GetAll) -> Option { } #[allow(dead_code)] -enum HttpConnection { +pub(crate) enum HttpConnection { Http1(http1::SendRequest), } @@ -174,7 +194,7 @@ where B::Data: Send, B::Error: Into>, { - async fn connect( + pub(crate) async fn connect( access_label: &AccessLabel, ipv6_first: Option, stream_map_func: impl FnOnce(EitherTlsStream, AccessLabel) -> CounterIO>, ) -> io::Result> { diff --git a/rust_http_proxy/src/proxy.rs b/rust_http_proxy/src/proxy.rs index c4825e9..97f7fc5 100644 --- a/rust_http_proxy/src/proxy.rs +++ b/rust_http_proxy/src/proxy.rs @@ -296,6 +296,131 @@ impl ProxyHandler { // 默认为正向代理 } + /// 处理 WebSocket 升级请求(正向代理场景) + async fn handle_websocket_upgrade_forward( + &self, mut req: Request, traffic_label: AccessLabel, + ) -> Result>, io::Error> { + use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; + + // 直接建立到上游的 TCP 连接 + let upstream_stream = connect_with_preference(&traffic_label.target, self.config.ipv6_first).await?; + info!("[forward] WebSocket TCP connection established to {}", &traffic_label.target); + + let mut upstream_io = + CounterIO::new(upstream_stream, METRICS.proxy_traffic.clone(), LabelImpl::new(traffic_label.clone())); + + // 构建 HTTP 请求行和头部 + let mut request_bytes = Vec::new(); + request_bytes.extend_from_slice( + format!( + "{} {} {:?}\r\n", + req.method(), + req.uri().path_and_query().map(|p| p.as_str()).unwrap_or("/"), + req.version() + ) + .as_bytes(), + ); + + // 添加所有请求头 + for (name, value) in req.headers() { + request_bytes.extend_from_slice(name.as_str().as_bytes()); + request_bytes.extend_from_slice(b": "); + request_bytes.extend_from_slice(value.as_bytes()); + request_bytes.extend_from_slice(b"\r\n"); + } + request_bytes.extend_from_slice(b"\r\n"); + + // 发送请求到上游 + upstream_io.write_all(&request_bytes).await?; + upstream_io.flush().await?; + + info!("[forward] WebSocket upgrade request sent to upstream"); + + // 读取上游响应 + // 将 upstream_io 包装在 BufReader 中,以便按行高效读取 HTTP 状态行和头部。 + // 在完成 HTTP 响应解析后,我们会通过 reader.into_inner() 取回底层的 upstream_io, + // 继续将同一个 TCP 连接用于 WebSocket 隧道的数据转发。 + let mut reader = tokio::io::BufReader::new(upstream_io); + let mut response_line = String::new(); + reader.read_line(&mut response_line).await?; + + // 检查响应状态码 + let status_code = response_line.split_whitespace().nth(1).unwrap_or(""); + if status_code.is_empty() { + warn!("[forward] Failed to parse status code from upstream response: {}", response_line); + return Err(io::Error::other(format!( + "Failed to parse status code from upstream response: {}", + response_line + ))); + } + if status_code != "101" { + warn!("[forward] WebSocket upgrade failed, upstream returned: {}", response_line); + return Err(io::Error::other(format!("WebSocket upgrade failed: {}", response_line))); + } + + info!("[forward] WebSocket upgrade successful, status: {}", status_code); + + // 读取并保存响应头 + let mut response_headers = Vec::new(); + loop { + let mut header_line = String::new(); + reader.read_line(&mut header_line).await?; + if header_line == "\r\n" || header_line == "\n" { + break; + } + response_headers.push(header_line); + } + + // 从BufReader中取回原始stream + let upstream_io = reader.into_inner(); + + // 构造 101 响应给客户端,并添加上游返回的响应头 + let mut response_builder = Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS); + + // 添加上游返回的所有响应头 + for header_line in response_headers { + if let Some((name, value)) = header_line.trim_end().split_once(':') { + let name = name.trim(); + let value = value.trim(); + if let Ok(header_value) = HeaderValue::from_str(value) { + response_builder = response_builder.header(name, header_value); + } + } + } + + let client_response = response_builder + .body(http_body_util::Empty::::new().map_err(|e| match e {}).boxed()) + .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; + + // 启动异步任务进行双向数据转发 + tokio::spawn(async move { + match hyper::upgrade::on(&mut req).await { + Ok(client_upgraded) => { + if let Err(e) = Self::tunnel_websocket_forward(upstream_io, client_upgraded).await { + warn!("[forward] WebSocket tunnel error: {e:?}"); + } + } + Err(e) => { + warn!("[forward] WebSocket client upgrade error: {e:?}"); + } + } + }); + + Ok(client_response) + } + + /// WebSocket 双向数据转发(正向代理场景) + async fn tunnel_websocket_forward( + mut upstream_io: CounterIO>, client: Upgraded, + ) -> io::Result<()> { + let mut client_io = TokioIo::new(client); + + // 双向数据转发 + let _ = tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await?; + + Ok(()) + } + /// 代理普通请求 /// HTTP/1.1 GET/POST/PUT/DELETE/HEAD async fn simple_proxy( @@ -309,7 +434,27 @@ impl ProxyHandler { username, relay_over_tls: None, }; + + // 先检测是否是 WebSocket 升级请求(在 request 被消费之前) + let is_websocket = req + .headers() + .get(http::header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + mod_http1_proxy_req(&mut req)?; + if is_websocket { + info!( + "[forward] WebSocket upgrade request: {:^35} ==> {} {:?}", + client_socket_addr.to_string(), + req.method(), + req.uri(), + ); + + return self.handle_websocket_upgrade_forward(req, access_label).await; + } + match self .forward_proxy_client .send_request( @@ -344,6 +489,15 @@ impl ProxyHandler { username, relay_over_tls: Some(forward_bypass_config.is_https), }; + + // 先检测是否是 WebSocket 升级请求(在 request 被消费之前) + let is_websocket = req + .headers() + .get(http::header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + // 如果配置了 username 和 password,添加 Proxy-Authorization 头 if let (Some(username), Some(password)) = (&forward_bypass_config.username, &forward_bypass_config.password) { let credentials = format!("{}:{}", username, password); @@ -363,6 +517,17 @@ impl ProxyHandler { info!("change host header: {origin:?} -> {host_header:?}"); } + if is_websocket { + info!( + "[forward_bypass] WebSocket upgrade request: {:^35} ==> {} {:?}", + client_socket_addr.to_string(), + req.method(), + req.uri(), + ); + + return self.handle_websocket_upgrade_forward(req, access_label).await; + } + warn!("bypass {:?} {} {}", req.version(), req.method(), req.uri()); match self