Skip to content

Configurable routing for queries without routing comment #567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
# Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
# See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
docker:
- image: ghcr.io/levkk/pgcat-ci:1.67
- image: ghcr.io/levkk/pgcat-ci:latest
environment:
RUST_LOG: info
LLVM_PROFILE_FILE: /tmp/pgcat-%m-%p.profraw
Expand Down
2 changes: 2 additions & 0 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
FROM cimg/rust:1.67.1
COPY --from=sclevine/yj /bin/yj /bin/yj
RUN /bin/yj -h
RUN sudo apt-get update && \
sudo apt-get install -y \
psmisc postgresql-contrib-14 postgresql-client-14 libpq-dev \
Expand Down
11 changes: 8 additions & 3 deletions pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,17 @@ query_parser_read_write_splitting = true
# queries. The primary can always be explicitly selected with our custom protocol.
primary_reads_enabled = true

# Allow sharding commands to be passed as statement comments instead of
# separate commands. If these are unset this functionality is disabled.
# sharding_key_regex = '/\* sharding_key: (\d+) \*/'
# shard_id_regex = '/\* shard_id: (\d+) \*/'
# shard_id_regex = '/\*shard_id:(\d+)\*/'
# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements

# Defines the behavior when no shard_id or sharding_key are specified for a query against
# a sharded system with either sharding_key_regex or shard_id_regex specified.
# `random`: picks a shard at random
# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors
# `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
# no_shard_specified_behavior = "random"

# So what if you wanted to implement a different hashing function,
# or you've already built one and you want this pooler to use it?
# Current options:
Expand Down
65 changes: 65 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ use arc_swap::ArcSwap;
use log::{error, info};
use once_cell::sync::Lazy;
use regex::Regex;
use serde::{Deserializer, Serializer};
use serde_derive::{Deserialize, Serialize};

use std::collections::hash_map::DefaultHasher;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
Expand Down Expand Up @@ -101,6 +104,9 @@ pub struct Address {

/// Address stats
pub stats: Arc<AddressStats>,

/// Number of errors encountered
pub error_count: Arc<AtomicU64>,
}

impl Default for Address {
Expand All @@ -118,6 +124,7 @@ impl Default for Address {
pool_name: String::from("pool_name"),
mirrors: Vec::new(),
stats: Arc::new(AddressStats::default()),
error_count: Arc::new(AtomicU64::new(0)),
}
}
}
Expand Down Expand Up @@ -182,6 +189,18 @@ impl Address {
),
}
}

pub fn error_count(&self) -> u64 {
self.error_count.load(Ordering::Relaxed)
}

pub fn increment_error_count(&self) {
self.error_count.fetch_add(1, Ordering::Relaxed);
}

pub fn reset_error_count(&self) {
self.error_count.store(0, Ordering::Relaxed);
}
}

/// PostgreSQL user.
Expand Down Expand Up @@ -539,6 +558,7 @@ pub struct Pool {
pub sharding_key_regex: Option<String>,
pub shard_id_regex: Option<String>,
pub regex_search_limit: Option<usize>,
pub no_shard_specified_behavior: Option<NoShardSpecifiedHandling>,

pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
Expand Down Expand Up @@ -693,6 +713,7 @@ impl Default for Pool {
sharding_key_regex: None,
shard_id_regex: None,
regex_search_limit: Some(1000),
no_shard_specified_behavior: None,
auth_query: None,
auth_query_user: None,
auth_query_password: None,
Expand All @@ -711,6 +732,50 @@ pub struct ServerConfig {
pub role: Role,
}

// No Shard Specified handling.
#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy)]
pub enum NoShardSpecifiedHandling {
Shard(usize),
Random,
RandomHealthy,
}
impl Default for NoShardSpecifiedHandling {
fn default() -> Self {
NoShardSpecifiedHandling::Shard(0)
}
}
impl serde::Serialize for NoShardSpecifiedHandling {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
NoShardSpecifiedHandling::Shard(shard) => {
serializer.serialize_str(&format!("shard_{}", &shard.to_string()))
}
NoShardSpecifiedHandling::Random => serializer.serialize_str("random"),
NoShardSpecifiedHandling::RandomHealthy => serializer.serialize_str("random_healthy"),
}
}
}
impl<'de> serde::Deserialize<'de> for NoShardSpecifiedHandling {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if s.starts_with("shard_") {
let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?;
return Ok(NoShardSpecifiedHandling::Shard(shard));
}

match s.as_str() {
"random" => Ok(NoShardSpecifiedHandling::Random),
"random_healthy" => Ok(NoShardSpecifiedHandling::RandomHealthy),
_ => Err(serde::de::Error::custom(
"invalid value for no_shard_specified_behavior",
)),
}
}
}

#[derive(Clone, PartialEq, Serialize, Deserialize, Debug, Hash, Eq)]
pub struct MirrorServerConfig {
pub host: String,
Expand Down
56 changes: 54 additions & 2 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use rand::thread_rng;
use regex::Regex;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::sync::atomic::AtomicU64;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
Expand All @@ -18,7 +19,8 @@ use std::time::Instant;
use tokio::sync::Notify;

use crate::config::{
get_config, Address, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
get_config, Address, General, LoadBalancingMode, NoShardSpecifiedHandling, Plugins, PoolMode,
Role, User,
};
use crate::errors::Error;

Expand Down Expand Up @@ -140,6 +142,9 @@ pub struct PoolSettings {
// Regex for searching for the shard id in SQL statements
pub shard_id_regex: Option<Regex>,

// What to do when no shard is specified in the SQL statement
pub no_shard_specified_behavior: Option<NoShardSpecifiedHandling>,

// Limit how much of each query is searched for a potential shard regex match
pub regex_search_limit: usize,

Expand Down Expand Up @@ -173,6 +178,7 @@ impl Default for PoolSettings {
sharding_key_regex: None,
shard_id_regex: None,
regex_search_limit: 1000,
no_shard_specified_behavior: None,
auth_query: None,
auth_query_user: None,
auth_query_password: None,
Expand Down Expand Up @@ -299,6 +305,7 @@ impl ConnectionPool {
pool_name: pool_name.clone(),
mirrors: vec![],
stats: Arc::new(AddressStats::default()),
error_count: Arc::new(AtomicU64::new(0)),
});
address_id += 1;
}
Expand All @@ -317,6 +324,7 @@ impl ConnectionPool {
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
stats: Arc::new(AddressStats::default()),
error_count: Arc::new(AtomicU64::new(0)),
};

address_id += 1;
Expand Down Expand Up @@ -482,6 +490,9 @@ impl ConnectionPool {
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
no_shard_specified_behavior: pool_config
.no_shard_specified_behavior
.clone(),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
Expand Down Expand Up @@ -651,7 +662,10 @@ impl ConnectionPool {
.get()
.await
{
Ok(conn) => conn,
Ok(conn) => {
address.reset_error_count();
conn
}
Err(err) => {
error!(
"Connection checkout error for instance {:?}, error: {:?}",
Expand Down Expand Up @@ -766,6 +780,18 @@ impl ConnectionPool {
/// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, reason: BanReason, client_info: Option<&ClientStats>) {
// Count the number of errors since the last successful checkout
// This is used to determine if the shard is down
match reason {
BanReason::FailedHealthCheck
| BanReason::FailedCheckout
| BanReason::MessageSendFailed
| BanReason::MessageReceiveFailed => {
address.increment_error_count();
}
_ => (),
};

// Primary can never be banned
if address.role == Role::Primary {
return;
Expand Down Expand Up @@ -920,6 +946,32 @@ impl ConnectionPool {
self.original_server_parameters.read().clone()
}

pub fn get_random_healthy_shard_id(&self) -> usize {
let mut shards = Vec::new();
for shard in 0..self.shards() {
shards.push(shard);
}

// Shuffle to avoid always picking the same shard when error counts are equal
shards.shuffle(&mut thread_rng());
shards.sort_by(|a, b| {
let err_count_a = self.addresses[*a]
.iter()
.fold(0, |acc, address| acc + address.error_count());

let err_count_b = self.addresses[*b]
.iter()
.fold(0, |acc, address| acc + address.error_count());

err_count_a.partial_cmp(&err_count_b).unwrap()
});

match shards.first() {
Some(shard) => *shard,
None => 0,
}
}

fn busy_connection_count(&self, address: &Address) -> u32 {
let state = self.pool_state(address.shard, address.address_index);
let idle = state.idle_connections;
Expand Down
Loading