From 4bd42278d96450bce8ee148f20fccef81b2e423b Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 16 Nov 2024 23:47:47 +0100 Subject: [PATCH 1/3] add support for sqlite's OR clauses in update statements fix https://github.com/apache/datafusion-sqlparser-rs/issues/1529 --- src/ast/mod.rs | 9 ++++++++- src/parser/mod.rs | 39 +++++++++++++++++++++------------------ tests/sqlparser_common.rs | 33 +++++++++++++++++++++++++++++++++ tests/sqlparser_mysql.rs | 1 + tests/sqlparser_sqlite.rs | 1 + 5 files changed, 64 insertions(+), 19 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 39c742153..0565f8d4b 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2396,6 +2396,8 @@ pub enum Statement { selection: Option, /// RETURNING returning: Option>, + /// SQLite-specific conflict resolution clause + or: Option, }, /// ```sql /// DELETE @@ -3691,8 +3693,13 @@ impl fmt::Display for Statement { from, selection, returning, + or, } => { - write!(f, "UPDATE {table}")?; + write!(f, "UPDATE ")?; + if let Some(or) = or { + write!(f, "{or} ")?; + } + write!(f, "{table}")?; if !assignments.is_empty() { write!(f, " SET {}", display_comma_separated(assignments))?; } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index a583112a7..dcfd39848 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11032,24 +11032,7 @@ impl<'a> Parser<'a> { /// Parse an INSERT statement pub fn parse_insert(&mut self) -> Result { - let or = if !dialect_of!(self is SQLiteDialect) { - None - } else if self.parse_keywords(&[Keyword::OR, Keyword::REPLACE]) { - Some(SqliteOnConflict::Replace) - } else if self.parse_keywords(&[Keyword::OR, Keyword::ROLLBACK]) { - Some(SqliteOnConflict::Rollback) - } else if self.parse_keywords(&[Keyword::OR, Keyword::ABORT]) { - Some(SqliteOnConflict::Abort) - } else if self.parse_keywords(&[Keyword::OR, Keyword::FAIL]) { - Some(SqliteOnConflict::Fail) - } else if self.parse_keywords(&[Keyword::OR, Keyword::IGNORE]) { - Some(SqliteOnConflict::Ignore) - } else if self.parse_keyword(Keyword::REPLACE) { - Some(SqliteOnConflict::Replace) - } else { - None - }; - + let or = self.parse_conflict_clause(); let priority = if !dialect_of!(self is MySqlDialect | GenericDialect) { None } else if self.parse_keyword(Keyword::LOW_PRIORITY) { @@ -11208,6 +11191,24 @@ impl<'a> Parser<'a> { } } + fn parse_conflict_clause(&mut self) -> Option { + if self.parse_keywords(&[Keyword::OR, Keyword::REPLACE]) { + Some(SqliteOnConflict::Replace) + } else if self.parse_keywords(&[Keyword::OR, Keyword::ROLLBACK]) { + Some(SqliteOnConflict::Rollback) + } else if self.parse_keywords(&[Keyword::OR, Keyword::ABORT]) { + Some(SqliteOnConflict::Abort) + } else if self.parse_keywords(&[Keyword::OR, Keyword::FAIL]) { + Some(SqliteOnConflict::Fail) + } else if self.parse_keywords(&[Keyword::OR, Keyword::IGNORE]) { + Some(SqliteOnConflict::Ignore) + } else if self.parse_keyword(Keyword::REPLACE) { + Some(SqliteOnConflict::Replace) + } else { + None + } + } + pub fn parse_insert_partition(&mut self) -> Result>, ParserError> { if self.parse_keyword(Keyword::PARTITION) { self.expect_token(&Token::LParen)?; @@ -11243,6 +11244,7 @@ impl<'a> Parser<'a> { } pub fn parse_update(&mut self) -> Result { + let or = self.parse_conflict_clause(); let table = self.parse_table_and_joins()?; self.expect_keyword(Keyword::SET)?; let assignments = self.parse_comma_separated(Parser::parse_assignment)?; @@ -11269,6 +11271,7 @@ impl<'a> Parser<'a> { from, selection, returning, + or, }) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 2ffb5f44b..a4e987291 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -443,6 +443,7 @@ fn parse_update_set_from() { ])), }), returning: None, + or: None, } ); } @@ -457,6 +458,7 @@ fn parse_update_with_table_alias() { from: _from, selection, returning, + or: None, } => { assert_eq!( TableWithJoins { @@ -505,6 +507,37 @@ fn parse_update_with_table_alias() { } } +#[test] +fn parse_update_or() { + let dialect = SQLiteDialect {}; + + let check = |sql: &str, expected_action: Option| match Parser::parse_sql( + &dialect, sql, + ) + .unwrap() + .pop() + .unwrap() + { + Statement::Update { or, .. } => assert_eq!(or, expected_action), + _ => panic!("{}", sql), + }; + + let sql = "UPDATE OR REPLACE t SET n = n + 1"; + check(sql, Some(SqliteOnConflict::Replace)); + + let sql = "UPDATE OR ROLLBACK t SET n = n + 1"; + check(sql, Some(SqliteOnConflict::Rollback)); + + let sql = "UPDATE OR ABORT t SET n = n + 1"; + check(sql, Some(SqliteOnConflict::Abort)); + + let sql = "UPDATE OR FAIL t SET n = n + 1"; + check(sql, Some(SqliteOnConflict::Fail)); + + let sql = "UPDATE OR IGNORE t SET n = n + 1"; + check(sql, Some(SqliteOnConflict::Ignore)); +} + #[test] fn parse_select_with_table_alias_as() { // AS is optional diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 8269eadc0..2a876cff2 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -1970,6 +1970,7 @@ fn parse_update_with_joins() { from: _from, selection, returning, + or: None, } => { assert_eq!( TableWithJoins { diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 6f8bbb2d8..6f8e654dc 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -465,6 +465,7 @@ fn parse_update_tuple_row_values() { assert_eq!( sqlite().verified_stmt("UPDATE x SET (a, b) = (1, 2)"), Statement::Update { + or: None, assignments: vec![Assignment { target: AssignmentTarget::Tuple(vec![ ObjectName(vec![Ident::new("a"),]), From 40713cb19403f0dae244792ea3c8c89d4473f692 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 16 Nov 2024 23:55:10 +0100 Subject: [PATCH 2/3] simplify test --- tests/sqlparser_common.rs | 40 ++++++++++++++------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index a4e987291..0f4e66943 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -509,33 +509,21 @@ fn parse_update_with_table_alias() { #[test] fn parse_update_or() { - let dialect = SQLiteDialect {}; - - let check = |sql: &str, expected_action: Option| match Parser::parse_sql( - &dialect, sql, - ) - .unwrap() - .pop() - .unwrap() - { - Statement::Update { or, .. } => assert_eq!(or, expected_action), - _ => panic!("{}", sql), + let expect_or_clause = |sql: &str, expected_action: SqliteOnConflict| match verified_stmt(sql) { + Statement::Update { or, .. } => assert_eq!(or, Some(expected_action)), + other => unreachable!("Expected update with or, got {:?}", other), }; - - let sql = "UPDATE OR REPLACE t SET n = n + 1"; - check(sql, Some(SqliteOnConflict::Replace)); - - let sql = "UPDATE OR ROLLBACK t SET n = n + 1"; - check(sql, Some(SqliteOnConflict::Rollback)); - - let sql = "UPDATE OR ABORT t SET n = n + 1"; - check(sql, Some(SqliteOnConflict::Abort)); - - let sql = "UPDATE OR FAIL t SET n = n + 1"; - check(sql, Some(SqliteOnConflict::Fail)); - - let sql = "UPDATE OR IGNORE t SET n = n + 1"; - check(sql, Some(SqliteOnConflict::Ignore)); + expect_or_clause( + "UPDATE OR REPLACE t SET n = n + 1", + SqliteOnConflict::Replace, + ); + expect_or_clause( + "UPDATE OR ROLLBACK t SET n = n + 1", + SqliteOnConflict::Rollback, + ); + expect_or_clause("UPDATE OR ABORT t SET n = n + 1", SqliteOnConflict::Abort); + expect_or_clause("UPDATE OR FAIL t SET n = n + 1", SqliteOnConflict::Fail); + expect_or_clause("UPDATE OR IGNORE t SET n = n + 1", SqliteOnConflict::Ignore); } #[test] From f4b0ac66c501727b2e6e8399f4d80955fa2063d5 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sun, 17 Nov 2024 00:04:33 +0100 Subject: [PATCH 3/3] deduplicate formatting of OR in SqliteOnConflict --- src/ast/dml.rs | 4 ++-- src/ast/mod.rs | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ast/dml.rs b/src/ast/dml.rs index 2932fafb5..22309c8f8 100644 --- a/src/ast/dml.rs +++ b/src/ast/dml.rs @@ -505,8 +505,8 @@ impl Display for Insert { self.table_name.to_string() }; - if let Some(action) = self.or { - write!(f, "INSERT OR {action} INTO {table_name} ")?; + if let Some(on_conflict) = self.or { + write!(f, "INSERT {on_conflict} INTO {table_name} ")?; } else { write!( f, diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 0565f8d4b..7782acf42 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -6311,11 +6311,11 @@ impl fmt::Display for SqliteOnConflict { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use SqliteOnConflict::*; match self { - Rollback => write!(f, "ROLLBACK"), - Abort => write!(f, "ABORT"), - Fail => write!(f, "FAIL"), - Ignore => write!(f, "IGNORE"), - Replace => write!(f, "REPLACE"), + Rollback => write!(f, "OR ROLLBACK"), + Abort => write!(f, "OR ABORT"), + Fail => write!(f, "OR FAIL"), + Ignore => write!(f, "OR IGNORE"), + Replace => write!(f, "OR REPLACE"), } } }