@@ -4,8 +4,8 @@ use crate::file_cache::AsyncFromStrWithState;
4
4
use crate :: { AppState , Database } ;
5
5
use async_trait:: async_trait;
6
6
use sqlparser:: ast:: {
7
- DataType , Expr , Function , FunctionArg , FunctionArgExpr , Ident , ObjectName , Statement , Value ,
8
- VisitMut , VisitorMut ,
7
+ BinaryOperator , DataType , Expr , Function , FunctionArg , FunctionArgExpr , Ident , ObjectName ,
8
+ Statement , Value , VisitMut , VisitorMut ,
9
9
} ;
10
10
use sqlparser:: dialect:: GenericDialect ;
11
11
use sqlparser:: parser:: { Parser , ParserError } ;
@@ -182,6 +182,10 @@ struct ParameterExtractor {
182
182
parameters : Vec < StmtParam > ,
183
183
}
184
184
185
+ const PLACEHOLDER_PREFIXES : [ ( AnyKind , & str ) ; 2 ] =
186
+ [ ( AnyKind :: Postgres , "$" ) , ( AnyKind :: Mssql , "@p" ) ] ;
187
+ const DEFAULT_PLACEHOLDER : & str = "?" ;
188
+
185
189
impl ParameterExtractor {
186
190
fn extract_parameters (
187
191
sql_ast : & mut sqlparser:: ast:: Statement ,
@@ -200,7 +204,8 @@ impl ParameterExtractor {
200
204
let data_type = match self . db_kind {
201
205
// MySQL requires CAST(? AS CHAR) and does not understand CAST(? AS TEXT)
202
206
AnyKind :: MySql => DataType :: Char ( None ) ,
203
- _ => DataType :: Text ,
207
+ AnyKind :: Postgres => DataType :: Text ,
208
+ _ => DataType :: Varchar ( None ) ,
204
209
} ;
205
210
let value = Expr :: Value ( Value :: Placeholder ( name) ) ;
206
211
Expr :: Cast {
@@ -220,6 +225,21 @@ impl ParameterExtractor {
220
225
self . parameters . push ( param) ;
221
226
placeholder
222
227
}
228
+
229
+ fn is_own_placeholder ( & self , param : & str ) -> bool {
230
+ if let Some ( ( _, prefix) ) = PLACEHOLDER_PREFIXES
231
+ . iter ( )
232
+ . find ( |( kind, _prefix) | * kind == self . db_kind )
233
+ {
234
+ if let Some ( param) = param. strip_prefix ( prefix) {
235
+ if let Ok ( index) = param. parse :: < usize > ( ) {
236
+ return index <= self . parameters . len ( ) + 1 ;
237
+ }
238
+ }
239
+ return false ;
240
+ }
241
+ param == DEFAULT_PLACEHOLDER
242
+ }
223
243
}
224
244
225
245
/** This is a helper struct to format a list of arguments for an error message. */
@@ -299,21 +319,20 @@ fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> {
299
319
300
320
#[ inline]
301
321
pub fn make_placeholder ( db_kind : AnyKind , arg_number : usize ) -> String {
302
- match db_kind {
303
- // Postgres only supports numbered parameters with $1, $2, etc.
304
- AnyKind :: Postgres => format ! ( "${arg_number}" ) ,
305
- // MSSQL only supports named parameters with @p1, @p2, etc.
306
- AnyKind :: Mssql => format ! ( "@p{arg_number}" ) ,
307
- _ => '?' . to_string ( ) ,
322
+ if let Some ( ( _, prefix) ) = PLACEHOLDER_PREFIXES
323
+ . iter ( )
324
+ . find ( |( kind, _) | * kind == db_kind)
325
+ {
326
+ return format ! ( "{prefix}{arg_number}" ) ;
308
327
}
328
+ DEFAULT_PLACEHOLDER . to_string ( )
309
329
}
310
330
311
331
impl VisitorMut for ParameterExtractor {
312
332
type Break = ( ) ;
313
333
fn pre_visit_expr ( & mut self , value : & mut Expr ) -> ControlFlow < Self :: Break > {
314
334
match value {
315
- Expr :: Value ( Value :: Placeholder ( param) )
316
- if param. chars ( ) . nth ( 1 ) . is_some_and ( char:: is_alphabetic) =>
335
+ Expr :: Value ( Value :: Placeholder ( param) ) if !self . is_own_placeholder ( param) =>
317
336
// this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves
318
337
{
319
338
let new_expr = self . make_placeholder ( ) ;
@@ -334,6 +353,26 @@ impl VisitorMut for ParameterExtractor {
334
353
let arguments = std:: mem:: take ( args) ;
335
354
* value = self . handle_builtin_function ( func_name, arguments) ;
336
355
}
356
+ // Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL
357
+ Expr :: BinaryOp {
358
+ left,
359
+ op : BinaryOperator :: StringConcat ,
360
+ right,
361
+ } if self . db_kind == AnyKind :: Mssql => {
362
+ let left = std:: mem:: replace ( left. as_mut ( ) , Expr :: Value ( Value :: Null ) ) ;
363
+ let right = std:: mem:: replace ( right. as_mut ( ) , Expr :: Value ( Value :: Null ) ) ;
364
+ * value = Expr :: Function ( Function {
365
+ name : ObjectName ( vec ! [ Ident :: new( "CONCAT" ) ] ) ,
366
+ args : vec ! [
367
+ FunctionArg :: Unnamed ( FunctionArgExpr :: Expr ( left) ) ,
368
+ FunctionArg :: Unnamed ( FunctionArgExpr :: Expr ( right) ) ,
369
+ ] ,
370
+ over : None ,
371
+ distinct : false ,
372
+ special : false ,
373
+ order_by : vec ! [ ] ,
374
+ } ) ;
375
+ }
337
376
_ => ( ) ,
338
377
}
339
378
ControlFlow :: < ( ) > :: Continue ( ( ) )
@@ -390,6 +429,17 @@ mod test {
390
429
) ;
391
430
}
392
431
432
+ #[ test]
433
+ fn test_mssql_statement_rewrite ( ) {
434
+ let mut ast = parse_stmt ( "select '' || $1 from t" ) ;
435
+ let parameters = ParameterExtractor :: extract_parameters ( & mut ast, AnyKind :: Mssql ) ;
436
+ assert_eq ! (
437
+ ast. to_string( ) ,
438
+ "SELECT CONCAT('', CAST(@p1 AS VARCHAR)) FROM t"
439
+ ) ;
440
+ assert_eq ! ( parameters, [ StmtParam :: GetOrPost ( "1" . to_string( ) ) , ] ) ;
441
+ }
442
+
393
443
#[ test]
394
444
fn test_static_extract ( ) {
395
445
assert_eq ! (
0 commit comments