Skip to content

Commit 5dee741

Browse files
committed
Merge branch 'main' into pr/parse_column_def
2 parents 8106d25 + 1b46e82 commit 5dee741

File tree

12 files changed

+264
-44
lines changed

12 files changed

+264
-44
lines changed

.github/dependabot.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
version: 2
2+
updates:
3+
- package-ecosystem: cargo
4+
directory: "/"
5+
schedule:
6+
interval: daily
7+
open-pull-requests-limit: 10
8+
- package-ecosystem: cargo
9+
directory: "/sqlparser_bench"
10+
schedule:
11+
interval: daily
12+
open-pull-requests-limit: 10

.github/workflows/rust.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
# it's an unstable feature.
1616
rust-version: nightly
1717
- uses: actions/checkout@v2
18-
- run: cargo fmt -- --check --config-path <(echo 'license_template_path = "HEADER"')
18+
- run: cargo +nightly fmt -- --check --config-path <(echo 'license_template_path = "HEADER"')
1919

2020
lint:
2121
runs-on: ubuntu-latest
@@ -49,7 +49,7 @@ jobs:
4949
uses: actions-rs/[email protected]
5050
with:
5151
crate: cargo-tarpaulin
52-
version: 0.13.3
52+
version: 0.14.2
5353
use-tool-cache: true
5454
- name: Checkout
5555
uses: actions/checkout@v2

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ path = "src/lib.rs"
2323
json_example = ["serde_json", "serde"]
2424

2525
[dependencies]
26-
bigdecimal = { version = "0.1.0", features = ["serde"], optional = true }
27-
log = "0.4.5"
26+
bigdecimal = { version = "0.1", features = ["serde"], optional = true }
27+
log = "0.4"
2828
serde = { version = "1.0", features = ["derive"], optional = true }
2929
# serde_json is only used in examples/cli, but we have to put it outside
3030
# of dev-dependencies because of
3131
# https://github.com/rust-lang/cargo/issues/1596
3232
serde_json = { version = "1.0", optional = true }
3333

3434
[dev-dependencies]
35-
simple_logger = "1.0.1"
35+
simple_logger = "1.6"
3636
matches = "0.1"

src/ast/mod.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,12 +873,28 @@ impl fmt::Display for Assignment {
873873
}
874874
}
875875

876+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
877+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
878+
pub enum FunctionArg {
879+
Named { name: Ident, arg: Expr },
880+
Unnamed(Expr),
881+
}
882+
883+
impl fmt::Display for FunctionArg {
884+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
885+
match self {
886+
FunctionArg::Named { name, arg } => write!(f, "{} => {}", name, arg),
887+
FunctionArg::Unnamed(unnamed_arg) => write!(f, "{}", unnamed_arg),
888+
}
889+
}
890+
}
891+
876892
/// A function call
877893
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
878894
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
879895
pub struct Function {
880896
pub name: ObjectName,
881-
pub args: Vec<Expr>,
897+
pub args: Vec<FunctionArg>,
882898
pub over: Option<WindowSpec>,
883899
// aggregate functions may specify eg `COUNT(DISTINCT x)`
884900
pub distinct: bool,

src/ast/query.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ pub enum TableFactor {
226226
/// Arguments of a table-valued function, as supported by Postgres
227227
/// and MSSQL. Note that deprecated MSSQL `FROM foo (NOLOCK)` syntax
228228
/// will also be parsed as `args`.
229-
args: Vec<Expr>,
229+
args: Vec<FunctionArg>,
230230
/// MSSQL-specific `WITH (...)` hints such as NOLOCK.
231231
with_hints: Vec<Expr>,
232232
},
@@ -235,6 +235,11 @@ pub enum TableFactor {
235235
subquery: Box<Query>,
236236
alias: Option<TableAlias>,
237237
},
238+
/// `TABLE(<expr>)[ AS <alias> ]`
239+
TableFunction {
240+
expr: Expr,
241+
alias: Option<TableAlias>,
242+
},
238243
/// Represents a parenthesized table factor. The SQL spec only allows a
239244
/// join expression (`(foo <JOIN> bar [ <JOIN> baz ... ])`) to be nested,
240245
/// possibly several times, but the parser also accepts the non-standard
@@ -278,6 +283,13 @@ impl fmt::Display for TableFactor {
278283
}
279284
Ok(())
280285
}
286+
TableFactor::TableFunction { expr, alias } => {
287+
write!(f, "TABLE({})", expr)?;
288+
if let Some(alias) = alias {
289+
write!(f, " AS {}", alias)?;
290+
}
291+
Ok(())
292+
}
281293
TableFactor::NestedJoin(table_reference) => write!(f, "({})", table_reference),
282294
}
283295
}

src/dialect/mod.rs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ mod mysql;
1818
mod postgresql;
1919
mod sqlite;
2020

21+
use std::any::{Any, TypeId};
2122
use std::fmt::Debug;
2223

2324
pub use self::ansi::AnsiDialect;
@@ -27,7 +28,15 @@ pub use self::mysql::MySqlDialect;
2728
pub use self::postgresql::PostgreSqlDialect;
2829
pub use self::sqlite::SQLiteDialect;
2930

30-
pub trait Dialect: Debug {
31+
/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
32+
/// to `true` iff `parser.dialect` is one of the `Dialect`s specified.
33+
macro_rules! dialect_of {
34+
( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => {
35+
($($parsed_dialect.dialect.is::<$dialect_type>())||+)
36+
};
37+
}
38+
39+
pub trait Dialect: Debug + Any {
3140
/// Determine if a character starts a quoted identifier. The default
3241
/// implementation, accepting "double quoted" ids is both ANSI-compliant
3342
/// and appropriate for most dialects (with the notable exception of
@@ -41,3 +50,51 @@ pub trait Dialect: Debug {
4150
/// Determine if a character is a valid unquoted identifier character
4251
fn is_identifier_part(&self, ch: char) -> bool;
4352
}
53+
54+
impl dyn Dialect {
55+
#[inline]
56+
pub fn is<T: Dialect>(&self) -> bool {
57+
// borrowed from `Any` implementation
58+
TypeId::of::<T>() == self.type_id()
59+
}
60+
}
61+
62+
#[cfg(test)]
63+
mod tests {
64+
use super::ansi::AnsiDialect;
65+
use super::generic::GenericDialect;
66+
use super::*;
67+
68+
struct DialectHolder<'a> {
69+
dialect: &'a dyn Dialect,
70+
}
71+
72+
#[test]
73+
fn test_is_dialect() {
74+
let generic_dialect: &dyn Dialect = &GenericDialect {};
75+
let ansi_dialect: &dyn Dialect = &AnsiDialect {};
76+
77+
let generic_holder = DialectHolder {
78+
dialect: generic_dialect,
79+
};
80+
let ansi_holder = DialectHolder {
81+
dialect: ansi_dialect,
82+
};
83+
84+
assert_eq!(
85+
dialect_of!(generic_holder is GenericDialect | AnsiDialect),
86+
true
87+
);
88+
assert_eq!(dialect_of!(generic_holder is AnsiDialect), false);
89+
90+
assert_eq!(dialect_of!(ansi_holder is AnsiDialect), true);
91+
assert_eq!(
92+
dialect_of!(ansi_holder is GenericDialect | AnsiDialect),
93+
true
94+
);
95+
assert_eq!(
96+
dialect_of!(ansi_holder is GenericDialect | MsSqlDialect),
97+
false
98+
);
99+
}
100+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#![warn(clippy::all)]
3636

3737
pub mod ast;
38+
#[macro_use]
3839
pub mod dialect;
3940
pub mod parser;
4041
pub mod tokenizer;

src/parser.rs

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
use log::debug;
1616

1717
use super::ast::*;
18-
use super::dialect::keywords;
1918
use super::dialect::keywords::Keyword;
20-
use super::dialect::Dialect;
19+
use super::dialect::*;
2120
use super::tokenizer::*;
2221
use std::error::Error;
2322
use std::fmt;
@@ -82,24 +81,28 @@ impl fmt::Display for ParserError {
8281

8382
impl Error for ParserError {}
8483

85-
/// SQL Parser
86-
pub struct Parser {
84+
pub struct Parser<'a> {
8785
tokens: Vec<Token>,
8886
/// The index of the first unprocessed token in `self.tokens`
8987
index: usize,
88+
dialect: &'a dyn Dialect,
9089
}
9190

92-
impl Parser {
91+
impl<'a> Parser<'a> {
9392
/// Parse the specified tokens
94-
pub fn new(tokens: Vec<Token>) -> Self {
95-
Parser { tokens, index: 0 }
93+
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
94+
Parser {
95+
tokens,
96+
index: 0,
97+
dialect,
98+
}
9699
}
97100

98101
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
99102
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
100103
let mut tokenizer = Tokenizer::new(dialect, &sql);
101104
let tokens = tokenizer.tokenize()?;
102-
let mut parser = Parser::new(tokens);
105+
let mut parser = Parser::new(tokens, dialect);
103106
let mut stmts = Vec::new();
104107
let mut expecting_statement_delimiter = false;
105108
debug!("Parsing sql '{}'...", sql);
@@ -950,7 +953,7 @@ impl Parser {
950953
/// Parse a comma-separated list of 1+ items accepted by `F`
951954
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
952955
where
953-
F: FnMut(&mut Parser) -> Result<T, ParserError>,
956+
F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
954957
{
955958
let mut values = vec![];
956959
loop {
@@ -1289,12 +1292,14 @@ impl Parser {
12891292
let expr = self.parse_expr()?;
12901293
self.expect_token(&Token::RParen)?;
12911294
Ok(Some(ColumnOption::Check(expr)))
1292-
} else if self.parse_keyword(Keyword::AUTO_INCREMENT) {
1295+
} else if self.parse_keyword(Keyword::AUTO_INCREMENT)
1296+
&& dialect_of!(self is MySqlDialect | GenericDialect)
1297+
{
12931298
// Support AUTO_INCREMENT for MySQL
1294-
Ok(Some(ColumnOption::DialectSpecific(vec![
1295-
Token::make_keyword("AUTO_INCREMENT"),
1296-
])))
1297-
} else if self.parse_keyword(Keyword::AUTOINCREMENT) {
1299+
Ok(Some(ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")])))
1300+
} else if self.parse_keyword(Keyword::AUTOINCREMENT)
1301+
&& dialect_of!(self is SQLiteDialect | GenericDialect)
1302+
{
12981303
// Support AUTOINCREMENT for SQLite
12991304
Ok(Some(ColumnOption::DialectSpecific(vec![
13001305
Token::make_keyword("AUTOINCREMENT"),
@@ -2072,10 +2077,15 @@ impl Parser {
20722077
if !self.consume_token(&Token::LParen) {
20732078
self.expected("subquery after LATERAL", self.peek_token())?;
20742079
}
2075-
return self.parse_derived_table_factor(Lateral);
2076-
}
2077-
2078-
if self.consume_token(&Token::LParen) {
2080+
self.parse_derived_table_factor(Lateral)
2081+
} else if self.parse_keyword(Keyword::TABLE) {
2082+
// parse table function (SELECT * FROM TABLE (<expr>) [ AS <alias> ])
2083+
self.expect_token(&Token::LParen)?;
2084+
let expr = self.parse_expr()?;
2085+
self.expect_token(&Token::RParen)?;
2086+
let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
2087+
Ok(TableFactor::TableFunction { expr, alias })
2088+
} else if self.consume_token(&Token::LParen) {
20792089
// A left paren introduces either a derived table (i.e., a subquery)
20802090
// or a nested join. It's nearly impossible to determine ahead of
20812091
// time which it is... so we just try to parse both.
@@ -2209,11 +2219,24 @@ impl Parser {
22092219
Ok(Assignment { id, value })
22102220
}
22112221

2212-
pub fn parse_optional_args(&mut self) -> Result<Vec<Expr>, ParserError> {
2222+
fn parse_function_args(&mut self) -> Result<FunctionArg, ParserError> {
2223+
if self.peek_nth_token(1) == Token::RArrow {
2224+
let name = self.parse_identifier()?;
2225+
2226+
self.expect_token(&Token::RArrow)?;
2227+
let arg = self.parse_expr()?;
2228+
2229+
Ok(FunctionArg::Named { name, arg })
2230+
} else {
2231+
Ok(FunctionArg::Unnamed(self.parse_expr()?))
2232+
}
2233+
}
2234+
2235+
pub fn parse_optional_args(&mut self) -> Result<Vec<FunctionArg>, ParserError> {
22132236
if self.consume_token(&Token::RParen) {
22142237
Ok(vec![])
22152238
} else {
2216-
let args = self.parse_comma_separated(Parser::parse_expr)?;
2239+
let args = self.parse_comma_separated(Parser::parse_function_args)?;
22172240
self.expect_token(&Token::RParen)?;
22182241
Ok(args)
22192242
}

src/test_utils.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl TestedDialects {
5353
self.one_of_identical_results(|dialect| {
5454
let mut tokenizer = Tokenizer::new(dialect, sql);
5555
let tokens = tokenizer.tokenize().unwrap();
56-
f(&mut Parser::new(tokens))
56+
f(&mut Parser::new(tokens, dialect))
5757
})
5858
}
5959

@@ -104,7 +104,9 @@ impl TestedDialects {
104104
/// Ensures that `sql` parses as an expression, and is not modified
105105
/// after a serialization round-trip.
106106
pub fn verified_expr(&self, sql: &str) -> Expr {
107-
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap();
107+
let ast = self
108+
.run_parser_method(sql, |parser| parser.parse_expr())
109+
.unwrap();
108110
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
109111
ast
110112
}

src/tokenizer.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ pub enum Token {
9999
LBrace,
100100
/// Right brace `}`
101101
RBrace,
102+
/// Right Arrow `=>`
103+
RArrow,
102104
}
103105

104106
impl fmt::Display for Token {
@@ -139,6 +141,7 @@ impl fmt::Display for Token {
139141
Token::Pipe => f.write_str("|"),
140142
Token::LBrace => f.write_str("{"),
141143
Token::RBrace => f.write_str("}"),
144+
Token::RArrow => f.write_str("=>"),
142145
}
143146
}
144147
}
@@ -400,7 +403,13 @@ impl<'a> Tokenizer<'a> {
400403
_ => Ok(Some(Token::Pipe)),
401404
}
402405
}
403-
'=' => self.consume_and_return(chars, Token::Eq),
406+
'=' => {
407+
chars.next(); // consume
408+
match chars.peek() {
409+
Some('>') => self.consume_and_return(chars, Token::RArrow),
410+
_ => Ok(Some(Token::Eq)),
411+
}
412+
}
404413
'.' => self.consume_and_return(chars, Token::Period),
405414
'!' => {
406415
chars.next(); // consume
@@ -766,6 +775,23 @@ mod tests {
766775
compare(expected, tokens);
767776
}
768777

778+
#[test]
779+
fn tokenize_right_arrow() {
780+
let sql = String::from("FUNCTION(key=>value)");
781+
let dialect = GenericDialect {};
782+
let mut tokenizer = Tokenizer::new(&dialect, &sql);
783+
let tokens = tokenizer.tokenize().unwrap();
784+
let expected = vec![
785+
Token::make_word("FUNCTION", None),
786+
Token::LParen,
787+
Token::make_word("key", None),
788+
Token::RArrow,
789+
Token::make_word("value", None),
790+
Token::RParen,
791+
];
792+
compare(expected, tokens);
793+
}
794+
769795
#[test]
770796
fn tokenize_is_null() {
771797
let sql = String::from("a IS NULL");

0 commit comments

Comments
 (0)