From c87f916707d075f780554c5a9cb388604568207d Mon Sep 17 00:00:00 2001 From: Eyal Leshem Date: Sun, 2 Aug 2020 12:51:31 +0300 Subject: [PATCH] Enable dialect specific behaviours in the parser --- src/dialect/mod.rs | 59 ++++++++++++++++++++++++++++++++++++++- src/lib.rs | 1 + src/parser.rs | 29 +++++++++++-------- src/test_utils.rs | 6 ++-- tests/sqlparser_common.rs | 7 +++-- 5 files changed, 85 insertions(+), 17 deletions(-) diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index ff28314c8..91d69a33f 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -18,6 +18,7 @@ mod mysql; mod postgresql; mod sqlite; +use std::any::{Any, TypeId}; use std::fmt::Debug; pub use self::ansi::AnsiDialect; @@ -27,7 +28,15 @@ pub use self::mysql::MySqlDialect; pub use self::postgresql::PostgreSqlDialect; pub use self::sqlite::SQLiteDialect; -pub trait Dialect: Debug { +/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates +/// to `true` iff `parser.dialect` is one of the `Dialect`s specified. +macro_rules! dialect_of { + ( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => { + ($($parsed_dialect.dialect.is::<$dialect_type>())||+) + }; +} + +pub trait Dialect: Debug + Any { /// Determine if a character starts a quoted identifier. The default /// implementation, accepting "double quoted" ids is both ANSI-compliant /// and appropriate for most dialects (with the notable exception of @@ -41,3 +50,51 @@ pub trait Dialect: Debug { /// Determine if a character is a valid unquoted identifier character fn is_identifier_part(&self, ch: char) -> bool; } + +impl dyn Dialect { + #[inline] + pub fn is(&self) -> bool { + // borrowed from `Any` implementation + TypeId::of::() == self.type_id() + } +} + +#[cfg(test)] +mod tests { + use super::ansi::AnsiDialect; + use super::generic::GenericDialect; + use super::*; + + struct DialectHolder<'a> { + dialect: &'a dyn Dialect, + } + + #[test] + fn test_is_dialect() { + let generic_dialect: &dyn Dialect = &GenericDialect {}; + let ansi_dialect: &dyn Dialect = &AnsiDialect {}; + + let generic_holder = DialectHolder { + dialect: generic_dialect, + }; + let ansi_holder = DialectHolder { + dialect: ansi_dialect, + }; + + assert_eq!( + dialect_of!(generic_holder is GenericDialect | AnsiDialect), + true + ); + assert_eq!(dialect_of!(generic_holder is AnsiDialect), false); + + assert_eq!(dialect_of!(ansi_holder is AnsiDialect), true); + assert_eq!( + dialect_of!(ansi_holder is GenericDialect | AnsiDialect), + true + ); + assert_eq!( + dialect_of!(ansi_holder is GenericDialect | MsSqlDialect), + false + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index d25b24997..8c9b01702 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,7 @@ #![warn(clippy::all)] pub mod ast; +#[macro_use] pub mod dialect; pub mod parser; pub mod tokenizer; diff --git a/src/parser.rs b/src/parser.rs index 5f77b6691..5f113a070 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -15,9 +15,8 @@ use log::debug; use super::ast::*; -use super::dialect::keywords; use super::dialect::keywords::Keyword; -use super::dialect::Dialect; +use super::dialect::*; use super::tokenizer::*; use std::error::Error; use std::fmt; @@ -82,24 +81,28 @@ impl fmt::Display for ParserError { impl Error for ParserError {} -/// SQL Parser -pub struct Parser { +pub struct Parser<'a> { tokens: Vec, /// The index of the first unprocessed token in `self.tokens` index: usize, + dialect: &'a dyn Dialect, } -impl Parser { +impl<'a> Parser<'a> { /// Parse the specified tokens - pub fn new(tokens: Vec) -> Self { - Parser { tokens, index: 0 } + pub fn new(tokens: Vec, dialect: &'a dyn Dialect) -> Self { + Parser { + tokens, + index: 0, + dialect, + } } /// Parse a SQL statement and produce an Abstract Syntax Tree (AST) pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result, ParserError> { let mut tokenizer = Tokenizer::new(dialect, &sql); let tokens = tokenizer.tokenize()?; - let mut parser = Parser::new(tokens); + let mut parser = Parser::new(tokens, dialect); let mut stmts = Vec::new(); let mut expecting_statement_delimiter = false; debug!("Parsing sql '{}'...", sql); @@ -950,7 +953,7 @@ impl Parser { /// Parse a comma-separated list of 1+ items accepted by `F` pub fn parse_comma_separated(&mut self, mut f: F) -> Result, ParserError> where - F: FnMut(&mut Parser) -> Result, + F: FnMut(&mut Parser<'a>) -> Result, { let mut values = vec![]; loop { @@ -1285,10 +1288,14 @@ impl Parser { let expr = self.parse_expr()?; self.expect_token(&Token::RParen)?; ColumnOption::Check(expr) - } else if self.parse_keyword(Keyword::AUTO_INCREMENT) { + } else if self.parse_keyword(Keyword::AUTO_INCREMENT) + && dialect_of!(self is MySqlDialect | GenericDialect) + { // Support AUTO_INCREMENT for MySQL ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")]) - } else if self.parse_keyword(Keyword::AUTOINCREMENT) { + } else if self.parse_keyword(Keyword::AUTOINCREMENT) + && dialect_of!(self is SQLiteDialect | GenericDialect) + { // Support AUTOINCREMENT for SQLite ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTOINCREMENT")]) } else { diff --git a/src/test_utils.rs b/src/test_utils.rs index 4d4d35616..848ea0508 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -53,7 +53,7 @@ impl TestedDialects { self.one_of_identical_results(|dialect| { let mut tokenizer = Tokenizer::new(dialect, sql); let tokens = tokenizer.tokenize().unwrap(); - f(&mut Parser::new(tokens)) + f(&mut Parser::new(tokens, dialect)) }) } @@ -104,7 +104,9 @@ impl TestedDialects { /// Ensures that `sql` parses as an expression, and is not modified /// after a serialization round-trip. pub fn verified_expr(&self, sql: &str) -> Expr { - let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap(); + let ast = self + .run_parser_method(sql, |parser| parser.parse_expr()) + .unwrap(); assert_eq!(sql, &ast.to_string(), "round-tripping without changes"); ast } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index a96ed1838..f3234e999 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -22,7 +22,7 @@ use matches::assert_matches; use sqlparser::ast::*; use sqlparser::dialect::keywords::ALL_KEYWORDS; -use sqlparser::parser::{Parser, ParserError}; +use sqlparser::parser::ParserError; use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only}; #[test] @@ -147,13 +147,14 @@ fn parse_update() { #[test] fn parse_invalid_table_name() { - let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name); + let ast = all_dialects() + .run_parser_method("db.public..customer", |parser| parser.parse_object_name()); assert!(ast.is_err()); } #[test] fn parse_no_table_name() { - let ast = all_dialects().run_parser_method("", Parser::parse_object_name); + let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name()); assert!(ast.is_err()); }