diff --git a/src/ast/query.rs b/src/ast/query.rs index 73477b126..5f07c5547 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -235,11 +235,11 @@ pub enum TableFactor { subquery: Box, alias: Option, }, - /// Represents a parenthesized table factor. The SQL spec only allows a - /// join expression (`(foo bar [ baz ... ])`) to be nested, - /// possibly several times, but the parser also accepts the non-standard - /// nesting of bare tables (`table_with_joins.joins.is_empty()`), so the - /// name `NestedJoin` is a bit of misnomer. + /// The inner `TableWithJoins` can have no joins only if its + /// `relation` is itself a `TableFactor::NestedJoin`. + /// Some dialects allow nesting lone `Table`/`Derived` in parens, + /// e.g. `FROM (mytable)`, but we don't expose the presence of these + /// extraneous parens in the AST. NestedJoin(Box), } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index c9ddbedd3..a64547d9a 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -16,7 +16,7 @@ pub mod keywords; mod mssql; mod mysql; mod postgresql; - +mod snowflake; use std::fmt::Debug; pub use self::ansi::AnsiDialect; @@ -24,6 +24,7 @@ pub use self::generic::GenericDialect; pub use self::mssql::MsSqlDialect; pub use self::mysql::MySqlDialect; pub use self::postgresql::PostgreSqlDialect; +pub use self::snowflake::SnowflakeDialect; pub trait Dialect: Debug { /// Determine if a character starts a quoted identifier. The default @@ -38,4 +39,8 @@ pub trait Dialect: Debug { fn is_identifier_start(&self, ch: char) -> bool; /// Determine if a character is a valid unquoted identifier character fn is_identifier_part(&self, ch: char) -> bool; + + fn alllow_single_table_in_parenthesis(&self) -> bool { + false + } } diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs new file mode 100644 index 000000000..e6acd61eb --- /dev/null +++ b/src/dialect/snowflake.rs @@ -0,0 +1,26 @@ +use crate::dialect::Dialect; + +#[derive(Debug, Default)] +pub struct SnowflakeDialect; + +impl Dialect for SnowflakeDialect { + //Revisit: currently copied from Genric dialect + fn is_identifier_start(&self, ch: char) -> bool { + (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' || ch == '#' || ch == '@' + } + + //Revisit: currently copied from Genric dialect + fn is_identifier_part(&self, ch: char) -> bool { + (ch >= 'a' && ch <= 'z') + || (ch >= 'A' && ch <= 'Z') + || (ch >= '0' && ch <= '9') + || ch == '@' + || ch == '$' + || ch == '#' + || ch == '_' + } + + fn alllow_single_table_in_parenthesis(&self) -> bool { + true + } +} diff --git a/src/parser.rs b/src/parser.rs index b58bdc5c6..b69dd8d59 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -83,23 +83,28 @@ impl fmt::Display for ParserError { impl Error for ParserError {} /// SQL Parser -pub struct Parser { +pub struct Parser<'a> { tokens: Vec, /// The index of the first unprocessed token in `self.tokens` index: usize, + dialect: &'a dyn Dialect, } -impl Parser { +impl<'a> Parser<'a> { /// Parse the specified tokens - pub fn new(tokens: Vec) -> Self { - Parser { tokens, index: 0 } + pub fn new(tokens: Vec, dialect: &'a dyn Dialect) -> Self { + Parser { + tokens, + index: 0, + dialect, + } } /// Parse a SQL statement and produce an Abstract Syntax Tree (AST) pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result, ParserError> { let mut tokenizer = Tokenizer::new(dialect, &sql); let tokens = tokenizer.tokenize()?; - let mut parser = Parser::new(tokens); + let mut parser = Parser::new(tokens, dialect); let mut stmts = Vec::new(); let mut expecting_statement_delimiter = false; debug!("Parsing sql '{}'...", sql); @@ -950,7 +955,7 @@ impl Parser { /// Parse a comma-separated list of 1+ items accepted by `F` pub fn parse_comma_separated(&mut self, mut f: F) -> Result, ParserError> where - F: FnMut(&mut Parser) -> Result, + F: FnMut(&mut Parser<'a>) -> Result, { let mut values = vec![]; loop { @@ -2056,6 +2061,7 @@ impl Parser { }; joins.push(join); } + Ok(TableWithJoins { relation, joins }) } @@ -2098,14 +2104,56 @@ impl Parser { // recently consumed does not start a derived table (cases 1, 2, or 4). // `maybe_parse` will ignore such an error and rewind to be after the opening '('. - // Inside the parentheses we expect to find a table factor - // followed by some joins or another level of nesting. - let table_and_joins = self.parse_table_and_joins()?; - self.expect_token(&Token::RParen)?; - // The SQL spec prohibits derived and bare tables from appearing - // alone in parentheses. We don't enforce this as some databases - // (e.g. Snowflake) allow such syntax. - Ok(TableFactor::NestedJoin(Box::new(table_and_joins))) + // Inside the parentheses we expect to find an (A) table factor + // followed by some joins or (B) another level of nesting. + let mut table_and_joins = self.parse_table_and_joins()?; + + if !table_and_joins.joins.is_empty() { + self.expect_token(&Token::RParen)?; + Ok(TableFactor::NestedJoin(Box::new(table_and_joins))) // (A) + } else if let TableFactor::NestedJoin(_) = &table_and_joins.relation { + // (B): `table_and_joins` (what we found inside the parentheses) + // is a nested join `(foo JOIN bar)`, not followed by other joins. + self.expect_token(&Token::RParen)?; + Ok(TableFactor::NestedJoin(Box::new(table_and_joins))) + } else if self.dialect.alllow_single_table_in_parenthesis() { + // Dialect-specific behavior: Snowflake diverges from the + // standard and most of other implementations by allowing + // extra parentheses not only around a join (B), but around + // lone table names (e.g. `FROM (mytable [AS alias])`) and + // around derived tables (e.g. `FROM ((SELECT ...) [AS alias])` + // as well. + self.expect_token(&Token::RParen)?; + + if let Some(outer_alias) = + self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)? + { + // Snowflake also allows specifying an alias *after* parens + // e.g. `FROM (mytable) AS alias` + match &mut table_and_joins.relation { + TableFactor::Derived { alias, .. } | TableFactor::Table { alias, .. } => { + // but not `FROM (mytable AS alias1) AS alias2`. + if let Some(inner_alias) = alias { + return Err(ParserError::ParserError(format!( + "duplicate alias {}", + inner_alias + ))); + } + // Act as if the alias was specified normally next + // to the table name: `(mytable) AS alias` -> + // `(mytable AS alias)` + alias.replace(outer_alias); + } + TableFactor::NestedJoin(_) => unreachable!(), + }; + } + // Do not store the extra set of parens in the AST + Ok(table_and_joins.relation) + } else { + // The SQL spec prohibits derived tables and bare tables from + // appearing alone in parentheses (e.g. `FROM (mytable)`) + self.expected("joined table", self.peek_token()) + } } else { let name = self.parse_object_name()?; // Postgres, MSSQL: table-valued functions: diff --git a/src/test_utils.rs b/src/test_utils.rs index 4d4d35616..848ea0508 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -53,7 +53,7 @@ impl TestedDialects { self.one_of_identical_results(|dialect| { let mut tokenizer = Tokenizer::new(dialect, sql); let tokens = tokenizer.tokenize().unwrap(); - f(&mut Parser::new(tokens)) + f(&mut Parser::new(tokens, dialect)) }) } @@ -104,7 +104,9 @@ impl TestedDialects { /// Ensures that `sql` parses as an expression, and is not modified /// after a serialization round-trip. pub fn verified_expr(&self, sql: &str) -> Expr { - let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap(); + let ast = self + .run_parser_method(sql, |parser| parser.parse_expr()) + .unwrap(); assert_eq!(sql, &ast.to_string(), "round-tripping without changes"); ast } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 1a1c3ad87..1e8e8d13a 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -22,7 +22,7 @@ use matches::assert_matches; use sqlparser::ast::*; use sqlparser::dialect::keywords::ALL_KEYWORDS; -use sqlparser::parser::{Parser, ParserError}; +use sqlparser::parser::ParserError; use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only}; #[test] @@ -147,13 +147,14 @@ fn parse_update() { #[test] fn parse_invalid_table_name() { - let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name); + let ast = all_dialects() + .run_parser_method("db.public..customer", |parser| parser.parse_object_name()); assert!(ast.is_err()); } #[test] fn parse_no_table_name() { - let ast = all_dialects().run_parser_method("", Parser::parse_object_name); + let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name()); assert!(ast.is_err()); } @@ -2273,19 +2274,12 @@ fn parse_join_nesting() { vec![join(nest!(nest!(nest!(table("b"), table("c")))))] ); - // Parenthesized table names are non-standard, but supported in Snowflake SQL - let sql = "SELECT * FROM (a NATURAL JOIN (b))"; - let select = verified_only_select(sql); - let from = only(select.from); - - assert_eq!(from.relation, nest!(table("a"), nest!(table("b")))); - - // Double parentheses around table names are non-standard, but supported in Snowflake SQL - let sql = "SELECT * FROM (a NATURAL JOIN ((b)))"; - let select = verified_only_select(sql); - let from = only(select.from); - - assert_eq!(from.relation, nest!(table("a"), nest!(nest!(table("b"))))); + // Nesting a subquery in parentheses is non-standard, but supported in Snowflake SQL + let res = parse_sql_statements("SELECT * FROM ((SELECT 1) AS t)"); + assert_eq!( + ParserError::ParserError("Expected joined table, found: )".to_string()), + res.unwrap_err() + ); } #[test] @@ -2427,26 +2421,6 @@ fn parse_derived_tables() { }], })) ); - - // Nesting a subquery in parentheses is non-standard, but supported in Snowflake SQL - let sql = "SELECT * FROM ((SELECT 1) AS t)"; - let select = verified_only_select(sql); - let from = only(select.from); - - assert_eq!( - from.relation, - TableFactor::NestedJoin(Box::new(TableWithJoins { - relation: TableFactor::Derived { - lateral: false, - subquery: Box::new(verified_query("SELECT 1")), - alias: Some(TableAlias { - name: "t".into(), - columns: vec![], - }) - }, - joins: Vec::new(), - })) - ); } #[test] diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs new file mode 100644 index 000000000..6d43f273b --- /dev/null +++ b/tests/sqlparser_snowflake.rs @@ -0,0 +1,152 @@ +use sqlparser::ast::*; +use sqlparser::dialect::SnowflakeDialect; +use sqlparser::parser::ParserError; +use sqlparser::test_utils::*; + +fn table_alias(alias: &str) -> TableAlias { + TableAlias { + name: Ident { + value: alias.to_owned(), + quote_style: None, + }, + columns: Vec::new(), + } +} + +fn table(name: impl Into, alias: Option) -> TableFactor { + TableFactor::Table { + name: ObjectName(vec![Ident::new(name.into())]), + alias, + args: vec![], + with_hints: vec![], + } +} + +fn join(relation: TableFactor) -> Join { + Join { + relation, + join_operator: JoinOperator::Inner(JoinConstraint::Natural), + } +} + +macro_rules! nest { + ($base:expr $(, $join:expr)*) => { + TableFactor::NestedJoin(Box::new(TableWithJoins { + relation: $base, + joins: vec![$(join($join)),*] + })) + }; +} + +fn sf() -> TestedDialects { + TestedDialects { + dialects: vec![Box::new(SnowflakeDialect {})], + } +} + +fn get_from_section_from_select_query(query: &str) -> Vec { + let statement = sf().parse_sql_statements(query).unwrap()[0].clone(); + + let query = match statement { + Statement::Query(query) => query, + _ => panic!("Not a query"), + }; + + let select = match query.body { + SetExpr::Select(select) => select, + _ => panic!("not a select query"), + }; + + select.from.clone() +} + +#[test] +fn test_sf_derives_single_table_in_parenthesis() { + let from = get_from_section_from_select_query("SELECT * FROM (((SELECT 1) AS t))"); + + assert_eq!( + from[0].relation, + TableFactor::Derived { + lateral: false, + subquery: Box::new(sf().verified_query("SELECT 1")), + alias: Some(TableAlias { + name: "t".into(), + columns: vec![], + }) + } + ); +} + +#[test] +fn test_single_table_in_parenthesis() { + //Parenthesized table names are non-standard, but supported in Snowflake SQL + let from = get_from_section_from_select_query("SELECT * FROM (a NATURAL JOIN (b))"); + + assert_eq!(from[0].relation, nest!(table("a", None), table("b", None))); + + let from = get_from_section_from_select_query("SELECT * FROM (a NATURAL JOIN ((b)))"); + assert_eq!(from[0].relation, nest!(table("a", None), table("b", None))); +} + +#[test] +fn test_single_table_in_parenthesis_with_alias() { + let sql = "SELECT * FROM (a NATURAL JOIN (b) c )"; + let table_with_joins = get_from_section_from_select_query(sql)[0].clone(); + assert_eq!( + table_with_joins.relation, + nest!(table("a", None), table("b", Some(table_alias("c")))) + ); + + let sql = "SELECT * FROM (a NATURAL JOIN ((b)) c )"; + let table_with_joins = get_from_section_from_select_query(sql)[0].clone(); + assert_eq!( + table_with_joins.relation, + nest!(table("a", None), table("b", Some(table_alias("c")))) + ); + + let sql = "SELECT * FROM (a NATURAL JOIN ( (b) c ) )"; + let table_with_joins = get_from_section_from_select_query(sql)[0].clone(); + assert_eq!( + table_with_joins.relation, + nest!(table("a", None), table("b", Some(table_alias("c")))) + ); + + let sql = "SELECT * FROM (a NATURAL JOIN ( (b) as c ) )"; + let table_with_joins = get_from_section_from_select_query(sql)[0].clone(); + assert_eq!( + table_with_joins.relation, + nest!(table("a", None), table("b", Some(table_alias("c")))) + ); + + let sql = "SELECT * FROM (a alias1 NATURAL JOIN ( (b) c ) )"; + let table_with_joins = get_from_section_from_select_query(sql)[0].clone(); + assert_eq!( + table_with_joins.relation, + nest!( + table("a", Some(table_alias("alias1"))), + table("b", Some(table_alias("c"))) + ) + ); + + let sql = "SELECT * FROM (a as alias1 NATURAL JOIN ( (b) as c ) )"; + let table_with_joins = get_from_section_from_select_query(sql)[0].clone(); + assert_eq!( + table_with_joins.relation, + nest!( + table("a", Some(table_alias("alias1"))), + table("b", Some(table_alias("c"))) + ) + ); + + let res = sf().parse_sql_statements("SELECT * FROM (a NATURAL JOIN b) c"); + assert_eq!( + ParserError::ParserError("Expected end of statement, found: c".to_string()), + res.unwrap_err() + ); + + let res = sf().parse_sql_statements("SELECT * FROM (a b) c"); + assert_eq!( + ParserError::ParserError("duplicate alias b".to_string()), + res.unwrap_err() + ); +}