Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions pgdog/src/backend/protocol/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ impl ExecutionCode {
fn extended(&self) -> bool {
matches!(self, Self::ParseComplete | Self::BindComplete)
}

fn done(&self) -> bool {
matches!(self, Self::ReadyForQuery)
}
}

impl From<char> for ExecutionCode {
Expand Down Expand Up @@ -94,15 +98,15 @@ impl ProtocolState {
///
pub(crate) fn add_ignore(&mut self, code: impl Into<ExecutionCode>) {
let code = code.into();
self.extended = self.extended || code.extended();
self.extended = (self.extended || code.extended()) && !code.done();
self.queue.push_back(ExecutionItem::Ignore(code));
}

/// Add a message to the execution queue. We expect this message
/// to be returned by the server.
pub(crate) fn add(&mut self, code: impl Into<ExecutionCode>) {
let code = code.into();
self.extended = self.extended || code.extended();
self.extended = (self.extended || code.extended()) && !code.done();
self.queue.push_back(ExecutionItem::Code(code))
}

Expand Down Expand Up @@ -163,6 +167,14 @@ impl ProtocolState {
ExecutionCode::ReadyForQuery => {
self.out_of_sync = false;
}
ExecutionCode::Copy => {
if self.extended {
// Remove any RFQ messages from the queue
// in case the client sent Sync during copy mode.
self.queue
.retain(|item| item != &ExecutionItem::Code(ExecutionCode::ReadyForQuery));
}
}
_ => (),
};
let in_queue = self.queue.pop_front().ok_or(Error::ProtocolOutOfSync)?;
Expand Down
131 changes: 130 additions & 1 deletion pgdog/src/backend/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,9 @@ impl Server {

/// Flush all pending messages making sure they are sent to the server immediately.
pub async fn flush(&mut self) -> Result<(), Error> {
trace!("😳 [{}]", self.addr());

if let Err(err) = self.stream().flush().await {
trace!("😳");
self.stats.state(State::Error);
Err(err.into())
} else {
Expand Down Expand Up @@ -2173,6 +2174,134 @@ pub mod test {
assert!(server.in_sync());
}

#[tokio::test]
async fn test_extended_set_back_to_normal_when_done() {
crate::logger();
let mut server = test_server().await;
server
.send(
&vec![
Parse::new_anonymous("SET statement_timeout TO '1s'").into(),
Bind::new_statement("").into(),
Execute::new().into(),
Sync.into(),
]
.into(),
)
.await
.unwrap();

for c in ['1', '2', 'C', 'Z'] {
let msg = server.read().await.unwrap();
assert_eq!(c, msg.code());
}

assert!(server.done());

server
.send(
&vec![
Query::new("COPY public.sharded FROM STDIN").into(),
CopyDone.into(),
]
.into(),
)
.await
.unwrap();

for c in ['G', 'C', 'Z'] {
let msg = server.read().await.unwrap();
assert_eq!(c, msg.code());
}

assert!(server.done());

server
.send(
&vec![
Parse::new_anonymous("SET statement_timeout TO '1s'").into(),
Bind::new_statement("").into(),
Execute::new().into(),
Sync.into(),
]
.into(),
)
.await
.unwrap();

for c in ['1', '2', 'C', 'Z'] {
let msg = server.read().await.unwrap();
assert_eq!(c, msg.code());
}

assert!(server.done());
}

#[tokio::test]
async fn test_copy_protocol_extended() {
crate::logger();

let mut server = test_server().await;
server.execute("BEGIN").await.unwrap();
server
.execute("CREATE TABLE test_copy_protocol_extended (id BIGINT)")
.await
.unwrap();

server
.send(
&vec![
Parse::new_anonymous("COPY test_copy_protocol_extended FROM STDIN").into(),
Bind::new_statement("").into(),
Execute::new().into(),
Sync.into(),
]
.into(),
)
.await
.unwrap();
for c in ['1', '2', 'G'] {
let msg = server.read().await.unwrap();
assert_eq!(msg.code(), c);
}

server
.send_one(&CopyData::new("1\n".as_bytes()).into())
.await
.unwrap();

server.send_one(&CopyDone.into()).await.unwrap();
server.send_one(&Sync.into()).await.unwrap();
server.flush().await.unwrap();

for c in ['C', 'Z'] {
let msg = server.read().await.unwrap();
assert_eq!(msg.code(), c);
}

assert!(server.prepared_statements().done());

server
.send(
&vec![
Parse::new_anonymous("ROLLBACK").into(),
Bind::new_statement("").into(),
Execute::new().into(),
Sync.into(),
]
.into(),
)
.await
.unwrap();

for c in ['1', '2', 'C', 'Z'] {
let msg = server.read().await.unwrap();
assert_eq!(msg.code(), c);
}

assert!(server.done());
}

#[tokio::test]
async fn test_copy_client_fail() {
let mut server = test_server().await;
Expand Down
Loading