Skip to content

delegate statement preparation to sqlx #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- 18 new icons available (see https://github.com/tabler/tabler-icons/releases/tag/v2.40.0)
- Support multiple statements in [`on_connect.sql`](./configuration.md) in MySQL.
- Randomize postgres prepared statement names to avoid name collisions. This should fix a bug where SQLPage would report errors like `prepared statement "sqlx_s_3" already exists` when using a connection pooler in front of a PostgreSQL database.
- Delegate statement preparation to sqlx. The logic of preparing statements and caching them for later reuse is now entirely delegated to the sql driver library (sqlx). This simplifies the code and logic inside sqlpage itself. More importantly, statements are now prepared in a streaming fashion when a file is first loaded, instead of all at once, which allows referencing a temporary table created at the start of a file in a later statement in the same file.

## 0.15.2 (2023-11-12)

Expand Down
4 changes: 4 additions & 0 deletions mssql/setup.sql
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ GO

CREATE LOGIN root WITH PASSWORD = 'Password123!';
CREATE USER root FOR LOGIN root;
GO

GRANT CREATE TABLE TO root;
GRANT ALTER, DELETE, INSERT, SELECT, UPDATE ON SCHEMA::dbo TO root;
GO
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl AppState {
let file_system = FileSystem::init(&config.web_root, &db).await;
sql_file_cache.add_static(
PathBuf::from("index.sql"),
ParsedSqlFile::new(&db, include_str!("../index.sql")).await,
ParsedSqlFile::new(&db, include_str!("../index.sql")),
);
Ok(AppState {
db,
Expand Down
63 changes: 42 additions & 21 deletions src/webserver/database/execute_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@ use serde_json::Value;
use std::borrow::Cow;
use std::collections::HashMap;

use super::sql::{ParsedSQLStatement, ParsedSqlFile};
use super::sql::{ParsedSqlFile, ParsedStatement, StmtWithParams};
use crate::webserver::database::sql_pseudofunctions::extract_req_param;
use crate::webserver::http::{RequestInfo, SingleOrVec};

use sqlx::any::{AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo};
use sqlx::pool::PoolConnection;
use sqlx::query::Query;
use sqlx::{AnyConnection, Arguments, Either, Executor, Row, Statement};
use sqlx::{Any, AnyConnection, Arguments, Either, Executor, Row, Statement};

use super::sql_pseudofunctions::StmtParam;
use super::sql_to_json::sql_to_json;
use super::{highlight_sql_error, Database, DbItem, PreparedStatement};
use super::{highlight_sql_error, Database, DbItem};

impl Database {
pub(crate) async fn prepare_with(
Expand All @@ -41,33 +40,33 @@ pub fn stream_query_results<'a>(
let mut connection_opt = None;
for res in &sql_file.statements {
match res {
ParsedSQLStatement::Statement(stmt) => {
ParsedStatement::StmtWithParams(stmt) => {
let query = bind_parameters(stmt, request).await?;
let connection = take_connection(db, &mut connection_opt).await?;
let mut stream = query.fetch_many(connection);
let mut stream = connection.fetch_many(query);
while let Some(elem) = stream.next().await {
let is_err = elem.is_err();
yield parse_single_sql_result(stmt, elem);
yield parse_single_sql_result(&stmt.query, elem);
if is_err {
break;
}
}
},
ParsedSQLStatement::SetVariable { variable, value} => {
ParsedStatement::SetVariable { variable, value} => {
let query = bind_parameters(value, request).await?;
let connection = take_connection(db, &mut connection_opt).await?;
let row = query.fetch_optional(connection).await?;
let row = connection.fetch_optional(query).await?;
let (vars, name) = vars_and_name(request, variable)?;
if let Some(row) = row {
vars.insert(name.clone(), row_to_varvalue(&row));
} else {
vars.remove(&name);
}
},
ParsedSQLStatement::StaticSimpleSelect(value) => {
ParsedStatement::StaticSimpleSelect(value) => {
yield DbItem::Row(value.clone().into())
}
ParsedSQLStatement::Error(e) => yield DbItem::Error(clone_anyhow_err(e)),
ParsedStatement::Error(e) => yield DbItem::Error(clone_anyhow_err(e)),
}
}
}
Expand Down Expand Up @@ -132,10 +131,7 @@ async fn take_connection<'a, 'b>(
}

#[inline]
fn parse_single_sql_result(
stmt: &PreparedStatement,
res: sqlx::Result<Either<AnyQueryResult, AnyRow>>,
) -> DbItem {
fn parse_single_sql_result(sql: &str, res: sqlx::Result<Either<AnyQueryResult, AnyRow>>) -> DbItem {
match res {
Ok(Either::Right(r)) => DbItem::Row(super::sql_to_json::row_to_json(&r)),
Ok(Either::Left(res)) => {
Expand All @@ -144,7 +140,7 @@ fn parse_single_sql_result(
}
Err(err) => DbItem::Error(highlight_sql_error(
"Failed to execute SQL statement",
stmt.statement.sql(),
sql,
err,
)),
}
Expand All @@ -159,18 +155,43 @@ fn clone_anyhow_err(err: &anyhow::Error) -> anyhow::Error {
}

async fn bind_parameters<'a>(
stmt: &'a PreparedStatement,
stmt: &'a StmtWithParams,
request: &'a RequestInfo,
) -> anyhow::Result<Query<'a, sqlx::Any, AnyArguments<'a>>> {
) -> anyhow::Result<StatementWithParams<'a>> {
let sql = stmt.query.as_str();
let mut arguments = AnyArguments::default();
for param in &stmt.parameters {
for param in &stmt.params {
let argument = extract_req_param(param, request).await?;
log::debug!("Binding value {:?} in statement {}", &argument, stmt);
log::debug!("Binding value {:?} in statement {}", &argument, stmt.query);
match argument {
None => arguments.add(None::<String>),
Some(Cow::Owned(s)) => arguments.add(s),
Some(Cow::Borrowed(v)) => arguments.add(v),
}
}
Ok(stmt.statement.query_with(arguments))
Ok(StatementWithParams { sql, arguments })
}

pub struct StatementWithParams<'a> {
sql: &'a str,
arguments: AnyArguments<'a>,
}

impl<'q> sqlx::Execute<'q, Any> for StatementWithParams<'q> {
fn sql(&self) -> &'q str {
self.sql
}

fn statement(&self) -> Option<&<Any as sqlx::database::HasStatement<'q>>::Statement> {
None
}

fn take_arguments(&mut self) -> Option<<Any as sqlx::database::HasArguments<'q>>::Arguments> {
Some(std::mem::take(&mut self.arguments))
}

fn persistent(&self) -> bool {
// Let sqlx create a prepared statement the first time it is executed, and then reuse it.
true
}
}
12 changes: 0 additions & 12 deletions src/webserver/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@ pub enum DbItem {
Error(anyhow::Error),
}

struct PreparedStatement {
statement: sqlx::any::AnyStatement<'static>,
parameters: Vec<sql_pseudofunctions::StmtParam>,
}

impl std::fmt::Display for PreparedStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use sqlx::Statement;
write!(f, "{}", self.statement.sql())
}
}

#[must_use]
pub fn highlight_sql_error(
context: &str,
Expand Down
80 changes: 11 additions & 69 deletions src/webserver/database/sql.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::sql_pseudofunctions::{func_call_to_param, StmtParam};
use super::PreparedStatement;
use crate::file_cache::AsyncFromStrWithState;
use crate::utils::add_value_to_map;
use crate::{AppState, Database};
Expand All @@ -13,101 +12,51 @@ use sqlparser::dialect::{Dialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect,
use sqlparser::parser::{Parser, ParserError};
use sqlparser::tokenizer::Token::{SemiColon, EOF};
use sqlparser::tokenizer::Tokenizer;
use sqlx::any::{AnyKind, AnyTypeInfo};
use sqlx::Postgres;
use sqlx::any::AnyKind;
use std::fmt::Write;
use std::ops::ControlFlow;

#[derive(Default)]
pub struct ParsedSqlFile {
pub(super) statements: Vec<ParsedSQLStatement>,
}

pub(super) enum ParsedSQLStatement {
Statement(PreparedStatement),
StaticSimpleSelect(serde_json::Map<String, serde_json::Value>),
Error(anyhow::Error),
SetVariable {
variable: StmtParam,
value: PreparedStatement,
},
pub(super) statements: Vec<ParsedStatement>,
}

impl ParsedSqlFile {
pub async fn new(db: &Database, sql: &str) -> ParsedSqlFile {
#[must_use]
pub fn new(db: &Database, sql: &str) -> ParsedSqlFile {
let dialect = dialect_for_db(db.connection.any_kind());
let parsed_statements = match parse_sql(dialect.as_ref(), sql) {
Ok(parsed) => parsed,
Err(err) => return Self::from_err(err),
};
let mut statements = Vec::with_capacity(8);
for parsed in parsed_statements {
statements.push(match parsed {
ParsedStatement::StaticSimpleSelect(s) => ParsedSQLStatement::StaticSimpleSelect(s),
ParsedStatement::Error(e) => ParsedSQLStatement::Error(e),
ParsedStatement::StmtWithParams(stmt_with_params) => {
prepare_query_with_params(db, stmt_with_params).await
}
ParsedStatement::SetVariable { variable, value } => {
match prepare_query_with_params(db, value).await {
ParsedSQLStatement::Statement(value) => {
ParsedSQLStatement::SetVariable { variable, value }
}
err => err,
}
}
});
}
statements.shrink_to_fit();
let statements = parsed_statements.collect();
ParsedSqlFile { statements }
}

fn from_err(e: impl Into<anyhow::Error>) -> Self {
Self {
statements: vec![ParsedSQLStatement::Error(
statements: vec![ParsedStatement::Error(
e.into().context("SQLPage could not parse the SQL file"),
)],
}
}
}

async fn prepare_query_with_params(
db: &Database,
StmtWithParams { query, params }: StmtWithParams,
) -> ParsedSQLStatement {
let param_types = get_param_types(&params);
match db.prepare_with(&query, &param_types).await {
Ok(statement) => {
log::debug!("Successfully prepared SQL statement '{query}'");
ParsedSQLStatement::Statement(PreparedStatement {
statement,
parameters: params,
})
}
Err(err) => {
log::warn!("Failed to prepare {query:?}: {err:#}");
ParsedSQLStatement::Error(err.context(format!(
"The database returned an error when preparing this SQL statement: {query}"
)))
}
}
}

#[async_trait(? Send)]
impl AsyncFromStrWithState for ParsedSqlFile {
async fn from_str_with_state(app_state: &AppState, source: &str) -> anyhow::Result<Self> {
Ok(ParsedSqlFile::new(&app_state.db, source).await)
Ok(ParsedSqlFile::new(&app_state.db, source))
}
}

#[derive(Debug, PartialEq)]
struct StmtWithParams {
query: String,
params: Vec<StmtParam>,
pub(super) struct StmtWithParams {
pub query: String,
pub params: Vec<StmtParam>,
}

#[derive(Debug)]
enum ParsedStatement {
pub(super) enum ParsedStatement {
StmtWithParams(StmtWithParams),
StaticSimpleSelect(serde_json::Map<String, serde_json::Value>),
SetVariable {
Expand Down Expand Up @@ -201,13 +150,6 @@ fn kind_of_dialect(dialect: &dyn Dialect) -> AnyKind {
}
}

fn get_param_types(parameters: &[StmtParam]) -> Vec<AnyTypeInfo> {
parameters
.iter()
.map(|_p| <str as sqlx::Type<Postgres>>::type_info().into())
.collect()
}

fn map_param(mut name: String) -> StmtParam {
if name.is_empty() {
return StmtParam::GetOrPost(name);
Expand Down
7 changes: 7 additions & 0 deletions tests/sql_test_files/it_works_create_table.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
drop table if exists my_tmp_store;
create table my_tmp_store(x varchar(100));

insert into my_tmp_store(x) values ('It works !');

select 'card' as component;
select x as description from my_tmp_store;