diff --git a/CONFIG.md b/CONFIG.md index 1a05a776..b36a190d 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -188,6 +188,22 @@ default: "admin_pass" Password to access the virtual administrative database +### dns_cache_enabled +``` +path: general.dns_cache_enabled +default: false +``` +When enabled, ip resolutions for server connections specified using hostnames will be cached +and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found +old ip connections are closed (gracefully) and new connections will start using new ip. + +### dns_max_ttl +``` +path: general.dns_max_ttl +default: 30 +``` +Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`). + ## `pools.` Section ### pool_mode diff --git a/Cargo.lock b/Cargo.lock index 7991667e..7641e9c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,27 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" +[[package]] +name = "async-stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "async-trait" version = "0.1.68" @@ -212,6 +233,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "data-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57" + [[package]] name = "digest" version = "0.10.6" @@ -223,6 +250,18 @@ dependencies = [ "subtle", ] +[[package]] +name = "enum-as-inner" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "env_logger" version = "0.10.0" @@ -275,6 +314,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" +dependencies = [ + "percent-encoding", +] + [[package]] name = "futures" version = "0.3.28" @@ -410,6 +458,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "heck" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" + [[package]] name = "hermit-abi" version = "0.2.6" @@ -434,6 +488,17 @@ dependencies = [ "digest", ] +[[package]] +name = "hostname" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" +dependencies = [ + "libc", + "match_cfg", + "winapi", +] + [[package]] name = "http" version = "0.2.9" @@ -522,6 +587,27 @@ dependencies = [ "cxx-build", ] +[[package]] +name = "idna" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" +dependencies = [ + "matches", + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "idna" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indexmap" version = "1.9.2" @@ -542,6 +628,24 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "ipconfig" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be" +dependencies = [ + "socket2", + "widestring", + "winapi", + "winreg", +] + +[[package]] +name = "ipnet" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745" + [[package]] name = "is-terminal" version = "0.4.4" @@ -589,6 +693,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.139" @@ -604,6 +714,12 @@ dependencies = [ "cc", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linux-raw-sys" version = "0.1.4" @@ -629,6 +745,27 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "lru-cache" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" +dependencies = [ + "linked-hash-map", +] + +[[package]] +name = "match_cfg" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" + +[[package]] +name = "matches" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" + [[package]] name = "md-5" version = "0.10.5" @@ -737,6 +874,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "percent-encoding" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" + [[package]] name = "pgcat" version = "1.0.1" @@ -777,7 +920,9 @@ dependencies = [ "stringprep", "tokio", "tokio-rustls", + "tokio-test", "toml", + "trust-dns-resolver", "webpki-roots", ] @@ -888,6 +1033,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.26" @@ -953,6 +1104,16 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6868896879ba532248f33598de5181522d8b3d9d724dfd230911e1a7d4822f5" +[[package]] +name = "resolv-conf" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" +dependencies = [ + "hostname", + "quick-error", +] + [[package]] name = "ring" version = "0.16.20" @@ -1191,6 +1352,26 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "thiserror" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "time" version = "0.1.45" @@ -1258,6 +1439,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-test" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.7" @@ -1320,9 +1525,21 @@ checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "tracing-core" version = "0.1.30" @@ -1332,6 +1549,51 @@ dependencies = [ "once_cell", ] +[[package]] +name = "trust-dns-proto" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26" +dependencies = [ + "async-trait", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "idna 0.2.3", + "ipnet", + "lazy_static", + "rand", + "smallvec", + "thiserror", + "tinyvec", + "tokio", + "tracing", + "url", +] + +[[package]] +name = "trust-dns-resolver" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe" +dependencies = [ + "cfg-if", + "futures-util", + "ipconfig", + "lazy_static", + "lru-cache", + "parking_lot", + "resolv-conf", + "smallvec", + "thiserror", + "tokio", + "tracing", + "trust-dns-proto", +] + [[package]] name = "try-lock" version = "0.2.4" @@ -1377,6 +1639,17 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "url" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" +dependencies = [ + "form_urlencoded", + "idna 0.3.0", + "percent-encoding", +] + [[package]] name = "version_check" version = "0.9.4" @@ -1478,6 +1751,12 @@ dependencies = [ "rustls-webpki", ] +[[package]] +name = "widestring" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17882f045410753661207383517a6f62ec3dbeb6a4ed2acce01f0728238d1983" + [[package]] name = "winapi" version = "0.3.9" @@ -1583,3 +1862,12 @@ checksum = "faf09497b8f8b5ac5d3bb4d05c0a99be20f26fd3d5f2db7b0716e946d5103658" dependencies = [ "memchr", ] + +[[package]] +name = "winreg" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +dependencies = [ + "winapi", +] diff --git a/Cargo.toml b/Cargo.toml index 28e94a6d..436c3dd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,8 @@ fallible-iterator = "0.2" pin-project = "1" webpki-roots = "0.23" rustls = { version = "0.21", features = ["dangerous_configuration"] } +trust-dns-resolver = "0.22.0" +tokio-test = "0.4.2" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/pgcat.toml b/pgcat.toml index df2ba715..c844ce1f 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -146,6 +146,14 @@ idle_timeout = 40000 # Connect timeout can be overwritten in the pool connect_timeout = 3000 +# When enabled, ip resolutions for server connections specified using hostnames will be cached +# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found +# old ip connections are closed (gracefully) and new connections will start using new ip. +# dns_cache_enabled = false + +# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`). +# dns_max_ttl = 30 + # User configs are structured as pool..users. # This section holds the credentials for users that may connect to this cluster [pools.sharded_db.users.0] diff --git a/src/config.rs b/src/config.rs index 4af7beda..fd7d3912 100644 --- a/src/config.rs +++ b/src/config.rs @@ -12,6 +12,7 @@ use std::sync::Arc; use tokio::fs::File; use tokio::io::AsyncReadExt; +use crate::dns_cache::CachedResolver; use crate::errors::Error; use crate::pool::{ClientServerMap, ConnectionPool}; use crate::sharding::ShardingFunction; @@ -255,6 +256,12 @@ pub struct General { #[serde(default)] // False pub log_client_disconnections: bool, + #[serde(default)] // False + pub dns_cache_enabled: bool, + + #[serde(default = "General::default_dns_max_ttl")] + pub dns_max_ttl: u64, + #[serde(default = "General::default_shutdown_timeout")] pub shutdown_timeout: u64, @@ -336,6 +343,10 @@ impl General { 60000 } + pub fn default_dns_max_ttl() -> u64 { + 30 + } + pub fn default_healthcheck_timeout() -> u64 { 1000 } @@ -378,6 +389,8 @@ impl Default for General { log_client_connections: false, log_client_disconnections: false, autoreload: None, + dns_cache_enabled: false, + dns_max_ttl: Self::default_dns_max_ttl(), tls_certificate: None, tls_private_key: None, server_tls: false, @@ -1119,6 +1132,10 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result (), + Err(err) => error!("DNS cache reinitialization error: {:?}", err), + }; if old_config.pools != new_config.pools { info!("Pool configuration changed"); diff --git a/src/dns_cache.rs b/src/dns_cache.rs new file mode 100644 index 00000000..5c2be5dc --- /dev/null +++ b/src/dns_cache.rs @@ -0,0 +1,410 @@ +use crate::config::get_config; +use crate::errors::Error; +use arc_swap::ArcSwap; +use log::{debug, error, info, warn}; +use once_cell::sync::Lazy; +use std::collections::{HashMap, HashSet}; +use std::io; +use std::net::IpAddr; +use std::sync::Arc; +use std::sync::RwLock; +use tokio::time::{sleep, Duration}; +use trust_dns_resolver::error::{ResolveError, ResolveResult}; +use trust_dns_resolver::lookup_ip::LookupIp; +use trust_dns_resolver::TokioAsyncResolver; + +/// Cached Resolver Globally available +pub static CACHED_RESOLVER: Lazy> = + Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default())); + +// Ip addressed are returned as a set of addresses +// so we can compare. +#[derive(Clone, PartialEq, Debug)] +pub struct AddrSet { + set: HashSet, +} + +impl AddrSet { + fn new() -> AddrSet { + AddrSet { + set: HashSet::new(), + } + } +} + +impl From for AddrSet { + fn from(lookup_ip: LookupIp) -> Self { + let mut addr_set = AddrSet::new(); + for address in lookup_ip.iter() { + addr_set.set.insert(address); + } + addr_set + } +} + +/// +/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time. +/// +/// The system works as follows: +/// +/// When a host is to be resolved, if we have not resolved it before, a new resolution is +/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the +/// cache is refreshed. +/// +/// # Example: +/// +/// ``` +/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver}; +/// +/// # tokio_test::block_on(async { +/// let config = CachedResolverConfig::default(); +/// let resolver = CachedResolver::new(config, None).await.unwrap(); +/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap(); +/// # }) +/// ``` +/// +/// // Now the ip resolution is stored in local cache and subsequent +/// // calls will be returned from cache. Also, the cache is refreshed +/// // and updated every 10 seconds. +/// +/// // You can now check if an 'old' lookup differs from what it's currently +/// // store in cache by using `has_changed`. +/// resolver.has_changed("www.example.com.", addrset) +#[derive(Default)] +pub struct CachedResolver { + // The configuration of the cached_resolver. + config: CachedResolverConfig, + + // This is the hash that contains the hash. + data: Option>>, + + // The resolver to be used for DNS queries. + resolver: Option, + + // The RefreshLoop + refresh_loop: RwLock>>, +} + +/// +/// Configuration +#[derive(Clone, Debug, Default, PartialEq)] +pub struct CachedResolverConfig { + /// Amount of time in secods that a resolved dns address is considered stale. + dns_max_ttl: u64, + + /// Enabled or disabled? (this is so we can reload config) + enabled: bool, +} + +impl CachedResolverConfig { + fn new(dns_max_ttl: u64, enabled: bool) -> Self { + CachedResolverConfig { + dns_max_ttl, + enabled, + } + } +} + +impl From for CachedResolverConfig { + fn from(config: crate::config::Config) -> Self { + CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled) + } +} + +impl CachedResolver { + /// + /// Returns a new Arc based on passed configuration. + /// It also starts the loop that will refresh cache entries. + /// + /// # Arguments: + /// + /// * `config` - The `CachedResolverConfig` to be used to create the resolver. + /// + /// # Example: + /// + /// ``` + /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver}; + /// + /// # tokio_test::block_on(async { + /// let config = CachedResolverConfig::default(); + /// let resolver = CachedResolver::new(config, None).await.unwrap(); + /// # }) + /// ``` + /// + pub async fn new( + config: CachedResolverConfig, + data: Option>, + ) -> Result, io::Error> { + // Construct a new Resolver with default configuration options + let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?); + + let data = if let Some(hash) = data { + Some(RwLock::new(hash)) + } else { + Some(RwLock::new(HashMap::new())) + }; + + let instance = Arc::new(Self { + config, + resolver, + data, + refresh_loop: RwLock::new(None), + }); + + if instance.enabled() { + info!("Scheduling DNS refresh loop"); + let refresh_loop = tokio::task::spawn({ + let instance = instance.clone(); + async move { + instance.refresh_dns_entries_loop().await; + } + }); + *(instance.refresh_loop.write().unwrap()) = Some(refresh_loop); + } + + Ok(instance) + } + + pub fn enabled(&self) -> bool { + self.config.enabled + } + + // Schedules the refresher + async fn refresh_dns_entries_loop(&self) { + let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap(); + let interval = Duration::from_secs(self.config.dns_max_ttl); + loop { + debug!("Begin refreshing cached DNS addresses."); + // To minimize the time we hold the lock, we first create + // an array with keys. + let mut hostnames: Vec = Vec::new(); + { + if let Some(ref data) = self.data { + for hostname in data.read().unwrap().keys() { + hostnames.push(hostname.clone()); + } + } + } + + for hostname in hostnames.iter() { + let addrset = self + .fetch_from_cache(hostname.as_str()) + .expect("Could not obtain expected address from cache, this should not happen"); + + match resolver.lookup_ip(hostname).await { + Ok(lookup_ip) => { + let new_addrset = AddrSet::from(lookup_ip); + debug!( + "Obtained address for host ({}) -> ({:?})", + hostname, new_addrset + ); + + if addrset != new_addrset { + debug!( + "Addr changed from {:?} to {:?} updating cache.", + addrset, new_addrset + ); + self.store_in_cache(hostname, new_addrset); + } + } + Err(err) => { + error!( + "There was an error trying to resolv {}: ({}).", + hostname, err + ); + } + } + } + debug!("Finished refreshing cached DNS addresses."); + sleep(interval).await; + } + } + + /// Returns a `AddrSet` given the specified hostname. + /// + /// This method first tries to fetch the value from the cache, if it misses + /// then it is resolved and stored in the cache. TTL from records is ignored. + /// + /// # Arguments + /// + /// * `host` - A string slice referencing the hostname to be resolved. + /// + /// # Example: + /// + /// ``` + /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver}; + /// + /// # tokio_test::block_on(async { + /// let config = CachedResolverConfig::default(); + /// let resolver = CachedResolver::new(config, None).await.unwrap(); + /// let response = resolver.lookup_ip("www.google.com."); + /// # }) + /// ``` + /// + pub async fn lookup_ip(&self, host: &str) -> ResolveResult { + debug!("Lookup up {} in cache", host); + match self.fetch_from_cache(host) { + Some(addr_set) => { + debug!("Cache hit!"); + Ok(addr_set) + } + None => { + debug!("Not found, executing a dns query!"); + if let Some(ref resolver) = self.resolver { + let addr_set = AddrSet::from(resolver.lookup_ip(host).await?); + debug!("Obtained: {:?}", addr_set); + self.store_in_cache(host, addr_set.clone()); + Ok(addr_set) + } else { + Err(ResolveError::from("No resolver available")) + } + } + } + } + + // + // Returns true if the stored host resolution differs from the AddrSet passed. + pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool { + if let Some(fetched_addr_set) = self.fetch_from_cache(host) { + return fetched_addr_set != *addr_set; + } + false + } + + // Fetches an AddrSet from the inner cache adquiring the read lock. + fn fetch_from_cache(&self, key: &str) -> Option { + if let Some(ref hash) = self.data { + if let Some(addr_set) = hash.read().unwrap().get(key) { + return Some(addr_set.clone()); + } + } + None + } + + // Sets up the global CACHED_RESOLVER static variable so we can globally use DNS + // cache. + pub async fn from_config() -> Result<(), Error> { + let cached_resolver = CACHED_RESOLVER.load(); + let desired_config = CachedResolverConfig::from(get_config()); + + if cached_resolver.config != desired_config { + if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) { + warn!("Killing Dnscache refresh loop as its configuration is being reloaded"); + refresh_loop.abort() + } + let new_resolver = if let Some(ref data) = cached_resolver.data { + let data = Some(data.read().unwrap().clone()); + CachedResolver::new(desired_config, data).await + } else { + CachedResolver::new(desired_config, None).await + }; + + match new_resolver { + Ok(ok) => { + CACHED_RESOLVER.store(ok); + Ok(()) + } + Err(err) => { + let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err); + Err(Error::DNSCachedError(message)) + } + } + } else { + Ok(()) + } + } + + // Stores the AddrSet in cache adquiring the write lock. + fn store_in_cache(&self, host: &str, addr_set: AddrSet) { + if let Some(ref data) = self.data { + data.write().unwrap().insert(host.to_string(), addr_set); + } else { + error!("Could not insert, Hash not initialized"); + } + } +} +#[cfg(test)] +mod tests { + use super::*; + use trust_dns_resolver::error::ResolveError; + + #[tokio::test] + async fn new() { + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; + let resolver = CachedResolver::new(config, None).await; + assert!(resolver.is_ok()); + } + + #[tokio::test] + async fn lookup_ip() { + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; + let resolver = CachedResolver::new(config, None).await.unwrap(); + let response = resolver.lookup_ip("www.google.com.").await; + assert!(response.is_ok()); + } + + #[tokio::test] + async fn has_changed() { + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; + let resolver = CachedResolver::new(config, None).await.unwrap(); + let hostname = "www.google.com."; + let response = resolver.lookup_ip(hostname).await; + let addr_set = response.unwrap(); + assert!(!resolver.has_changed(hostname, &addr_set)); + } + + #[tokio::test] + async fn unknown_host() { + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; + let resolver = CachedResolver::new(config, None).await.unwrap(); + let hostname = "www.idontexists."; + let response = resolver.lookup_ip(hostname).await; + assert!(matches!(response, Err(ResolveError { .. }))); + } + + #[tokio::test] + async fn incorrect_address() { + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; + let resolver = CachedResolver::new(config, None).await.unwrap(); + let hostname = "w ww.idontexists."; + let response = resolver.lookup_ip(hostname).await; + assert!(matches!(response, Err(ResolveError { .. }))); + assert!(!resolver.has_changed(hostname, &AddrSet::new())); + } + + #[tokio::test] + // Ok, this test is based on the fact that google does DNS RR + // and does not responds with every available ip everytime, so + // if I cache here, it will miss after one cache iteration or two. + async fn thread() { + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; + let resolver = CachedResolver::new(config, None).await.unwrap(); + let hostname = "www.google.com."; + let response = resolver.lookup_ip(hostname).await; + let addr_set = response.unwrap(); + assert!(!resolver.has_changed(hostname, &addr_set)); + let resolver_for_refresher = resolver.clone(); + let _thread_handle = tokio::task::spawn(async move { + resolver_for_refresher.refresh_dns_entries_loop().await; + }); + assert!(!resolver.has_changed(hostname, &addr_set)); + } +} diff --git a/src/errors.rs b/src/errors.rs index 0930ab8b..fb70c042 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -19,6 +19,7 @@ pub enum Error { ClientError(String), TlsError, StatementTimeout, + DNSCachedError(String), ShuttingDown, ParseBytesError(String), AuthError(String), diff --git a/src/lib.rs b/src/lib.rs index 2645cd42..3a58bb38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod auth_passthrough; pub mod config; pub mod constants; +pub mod dns_cache; pub mod errors; pub mod messages; pub mod mirrors; diff --git a/src/main.rs b/src/main.rs index b3265ed8..dc48dd58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,7 @@ extern crate sqlparser; extern crate tokio; extern crate tokio_rustls; extern crate toml; +extern crate trust_dns_resolver; #[cfg(not(target_env = "msvc"))] use jemallocator::Jemalloc; @@ -65,6 +66,7 @@ mod auth_passthrough; mod client; mod config; mod constants; +mod dns_cache; mod errors; mod messages; mod mirrors; @@ -166,8 +168,14 @@ fn main() -> Result<(), Box> { // Statistics reporting. REPORTER.store(Arc::new(Reporter::default())); - // Connection pool that allows to query all shards and replicas. - match ConnectionPool::from_config(client_server_map.clone()).await { + // Starts (if enabled) dns cache before pools initialization + match dns_cache::CachedResolver::from_config().await { + Ok(_) => (), + Err(err) => error!("DNS cache initialization error: {:?}", err), + }; + + // Connection pool that allows to query all shards and replicas. + match ConnectionPool::from_config(client_server_map.clone()).await { Ok(_) => (), Err(err) => { error!("Pool error: {:?}", err); diff --git a/src/server.rs b/src/server.rs index 5bcd5fb9..ff5ab20b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,6 +7,7 @@ use parking_lot::{Mutex, RwLock}; use postgres_protocol::message; use std::collections::HashMap; use std::io::Read; +use std::net::IpAddr; use std::sync::Arc; use std::time::SystemTime; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; @@ -16,6 +17,7 @@ use tokio_rustls::{client::TlsStream, TlsConnector}; use crate::config::{get_config, Address, User}; use crate::constants::*; +use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::errors::{Error, ServerIdentifier}; use crate::messages::*; use crate::mirrors::MirroringManager; @@ -148,6 +150,9 @@ pub struct Server { last_activity: SystemTime, mirror_manager: Option, + + // Associated addresses used + addr_set: Option, } impl Server { @@ -161,6 +166,24 @@ impl Server { stats: Arc, auth_hash: Arc>>, ) -> Result { + let cached_resolver = CACHED_RESOLVER.load(); + let mut addr_set: Option = None; + + // If we are caching addresses and hostname is not an IP + if cached_resolver.enabled() && address.host.parse::().is_err() { + debug!("Resolving {}", &address.host); + addr_set = match cached_resolver.lookup_ip(&address.host).await { + Ok(ok) => { + debug!("Obtained: {:?}", ok); + Some(ok) + } + Err(err) => { + warn!("Error trying to resolve {}, ({:?})", &address.host, err); + None + } + } + }; + let mut stream = match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await { Ok(stream) => stream, @@ -609,6 +632,7 @@ impl Server { bad: false, needs_cleanup: false, client_server_map, + addr_set, connected_at: chrono::offset::Utc::now().naive_utc(), stats, application_name: String::new(), @@ -849,7 +873,23 @@ impl Server { /// Server & client are out of sync, we must discard this connection. /// This happens with clients that misbehave. pub fn is_bad(&self) -> bool { - self.bad + if self.bad { + return self.bad; + }; + let cached_resolver = CACHED_RESOLVER.load(); + if cached_resolver.enabled() { + if let Some(addr_set) = &self.addr_set { + if cached_resolver.has_changed(self.address.host.as_str(), addr_set) { + warn!( + "DNS changed for {}, it was {:?}. Dropping server connection.", + self.address.host.as_str(), + addr_set + ); + return true; + } + } + } + false } /// Get server startup information to forward it to the client.