diff --git a/src/parser/mod.rs b/src/parser/mod.rs index d00f28a55..38b169b18 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1208,20 +1208,18 @@ impl<'a> Parser<'a> { Ok(Expr::Value(self.parse_value()?)) } Token::LParen => { - let expr = - if self.parse_keyword(Keyword::SELECT) || self.parse_keyword(Keyword::WITH) { - self.prev_token(); - Expr::Subquery(self.parse_boxed_query()?) - } else if let Some(lambda) = self.try_parse_lambda() { - return Ok(lambda); - } else { - let exprs = self.parse_comma_separated(Parser::parse_expr)?; - match exprs.len() { - 0 => unreachable!(), // parse_comma_separated ensures 1 or more - 1 => Expr::Nested(Box::new(exprs.into_iter().next().unwrap())), - _ => Expr::Tuple(exprs), - } - }; + let expr = if let Some(expr) = self.try_parse_expr_sub_query()? { + expr + } else if let Some(lambda) = self.try_parse_lambda() { + return Ok(lambda); + } else { + let exprs = self.parse_comma_separated(Parser::parse_expr)?; + match exprs.len() { + 0 => unreachable!(), // parse_comma_separated ensures 1 or more + 1 => Expr::Nested(Box::new(exprs.into_iter().next().unwrap())), + _ => Expr::Tuple(exprs), + } + }; self.expect_token(&Token::RParen)?; if !self.consume_token(&Token::Period) { Ok(expr) @@ -1263,6 +1261,18 @@ impl<'a> Parser<'a> { } } + fn try_parse_expr_sub_query(&mut self) -> Result, ParserError> { + if self + .parse_one_of_keywords(&[Keyword::SELECT, Keyword::WITH]) + .is_none() + { + return Ok(None); + } + self.prev_token(); + + Ok(Some(Expr::Subquery(self.parse_boxed_query()?))) + } + fn try_parse_lambda(&mut self) -> Option { if !self.dialect.supports_lambda_functions() { return None; @@ -8699,7 +8709,9 @@ impl<'a> Parser<'a> { let mut values = vec![]; loop { - let value = if let Ok(expr) = self.parse_expr() { + let value = if let Some(expr) = self.try_parse_expr_sub_query()? { + expr + } else if let Ok(expr) = self.parse_expr() { expr } else { self.expected("variable value", self.peek_token())? diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 1adda149e..b026bb13e 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -7123,9 +7123,39 @@ fn parse_set_variable() { _ => unreachable!(), } + // Subquery expression + for (sql, canonical) in [ + ( + "SET (a) = (SELECT 22 FROM tbl1)", + "SET (a) = ((SELECT 22 FROM tbl1))", + ), + ( + "SET (a) = (SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2))", + "SET (a) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)))", + ), + ( + "SET (a) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)))", + "SET (a) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)))", + ), + ( + "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), SELECT 33 FROM tbl3)", + "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), (SELECT 33 FROM tbl3))", + ), + ] { + multi_variable_dialects.one_statement_parses_to(sql, canonical); + } + let error_sqls = [ ("SET (a, b, c) = (1, 2, 3", "Expected: ), found: EOF"), ("SET (a, b, c) = 1, 2, 3", "Expected: (, found: 1"), + ( + "SET (a) = ((SELECT 22 FROM tbl1)", + "Expected: ), found: EOF", + ), + ( + "SET (a) = ((SELECT 22 FROM tbl1) (SELECT 22 FROM tbl1))", + "Expected: ), found: (", + ), ]; for (sql, error) in error_sqls { assert_eq!(