diff --git a/src/api/client/query.rs b/src/api/client/query.rs index 302db63..a90d2af 100644 --- a/src/api/client/query.rs +++ b/src/api/client/query.rs @@ -1,12 +1,17 @@ use std::str::FromStr; use async_trait::async_trait; -use futures::{Sink, SinkExt}; -use postgres_types::Oid; +use bytes::Bytes; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use postgres_types::{Oid, Type}; use crate::api::results::{FieldInfo, Tag}; -use crate::error::{ErrorInfo, PgWireClientError, PgWireClientResult}; -use crate::messages::data::{DataRow, RowDescription}; +use crate::error::{ErrorInfo, PgWireClientError, PgWireClientResult, PgWireError}; +use crate::messages::data::{DataRow, ParameterDescription, RowDescription}; +use crate::messages::extendedquery::{ + Bind, Close, Describe, Execute, Flush, Parse, Sync, TARGET_TYPE_BYTE_PORTAL, + TARGET_TYPE_BYTE_STATEMENT, +}; use crate::messages::response::{CommandComplete, EmptyQueryResponse, ReadyForQuery}; use crate::messages::simplequery::Query; use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; @@ -14,6 +19,30 @@ use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; use super::result::DataRowsReader; use super::{ClientInfo, ReadyState}; +#[derive(Debug, Clone)] +pub struct PrepareResponse { + pub name: Option, + pub param_types: Vec, +} + +#[derive(Debug, Default)] +pub struct DescribeResponse { + pub param_types: Vec, + pub fields: Vec, +} + +#[derive(Debug)] +pub enum ExecuteResult { + Complete(T), + Suspended(T), +} + +#[derive(Debug, Clone, Copy)] +pub enum DescribeTarget<'a> { + Statement(Option<&'a str>), + Portal(Option<&'a str>), +} + #[async_trait] pub trait SimpleQueryHandler: Send { type QueryResponse; @@ -102,6 +131,65 @@ pub trait SimpleQueryHandler: Send { PgWireClientError: From<>::Error>; } +#[async_trait] +pub trait ExtendedQueryHandler: Send { + type QueryResponse; + + async fn parse(&mut self, client: &mut C, query: Parse) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>; + + async fn bind(&mut self, client: &mut C, bind: Bind) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>; + + async fn execute(&mut self, client: &mut C, execute: Execute) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>; + + async fn describe(&mut self, client: &mut C, describe: Describe) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>; + + async fn close(&mut self, client: &mut C, close: Close) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>; + + async fn sync(&mut self, client: &mut C, sync: Sync) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>; + + async fn flush(&mut self, client: &mut C, flush: Flush) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>; + + async fn on_parameter_description( + &mut self, + msg: ParameterDescription, + ) -> PgWireClientResult>; + + async fn on_row_description( + &mut self, + msg: RowDescription, + ) -> PgWireClientResult>; + + async fn on_data_row(&mut self, msg: DataRow) -> PgWireClientResult; + + async fn on_command_complete(&mut self, msg: CommandComplete) -> PgWireClientResult; + + async fn on_portal_suspended(&mut self) -> PgWireClientResult<()>; +} + +#[derive(Debug)] +pub struct ExtendedQueryState {} + #[derive(Debug)] pub enum Response { EmptyQuery, @@ -254,3 +342,348 @@ impl SimpleQueryHandler for DefaultSimpleQueryHandler { Ok(responses) } } + +pub struct ExtendedQueryClient<'a, C, H> { + client: &'a mut C, + handler: &'a mut H, +} + +impl<'a, C, H> ExtendedQueryClient<'a, C, H> +where + C: ClientInfo + + Sink + + Stream> + + Unpin + + Send, + H: ExtendedQueryHandler, +{ + pub fn new(client: &'a mut C, handler: &'a mut H) -> Self { + Self { client, handler } + } + + pub async fn prepare( + &mut self, + name: Option<&str>, + query: &str, + param_types: &[Oid], + ) -> PgWireClientResult { + let parse = Parse::new( + name.map(|n| n.to_owned()), + query.to_owned(), + param_types.to_vec(), + ); + self.handler.parse(self.client, parse).await?; + self.handler.sync(self.client, Sync::new()).await?; + + let mut param_type_result = Vec::new(); + let mut response = PrepareResponse { + name: name.map(|n| n.to_owned()), + param_types: Vec::new(), + }; + + while let Some(message_result) = self.client.next().await { + let message = message_result?; + match message { + PgWireBackendMessage::ParseComplete(_) => {} + PgWireBackendMessage::ParameterDescription(param_desc) => { + param_type_result = self.handler.on_parameter_description(param_desc).await?; + } + PgWireBackendMessage::RowDescription(row_desc) => { + let _ = self.handler.on_row_description(row_desc).await?; + } + PgWireBackendMessage::NoData(_) => {} + PgWireBackendMessage::ReadyForQuery(_) => { + response.param_types = param_type_result; + return Ok(response); + } + PgWireBackendMessage::ErrorResponse(error) => { + return Err(ErrorInfo::from(error).into()); + } + PgWireBackendMessage::NoticeResponse(_) => {} + _ => { + return Err(PgWireClientError::UnexpectedMessage(Box::new(message))); + } + } + } + + Err(PgWireClientError::UnexpectedEOF) + } + + pub async fn bind( + &mut self, + portal: Option<&str>, + statement: Option<&str>, + params: Vec>, + result_formats: Vec, + ) -> PgWireClientResult<()> { + let bind = Bind::new( + portal.map(|p| p.to_owned()), + statement.map(|s| s.to_owned()), + vec![], + params, + result_formats, + ); + self.handler.bind(self.client, bind).await?; + self.handler.sync(self.client, Sync::new()).await?; + + while let Some(message_result) = self.client.next().await { + let message = message_result?; + match message { + PgWireBackendMessage::BindComplete(_) => {} + PgWireBackendMessage::ReadyForQuery(_) => { + return Ok(()); + } + PgWireBackendMessage::ErrorResponse(error) => { + return Err(ErrorInfo::from(error).into()); + } + PgWireBackendMessage::NoticeResponse(_) => {} + _ => { + return Err(PgWireClientError::UnexpectedMessage(Box::new(message))); + } + } + } + + Err(PgWireClientError::UnexpectedEOF) + } + + pub async fn execute( + &mut self, + portal: Option<&str>, + max_rows: i32, + ) -> PgWireClientResult>> { + let execute = Execute::new(portal.map(|p| p.to_owned()), max_rows); + self.handler.execute(self.client, execute).await?; + self.handler.sync(self.client, Sync::new()).await?; + + let mut rows = Vec::new(); + let mut is_suspended = false; + + while let Some(message_result) = self.client.next().await { + let message = message_result?; + match message { + PgWireBackendMessage::DataRow(data_row) => { + let row = self.handler.on_data_row(data_row).await?; + rows.push(row); + } + PgWireBackendMessage::CommandComplete(command_complete) => { + self.handler.on_command_complete(command_complete).await?; + } + PgWireBackendMessage::PortalSuspended(_) => { + self.handler.on_portal_suspended().await?; + is_suspended = true; + } + PgWireBackendMessage::ReadyForQuery(_) => { + if is_suspended { + return Ok(ExecuteResult::Suspended(rows)); + } else { + return Ok(ExecuteResult::Complete(rows)); + } + } + PgWireBackendMessage::ErrorResponse(error) => { + return Err(ErrorInfo::from(error).into()); + } + PgWireBackendMessage::NoticeResponse(_) => {} + _ => { + return Err(PgWireClientError::UnexpectedMessage(Box::new(message))); + } + } + } + + Err(PgWireClientError::UnexpectedEOF) + } + + pub async fn describe( + &mut self, + target: DescribeTarget<'_>, + ) -> PgWireClientResult { + let (target_type, name) = match target { + DescribeTarget::Statement(name) => (TARGET_TYPE_BYTE_STATEMENT, name), + DescribeTarget::Portal(name) => (TARGET_TYPE_BYTE_PORTAL, name), + }; + let describe = Describe::new(target_type, name.map(|n| n.to_owned())); + self.handler.describe(self.client, describe).await?; + self.handler.sync(self.client, Sync::new()).await?; + + let mut response = DescribeResponse::default(); + + while let Some(message_result) = self.client.next().await { + let message = message_result?; + match message { + PgWireBackendMessage::ParameterDescription(param_desc) => { + response.param_types = + self.handler.on_parameter_description(param_desc).await?; + } + PgWireBackendMessage::RowDescription(row_desc) => { + response.fields = self.handler.on_row_description(row_desc).await?; + } + PgWireBackendMessage::NoData(_) => {} + PgWireBackendMessage::ReadyForQuery(_) => { + return Ok(response); + } + PgWireBackendMessage::ErrorResponse(error) => { + return Err(ErrorInfo::from(error).into()); + } + PgWireBackendMessage::NoticeResponse(_) => {} + _ => { + return Err(PgWireClientError::UnexpectedMessage(Box::new(message))); + } + } + } + + Err(PgWireClientError::UnexpectedEOF) + } + + pub async fn close(&mut self, target: DescribeTarget<'_>) -> PgWireClientResult<()> { + let (target_type, name) = match target { + DescribeTarget::Statement(name) => (TARGET_TYPE_BYTE_STATEMENT, name), + DescribeTarget::Portal(name) => (TARGET_TYPE_BYTE_PORTAL, name), + }; + let close = Close::new(target_type, name.map(|n| n.to_owned())); + self.handler.close(self.client, close).await?; + self.handler.sync(self.client, Sync::new()).await?; + + while let Some(message_result) = self.client.next().await { + let message = message_result?; + match message { + PgWireBackendMessage::CloseComplete(_) => {} + PgWireBackendMessage::ReadyForQuery(_) => { + return Ok(()); + } + PgWireBackendMessage::ErrorResponse(error) => { + return Err(ErrorInfo::from(error).into()); + } + PgWireBackendMessage::NoticeResponse(_) => {} + _ => { + return Err(PgWireClientError::UnexpectedMessage(Box::new(message))); + } + } + } + + Err(PgWireClientError::UnexpectedEOF) + } + + pub async fn query( + &mut self, + sql: &str, + param_types: &[Oid], + params: Vec>, + ) -> PgWireClientResult> { + self.prepare(None, sql, param_types).await?; + self.bind(None, None, params, vec![]).await?; + let result = self.execute(None, 0).await?; + + match result { + ExecuteResult::Complete(rows) => Ok(rows), + ExecuteResult::Suspended(rows) => { + self.close(DescribeTarget::Portal(None)).await?; + Ok(rows) + } + } + } +} + +#[derive(Default, new)] +pub struct DefaultExtendedQueryHandler { + #[new(default)] + current_row: Option, + #[new(default)] + current_fields: Vec, +} + +#[async_trait] +impl ExtendedQueryHandler for DefaultExtendedQueryHandler { + type QueryResponse = DataRow; + + async fn parse(&mut self, client: &mut C, query: Parse) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>, + { + client.send(PgWireFrontendMessage::Parse(query)).await?; + Ok(()) + } + + async fn bind(&mut self, client: &mut C, bind: Bind) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>, + { + client.send(PgWireFrontendMessage::Bind(bind)).await?; + Ok(()) + } + + async fn execute(&mut self, client: &mut C, execute: Execute) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>, + { + client.send(PgWireFrontendMessage::Execute(execute)).await?; + Ok(()) + } + + async fn describe(&mut self, client: &mut C, describe: Describe) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>, + { + client + .send(PgWireFrontendMessage::Describe(describe)) + .await?; + Ok(()) + } + + async fn close(&mut self, client: &mut C, close: Close) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>, + { + client.send(PgWireFrontendMessage::Close(close)).await?; + Ok(()) + } + + async fn sync(&mut self, client: &mut C, _sync: Sync) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>, + { + client.send(PgWireFrontendMessage::Sync(_sync)).await?; + Ok(()) + } + + async fn flush(&mut self, client: &mut C, _flush: Flush) -> PgWireClientResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + PgWireClientError: From<>::Error>, + { + client.send(PgWireFrontendMessage::Flush(_flush)).await?; + Ok(()) + } + + async fn on_parameter_description( + &mut self, + msg: ParameterDescription, + ) -> PgWireClientResult> { + Ok(msg.types.into_iter().filter_map(Type::from_oid).collect()) + } + + async fn on_row_description( + &mut self, + msg: RowDescription, + ) -> PgWireClientResult> { + self.current_fields = msg.fields.into_iter().map(|f| f.into()).collect(); + Ok(self.current_fields.clone()) + } + + async fn on_data_row(&mut self, msg: DataRow) -> PgWireClientResult { + self.current_row = Some(msg.clone()); + Ok(msg) + } + + async fn on_command_complete(&mut self, _msg: CommandComplete) -> PgWireClientResult { + Ok(_msg.tag.parse::()?) + } + + async fn on_portal_suspended(&mut self) -> PgWireClientResult<()> { + Ok(()) + } +} diff --git a/src/tokio/client.rs b/src/tokio/client.rs index cccdfa0..d2f5d1a 100644 --- a/src/tokio/client.rs +++ b/src/tokio/client.rs @@ -22,7 +22,7 @@ use tokio_util::codec::{Decoder, Encoder, Framed}; use super::TlsConnector; use crate::api::client::auth::StartupHandler; use crate::api::client::config::Host; -use crate::api::client::query::SimpleQueryHandler; +use crate::api::client::query::{ExtendedQueryClient, ExtendedQueryHandler, SimpleQueryHandler}; use crate::api::client::{ClientInfo, Config, ReadyState, ServerInformation}; use crate::error::{PgWireClientError, PgWireClientResult, PgWireError}; use crate::messages::{ @@ -194,6 +194,17 @@ impl PgWireClient { Err(PgWireClientError::UnexpectedEOF) } + + /// Create an extended query client for extended query subprotocol + pub fn extended_query<'a, H>( + &'a mut self, + handler: &'a mut H, + ) -> ExtendedQueryClient<'a, Self, H> + where + H: ExtendedQueryHandler, + { + ExtendedQueryClient::new(self, handler) + } } impl Stream for PgWireClient {