Skip to content

Commit 4ac4832

Browse files
committed
Update client cache name if pool renames prepared statement
1 parent db70499 commit 4ac4832

File tree

2 files changed

+42
-15
lines changed

2 files changed

+42
-15
lines changed

src/client.rs

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,7 +1316,7 @@ where
13161316
{
13171317
match protocol_data {
13181318
ExtendedProtocolData::Parse { data, metadata } => {
1319-
let (parse, hash) = match metadata {
1319+
let (client_given_name, parse, hash) = match metadata {
13201320
Some(metadata) => metadata,
13211321
None => {
13221322
let first_char_in_name = *data.get(5).unwrap_or(&0);
@@ -1350,7 +1350,13 @@ where
13501350
// TODO: Consider adding the close logic that this function can send for eviction to the client buffer instead
13511351
// In this case we don't want to send the parse message to the server since the client is sending it
13521352
self.register_parse_to_server_cache(
1353-
false, hash, &parse, &pool, server, &address,
1353+
false,
1354+
client_given_name,
1355+
hash,
1356+
parse,
1357+
&pool,
1358+
server,
1359+
&address,
13541360
)
13551361
.await?;
13561362

@@ -1651,24 +1657,32 @@ where
16511657
/// Makes sure the the checked out server has the prepared statement and sends it to the server if it doesn't
16521658
async fn ensure_prepared_statement_is_on_server(
16531659
&mut self,
1654-
client_name: String,
1660+
client_given_name: String,
16551661
pool: &ConnectionPool,
16561662
server: &mut Server,
16571663
address: &Address,
16581664
) -> Result<(), Error> {
1659-
match self.prepared_statements.get(&client_name) {
1665+
match self.prepared_statements.get(&client_given_name) {
16601666
Some((parse, hash)) => {
16611667
debug!("Prepared statement `{}` found in cache", parse.name);
16621668
// In this case we want to send the parse message to the server
16631669
// since pgcat is initiating the prepared statement on this specific server
1664-
self.register_parse_to_server_cache(true, *hash, parse, pool, server, address)
1665-
.await?;
1670+
self.register_parse_to_server_cache(
1671+
true,
1672+
client_given_name,
1673+
*hash,
1674+
parse.clone(),
1675+
pool,
1676+
server,
1677+
address,
1678+
)
1679+
.await?;
16661680
}
16671681

16681682
None => {
16691683
return Err(Error::ClientError(format!(
16701684
"prepared statement `{}` not found",
1671-
client_name
1685+
client_given_name
16721686
)))
16731687
}
16741688
};
@@ -1679,21 +1693,34 @@ where
16791693
/// Register the parse to the server cache and send it to the server if requested (ie. requested by pgcat)
16801694
///
16811695
/// Also updates the pool LRU that this parse was used recently
1696+
#[allow(clippy::too_many_arguments)]
16821697
async fn register_parse_to_server_cache(
1683-
&self,
1698+
&mut self,
16841699
should_send_parse_to_server: bool,
1700+
client_given_name: String,
16851701
hash: u64,
1686-
parse: &Arc<Parse>,
1702+
mut parse: Arc<Parse>,
16871703
pool: &ConnectionPool,
16881704
server: &mut Server,
16891705
address: &Address,
16901706
) -> Result<(), Error> {
16911707
// We want to update this in the LRU to know this was recently used and add it if it isn't there already
16921708
// This could be the case if it was evicted or if doesn't exist (ie. we reloaded and it got removed)
1693-
pool.register_parse_to_cache(hash, parse);
1709+
if let Some(new_parse) = pool.register_parse_to_cache(hash, &parse) {
1710+
// If the pool has renamed this parse, we need to update the client cache with the new name
1711+
if parse.name != new_parse.name {
1712+
warn!(
1713+
"Pool renamed prepared statement prepared statement `{}` to `{}` saving new name to client cache",
1714+
parse.name, new_parse.name
1715+
);
1716+
}
1717+
self.prepared_statements
1718+
.insert(client_given_name.clone(), (new_parse.clone(), hash));
1719+
parse = new_parse;
1720+
};
16941721

16951722
if let Err(err) = server
1696-
.register_prepared_statement(parse, should_send_parse_to_server)
1723+
.register_prepared_statement(&parse, should_send_parse_to_server)
16971724
.await
16981725
{
16991726
pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats));
@@ -1741,12 +1768,12 @@ where
17411768
);
17421769

17431770
self.prepared_statements
1744-
.insert(client_given_name, (new_parse.clone(), hash));
1771+
.insert(client_given_name.clone(), (new_parse.clone(), hash));
17451772

17461773
self.extended_protocol_data_buffer
17471774
.push_back(ExtendedProtocolData::create_new_parse(
17481775
new_parse.as_ref().try_into()?,
1749-
Some((new_parse.clone(), hash)),
1776+
Some((client_given_name, new_parse.clone(), hash)),
17501777
));
17511778

17521779
Ok(())

src/messages.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ impl BytesMutReader for BytesMut {
770770
pub enum ExtendedProtocolData {
771771
Parse {
772772
data: BytesMut,
773-
metadata: Option<(Arc<Parse>, u64)>,
773+
metadata: Option<(String, Arc<Parse>, u64)>,
774774
},
775775
Bind {
776776
data: BytesMut,
@@ -786,7 +786,7 @@ pub enum ExtendedProtocolData {
786786
}
787787

788788
impl ExtendedProtocolData {
789-
pub fn create_new_parse(data: BytesMut, metadata: Option<(Arc<Parse>, u64)>) -> Self {
789+
pub fn create_new_parse(data: BytesMut, metadata: Option<(String, Arc<Parse>, u64)>) -> Self {
790790
Self::Parse { data, metadata }
791791
}
792792

0 commit comments

Comments
 (0)