diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index f38b04e4f..99968a8db 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -27,11 +27,10 @@ pub use self::query::{ Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect, SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, TableWithJoins, }; +pub use self::sql_operator::{SQLBinaryOperator, SQLUnaryOperator}; pub use self::sqltype::SQLType; pub use self::value::{SQLDateTimeField, Value}; -pub use self::sql_operator::SQLOperator; - /// Like `vec.join(", ")`, but for any types implementing ToString. fn comma_separated_string(iter: I) -> String where @@ -89,12 +88,17 @@ pub enum ASTNode { low: Box, high: Box, }, - /// Binary expression e.g. `1 + 1` or `foo > bar` - SQLBinaryExpr { + /// Binary operation e.g. `1 + 1` or `foo > bar` + SQLBinaryOp { left: Box, - op: SQLOperator, + op: SQLBinaryOperator, right: Box, }, + /// Unary operation e.g. `NOT foo` + SQLUnaryOp { + op: SQLUnaryOperator, + expr: Box, + }, /// CAST an expression to a different data type e.g. `CAST(foo AS VARCHAR(123))` SQLCast { expr: Box, @@ -111,11 +115,6 @@ pub enum ASTNode { }, /// Nested expression e.g. `(foo > bar)` or `(1)` SQLNested(Box), - /// Unary expression - SQLUnary { - operator: SQLOperator, - expr: Box, - }, /// SQLValue SQLValue(Value), /// Scalar function call e.g. `LEFT(foo, 5)` @@ -179,12 +178,15 @@ impl ToString for ASTNode { low.to_string(), high.to_string() ), - ASTNode::SQLBinaryExpr { left, op, right } => format!( + ASTNode::SQLBinaryOp { left, op, right } => format!( "{} {} {}", left.as_ref().to_string(), op.to_string(), right.as_ref().to_string() ), + ASTNode::SQLUnaryOp { op, expr } => { + format!("{} {}", op.to_string(), expr.as_ref().to_string()) + } ASTNode::SQLCast { expr, data_type } => format!( "CAST({} AS {})", expr.as_ref().to_string(), @@ -199,9 +201,6 @@ impl ToString for ASTNode { collation.to_string() ), ASTNode::SQLNested(ast) => format!("({})", ast.as_ref().to_string()), - ASTNode::SQLUnary { operator, expr } => { - format!("{} {}", operator.to_string(), expr.as_ref().to_string()) - } ASTNode::SQLValue(v) => v.to_string(), ASTNode::SQLFunction(f) => f.to_string(), ASTNode::SQLCase { diff --git a/src/sqlast/sql_operator.rs b/src/sqlast/sql_operator.rs index d080c7b1b..4845d0623 100644 --- a/src/sqlast/sql_operator.rs +++ b/src/sqlast/sql_operator.rs @@ -10,9 +10,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -/// SQL Operator +/// Unary operators #[derive(Debug, Clone, PartialEq, Hash)] -pub enum SQLOperator { +pub enum SQLUnaryOperator { + Plus, + Minus, + Not, +} + +impl ToString for SQLUnaryOperator { + fn to_string(&self) -> String { + match self { + SQLUnaryOperator::Plus => "+".to_string(), + SQLUnaryOperator::Minus => "-".to_string(), + SQLUnaryOperator::Not => "NOT".to_string(), + } + } +} + +/// Binary operators +#[derive(Debug, Clone, PartialEq, Hash)] +pub enum SQLBinaryOperator { Plus, Minus, Multiply, @@ -26,30 +44,28 @@ pub enum SQLOperator { NotEq, And, Or, - Not, Like, NotLike, } -impl ToString for SQLOperator { +impl ToString for SQLBinaryOperator { fn to_string(&self) -> String { match self { - SQLOperator::Plus => "+".to_string(), - SQLOperator::Minus => "-".to_string(), - SQLOperator::Multiply => "*".to_string(), - SQLOperator::Divide => "/".to_string(), - SQLOperator::Modulus => "%".to_string(), - SQLOperator::Gt => ">".to_string(), - SQLOperator::Lt => "<".to_string(), - SQLOperator::GtEq => ">=".to_string(), - SQLOperator::LtEq => "<=".to_string(), - SQLOperator::Eq => "=".to_string(), - SQLOperator::NotEq => "<>".to_string(), - SQLOperator::And => "AND".to_string(), - SQLOperator::Or => "OR".to_string(), - SQLOperator::Not => "NOT".to_string(), - SQLOperator::Like => "LIKE".to_string(), - SQLOperator::NotLike => "NOT LIKE".to_string(), + SQLBinaryOperator::Plus => "+".to_string(), + SQLBinaryOperator::Minus => "-".to_string(), + SQLBinaryOperator::Multiply => "*".to_string(), + SQLBinaryOperator::Divide => "/".to_string(), + SQLBinaryOperator::Modulus => "%".to_string(), + SQLBinaryOperator::Gt => ">".to_string(), + SQLBinaryOperator::Lt => "<".to_string(), + SQLBinaryOperator::GtEq => ">=".to_string(), + SQLBinaryOperator::LtEq => "<=".to_string(), + SQLBinaryOperator::Eq => "=".to_string(), + SQLBinaryOperator::NotEq => "<>".to_string(), + SQLBinaryOperator::And => "AND".to_string(), + SQLBinaryOperator::Or => "OR".to_string(), + SQLBinaryOperator::Like => "LIKE".to_string(), + SQLBinaryOperator::NotLike => "NOT LIKE".to_string(), } } } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index e1907f23a..2836b8d15 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -183,8 +183,8 @@ impl Parser { "EXISTS" => self.parse_exists_expression(), "EXTRACT" => self.parse_extract_expression(), "INTERVAL" => self.parse_literal_interval(), - "NOT" => Ok(ASTNode::SQLUnary { - operator: SQLOperator::Not, + "NOT" => Ok(ASTNode::SQLUnaryOp { + op: SQLUnaryOperator::Not, expr: Box::new(self.parse_subexpr(Self::UNARY_NOT_PREC)?), }), "TIME" => Ok(ASTNode::SQLValue(Value::Time(self.parse_literal_string()?))), @@ -224,13 +224,13 @@ impl Parser { }, // End of Token::SQLWord Token::Mult => Ok(ASTNode::SQLWildcard), tok @ Token::Minus | tok @ Token::Plus => { - let operator = if tok == Token::Plus { - SQLOperator::Plus + let op = if tok == Token::Plus { + SQLUnaryOperator::Plus } else { - SQLOperator::Minus + SQLUnaryOperator::Minus }; - Ok(ASTNode::SQLUnary { - operator, + Ok(ASTNode::SQLUnaryOp { + op, expr: Box::new(self.parse_subexpr(Self::PLUS_MINUS_PREC)?), }) } @@ -513,24 +513,24 @@ impl Parser { let tok = self.next_token().unwrap(); // safe as EOF's precedence is the lowest let regular_binary_operator = match tok { - Token::Eq => Some(SQLOperator::Eq), - Token::Neq => Some(SQLOperator::NotEq), - Token::Gt => Some(SQLOperator::Gt), - Token::GtEq => Some(SQLOperator::GtEq), - Token::Lt => Some(SQLOperator::Lt), - Token::LtEq => Some(SQLOperator::LtEq), - Token::Plus => Some(SQLOperator::Plus), - Token::Minus => Some(SQLOperator::Minus), - Token::Mult => Some(SQLOperator::Multiply), - Token::Mod => Some(SQLOperator::Modulus), - Token::Div => Some(SQLOperator::Divide), + Token::Eq => Some(SQLBinaryOperator::Eq), + Token::Neq => Some(SQLBinaryOperator::NotEq), + Token::Gt => Some(SQLBinaryOperator::Gt), + Token::GtEq => Some(SQLBinaryOperator::GtEq), + Token::Lt => Some(SQLBinaryOperator::Lt), + Token::LtEq => Some(SQLBinaryOperator::LtEq), + Token::Plus => Some(SQLBinaryOperator::Plus), + Token::Minus => Some(SQLBinaryOperator::Minus), + Token::Mult => Some(SQLBinaryOperator::Multiply), + Token::Mod => Some(SQLBinaryOperator::Modulus), + Token::Div => Some(SQLBinaryOperator::Divide), Token::SQLWord(ref k) => match k.keyword.as_ref() { - "AND" => Some(SQLOperator::And), - "OR" => Some(SQLOperator::Or), - "LIKE" => Some(SQLOperator::Like), + "AND" => Some(SQLBinaryOperator::And), + "OR" => Some(SQLBinaryOperator::Or), + "LIKE" => Some(SQLBinaryOperator::Like), "NOT" => { if self.parse_keyword("LIKE") { - Some(SQLOperator::NotLike) + Some(SQLBinaryOperator::NotLike) } else { None } @@ -541,7 +541,7 @@ impl Parser { }; if let Some(op) = regular_binary_operator { - Ok(ASTNode::SQLBinaryExpr { + Ok(ASTNode::SQLBinaryOp { left: Box::new(expr), op, right: Box::new(self.parse_subexpr(precedence)?), diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 594a3bda8..4b2ab958c 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -169,7 +169,7 @@ fn parse_delete_statement() { #[test] fn parse_where_delete_statement() { use self::ASTNode::*; - use self::SQLOperator::*; + use self::SQLBinaryOperator::*; let sql = "DELETE FROM foo WHERE name = 5"; match verified_stmt(sql) { @@ -181,7 +181,7 @@ fn parse_where_delete_statement() { assert_eq!(SQLObjectName(vec!["foo".to_string()]), table_name); assert_eq!( - SQLBinaryExpr { + SQLBinaryOp { left: Box::new(SQLIdentifier("name".to_string())), op: Eq, right: Box::new(SQLValue(Value::Long(5))), @@ -282,13 +282,13 @@ fn parse_column_aliases() { let sql = "SELECT a.col + 1 AS newname FROM foo AS a"; let select = verified_only_select(sql); if let SQLSelectItem::ExpressionWithAlias { - expr: ASTNode::SQLBinaryExpr { + expr: ASTNode::SQLBinaryOp { ref op, ref right, .. }, ref alias, } = only(&select.projection) { - assert_eq!(&SQLOperator::Plus, op); + assert_eq!(&SQLBinaryOperator::Plus, op); assert_eq!(&ASTNode::SQLValue(Value::Long(1)), right.as_ref()); assert_eq!("newname", alias); } else { @@ -336,8 +336,8 @@ fn parse_select_count_distinct() { assert_eq!( &ASTNode::SQLFunction(SQLFunction { name: SQLObjectName(vec!["COUNT".to_string()]), - args: vec![ASTNode::SQLUnary { - operator: SQLOperator::Plus, + args: vec![ASTNode::SQLUnaryOp { + op: SQLUnaryOperator::Plus, expr: Box::new(ASTNode::SQLIdentifier("x".to_string())) }], over: None, @@ -404,12 +404,12 @@ fn parse_projection_nested_type() { #[test] fn parse_escaped_single_quote_string_predicate() { use self::ASTNode::*; - use self::SQLOperator::*; + use self::SQLBinaryOperator::*; let sql = "SELECT id, fname, lname FROM customer \ WHERE salary <> 'Jim''s salary'"; let ast = verified_only_select(sql); assert_eq!( - Some(SQLBinaryExpr { + Some(SQLBinaryOp { left: Box::new(SQLIdentifier("salary".to_string())), op: NotEq, right: Box::new(SQLValue(Value::SingleQuotedString( @@ -423,13 +423,13 @@ fn parse_escaped_single_quote_string_predicate() { #[test] fn parse_compound_expr_1() { use self::ASTNode::*; - use self::SQLOperator::*; + use self::SQLBinaryOperator::*; let sql = "a + b * c"; assert_eq!( - SQLBinaryExpr { + SQLBinaryOp { left: Box::new(SQLIdentifier("a".to_string())), op: Plus, - right: Box::new(SQLBinaryExpr { + right: Box::new(SQLBinaryOp { left: Box::new(SQLIdentifier("b".to_string())), op: Multiply, right: Box::new(SQLIdentifier("c".to_string())) @@ -442,11 +442,11 @@ fn parse_compound_expr_1() { #[test] fn parse_compound_expr_2() { use self::ASTNode::*; - use self::SQLOperator::*; + use self::SQLBinaryOperator::*; let sql = "a * b + c"; assert_eq!( - SQLBinaryExpr { - left: Box::new(SQLBinaryExpr { + SQLBinaryOp { + left: Box::new(SQLBinaryOp { left: Box::new(SQLIdentifier("a".to_string())), op: Multiply, right: Box::new(SQLIdentifier("b".to_string())) @@ -461,17 +461,16 @@ fn parse_compound_expr_2() { #[test] fn parse_unary_math() { use self::ASTNode::*; - use self::SQLOperator::*; let sql = "- a + - b"; assert_eq!( - SQLBinaryExpr { - left: Box::new(SQLUnary { - operator: Minus, + SQLBinaryOp { + left: Box::new(SQLUnaryOp { + op: SQLUnaryOperator::Minus, expr: Box::new(SQLIdentifier("a".to_string())), }), - op: Plus, - right: Box::new(SQLUnary { - operator: Minus, + op: SQLBinaryOperator::Plus, + right: Box::new(SQLUnaryOp { + op: SQLUnaryOperator::Minus, expr: Box::new(SQLIdentifier("b".to_string())), }), }, @@ -504,15 +503,15 @@ fn parse_not_precedence() { use self::ASTNode::*; // NOT has higher precedence than OR/AND, so the following must parse as (NOT true) OR true let sql = "NOT true OR true"; - assert_matches!(verified_expr(sql), SQLBinaryExpr { - op: SQLOperator::Or, + assert_matches!(verified_expr(sql), SQLBinaryOp { + op: SQLBinaryOperator::Or, .. }); // But NOT has lower precedence than comparison operators, so the following parses as NOT (a IS NULL) let sql = "NOT a IS NULL"; - assert_matches!(verified_expr(sql), SQLUnary { - operator: SQLOperator::Not, + assert_matches!(verified_expr(sql), SQLUnaryOp { + op: SQLUnaryOperator::Not, .. }); @@ -520,8 +519,8 @@ fn parse_not_precedence() { let sql = "NOT 1 NOT BETWEEN 1 AND 2"; assert_eq!( verified_expr(sql), - SQLUnary { - operator: SQLOperator::Not, + SQLUnaryOp { + op: SQLUnaryOperator::Not, expr: Box::new(SQLBetween { expr: Box::new(SQLValue(Value::Long(1))), low: Box::new(SQLValue(Value::Long(1))), @@ -535,11 +534,11 @@ fn parse_not_precedence() { let sql = "NOT 'a' NOT LIKE 'b'"; assert_eq!( verified_expr(sql), - SQLUnary { - operator: SQLOperator::Not, - expr: Box::new(SQLBinaryExpr { + SQLUnaryOp { + op: SQLUnaryOperator::Not, + expr: Box::new(SQLBinaryOp { left: Box::new(SQLValue(Value::SingleQuotedString("a".into()))), - op: SQLOperator::NotLike, + op: SQLBinaryOperator::NotLike, right: Box::new(SQLValue(Value::SingleQuotedString("b".into()))), }), }, @@ -549,8 +548,8 @@ fn parse_not_precedence() { let sql = "NOT a NOT IN ('a')"; assert_eq!( verified_expr(sql), - SQLUnary { - operator: SQLOperator::Not, + SQLUnaryOp { + op: SQLUnaryOperator::Not, expr: Box::new(SQLInList { expr: Box::new(SQLIdentifier("a".into())), list: vec![SQLValue(Value::SingleQuotedString("a".into()))], @@ -569,12 +568,12 @@ fn parse_like() { ); let select = verified_only_select(sql); assert_eq!( - ASTNode::SQLBinaryExpr { + ASTNode::SQLBinaryOp { left: Box::new(ASTNode::SQLIdentifier("name".to_string())), op: if negated { - SQLOperator::NotLike + SQLBinaryOperator::NotLike } else { - SQLOperator::Like + SQLBinaryOperator::Like }, right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( "%a".to_string() @@ -591,12 +590,12 @@ fn parse_like() { ); let select = verified_only_select(sql); assert_eq!( - ASTNode::SQLIsNull(Box::new(ASTNode::SQLBinaryExpr { + ASTNode::SQLIsNull(Box::new(ASTNode::SQLBinaryOp { left: Box::new(ASTNode::SQLIdentifier("name".to_string())), op: if negated { - SQLOperator::NotLike + SQLBinaryOperator::NotLike } else { - SQLOperator::Like + SQLBinaryOperator::Like }, right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( "%a".to_string() @@ -672,18 +671,18 @@ fn parse_between() { #[test] fn parse_between_with_expr() { use self::ASTNode::*; - use self::SQLOperator::*; + use self::SQLBinaryOperator::*; let sql = "SELECT * FROM t WHERE 1 BETWEEN 1 + 2 AND 3 + 4 IS NULL"; let select = verified_only_select(sql); assert_eq!( ASTNode::SQLIsNull(Box::new(ASTNode::SQLBetween { expr: Box::new(ASTNode::SQLValue(Value::Long(1))), - low: Box::new(SQLBinaryExpr { + low: Box::new(SQLBinaryOp { left: Box::new(ASTNode::SQLValue(Value::Long(1))), op: Plus, right: Box::new(ASTNode::SQLValue(Value::Long(2))), }), - high: Box::new(SQLBinaryExpr { + high: Box::new(SQLBinaryOp { left: Box::new(ASTNode::SQLValue(Value::Long(3))), op: Plus, right: Box::new(ASTNode::SQLValue(Value::Long(4))), @@ -696,17 +695,17 @@ fn parse_between_with_expr() { let sql = "SELECT * FROM t WHERE 1 = 1 AND 1 + x BETWEEN 1 AND 2"; let select = verified_only_select(sql); assert_eq!( - ASTNode::SQLBinaryExpr { - left: Box::new(ASTNode::SQLBinaryExpr { + ASTNode::SQLBinaryOp { + left: Box::new(ASTNode::SQLBinaryOp { left: Box::new(ASTNode::SQLValue(Value::Long(1))), - op: SQLOperator::Eq, + op: SQLBinaryOperator::Eq, right: Box::new(ASTNode::SQLValue(Value::Long(1))), }), - op: SQLOperator::And, + op: SQLBinaryOperator::And, right: Box::new(ASTNode::SQLBetween { - expr: Box::new(ASTNode::SQLBinaryExpr { + expr: Box::new(ASTNode::SQLBinaryOp { left: Box::new(ASTNode::SQLValue(Value::Long(1))), - op: SQLOperator::Plus, + op: SQLBinaryOperator::Plus, right: Box::new(ASTNode::SQLIdentifier("x".to_string())), }), low: Box::new(ASTNode::SQLValue(Value::Long(1))), @@ -1365,17 +1364,17 @@ fn parse_delimited_identifiers() { #[test] fn parse_parens() { use self::ASTNode::*; - use self::SQLOperator::*; + use self::SQLBinaryOperator::*; let sql = "(a + b) - (c + d)"; assert_eq!( - SQLBinaryExpr { - left: Box::new(SQLNested(Box::new(SQLBinaryExpr { + SQLBinaryOp { + left: Box::new(SQLNested(Box::new(SQLBinaryOp { left: Box::new(SQLIdentifier("a".to_string())), op: Plus, right: Box::new(SQLIdentifier("b".to_string())) }))), op: Minus, - right: Box::new(SQLNested(Box::new(SQLBinaryExpr { + right: Box::new(SQLNested(Box::new(SQLBinaryOp { left: Box::new(SQLIdentifier("c".to_string())), op: Plus, right: Box::new(SQLIdentifier("d".to_string())) @@ -1388,20 +1387,20 @@ fn parse_parens() { #[test] fn parse_searched_case_expression() { let sql = "SELECT CASE WHEN bar IS NULL THEN 'null' WHEN bar = 0 THEN '=0' WHEN bar >= 0 THEN '>=0' ELSE '<0' END FROM foo"; - use self::ASTNode::{SQLBinaryExpr, SQLCase, SQLIdentifier, SQLIsNull, SQLValue}; - use self::SQLOperator::*; + use self::ASTNode::{SQLBinaryOp, SQLCase, SQLIdentifier, SQLIsNull, SQLValue}; + use self::SQLBinaryOperator::*; let select = verified_only_select(sql); assert_eq!( &SQLCase { operand: None, conditions: vec![ SQLIsNull(Box::new(SQLIdentifier("bar".to_string()))), - SQLBinaryExpr { + SQLBinaryOp { left: Box::new(SQLIdentifier("bar".to_string())), op: Eq, right: Box::new(SQLValue(Value::Long(0))) }, - SQLBinaryExpr { + SQLBinaryOp { left: Box::new(SQLIdentifier("bar".to_string())), op: GtEq, right: Box::new(SQLValue(Value::Long(0))) @@ -1555,9 +1554,9 @@ fn parse_joins_on() { args: vec![], with_hints: vec![], }, - join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryExpr { + join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryOp { left: Box::new(ASTNode::SQLIdentifier("c1".into())), - op: SQLOperator::Eq, + op: SQLBinaryOperator::Eq, right: Box::new(ASTNode::SQLIdentifier("c2".into())), })), } @@ -1920,8 +1919,8 @@ fn parse_multiple_statements() { fn parse_scalar_subqueries() { use self::ASTNode::*; let sql = "(SELECT 1) + (SELECT 2)"; - assert_matches!(verified_expr(sql), SQLBinaryExpr { - op: SQLOperator::Plus, .. + assert_matches!(verified_expr(sql), SQLBinaryOp { + op: SQLBinaryOperator::Plus, .. //left: box SQLSubquery { .. }, //right: box SQLSubquery { .. }, }); @@ -1940,8 +1939,8 @@ fn parse_exists_subquery() { let sql = "SELECT * FROM t WHERE NOT EXISTS (SELECT 1)"; let select = verified_only_select(sql); assert_eq!( - ASTNode::SQLUnary { - operator: SQLOperator::Not, + ASTNode::SQLUnaryOp { + op: SQLUnaryOperator::Not, expr: Box::new(ASTNode::SQLExists(Box::new(expected_inner))), }, select.selection.unwrap(),