diff --git a/postgres/src/client.rs b/postgres/src/client.rs index f5637cdbd..050c5b229 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -4,6 +4,7 @@ use crate::{ ToStatement, Transaction, TransactionBuilder, }; use std::task::Poll; +use std::time::Duration; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::types::{BorrowToSql, ToSql, Type}; use tokio_postgres::{Error, Row, SimpleQueryMessage, Socket}; @@ -413,6 +414,18 @@ impl Client { self.connection.block_on(self.client.simple_query(query)) } + /// Validates connection, timing out after specified duration. + pub fn is_valid(&mut self, timeout: Duration) -> Result<(), Error> { + let inner_client = &self.client; + self.connection.block_on(async { + let trivial_query = inner_client.simple_query(""); + tokio::time::timeout(timeout, trivial_query) + .await + .map_err(|_| Error::timeout())? + .map(|_| ()) + }) + } + /// Executes a sequence of SQL statements using the simple query protocol. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index f9335cfe7..c5383df92 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -354,6 +354,7 @@ enum Kind { RowCount, #[cfg(feature = "runtime")] Connect, + Timeout, } struct ErrorInner { @@ -392,6 +393,7 @@ impl fmt::Display for Error { Kind::RowCount => fmt.write_str("query returned an unexpected number of rows")?, #[cfg(feature = "runtime")] Kind::Connect => fmt.write_str("error connecting to server")?, + Kind::Timeout => fmt.write_str("timeout waiting for server")?, }; if let Some(ref cause) = self.0.cause { write!(fmt, ": {}", cause)?; @@ -491,4 +493,9 @@ impl Error { pub(crate) fn connect(e: io::Error) -> Error { Error::new(Kind::Connect, Some(Box::new(e))) } + + #[doc(hidden)] + pub fn timeout() -> Error { + Error::new(Kind::Timeout, None) + } }