diff --git a/pgdog/src/backend/protocol/state.rs b/pgdog/src/backend/protocol/state.rs index ca86915d2..73aaf6ecf 100644 --- a/pgdog/src/backend/protocol/state.rs +++ b/pgdog/src/backend/protocol/state.rs @@ -36,6 +36,10 @@ impl ExecutionCode { fn extended(&self) -> bool { matches!(self, Self::ParseComplete | Self::BindComplete) } + + fn done(&self) -> bool { + matches!(self, Self::ReadyForQuery) + } } impl From for ExecutionCode { @@ -94,7 +98,7 @@ impl ProtocolState { /// pub(crate) fn add_ignore(&mut self, code: impl Into) { let code = code.into(); - self.extended = self.extended || code.extended(); + self.extended = (self.extended || code.extended()) && !code.done(); self.queue.push_back(ExecutionItem::Ignore(code)); } @@ -102,7 +106,7 @@ impl ProtocolState { /// to be returned by the server. pub(crate) fn add(&mut self, code: impl Into) { let code = code.into(); - self.extended = self.extended || code.extended(); + self.extended = (self.extended || code.extended()) && !code.done(); self.queue.push_back(ExecutionItem::Code(code)) } @@ -163,6 +167,14 @@ impl ProtocolState { ExecutionCode::ReadyForQuery => { self.out_of_sync = false; } + ExecutionCode::Copy => { + if self.extended { + // Remove any RFQ messages from the queue + // in case the client sent Sync during copy mode. + self.queue + .retain(|item| item != &ExecutionItem::Code(ExecutionCode::ReadyForQuery)); + } + } _ => (), }; let in_queue = self.queue.pop_front().ok_or(Error::ProtocolOutOfSync)?; diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index faa708b9d..3f74ac85d 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -408,8 +408,9 @@ impl Server { /// Flush all pending messages making sure they are sent to the server immediately. pub async fn flush(&mut self) -> Result<(), Error> { + trace!("😳 [{}]", self.addr()); + if let Err(err) = self.stream().flush().await { - trace!("😳"); self.stats.state(State::Error); Err(err.into()) } else { @@ -2173,6 +2174,134 @@ pub mod test { assert!(server.in_sync()); } + #[tokio::test] + async fn test_extended_set_back_to_normal_when_done() { + crate::logger(); + let mut server = test_server().await; + server + .send( + &vec![ + Parse::new_anonymous("SET statement_timeout TO '1s'").into(), + Bind::new_statement("").into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', '2', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(c, msg.code()); + } + + assert!(server.done()); + + server + .send( + &vec![ + Query::new("COPY public.sharded FROM STDIN").into(), + CopyDone.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['G', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(c, msg.code()); + } + + assert!(server.done()); + + server + .send( + &vec![ + Parse::new_anonymous("SET statement_timeout TO '1s'").into(), + Bind::new_statement("").into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', '2', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(c, msg.code()); + } + + assert!(server.done()); + } + + #[tokio::test] + async fn test_copy_protocol_extended() { + crate::logger(); + + let mut server = test_server().await; + server.execute("BEGIN").await.unwrap(); + server + .execute("CREATE TABLE test_copy_protocol_extended (id BIGINT)") + .await + .unwrap(); + + server + .send( + &vec![ + Parse::new_anonymous("COPY test_copy_protocol_extended FROM STDIN").into(), + Bind::new_statement("").into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + for c in ['1', '2', 'G'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + server + .send_one(&CopyData::new("1\n".as_bytes()).into()) + .await + .unwrap(); + + server.send_one(&CopyDone.into()).await.unwrap(); + server.send_one(&Sync.into()).await.unwrap(); + server.flush().await.unwrap(); + + for c in ['C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.prepared_statements().done()); + + server + .send( + &vec![ + Parse::new_anonymous("ROLLBACK").into(), + Bind::new_statement("").into(), + Execute::new().into(), + Sync.into(), + ] + .into(), + ) + .await + .unwrap(); + + for c in ['1', '2', 'C', 'Z'] { + let msg = server.read().await.unwrap(); + assert_eq!(msg.code(), c); + } + + assert!(server.done()); + } + #[tokio::test] async fn test_copy_client_fail() { let mut server = test_server().await;