diff --git a/Cargo.lock b/Cargo.lock index 236039f0ff..581c3eab76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1377,9 +1377,9 @@ dependencies = [ [[package]] name = "flume" -version = "0.11.1" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" dependencies = [ "futures-core", "futures-sink", diff --git a/sqlx-core/src/net/mod.rs b/sqlx-core/src/net/mod.rs index f9c43668ab..6265a94a34 100644 --- a/sqlx-core/src/net/mod.rs +++ b/sqlx-core/src/net/mod.rs @@ -2,5 +2,6 @@ mod socket; pub mod tls; pub use socket::{ - connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer, + connect_tcp, connect_uds, connect_with, BufferedSocket, Socket, SocketIntoBox, WithSocket, + WriteBuffer, }; diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index 0f9aae61b4..30f1d86a20 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -249,6 +249,16 @@ async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result(socket: S, with_socket: Ws) -> Ws::Output { + with_socket.with_socket(socket).await +} + /// Connect a Unix Domain Socket at the given path. /// /// Returns an error if Unix Domain Sockets are not supported on this platform. diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index f61654d876..c1da84d43d 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -20,8 +20,31 @@ impl MySqlConnection { None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?, }; - let stream = handshake?; + Self::from_stream(options, handshake?) + } + + /// Establish a connection over a pre-connected socket. + /// + /// The provided socket must already be connected to a MySQL-compatible + /// server. The MySQL handshake and authentication will be performed + /// using the credentials from `options`. + /// + /// This enables custom transports such as in-memory pipes, simulation + /// frameworks (e.g. turmoil), SSH tunnels, or SOCKS proxies. + /// + /// Note: this only performs the low-level handshake and authentication. + /// Use [`MySqlConnectOptions::connect_with_socket()`] for a fully + /// initialized connection (with `SET NAMES`, `sql_mode`, etc.). + pub async fn connect_with_socket( + options: &MySqlConnectOptions, + socket: S, + ) -> Result { + let do_handshake = DoHandshake::new(options)?; + let stream = crate::net::connect_with(socket, do_handshake).await?; + Self::from_stream(options, stream) + } + fn from_stream(options: &MySqlConnectOptions, stream: MySqlStream) -> Result { Ok(Self { inner: Box::new(MySqlConnectionInner { stream, diff --git a/sqlx-mysql/src/options/connect.rs b/sqlx-mysql/src/options/connect.rs index f3b0492781..816dee9bd8 100644 --- a/sqlx-mysql/src/options/connect.rs +++ b/sqlx-mysql/src/options/connect.rs @@ -1,54 +1,68 @@ use crate::connection::ConnectOptions; use crate::error::Error; use crate::executor::Executor; +use crate::net::Socket; use crate::{MySqlConnectOptions, MySqlConnection}; use log::LevelFilter; use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::Url; use std::time::Duration; -impl ConnectOptions for MySqlConnectOptions { - type Connection = MySqlConnection; - - fn from_url(url: &Url) -> Result { - Self::parse_from_url(url) - } - - fn to_url_lossy(&self) -> Url { - self.build_url() +impl MySqlConnectOptions { + /// Establish a fully initialized connection over a pre-connected socket. + /// + /// This performs the MySQL handshake, authentication, and post-connect + /// initialization (`SET NAMES`, `sql_mode`, `time_zone`) over the + /// provided socket. + /// + /// The socket must already be connected to a MySQL-compatible server. + /// This enables custom transports such as in-memory pipes, simulation + /// frameworks (e.g. turmoil), SSH tunnels, or SOCKS proxies. + /// + /// # Example + /// + /// ```rust,ignore + /// use sqlx::mysql::MySqlConnectOptions; + /// + /// let options = MySqlConnectOptions::new() + /// .username("root") + /// .database("mydb"); + /// + /// let stream = tokio::net::TcpStream::connect("127.0.0.1:3306").await?; + /// let conn = options.connect_with_socket(stream).await?; + /// ``` + pub async fn connect_with_socket( + &self, + socket: S, + ) -> Result { + let mut conn = MySqlConnection::connect_with_socket(self, socket).await?; + self.after_connect(&mut conn).await?; + Ok(conn) } - async fn connect(&self) -> Result - where - Self::Connection: Sized, - { - let mut conn = MySqlConnection::establish(self).await?; - - // After the connection is established, we initialize by configuring a few - // connection parameters - - // https://mariadb.com/kb/en/sql-mode/ - - // PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator. - // This means that "A" || "B" can be used in place of CONCAT("A", "B"). - - // NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is - // not available, a warning is given and the default storage - // engine is used instead. - - // NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust. - - // NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust. - - // -- - - // Setting the time zone allows us to assume that the output - // from a TIMESTAMP field is UTC - - // -- - - // https://mathiasbynens.be/notes/mysql-utf8mb4 - + /// Post-connection initialization shared between `connect()` and + /// `connect_with_socket()`. + /// + /// After the connection is established, we initialize by configuring a few + /// connection parameters: + /// + /// - + /// + /// - `PIPES_AS_CONCAT` - Allows using the pipe character (ASCII 124) as string concatenation + /// operator. This means that "A" || "B" can be used in place of CONCAT("A", "B"). + /// + /// - `NO_ENGINE_SUBSTITUTION` - If not set, if the available storage engine specified by a + /// CREATE TABLE is not available, a warning is given and the default storage engine is used + /// instead. + /// + /// - `NO_ZERO_DATE` - Don't allow '0000-00-00'. This is invalid in Rust. + /// + /// - `NO_ZERO_IN_DATE` - Don't allow 'YYYY-00-00'. This is invalid in Rust. + /// + /// Setting the time zone allows us to assume that the output from a TIMESTAMP field is UTC. + /// + /// - + async fn after_connect(&self, conn: &mut MySqlConnection) -> Result<(), Error> { let mut sql_mode = Vec::new(); if self.pipes_as_concat { sql_mode.push(r#"PIPES_AS_CONCAT"#); @@ -64,11 +78,9 @@ impl ConnectOptions for MySqlConnectOptions { sql_mode.join(",") )); } - if let Some(timezone) = &self.timezone { options.push(format!(r#"time_zone='{}'"#, timezone)); } - if self.set_names { // As it turns out, we don't _have_ to set a collation if we don't want to. // We can let the server choose the default collation for the charset. @@ -88,6 +100,27 @@ impl ConnectOptions for MySqlConnectOptions { .await?; } + Ok(()) + } +} + +impl ConnectOptions for MySqlConnectOptions { + type Connection = MySqlConnection; + + fn from_url(url: &Url) -> Result { + Self::parse_from_url(url) + } + + fn to_url_lossy(&self) -> Url { + self.build_url() + } + + async fn connect(&self) -> Result + where + Self::Connection: Sized, + { + let mut conn = MySqlConnection::establish(self).await?; + self.after_connect(&mut conn).await?; Ok(conn) } diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 634b71de4b..47bcc70d6b 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -7,6 +7,7 @@ use crate::io::StatementId; use crate::message::{ Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup, }; +use crate::net::Socket; use crate::{PgConnectOptions, PgConnection}; use super::PgConnectionInner; @@ -16,9 +17,35 @@ use super::PgConnectionInner; impl PgConnection { pub(crate) async fn establish(options: &PgConnectOptions) -> Result { - // Upgrade to TLS if we were asked to and the server supports it - let mut stream = PgStream::connect(options).await?; + let stream = PgStream::connect(options).await?; + Self::establish_with_stream(options, stream).await + } + + /// Establish a connection over a pre-connected socket. + /// + /// The provided socket must already be connected to a + /// PostgreSQL-compatible server. The startup handshake, TLS upgrade + /// (if configured), and authentication will be performed over this + /// socket. + /// + /// This enables custom transports such as in-memory pipes, simulation + /// frameworks (e.g. turmoil), SSH tunnels, or SOCKS proxies. + /// + /// Note: this only performs the low-level handshake and authentication. + /// Use [`PgConnectOptions::connect_with_socket()`] for a fully + /// initialized connection. + pub async fn connect_with_socket( + options: &PgConnectOptions, + socket: S, + ) -> Result { + let stream = PgStream::connect_with_socket(options, socket).await?; + Self::establish_with_stream(options, stream).await + } + async fn establish_with_stream( + options: &PgConnectOptions, + mut stream: PgStream, + ) -> Result { // To begin a session, a frontend opens a connection to the server // and sends a startup message. diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index e8a1aedc47..824bce8297 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -57,6 +57,24 @@ impl PgStream { }) } + /// Create a stream from a pre-connected socket. + /// + /// The socket must already be connected to a PostgreSQL server. + /// TLS upgrade will be attempted if configured in `options`. + pub(super) async fn connect_with_socket( + options: &PgConnectOptions, + socket: S, + ) -> Result { + let socket = net::connect_with(socket, MaybeUpgradeTls(options)).await?; + + Ok(Self { + inner: BufferedSocket::new(socket), + notifications: None, + parameter_statuses: BTreeMap::default(), + server_version_num: None, + }) + } + #[inline(always)] pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> { self.write(EncodeMessage(message)) diff --git a/sqlx-postgres/src/options/connect.rs b/sqlx-postgres/src/options/connect.rs index a80ea2165f..dd5b87915d 100644 --- a/sqlx-postgres/src/options/connect.rs +++ b/sqlx-postgres/src/options/connect.rs @@ -1,11 +1,39 @@ use crate::connection::ConnectOptions; use crate::error::Error; +use crate::net::Socket; use crate::{PgConnectOptions, PgConnection}; use log::LevelFilter; use sqlx_core::Url; use std::future::Future; use std::time::Duration; +impl PgConnectOptions { + /// Establish a connection over a pre-connected socket. + /// + /// This performs the PostgreSQL startup handshake, TLS upgrade + /// (if configured), and authentication over the provided socket. + /// + /// The socket must already be connected to a PostgreSQL-compatible server. + /// This enables custom transports such as in-memory pipes, simulation + /// frameworks (e.g. turmoil), SSH tunnels, or SOCKS proxies. + /// + /// # Example + /// + /// ```rust,ignore + /// use sqlx::postgres::PgConnectOptions; + /// + /// let options = PgConnectOptions::new() + /// .username("postgres") + /// .database("mydb"); + /// + /// let stream = tokio::net::TcpStream::connect("127.0.0.1:5432").await?; + /// let conn = options.connect_with_socket(stream).await?; + /// ``` + pub async fn connect_with_socket(&self, socket: S) -> Result { + PgConnection::connect_with_socket(self, socket).await + } +} + impl ConnectOptions for PgConnectOptions { type Connection = PgConnection; diff --git a/src/lib.rs b/src/lib.rs index 438463210d..a5b31cf873 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,6 +38,18 @@ pub use sqlx_core::types::Type; pub use sqlx_core::value::{Value, ValueRef}; pub use sqlx_core::Either; +/// Networking primitives used by SQLx database drivers. +/// +/// The [`Socket`][net::Socket] trait allows implementing custom transports +/// for database connections (e.g. in-memory pipes, simulation frameworks, +/// SSH tunnels, SOCKS proxies). Use with +/// [`MySqlConnectOptions::connect_with_socket()`][crate::mysql::MySqlConnectOptions::connect_with_socket] +/// or [`PgConnectOptions::connect_with_socket()`][crate::postgres::PgConnectOptions::connect_with_socket]. +pub mod net { + pub use sqlx_core::io::ReadBuf; + pub use sqlx_core::net::{connect_with, Socket}; +} + #[doc(inline)] pub use sqlx_core::error::{self, Error, Result};