Skip to content

Commit 1a56d86

Browse files
committed
Enable dialect specific behaviours in the parser
1 parent 1cc3bf4 commit 1a56d86

File tree

10 files changed

+78
-12
lines changed

10 files changed

+78
-12
lines changed

src/dialect/ansi.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,8 @@ impl Dialect for AnsiDialect {
2626
|| (ch >= '0' && ch <= '9')
2727
|| ch == '_'
2828
}
29+
30+
fn dialect_name(&self) -> &'static str {
31+
"ansi"
32+
}
2933
}

src/dialect/generic.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,8 @@ impl Dialect for GenericDialect {
2929
|| ch == '#'
3030
|| ch == '_'
3131
}
32+
33+
fn dialect_name(&self) -> &'static str {
34+
"generic"
35+
}
3236
}

src/dialect/mod.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,37 @@ pub trait Dialect: Debug {
4040
fn is_identifier_start(&self, ch: char) -> bool;
4141
/// Determine if a character is a valid unquoted identifier character
4242
fn is_identifier_part(&self, ch: char) -> bool;
43+
44+
/// The name of the dialect
45+
fn dialect_name(&self) -> &'static str;
46+
47+
/// Enable the parser to implement dialect specific functionality.
48+
/// The input for this function a list dialect names.
49+
/// Function will return true if the current dialect is a subset of the input.
50+
///
51+
/// parser usage exmple:
52+
/// `if self.dialect.is_dialect(vec!["mssql"]) {
53+
/// // some special mssql behaviour
54+
/// } else {
55+
/// // defualt bahviour
56+
/// }`
57+
fn is_dialect(&self, dialects: Vec<&str>) -> bool {
58+
dialects.contains(&self.dialect_name())
59+
}
60+
}
61+
62+
#[cfg(test)]
63+
mod tests {
64+
use super::generic::GenericDialect;
65+
use super::*;
66+
67+
#[test]
68+
fn test_is_diaclect() {
69+
let generic_dailect = GenericDialect {};
70+
71+
assert_eq!(generic_dailect.is_dialect(vec!["generic"]), true);
72+
assert_eq!(generic_dailect.is_dialect(vec!["generic", "mssql"]), true);
73+
assert_eq!(generic_dailect.is_dialect(vec!["mssql"]), false);
74+
assert_eq!(generic_dailect.is_dialect(vec!["mssql", "mysql"]), false);
75+
}
4376
}

src/dialect/mssql.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ impl Dialect for MsSqlDialect {
3535
|| ch == '#'
3636
|| ch == '_'
3737
}
38+
39+
fn dialect_name(&self) -> &'static str {
40+
"mssql"
41+
}
3842
}

src/dialect/mysql.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,8 @@ impl Dialect for MySqlDialect {
3434
fn is_delimited_identifier_start(&self, ch: char) -> bool {
3535
ch == '`'
3636
}
37+
38+
fn dialect_name(&self) -> &'static str {
39+
"mysql"
40+
}
3741
}

src/dialect/postgresql.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,8 @@ impl Dialect for PostgreSqlDialect {
3030
|| ch == '$'
3131
|| ch == '_'
3232
}
33+
34+
fn dialect_name(&self) -> &'static str {
35+
"postgresql"
36+
}
3337
}

src/dialect/sqlite.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ impl Dialect for SQLiteDialect {
3535
fn is_identifier_part(&self, ch: char) -> bool {
3636
self.is_identifier_start(ch) || (ch >= '0' && ch <= '9')
3737
}
38+
39+
fn dialect_name(&self) -> &'static str {
40+
"sqllite"
41+
}
3842
}

src/parser.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,30 @@ impl fmt::Display for ParserError {
8282

8383
impl Error for ParserError {}
8484

85-
/// SQL Parser
86-
pub struct Parser {
85+
//TODO: remove dead_code annotation when dialect will be in use
86+
#[allow(dead_code)]
87+
pub struct Parser<'a> {
8788
tokens: Vec<Token>,
8889
/// The index of the first unprocessed token in `self.tokens`
8990
index: usize,
91+
dialect: &'a dyn Dialect,
9092
}
9193

92-
impl Parser {
94+
impl<'a> Parser<'a> {
9395
/// Parse the specified tokens
94-
pub fn new(tokens: Vec<Token>) -> Self {
95-
Parser { tokens, index: 0 }
96+
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
97+
Parser {
98+
tokens,
99+
index: 0,
100+
dialect,
101+
}
96102
}
97103

98104
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
99105
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
100106
let mut tokenizer = Tokenizer::new(dialect, &sql);
101107
let tokens = tokenizer.tokenize()?;
102-
let mut parser = Parser::new(tokens);
108+
let mut parser = Parser::new(tokens, dialect);
103109
let mut stmts = Vec::new();
104110
let mut expecting_statement_delimiter = false;
105111
debug!("Parsing sql '{}'...", sql);
@@ -950,7 +956,7 @@ impl Parser {
950956
/// Parse a comma-separated list of 1+ items accepted by `F`
951957
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
952958
where
953-
F: FnMut(&mut Parser) -> Result<T, ParserError>,
959+
F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
954960
{
955961
let mut values = vec![];
956962
loop {

src/test_utils.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl TestedDialects {
5353
self.one_of_identical_results(|dialect| {
5454
let mut tokenizer = Tokenizer::new(dialect, sql);
5555
let tokens = tokenizer.tokenize().unwrap();
56-
f(&mut Parser::new(tokens))
56+
f(&mut Parser::new(tokens, dialect))
5757
})
5858
}
5959

@@ -104,7 +104,9 @@ impl TestedDialects {
104104
/// Ensures that `sql` parses as an expression, and is not modified
105105
/// after a serialization round-trip.
106106
pub fn verified_expr(&self, sql: &str) -> Expr {
107-
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap();
107+
let ast = self
108+
.run_parser_method(sql, |parser| parser.parse_expr())
109+
.unwrap();
108110
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
109111
ast
110112
}

tests/sqlparser_common.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use matches::assert_matches;
2222

2323
use sqlparser::ast::*;
2424
use sqlparser::dialect::keywords::ALL_KEYWORDS;
25-
use sqlparser::parser::{Parser, ParserError};
25+
use sqlparser::parser::ParserError;
2626
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};
2727

2828
#[test]
@@ -147,13 +147,14 @@ fn parse_update() {
147147

148148
#[test]
149149
fn parse_invalid_table_name() {
150-
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);
150+
let ast = all_dialects()
151+
.run_parser_method("db.public..customer", |parser| parser.parse_object_name());
151152
assert!(ast.is_err());
152153
}
153154

154155
#[test]
155156
fn parse_no_table_name() {
156-
let ast = all_dialects().run_parser_method("", Parser::parse_object_name);
157+
let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name());
157158
assert!(ast.is_err());
158159
}
159160

0 commit comments

Comments
 (0)