Skip to content

Commit e8b1dd5

Browse files
committed
Enable dialect specific behaviours in the parser
1 parent 1cc3bf4 commit e8b1dd5

File tree

5 files changed

+72
-17
lines changed

5 files changed

+72
-17
lines changed

src/dialect/mod.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ mod mysql;
1818
mod postgresql;
1919
mod sqlite;
2020

21+
use std::any::{Any, TypeId};
2122
use std::fmt::Debug;
2223

2324
pub use self::ansi::AnsiDialect;
@@ -27,7 +28,23 @@ pub use self::mysql::MySqlDialect;
2728
pub use self::postgresql::PostgreSqlDialect;
2829
pub use self::sqlite::SQLiteDialect;
2930

30-
pub trait Dialect: Debug {
31+
/// Determine if the specified dialect matched to
32+
/// the parsed dialect , used for dialect spefic behaviour
33+
macro_rules! is_dialect {
34+
($dialect :expr,$t: ty) => {
35+
$dialect.is_dialect(TypeId::of::<$t>())
36+
};
37+
}
38+
39+
// Same is is_dialect but return true
40+
// also when Genric dialect is used
41+
macro_rules! is_dialect_or_genric {
42+
($dialect :expr,$t: ty) => {
43+
(is_dialect!($dialect, $t) || is_dialect!($dialect, GenericDialect))
44+
};
45+
}
46+
47+
pub trait Dialect: Debug + Any {
3148
/// Determine if a character starts a quoted identifier. The default
3249
/// implementation, accepting "double quoted" ids is both ANSI-compliant
3350
/// and appropriate for most dialects (with the notable exception of
@@ -40,4 +57,30 @@ pub trait Dialect: Debug {
4057
fn is_identifier_start(&self, ch: char) -> bool;
4158
/// Determine if a character is a valid unquoted identifier character
4259
fn is_identifier_part(&self, ch: char) -> bool;
60+
61+
fn is_dialect(&self, type_id: TypeId) -> bool {
62+
type_id == self.type_id()
63+
}
64+
}
65+
66+
#[cfg(test)]
67+
mod tests {
68+
use super::ansi::AnsiDialect;
69+
use super::generic::GenericDialect;
70+
use super::*;
71+
72+
#[test]
73+
fn test_is_diaclect() {
74+
let generic_dailect = GenericDialect {};
75+
let ansi_dialect = AnsiDialect {};
76+
let mssql_dialect = MySqlDialect {};
77+
78+
assert_eq!(is_dialect!(generic_dailect, GenericDialect), true);
79+
assert_eq!(is_dialect!(generic_dailect, AnsiDialect), false);
80+
assert_eq!(is_dialect!(ansi_dialect, AnsiDialect), true);
81+
82+
assert_eq!(is_dialect_or_genric!(ansi_dialect, AnsiDialect), true);
83+
assert_eq!(is_dialect_or_genric!(mssql_dialect, AnsiDialect), false);
84+
assert_eq!(is_dialect_or_genric!(generic_dailect, AnsiDialect), true);
85+
}
4386
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#![warn(clippy::all)]
3636

3737
pub mod ast;
38+
#[macro_use]
3839
pub mod dialect;
3940
pub mod parser;
4041
pub mod tokenizer;

src/parser.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
use log::debug;
1616

1717
use super::ast::*;
18-
use super::dialect::keywords;
1918
use super::dialect::keywords::Keyword;
20-
use super::dialect::Dialect;
19+
use super::dialect::*;
2120
use super::tokenizer::*;
21+
use std::any::TypeId;
2222
use std::error::Error;
2323
use std::fmt;
2424

@@ -82,24 +82,28 @@ impl fmt::Display for ParserError {
8282

8383
impl Error for ParserError {}
8484

85-
/// SQL Parser
86-
pub struct Parser {
85+
pub struct Parser<'a> {
8786
tokens: Vec<Token>,
8887
/// The index of the first unprocessed token in `self.tokens`
8988
index: usize,
89+
dialect: &'a dyn Dialect,
9090
}
9191

92-
impl Parser {
92+
impl<'a> Parser<'a> {
9393
/// Parse the specified tokens
94-
pub fn new(tokens: Vec<Token>) -> Self {
95-
Parser { tokens, index: 0 }
94+
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
95+
Parser {
96+
tokens,
97+
index: 0,
98+
dialect,
99+
}
96100
}
97101

98102
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
99103
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
100104
let mut tokenizer = Tokenizer::new(dialect, &sql);
101105
let tokens = tokenizer.tokenize()?;
102-
let mut parser = Parser::new(tokens);
106+
let mut parser = Parser::new(tokens, dialect);
103107
let mut stmts = Vec::new();
104108
let mut expecting_statement_delimiter = false;
105109
debug!("Parsing sql '{}'...", sql);
@@ -950,7 +954,7 @@ impl Parser {
950954
/// Parse a comma-separated list of 1+ items accepted by `F`
951955
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
952956
where
953-
F: FnMut(&mut Parser) -> Result<T, ParserError>,
957+
F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
954958
{
955959
let mut values = vec![];
956960
loop {
@@ -1285,10 +1289,14 @@ impl Parser {
12851289
let expr = self.parse_expr()?;
12861290
self.expect_token(&Token::RParen)?;
12871291
ColumnOption::Check(expr)
1288-
} else if self.parse_keyword(Keyword::AUTO_INCREMENT) {
1292+
} else if self.parse_keyword(Keyword::AUTO_INCREMENT)
1293+
&& is_dialect_or_genric!(self.dialect, MySqlDialect)
1294+
{
12891295
// Support AUTO_INCREMENT for MySQL
12901296
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")])
1291-
} else if self.parse_keyword(Keyword::AUTOINCREMENT) {
1297+
} else if self.parse_keyword(Keyword::AUTOINCREMENT)
1298+
&& is_dialect_or_genric!(self.dialect, SQLiteDialect)
1299+
{
12921300
// Support AUTOINCREMENT for SQLite
12931301
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTOINCREMENT")])
12941302
} else {

src/test_utils.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl TestedDialects {
5353
self.one_of_identical_results(|dialect| {
5454
let mut tokenizer = Tokenizer::new(dialect, sql);
5555
let tokens = tokenizer.tokenize().unwrap();
56-
f(&mut Parser::new(tokens))
56+
f(&mut Parser::new(tokens, dialect))
5757
})
5858
}
5959

@@ -104,7 +104,9 @@ impl TestedDialects {
104104
/// Ensures that `sql` parses as an expression, and is not modified
105105
/// after a serialization round-trip.
106106
pub fn verified_expr(&self, sql: &str) -> Expr {
107-
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap();
107+
let ast = self
108+
.run_parser_method(sql, |parser| parser.parse_expr())
109+
.unwrap();
108110
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
109111
ast
110112
}

tests/sqlparser_common.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use matches::assert_matches;
2222

2323
use sqlparser::ast::*;
2424
use sqlparser::dialect::keywords::ALL_KEYWORDS;
25-
use sqlparser::parser::{Parser, ParserError};
25+
use sqlparser::parser::ParserError;
2626
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};
2727

2828
#[test]
@@ -147,13 +147,14 @@ fn parse_update() {
147147

148148
#[test]
149149
fn parse_invalid_table_name() {
150-
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);
150+
let ast = all_dialects()
151+
.run_parser_method("db.public..customer", |parser| parser.parse_object_name());
151152
assert!(ast.is_err());
152153
}
153154

154155
#[test]
155156
fn parse_no_table_name() {
156-
let ast = all_dialects().run_parser_method("", Parser::parse_object_name);
157+
let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name());
157158
assert!(ast.is_err());
158159
}
159160

0 commit comments

Comments
 (0)