Skip to content

Commit 49f997d

Browse files
committed
mssql: fix infinite recursion in parameter replacement, add tests, remove some duplication
1 parent 7c04f4c commit 49f997d

File tree

3 files changed

+64
-14
lines changed

3 files changed

+64
-14
lines changed

index.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ SELECT a.x || ' times ' || b.x as title,
100100
a.x || ' x ' || b.x || ' = ' || (a.x * b.x) as description,
101101
'This is basic math' as footer,
102102
'?x=' || a.x as link -- This is the interesting part. Each card has a link. When you click the card, the current page is reloaded with '?x=a' appended to the end of the URL
103-
FROM (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), (11), (12)) as a(x),
104-
(VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), (11), (12)) as b(x)
103+
FROM (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)) as a(x),
104+
(VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)) as b(x)
105105
WHERE -- The powerful thing is here
106106
$x IS NULL
107107
OR -- The syntax $x allows us to extract the value 'a' when the URL ends with '?x=a'. It will be null if the URL does not contain '?x='

mssql/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ USER 10001
1818
ENV SA_PASSWORD="Password123!"
1919
ENV ACCEPT_EULA="Y"
2020

21-
HEALTHCHECK --interval=10s --timeout=3s --start-period=10s --retries=10 \
21+
HEALTHCHECK --interval=10s --timeout=3s --start-period=15s --retries=10 \
2222
CMD sqlcmd -S localhost -U root -P "Password123!" -Q "SELECT 1" || exit 1
2323

2424
ENTRYPOINT ["/usr/config/entrypoint.sh"]

src/webserver/database/sql.rs

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use crate::file_cache::AsyncFromStrWithState;
44
use crate::{AppState, Database};
55
use async_trait::async_trait;
66
use sqlparser::ast::{
7-
DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Statement, Value,
8-
VisitMut, VisitorMut,
7+
BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName,
8+
Statement, Value, VisitMut, VisitorMut,
99
};
1010
use sqlparser::dialect::GenericDialect;
1111
use sqlparser::parser::{Parser, ParserError};
@@ -182,6 +182,10 @@ struct ParameterExtractor {
182182
parameters: Vec<StmtParam>,
183183
}
184184

185+
const PLACEHOLDER_PREFIXES: [(AnyKind, &str); 2] =
186+
[(AnyKind::Postgres, "$"), (AnyKind::Mssql, "@p")];
187+
const DEFAULT_PLACEHOLDER: &str = "?";
188+
185189
impl ParameterExtractor {
186190
fn extract_parameters(
187191
sql_ast: &mut sqlparser::ast::Statement,
@@ -200,7 +204,8 @@ impl ParameterExtractor {
200204
let data_type = match self.db_kind {
201205
// MySQL requires CAST(? AS CHAR) and does not understand CAST(? AS TEXT)
202206
AnyKind::MySql => DataType::Char(None),
203-
_ => DataType::Text,
207+
AnyKind::Postgres => DataType::Text,
208+
_ => DataType::Varchar(None),
204209
};
205210
let value = Expr::Value(Value::Placeholder(name));
206211
Expr::Cast {
@@ -220,6 +225,21 @@ impl ParameterExtractor {
220225
self.parameters.push(param);
221226
placeholder
222227
}
228+
229+
fn is_own_placeholder(&self, param: &str) -> bool {
230+
if let Some((_, prefix)) = PLACEHOLDER_PREFIXES
231+
.iter()
232+
.find(|(kind, _prefix)| *kind == self.db_kind)
233+
{
234+
if let Some(param) = param.strip_prefix(prefix) {
235+
if let Ok(index) = param.parse::<usize>() {
236+
return index <= self.parameters.len() + 1;
237+
}
238+
}
239+
return false;
240+
}
241+
param == DEFAULT_PLACEHOLDER
242+
}
223243
}
224244

225245
/** This is a helper struct to format a list of arguments for an error message. */
@@ -299,21 +319,20 @@ fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> {
299319

300320
#[inline]
301321
pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String {
302-
match db_kind {
303-
// Postgres only supports numbered parameters with $1, $2, etc.
304-
AnyKind::Postgres => format!("${arg_number}"),
305-
// MSSQL only supports named parameters with @p1, @p2, etc.
306-
AnyKind::Mssql => format!("@p{arg_number}"),
307-
_ => '?'.to_string(),
322+
if let Some((_, prefix)) = PLACEHOLDER_PREFIXES
323+
.iter()
324+
.find(|(kind, _)| *kind == db_kind)
325+
{
326+
return format!("{prefix}{arg_number}");
308327
}
328+
DEFAULT_PLACEHOLDER.to_string()
309329
}
310330

311331
impl VisitorMut for ParameterExtractor {
312332
type Break = ();
313333
fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow<Self::Break> {
314334
match value {
315-
Expr::Value(Value::Placeholder(param))
316-
if param.chars().nth(1).is_some_and(char::is_alphabetic) =>
335+
Expr::Value(Value::Placeholder(param)) if !self.is_own_placeholder(param) =>
317336
// this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves
318337
{
319338
let new_expr = self.make_placeholder();
@@ -334,6 +353,26 @@ impl VisitorMut for ParameterExtractor {
334353
let arguments = std::mem::take(args);
335354
*value = self.handle_builtin_function(func_name, arguments);
336355
}
356+
// Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL
357+
Expr::BinaryOp {
358+
left,
359+
op: BinaryOperator::StringConcat,
360+
right,
361+
} if self.db_kind == AnyKind::Mssql => {
362+
let left = std::mem::replace(left.as_mut(), Expr::Value(Value::Null));
363+
let right = std::mem::replace(right.as_mut(), Expr::Value(Value::Null));
364+
*value = Expr::Function(Function {
365+
name: ObjectName(vec![Ident::new("CONCAT")]),
366+
args: vec![
367+
FunctionArg::Unnamed(FunctionArgExpr::Expr(left)),
368+
FunctionArg::Unnamed(FunctionArgExpr::Expr(right)),
369+
],
370+
over: None,
371+
distinct: false,
372+
special: false,
373+
order_by: vec![],
374+
});
375+
}
337376
_ => (),
338377
}
339378
ControlFlow::<()>::Continue(())
@@ -390,6 +429,17 @@ mod test {
390429
);
391430
}
392431

432+
#[test]
433+
fn test_mssql_statement_rewrite() {
434+
let mut ast = parse_stmt("select '' || $1 from t");
435+
let parameters = ParameterExtractor::extract_parameters(&mut ast, AnyKind::Mssql);
436+
assert_eq!(
437+
ast.to_string(),
438+
"SELECT CONCAT('', CAST(@p1 AS VARCHAR)) FROM t"
439+
);
440+
assert_eq!(parameters, [StmtParam::GetOrPost("1".to_string()),]);
441+
}
442+
393443
#[test]
394444
fn test_static_extract() {
395445
assert_eq!(

0 commit comments

Comments
 (0)