Skip to content

Commit 2643193

Browse files
committed
Support IN
1 parent bed03ab commit 2643193

File tree

3 files changed

+116
-15
lines changed

3 files changed

+116
-15
lines changed

src/sqlast/mod.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ pub enum ASTNode {
5252
SQLIsNull(Box<ASTNode>),
5353
/// `IS NOT NULL` expression
5454
SQLIsNotNull(Box<ASTNode>),
55+
/// `[ NOT ] IN (val1, val2, ...)`
56+
SQLInList {
57+
expr: Box<ASTNode>,
58+
list: Vec<ASTNode>,
59+
negated: bool,
60+
},
61+
/// `[ NOT ] IN (SELECT ...)`
62+
SQLInSubquery {
63+
expr: Box<ASTNode>,
64+
subquery: Box<SQLQuery>,
65+
negated: bool,
66+
},
5567
/// Binary expression e.g. `1 + 1` or `foo > bar`
5668
SQLBinaryExpr {
5769
left: Box<ASTNode>,
@@ -96,6 +108,29 @@ impl ToString for ASTNode {
96108
ASTNode::SQLCompoundIdentifier(s) => s.join("."),
97109
ASTNode::SQLIsNull(ast) => format!("{} IS NULL", ast.as_ref().to_string()),
98110
ASTNode::SQLIsNotNull(ast) => format!("{} IS NOT NULL", ast.as_ref().to_string()),
111+
ASTNode::SQLInList {
112+
expr,
113+
list,
114+
negated,
115+
} => format!(
116+
"{} {}IN ({})",
117+
expr.as_ref().to_string(),
118+
if *negated { "NOT " } else { "" },
119+
list.iter()
120+
.map(|a| a.to_string())
121+
.collect::<Vec<String>>()
122+
.join(", ")
123+
),
124+
ASTNode::SQLInSubquery {
125+
expr,
126+
subquery,
127+
negated,
128+
} => format!(
129+
"{} {}IN ({})",
130+
expr.as_ref().to_string(),
131+
if *negated { "NOT " } else { "" },
132+
subquery.to_string()
133+
),
99134
ASTNode::SQLBinaryExpr { left, op, right } => format!(
100135
"{} {} {}",
101136
left.as_ref().to_string(),

src/sqlparser.rs

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -289,14 +289,6 @@ impl Parser {
289289
})
290290
}
291291

292-
/// Parse a postgresql casting style which is in the form of `expr::datatype`
293-
pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result<ASTNode, ParserError> {
294-
Ok(ASTNode::SQLCast {
295-
expr: Box::new(expr),
296-
data_type: self.parse_data_type()?,
297-
})
298-
}
299-
300292
/// Parse an expression infix (typically an operator)
301293
pub fn parse_infix(&mut self, expr: ASTNode, precedence: u8) -> Result<ASTNode, ParserError> {
302294
debug!("parsing infix");
@@ -308,24 +300,30 @@ impl Parser {
308300
} else if self.parse_keywords(vec!["NOT", "NULL"]) {
309301
Ok(ASTNode::SQLIsNotNull(Box::new(expr)))
310302
} else {
311-
parser_err!("Invalid tokens after IS")
303+
parser_err!(format!(
304+
"Expected NULL or NOT NULL after IS, found {:?}",
305+
self.peek_token()
306+
))
312307
}
313308
}
314309
Token::SQLWord(ref k) if k.keyword == "NOT" => {
315-
if self.parse_keywords(vec!["LIKE"]) {
310+
if self.parse_keyword("IN") {
311+
self.parse_in(expr, true)
312+
} else if self.parse_keyword("LIKE") {
316313
Ok(ASTNode::SQLBinaryExpr {
317314
left: Box::new(expr),
318315
op: SQLOperator::NotLike,
319316
right: Box::new(self.parse_subexpr(precedence)?),
320317
})
321318
} else {
322-
parser_err!("Invalid tokens after NOT")
319+
parser_err!(format!(
320+
"Expected IN or LIKE after NOT, found {:?}",
321+
self.peek_token()
322+
))
323323
}
324324
}
325-
Token::DoubleColon => {
326-
let pg_cast = self.parse_pg_cast(expr)?;
327-
Ok(pg_cast)
328-
}
325+
Token::SQLWord(ref k) if k.keyword == "IN" => self.parse_in(expr, false),
326+
Token::DoubleColon => self.parse_pg_cast(expr),
329327
Token::SQLWord(_)
330328
| Token::Eq
331329
| Token::Neq
@@ -350,6 +348,35 @@ impl Parser {
350348
}
351349
}
352350

351+
/// Parses the parens following the `[ NOT ] IN` operator
352+
pub fn parse_in(&mut self, expr: ASTNode, negated: bool) -> Result<ASTNode, ParserError> {
353+
self.expect_token(&Token::LParen)?;
354+
let in_op = if self.parse_keyword("SELECT") || self.parse_keyword("WITH") {
355+
self.prev_token();
356+
ASTNode::SQLInSubquery {
357+
expr: Box::new(expr),
358+
subquery: Box::new(self.parse_query()?),
359+
negated,
360+
}
361+
} else {
362+
ASTNode::SQLInList {
363+
expr: Box::new(expr),
364+
list: self.parse_expr_list()?,
365+
negated,
366+
}
367+
};
368+
self.expect_token(&Token::RParen)?;
369+
Ok(in_op)
370+
}
371+
372+
/// Parse a postgresql casting style which is in the form of `expr::datatype`
373+
pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result<ASTNode, ParserError> {
374+
Ok(ASTNode::SQLCast {
375+
expr: Box::new(expr),
376+
data_type: self.parse_data_type()?,
377+
})
378+
}
379+
353380
/// Convert a token operator to an AST operator
354381
pub fn to_sql_operator(&self, tok: &Token) -> Result<SQLOperator, ParserError> {
355382
match tok {
@@ -390,6 +417,7 @@ impl Parser {
390417
&Token::SQLWord(ref k) if k.keyword == "AND" => Ok(10),
391418
&Token::SQLWord(ref k) if k.keyword == "NOT" => Ok(15),
392419
&Token::SQLWord(ref k) if k.keyword == "IS" => Ok(17),
420+
&Token::SQLWord(ref k) if k.keyword == "IN" => Ok(20),
393421
&Token::SQLWord(ref k) if k.keyword == "LIKE" => Ok(20),
394422
&Token::Eq | &Token::Lt | &Token::LtEq | &Token::Neq | &Token::Gt | &Token::GtEq => {
395423
Ok(20)

tests/sqlparser_generic.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,44 @@ fn parse_not_like() {
250250
);
251251
}
252252

253+
#[test]
254+
fn parse_in_list() {
255+
fn chk(negated: bool) {
256+
let sql = &format!(
257+
"SELECT * FROM customers WHERE segment {}IN ('HIGH', 'MED')",
258+
if negated { "NOT " } else { "" }
259+
);
260+
let select = verified_only_select(sql);
261+
assert_eq!(
262+
ASTNode::SQLInList {
263+
expr: Box::new(ASTNode::SQLIdentifier("segment".to_string())),
264+
list: vec![
265+
ASTNode::SQLValue(Value::SingleQuotedString("HIGH".to_string())),
266+
ASTNode::SQLValue(Value::SingleQuotedString("MED".to_string())),
267+
],
268+
negated,
269+
},
270+
select.selection.unwrap()
271+
);
272+
}
273+
chk(false);
274+
chk(true);
275+
}
276+
277+
#[test]
278+
fn parse_in_subquery() {
279+
let sql = "SELECT * FROM customers WHERE segment IN (SELECT segm FROM bar)";
280+
let select = verified_only_select(sql);
281+
assert_eq!(
282+
ASTNode::SQLInSubquery {
283+
expr: Box::new(ASTNode::SQLIdentifier("segment".to_string())),
284+
subquery: Box::new(verified_query("SELECT segm FROM bar")),
285+
negated: false,
286+
},
287+
select.selection.unwrap()
288+
);
289+
}
290+
253291
#[test]
254292
fn parse_select_order_by() {
255293
fn chk(sql: &str) {

0 commit comments

Comments
 (0)