From 7d78be3a4a579f273f8d8ca1e79f1d2b7add7f7e Mon Sep 17 00:00:00 2001 From: andrej-db Date: Wed, 23 Oct 2024 15:19:13 +0200 Subject: [PATCH 01/28] core --- .../jdbc/v2/MsSqlServerIntegrationSuite.scala | 66 +++++++++++++++++++ .../catalyst/util/V2ExpressionBuilder.scala | 21 +++++- .../spark/sql/jdbc/MsSqlServerDialect.scala | 61 ++++++++++++++--- 3 files changed, 135 insertions(+), 13 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index d884ad4c6246..766088970a6f 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -146,4 +146,70 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |""".stripMargin) assert(df.collect().length == 2) } + + test("SPARK-50087: SqlServer handle booleans in IF in SELECT test") { + // This doesn't compile on SqlServer unless result boolean expressions + // in IF / CASE WHEN are wrapped with an IIF(<>, 1, 0). + val df = sql( + s"""|WITH dummy AS ( + | SELECT + | DISTINCT name AS full_name, + | UPPER(name) AS test_type, + | name, + | IF( + | LOWER(name) = 'adfsaef' OR LOWER(name) = 'agadg', + | 'agfagff', + | IF( + | LOWER(name) = 'adgfda' OR LOWER(name) = 'ssadf', + | 'sxzvfvxf', + | IF( + | LOWER(name) = 'sdfadsf' OR LOWER(name) = 'sadfgvad', + | 'sAFvadsfvcds', + | LOWER(name) + | ) + | ) + | ) AS test_type_name + | FROM $catalogName.employee + |), + |dummy_new AS ( + | SELECT * + | FROM dummy WHERE test_type_name = 'safcdfz' + |) + |SELECT * FROM dummy_new limit 1""".stripMargin + ) + df.explain("formatted") + df.collect() + } + + test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN name = '1' THEN name = 'barxxyz' ELSE NOT (name = 'barxxyz') END + |""".stripMargin + ) + df.explain("formatted") + df.collect() + } + + test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'barxxyz') THEN (name = 'barx') ELSE (1=1) END + |""".stripMargin + ) + df.explain("formatted") + df.collect() + } + + test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'barxxyz') THEN + | CASE WHEN (name = '5') THEN (name = 'barx') ELSE (name = 'aweda') END + | ELSE (name = '1') END + |""".stripMargin + ) + df.explain("formatted") + df.collect() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 61a26d7a4fbd..3b631fb0b452 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -221,8 +221,18 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate) case caseWhen @ CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) - val values = branches.map(_._2).flatMap(generateExpression(_)) - val elseExprOpt = elseValue.flatMap(generateExpression(_)) + val values = branches.map(_._2).flatMap(child => + generateExpression( + child, + isPredicate && child.dataType.isInstanceOf[BooleanType] + ) + ) + val elseExprOpt = elseValue.flatMap(child => + generateExpression( + child, + isPredicate && child.dataType.isInstanceOf[BooleanType] + ) + ) if (conditions.length == branches.length && values.length == branches.length && elseExprOpt.size == elseValue.size) { val branchExpressions = conditions.zip(values).flatMap { case (c, v) => @@ -421,7 +431,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L children: Seq[Expression], dataType: DataType, isPredicate: Boolean): Option[V2Expression] = { - val childrenExpressions = children.flatMap(generateExpression(_)) + val childrenExpressions = children.flatMap(child => + generateExpression( + child, + isPredicate && child.dataType.isInstanceOf[BooleanType] + ) + ) if (childrenExpressions.length == children.length) { if (isPredicate && dataType.isInstanceOf[BooleanType]) { Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression])) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 7d476d43e5c7..6bed30353287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -73,6 +73,36 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr } } + override def visitCaseWhen(children: Array[String]): String = { + // Since MsSqlServer cannot handle boolean expressions inside + // a CASE WHEN, it is necessary to convert those to an IIF + // expression that will return 1 or 0 depending on the result. + // Example: + // In: ... CASE WHEN a = b THEN c = d ... + // Out: ... CASE WHEN a = b THEN IIF(c = d, 1, 0) ... + val sb = new StringBuilder("CASE") + var i = 0 + while (i < children.length) { + val c = children(i) + val j = i + 1 + if (j < children.length) { + val v = children(j) + sb.append(" WHEN ") + sb.append(c) + sb.append(" THEN ") + sb.append(MsSqlServerDialect.wrapPredicateWithIIF(v)) + } + else { + sb.append(" ELSE ") + sb.append(MsSqlServerDialect.wrapPredicateWithIIF(c)) + } + + i += 2 + } + sb.append(" END") + sb.toString + } + override def dialectFunctionName(funcName: String): String = funcName match { case "VAR_POP" => "VARP" case "VAR_SAMP" => "VAR" @@ -85,16 +115,21 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr // MsSqlServer does not support boolean comparison using standard comparison operators // We shouldn't propagate these queries to MsSqlServer expr match { - case e: Predicate => e.name() match { - case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => - val Array(l, r) = e.children().map { - case p: Predicate => s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0 END" - case o => inputToSQL(o) - } - visitBinaryComparison(e.name(), l, r) - case "CASE_WHEN" => visitCaseWhen(expressionsToStringArray(e.children())) + " = 1" - case _ => super.build(expr) - } + case e: Predicate if (e.name() match { + case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => true + case _ => false + }) => + val Array(l, r) = e.children().map { + case p: Predicate + if p.name() != "ALWAYS_TRUE" && p.name() != "ALWAYS_FALSE" => + s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0 END" + case o => inputToSQL(o) + } + visitBinaryComparison(e.name(), l, r) + // If a CASE WHEN is a predicate, appending the "= 1" + // because it cannot return a boolean + case e: Predicate if e.name() == "CASE_WHEN" => + visitCaseWhen(expressionsToStringArray(e.children())) + " = 1 " case _ => super.build(expr) } } @@ -250,4 +285,10 @@ private object MsSqlServerDialect { // https://github.com/microsoft/mssql-jdbc/blob/v9.4.1/src/main/java/microsoft/sql/Types.java final val GEOMETRY = -157 final val GEOGRAPHY = -158 + + def wrapPredicateWithIIF(expr: String): String = { + if (expr != "0" && expr != "1") { + s"""IIF($expr, 1, 0)""" + } else expr + } } From 7839e1d128ed150f9eb82b805fea14aa647e53ff Mon Sep 17 00:00:00 2001 From: andrej-db Date: Wed, 23 Oct 2024 16:47:27 +0200 Subject: [PATCH 02/28] revert MsSqlServerDialect build method --- .../spark/sql/jdbc/MsSqlServerDialect.scala | 31 ++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 6bed30353287..6098ee60417b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -115,21 +115,16 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr // MsSqlServer does not support boolean comparison using standard comparison operators // We shouldn't propagate these queries to MsSqlServer expr match { - case e: Predicate if (e.name() match { - case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => true - case _ => false - }) => - val Array(l, r) = e.children().map { - case p: Predicate - if p.name() != "ALWAYS_TRUE" && p.name() != "ALWAYS_FALSE" => - s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0 END" - case o => inputToSQL(o) - } - visitBinaryComparison(e.name(), l, r) - // If a CASE WHEN is a predicate, appending the "= 1" - // because it cannot return a boolean - case e: Predicate if e.name() == "CASE_WHEN" => - visitCaseWhen(expressionsToStringArray(e.children())) + " = 1 " + case e: Predicate => e.name() match { + case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => + val Array(l, r) = e.children().map { + case p: Predicate => s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0 END" + case o => inputToSQL(o) + } + visitBinaryComparison(e.name(), l, r) + case "CASE_WHEN" => visitCaseWhen(expressionsToStringArray(e.children())) + " = 1" + case _ => super.build(expr) + } case _ => super.build(expr) } } @@ -285,10 +280,4 @@ private object MsSqlServerDialect { // https://github.com/microsoft/mssql-jdbc/blob/v9.4.1/src/main/java/microsoft/sql/Types.java final val GEOMETRY = -157 final val GEOGRAPHY = -158 - - def wrapPredicateWithIIF(expr: String): String = { - if (expr != "0" && expr != "1") { - s"""IIF($expr, 1, 0)""" - } else expr - } } From 0172c49340f1e3240de6d867ba6e89fd158ffcab Mon Sep 17 00:00:00 2001 From: andrej-db Date: Wed, 23 Oct 2024 20:19:06 +0200 Subject: [PATCH 03/28] remove visitCaseWhen, more intuitive predicate wrapping --- .../util/V2ExpressionSQLBuilder.java | 4 ++ .../spark/sql/jdbc/MsSqlServerDialect.scala | 55 +++++++------------ 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index bd2dec9e27be..aa9de2cc6a14 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -219,6 +219,10 @@ protected String inputToSQL(Expression input) { } } + protected String inputToCaseWhenSQL(Expression input) { + return "CASE WHEN " + inputToSQL(input) + " THEN 1 ELSE 0"; + } + protected String visitBinaryComparison(String name, String l, String r) { if (name.equals("<=>")) { return "((" + l + " IS NOT NULL AND " + r + " IS NOT NULL AND " + l + " = " + r + ") " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 6098ee60417b..4d8a3e130e36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale - import scala.util.control.NonFatal - import org.apache.spark.SparkThrowable import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection} +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc +import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.connector.ExpressionWithToString import org.apache.spark.sql.jdbc.MsSqlServerDialect.{GEOGRAPHY, GEOMETRY} import org.apache.spark.sql.types._ @@ -73,36 +73,6 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr } } - override def visitCaseWhen(children: Array[String]): String = { - // Since MsSqlServer cannot handle boolean expressions inside - // a CASE WHEN, it is necessary to convert those to an IIF - // expression that will return 1 or 0 depending on the result. - // Example: - // In: ... CASE WHEN a = b THEN c = d ... - // Out: ... CASE WHEN a = b THEN IIF(c = d, 1, 0) ... - val sb = new StringBuilder("CASE") - var i = 0 - while (i < children.length) { - val c = children(i) - val j = i + 1 - if (j < children.length) { - val v = children(j) - sb.append(" WHEN ") - sb.append(c) - sb.append(" THEN ") - sb.append(MsSqlServerDialect.wrapPredicateWithIIF(v)) - } - else { - sb.append(" ELSE ") - sb.append(MsSqlServerDialect.wrapPredicateWithIIF(c)) - } - - i += 2 - } - sb.append(" END") - sb.toString - } - override def dialectFunctionName(funcName: String): String = funcName match { case "VAR_POP" => "VARP" case "VAR_SAMP" => "VAR" @@ -122,7 +92,24 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr case o => inputToSQL(o) } visitBinaryComparison(e.name(), l, r) - case "CASE_WHEN" => visitCaseWhen(expressionsToStringArray(e.children())) + " = 1" + case "CASE_WHEN" => + // Since MsSqlServer cannot handle boolean expressions inside + // a CASE WHEN, it is necessary to convert those to an IIF + // expression that will return 1 or 0 depending on the result. + // Example: + // In: ... CASE WHEN a = b THEN c = d ... + // Out: ... CASE WHEN a = b THEN IIF(c = d, 1, 0) ... + + // grouped turns Array[Expression] to Array[Array[Expression]] + // with a len of max 2 (final one will have only one) + val stringArray = e.children().grouped(2).flatMap { arr => + arr.dropRight(1).map(inputToSQL) :+ + (arr.last match { + case p: Predicate => inputToCaseWhenSQL(p) + case p => inputToSQL(p) + }) + } + visitCaseWhen(stringArray.toArray) + " = 1" case _ => super.build(expr) } case _ => super.build(expr) From 7ae73974a95947dbf3ed15ae49a5036b15f677b0 Mon Sep 17 00:00:00 2001 From: andrej-db Date: Wed, 23 Oct 2024 20:33:03 +0200 Subject: [PATCH 04/28] imports --- .../org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 4d8a3e130e36..c690ceff1ba1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale + import scala.util.control.NonFatal + import org.apache.spark.SparkThrowable import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc -import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} +import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.connector.ExpressionWithToString import org.apache.spark.sql.jdbc.MsSqlServerDialect.{GEOGRAPHY, GEOMETRY} import org.apache.spark.sql.types._ From 49d742e2ccd533c8f98fd9622edcdfe7c59b68fa Mon Sep 17 00:00:00 2001 From: andrej-db Date: Thu, 24 Oct 2024 00:11:38 +0200 Subject: [PATCH 05/28] fix --- .../spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 4 ---- .../spark/sql/connector/util/V2ExpressionSQLBuilder.java | 2 +- .../scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 5 +++-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 766088970a6f..6ffebdd2ce57 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -177,7 +177,6 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |) |SELECT * FROM dummy_new limit 1""".stripMargin ) - df.explain("formatted") df.collect() } @@ -187,7 +186,6 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |WHERE CASE WHEN name = '1' THEN name = 'barxxyz' ELSE NOT (name = 'barxxyz') END |""".stripMargin ) - df.explain("formatted") df.collect() } @@ -197,7 +195,6 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |WHERE CASE WHEN (name = 'barxxyz') THEN (name = 'barx') ELSE (1=1) END |""".stripMargin ) - df.explain("formatted") df.collect() } @@ -209,7 +206,6 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD | ELSE (name = '1') END |""".stripMargin ) - df.explain("formatted") df.collect() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index aa9de2cc6a14..bf00b923f549 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -220,7 +220,7 @@ protected String inputToSQL(Expression input) { } protected String inputToCaseWhenSQL(Expression input) { - return "CASE WHEN " + inputToSQL(input) + " THEN 1 ELSE 0"; + return "CASE WHEN " + inputToSQL(input) + " THEN 1 ELSE 0 END"; } protected String visitBinaryComparison(String name, String l, String r) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index c690ceff1ba1..fa578be389d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -88,7 +88,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr case e: Predicate => e.name() match { case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => val Array(l, r) = e.children().map { - case p: Predicate => s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0 END" + case p: Predicate => inputToCaseWhenSQL(p) case o => inputToSQL(o) } visitBinaryComparison(e.name(), l, r) @@ -105,7 +105,8 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr val stringArray = e.children().grouped(2).flatMap { arr => arr.dropRight(1).map(inputToSQL) :+ (arr.last match { - case p: Predicate => inputToCaseWhenSQL(p) + case p: Predicate if p.name() != "ALWAYS_TRUE" && p.name() != "ALWAYS_FALSE" => + inputToCaseWhenSQL(p) case p => inputToSQL(p) }) } From c2001d931d87eaaf82cf650a87689e2712813472 Mon Sep 17 00:00:00 2001 From: andrej-db Date: Thu, 24 Oct 2024 10:32:20 +0200 Subject: [PATCH 06/28] MsSqlServerDialect: comment --- .../org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index fa578be389d6..67f4ea493ed5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -94,11 +94,13 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr visitBinaryComparison(e.name(), l, r) case "CASE_WHEN" => // Since MsSqlServer cannot handle boolean expressions inside - // a CASE WHEN, it is necessary to convert those to an IIF - // expression that will return 1 or 0 depending on the result. + // a CASE WHEN, it is necessary to convert those to another + // CASE WHEN expression that will return 1 or 0 depending on + // the result. Exceptions are TRUE and FALSE, which already + // get translated to 1 and 0. // Example: - // In: ... CASE WHEN a = b THEN c = d ... - // Out: ... CASE WHEN a = b THEN IIF(c = d, 1, 0) ... + // In: ... CASE WHEN a = b THEN c = d ... END + // Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1 // grouped turns Array[Expression] to Array[Array[Expression]] // with a len of max 2 (final one will have only one) From 8b1b2da51b3206b4707279ac651ebf7c312b6bd7 Mon Sep 17 00:00:00 2001 From: andrej-db Date: Fri, 25 Oct 2024 17:47:44 +0200 Subject: [PATCH 07/28] JdbcDialects: move aux here MsSqlServerDialect: refactor V2ExpressionSQLBuilder: remove aux --- .../connector/util/V2ExpressionSQLBuilder.java | 4 ---- .../apache/spark/sql/jdbc/JdbcDialects.scala | 12 ++++++++++++ .../spark/sql/jdbc/MsSqlServerDialect.scala | 18 +++++++----------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index bf00b923f549..bd2dec9e27be 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -219,10 +219,6 @@ protected String inputToSQL(Expression input) { } } - protected String inputToCaseWhenSQL(Expression input) { - return "CASE WHEN " + inputToSQL(input) + " THEN 1 ELSE 0 END"; - } - protected String visitBinaryComparison(String name, String l, String r) { if (name.equals("<=>")) { return "((" + l + " IS NOT NULL AND " + r + " IS NOT NULL AND " + l + " = " + r + ") " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 3bf1390cb664..d23c43e87d4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcOptionsInWrite, JdbcUtils} @@ -377,6 +378,17 @@ abstract class JdbcDialect extends Serializable with Logging { } private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder { + def inputToCaseWhenSQL(input: Expression): String = + "CASE WHEN " + inputToSQL(input) + " THEN 1 ELSE 0 END" + + def inputToIntSQL(expr: Expression): String = { + expr match { + case p: Predicate if p.name() != "ALWAYS_TRUE" && p.name() != "ALWAYS_FALSE" => + inputToCaseWhenSQL(p) + case p => inputToSQL(p) + } + } + override def visitLiteral(literal: Literal[_]): String = { Option(literal.value()).map(v => compileValue(CatalystTypeConverters.convertToScala(v, literal.dataType())).toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 67f4ea493ed5..cc22c0c83800 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -101,18 +101,14 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr // Example: // In: ... CASE WHEN a = b THEN c = d ... END // Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1 + val stringArray = e.children().grouped(2).flatMap { + case Array(whenExpression, thenExpression) => + Array(inputToSQL(whenExpression), inputToIntSQL(thenExpression)) + case Array(elseExpression) => + Array(inputToIntSQL(elseExpression)) + }.toArray - // grouped turns Array[Expression] to Array[Array[Expression]] - // with a len of max 2 (final one will have only one) - val stringArray = e.children().grouped(2).flatMap { arr => - arr.dropRight(1).map(inputToSQL) :+ - (arr.last match { - case p: Predicate if p.name() != "ALWAYS_TRUE" && p.name() != "ALWAYS_FALSE" => - inputToCaseWhenSQL(p) - case p => inputToSQL(p) - }) - } - visitCaseWhen(stringArray.toArray) + " = 1" + visitCaseWhen(stringArray) + " = 1" case _ => super.build(expr) } case _ => super.build(expr) From 1f01b7730598e06fd4e5e2e7e5bb80b4870af035 Mon Sep 17 00:00:00 2001 From: andrej-db Date: Fri, 25 Oct 2024 18:03:24 +0200 Subject: [PATCH 08/28] V2ExpressionBuilder: refactor --- .../catalyst/util/V2ExpressionBuilder.scala | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 3b631fb0b452..b0ce2bb4293e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -221,18 +221,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate) case caseWhen @ CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) - val values = branches.map(_._2).flatMap(child => - generateExpression( - child, - isPredicate && child.dataType.isInstanceOf[BooleanType] - ) - ) - val elseExprOpt = elseValue.flatMap(child => - generateExpression( - child, - isPredicate && child.dataType.isInstanceOf[BooleanType] - ) - ) + val values = branches.map(_._2).flatMap(generateExpression(_, isPredicate)) + val elseExprOpt = elseValue.flatMap(generateExpression(_, isPredicate)) if (conditions.length == branches.length && values.length == branches.length && elseExprOpt.size == elseValue.size) { val branchExpressions = conditions.zip(values).flatMap { case (c, v) => @@ -431,12 +421,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L children: Seq[Expression], dataType: DataType, isPredicate: Boolean): Option[V2Expression] = { - val childrenExpressions = children.flatMap(child => - generateExpression( - child, - isPredicate && child.dataType.isInstanceOf[BooleanType] - ) - ) + val childrenExpressions = children.flatMap(generateExpression(_, isPredicate)) if (childrenExpressions.length == children.length) { if (isPredicate && dataType.isInstanceOf[BooleanType]) { Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression])) From 7e360297990a3871702ae7b2e352174590492dfe Mon Sep 17 00:00:00 2001 From: andrej-db Date: Tue, 29 Oct 2024 17:31:37 +0100 Subject: [PATCH 09/28] nit --- .../jdbc/v2/MsSqlServerIntegrationSuite.scala | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 6ffebdd2ce57..2751434d8ddd 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -157,14 +157,14 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD | UPPER(name) AS test_type, | name, | IF( - | LOWER(name) = 'adfsaef' OR LOWER(name) = 'agadg', - | 'agfagff', + | LOWER(name) = 'legolas' OR LOWER(name) = 'elrond', + | 'Elf', | IF( - | LOWER(name) = 'adgfda' OR LOWER(name) = 'ssadf', - | 'sxzvfvxf', + | LOWER(name) = 'gimli' OR LOWER(name) = 'thorin', + | 'Dwarf', | IF( - | LOWER(name) = 'sdfadsf' OR LOWER(name) = 'sadfgvad', - | 'sAFvadsfvcds', + | LOWER(name) = 'gandalf' OR LOWER(name) = 'radagast', + | 'Wizard', | LOWER(name) | ) | ) @@ -173,7 +173,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |), |dummy_new AS ( | SELECT * - | FROM dummy WHERE test_type_name = 'safcdfz' + | FROM dummy WHERE test_type_name = 'Wizard' |) |SELECT * FROM dummy_new limit 1""".stripMargin ) @@ -183,7 +183,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") { val df = sql( s"""|SELECT * FROM $catalogName.employee - |WHERE CASE WHEN name = '1' THEN name = 'barxxyz' ELSE NOT (name = 'barxxyz') END + |WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name = 'Wizard') END |""".stripMargin ) df.collect() @@ -192,7 +192,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true test") { val df = sql( s"""|SELECT * FROM $catalogName.employee - |WHERE CASE WHEN (name = 'barxxyz') THEN (name = 'barx') ELSE (1=1) END + |WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1) END |""".stripMargin ) df.collect() @@ -201,9 +201,20 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") { val df = sql( s"""|SELECT * FROM $catalogName.employee - |WHERE CASE WHEN (name = 'barxxyz') THEN - | CASE WHEN (name = '5') THEN (name = 'barx') ELSE (name = 'aweda') END - | ELSE (name = '1') END + |WHERE CASE WHEN (name = 'Legolas') THEN + | CASE WHEN (name = 'Elf') THEN (name = 'Elrond') ELSE (name = 'Gandalf') END + | ELSE (name = 'Sauron') END + |""".stripMargin + ) + df.collect() + } + + test("SPARK-50087: SqlServer handle non-booleans in nested CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN + | CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END + | ELSE 'Sauron' END |""".stripMargin ) df.collect() From ab79afea7fcc4f3cb95251219d45ec6a70e104b8 Mon Sep 17 00:00:00 2001 From: andrej-gobeljic_data Date: Mon, 18 Nov 2024 17:08:57 +0100 Subject: [PATCH 10/28] nit --- .../spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 2 +- .../main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 6 +++--- .../org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 2751434d8ddd..29d759713576 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -149,7 +149,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD test("SPARK-50087: SqlServer handle booleans in IF in SELECT test") { // This doesn't compile on SqlServer unless result boolean expressions - // in IF / CASE WHEN are wrapped with an IIF(<>, 1, 0). + // in IF / CASE WHEN are wrapped with a CASE WHEN(<>, 1, 0). val df = sql( s"""|WITH dummy AS ( | SELECT diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index d23c43e87d4b..a588a3c5a7fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -378,13 +378,13 @@ abstract class JdbcDialect extends Serializable with Logging { } private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder { - def inputToCaseWhenSQL(input: Expression): String = + def predicateToCaseWhenSQL(input: Expression): String = "CASE WHEN " + inputToSQL(input) + " THEN 1 ELSE 0 END" - def inputToIntSQL(expr: Expression): String = { + def predicateToIntSQL(expr: Expression): String = { expr match { case p: Predicate if p.name() != "ALWAYS_TRUE" && p.name() != "ALWAYS_FALSE" => - inputToCaseWhenSQL(p) + predicateToCaseWhenSQL(p) case p => inputToSQL(p) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index cc22c0c83800..1a5f0bac3c74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -88,7 +88,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr case e: Predicate => e.name() match { case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => val Array(l, r) = e.children().map { - case p: Predicate => inputToCaseWhenSQL(p) + case p: Predicate => predicateToCaseWhenSQL(p) case o => inputToSQL(o) } visitBinaryComparison(e.name(), l, r) @@ -103,9 +103,9 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr // Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1 val stringArray = e.children().grouped(2).flatMap { case Array(whenExpression, thenExpression) => - Array(inputToSQL(whenExpression), inputToIntSQL(thenExpression)) + Array(inputToSQL(whenExpression), predicateToIntSQL(thenExpression)) case Array(elseExpression) => - Array(inputToIntSQL(elseExpression)) + Array(predicateToIntSQL(elseExpression)) }.toArray visitCaseWhen(stringArray) + " = 1" From de38a8c712b0c97ec26c6865d03b09a31702b5ee Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Nov 2024 10:16:47 +0800 Subject: [PATCH 11/28] Update JdbcDialects.scala --- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a588a3c5a7fa..778ec168eac7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -378,17 +378,11 @@ abstract class JdbcDialect extends Serializable with Logging { } private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder { - def predicateToCaseWhenSQL(input: Expression): String = + // Some dialects do not support boolean type and this convenient util function is + // provided to generate a SQL statement to convert predicate to integer. + protected def predicateToIntSQL(input: Expression): String = "CASE WHEN " + inputToSQL(input) + " THEN 1 ELSE 0 END" - def predicateToIntSQL(expr: Expression): String = { - expr match { - case p: Predicate if p.name() != "ALWAYS_TRUE" && p.name() != "ALWAYS_FALSE" => - predicateToCaseWhenSQL(p) - case p => inputToSQL(p) - } - } - override def visitLiteral(literal: Literal[_]): String = { Option(literal.value()).map(v => compileValue(CatalystTypeConverters.convertToScala(v, literal.dataType())).toString) From 56cea3c84afc483f23292d790f83c82da068ce19 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Nov 2024 10:32:04 +0800 Subject: [PATCH 12/28] Update MsSqlServerDialect.scala --- .../spark/sql/jdbc/MsSqlServerDialect.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 1a5f0bac3c74..9e4dd6740eb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -87,29 +87,28 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr expr match { case e: Predicate => e.name() match { case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => - val Array(l, r) = e.children().map { - case p: Predicate => predicateToCaseWhenSQL(p) - case o => inputToSQL(o) - } + val Array(l, r) = e.children().map(inputToSQL) visitBinaryComparison(e.name(), l, r) case "CASE_WHEN" => // Since MsSqlServer cannot handle boolean expressions inside // a CASE WHEN, it is necessary to convert those to another // CASE WHEN expression that will return 1 or 0 depending on - // the result. Exceptions are TRUE and FALSE, which already - // get translated to 1 and 0. + // the result. // Example: // In: ... CASE WHEN a = b THEN c = d ... END // Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1 val stringArray = e.children().grouped(2).flatMap { case Array(whenExpression, thenExpression) => - Array(inputToSQL(whenExpression), predicateToIntSQL(thenExpression)) + Array(super.build(whenExpression), inputToSQL(thenExpression)) case Array(elseExpression) => - Array(predicateToIntSQL(elseExpression)) + Array(inputToSQL(elseExpression)) }.toArray visitCaseWhen(stringArray) + " = 1" - case _ => super.build(expr) + // MsSqlServerDialect translates boolean literals to 1/0, no need to rewrite them. + case "ALWAYS_TRUE" | "ALWAYS_FALSE" => + super.build(expr) + case _ => predicateToIntSQL(e) } case _ => super.build(expr) } From 21ec6224573d192f2d7ee30cbcf27fd1f55ba936 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Nov 2024 15:10:19 +0800 Subject: [PATCH 13/28] Update sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala --- .../src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 778ec168eac7..02d8220d3bfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -380,7 +380,7 @@ abstract class JdbcDialect extends Serializable with Logging { private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder { // Some dialects do not support boolean type and this convenient util function is // provided to generate a SQL statement to convert predicate to integer. - protected def predicateToIntSQL(input: Expression): String = + protected def predicateToIntSQL(input: Predicate): String = "CASE WHEN " + inputToSQL(input) + " THEN 1 ELSE 0 END" override def visitLiteral(literal: Literal[_]): String = { From 468eb89a8316dcab74ccb2d7a6581bd4e8f1a9a2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Nov 2024 20:52:09 +0800 Subject: [PATCH 14/28] Update JdbcDialects.scala --- .../main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 02d8220d3bfa..54d5754566c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -42,7 +42,6 @@ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc -import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcOptionsInWrite, JdbcUtils} @@ -380,8 +379,8 @@ abstract class JdbcDialect extends Serializable with Logging { private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder { // Some dialects do not support boolean type and this convenient util function is // provided to generate a SQL statement to convert predicate to integer. - protected def predicateToIntSQL(input: Predicate): String = - "CASE WHEN " + inputToSQL(input) + " THEN 1 ELSE 0 END" + protected def predicateToIntSQL(input: String): String = + "CASE WHEN " + input + " THEN 1 ELSE 0 END" override def visitLiteral(literal: Literal[_]): String = { Option(literal.value()).map(v => From 1bc39ee536ad04d854ec8f70aae834fc06a4f715 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Nov 2024 20:52:43 +0800 Subject: [PATCH 15/28] Update sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala --- .../scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 9e4dd6740eb0..6a8f1ba16bb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -108,7 +108,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr // MsSqlServerDialect translates boolean literals to 1/0, no need to rewrite them. case "ALWAYS_TRUE" | "ALWAYS_FALSE" => super.build(expr) - case _ => predicateToIntSQL(e) + case _ => predicateToIntSQL(super.build(e)) } case _ => super.build(expr) } From 6a221c67c9da976b0072eb2033f3ec946dd075cb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Nov 2024 21:39:41 +0800 Subject: [PATCH 16/28] Update JdbcDialects.scala --- .../scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 54d5754566c5..c65f5c6d29b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcOptionsInWrite, JdbcUtils} @@ -378,7 +379,14 @@ abstract class JdbcDialect extends Serializable with Logging { private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder { // Some dialects do not support boolean type and this convenient util function is - // provided to generate a SQL statement to convert predicate to integer. + // provided to generate SQL string without boolean values. + protected def inputToSQLNoPredicate(input: Expression): String = input match { + case p: Predicate if p.name() == "ALWAYS_TRUE" => "1" + case p: Predicate if p.name() == "ALWAYS_FALSE" => "0" + case p: Predicate => predicateToIntSQL(inputToSQL(p)) + case _ => inputToSQL(p) + } + protected def predicateToIntSQL(input: String): String = "CASE WHEN " + input + " THEN 1 ELSE 0 END" From ef1bcc8c85c7f77778872e43c06464aa281be2ef Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Nov 2024 21:42:48 +0800 Subject: [PATCH 17/28] Update MsSqlServerDialect.scala --- .../apache/spark/sql/jdbc/MsSqlServerDialect.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 6a8f1ba16bb8..811e1056805d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -87,7 +87,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr expr match { case e: Predicate => e.name() match { case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => - val Array(l, r) = e.children().map(inputToSQL) + val Array(l, r) = e.children().map(inputToSQLNoPredicate) visitBinaryComparison(e.name(), l, r) case "CASE_WHEN" => // Since MsSqlServer cannot handle boolean expressions inside @@ -99,16 +99,13 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr // Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1 val stringArray = e.children().grouped(2).flatMap { case Array(whenExpression, thenExpression) => - Array(super.build(whenExpression), inputToSQL(thenExpression)) + Array(inputToSQL(whenExpression), inputToSQLNoPredicate(thenExpression)) case Array(elseExpression) => - Array(inputToSQL(elseExpression)) + Array(inputToSQLNoPredicate(elseExpression)) }.toArray visitCaseWhen(stringArray) + " = 1" - // MsSqlServerDialect translates boolean literals to 1/0, no need to rewrite them. - case "ALWAYS_TRUE" | "ALWAYS_FALSE" => - super.build(expr) - case _ => predicateToIntSQL(super.build(e)) + case _ => super.build(e) } case _ => super.build(expr) } From 5d94fa164cf867b4e08f88379c58478078301e5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Gobelji=C4=87?= Date: Tue, 19 Nov 2024 16:20:51 +0100 Subject: [PATCH 18/28] Update JdbcDialects.scala --- .../src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c65f5c6d29b7..2e3fedb91d9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -384,7 +384,7 @@ abstract class JdbcDialect extends Serializable with Logging { case p: Predicate if p.name() == "ALWAYS_TRUE" => "1" case p: Predicate if p.name() == "ALWAYS_FALSE" => "0" case p: Predicate => predicateToIntSQL(inputToSQL(p)) - case _ => inputToSQL(p) + case p => inputToSQL(p) } protected def predicateToIntSQL(input: String): String = From f94f8e5ff5d94bd98e83baa9e6f9ad9e48f4fbb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Gobelji=C4=87?= Date: Tue, 19 Nov 2024 16:30:10 +0100 Subject: [PATCH 19/28] Update MsSqlServerDialect.scala --- .../scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 811e1056805d..8909f19e3ec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -105,7 +105,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr }.toArray visitCaseWhen(stringArray) + " = 1" - case _ => super.build(e) + case _ => super.build(expr) } case _ => super.build(expr) } From 55084a3f4515ce05d47f813e226587bc99f7ffe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Gobelji=C4=87?= Date: Tue, 19 Nov 2024 17:53:53 +0100 Subject: [PATCH 20/28] Update MsSqlServerIntegrationSuite.scala --- .../jdbc/v2/MsSqlServerIntegrationSuite.scala | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 29d759713576..ac1effda140a 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -177,6 +177,12 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |) |SELECT * FROM dummy_new limit 1""".stripMargin ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT TOP (1) "name" FROM "employee" WHERE (CASE WHEN ((LOWER("name") = 'legolas') OR (LOWER("name") = 'elrond')) THEN 0 ELSE IIF((CASE WHEN ((LOWER("name") = 'gimli') OR (LOWER("name") = 'thorin')) THEN 0 ELSE IIF((CASE WHEN ((LOWER("name") = 'gandalf') OR (LOWER("name") = 'radagast')) THEN 1 ELSE IIF((LOWER("name") = 'Wizard'), 1, 0) END = 1), 1, 0) END = 1), 1, 0) END = 1) """ + ) + // scalastyle:on df.collect() } @@ -186,6 +192,12 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name = 'Wizard') END |""".stripMargin ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """ + ) + // scalastyle:on df.collect() } @@ -195,6 +207,12 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1) END |""".stripMargin ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """ + ) + // scalastyle:on df.collect() } @@ -206,6 +224,12 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD | ELSE (name = 'Sauron') END |""".stripMargin ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """ + ) + // scalastyle:on df.collect() } @@ -214,9 +238,15 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD s"""|SELECT * FROM $catalogName.employee |WHERE CASE WHEN (name = 'Legolas') THEN | CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END - | ELSE 'Sauron' END + | ELSE 'Sauron' END = name |""".stripMargin ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """ + ) + // scalastyle:on df.collect() } } From 7fb6b29b5a3336fe59ffc64472577f3cb773dfc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Gobelji=C4=87?= Date: Tue, 19 Nov 2024 17:55:35 +0100 Subject: [PATCH 21/28] Update MsSqlServerDialect.scala --- .../scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8909f19e3ec3..7630a1eb7762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -59,6 +59,9 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr supportedFunctions.contains(funcName) class MsSqlServerSQLBuilder extends JDBCSQLBuilder { + override protected def predicateToIntSQL(input: String): String = + "IIF(" + input + ", 1, 0)" + override def visitSortOrder( sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = { (sortDirection, nullOrdering) match { From ca0545a00d1675c700c2f77affad24bea2610c7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Gobelji=C4=87?= Date: Tue, 19 Nov 2024 17:58:46 +0100 Subject: [PATCH 22/28] Update MsSqlServerIntegrationSuite.scala --- .../sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index ac1effda140a..b9436ccb6da2 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -20,7 +20,11 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{FilterExec, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker import org.apache.spark.sql.types._ @@ -36,6 +40,17 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { + + def getExternalEngineQuery(executedPlan: SparkPlan): String = { + getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery + } + + def getExternalEngineRdd(executedPlan: SparkPlan): RDD[InternalRow] = { + val queryNode = executedPlan.collect { case r: RowDataSourceScanExec => + r + }.head + queryNode.rdd + } override def excluded: Seq[String] = Seq( "simple scan with OFFSET", From df2fe41a00a003f45588e42ea1bcfe08c7dce70b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 20 Nov 2024 10:04:09 +0800 Subject: [PATCH 23/28] Apply suggestions from code review --- .../main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 4 ++-- .../org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 2e3fedb91d9a..7f9be3a59778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -380,11 +380,11 @@ abstract class JdbcDialect extends Serializable with Logging { private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder { // Some dialects do not support boolean type and this convenient util function is // provided to generate SQL string without boolean values. - protected def inputToSQLNoPredicate(input: Expression): String = input match { + protected def inputToSQLNoBool(input: Expression): String = input match { case p: Predicate if p.name() == "ALWAYS_TRUE" => "1" case p: Predicate if p.name() == "ALWAYS_FALSE" => "0" case p: Predicate => predicateToIntSQL(inputToSQL(p)) - case p => inputToSQL(p) + case _ => inputToSQL(input) } protected def predicateToIntSQL(input: String): String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 7630a1eb7762..d54bfd59bf86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -90,7 +90,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr expr match { case e: Predicate => e.name() match { case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => - val Array(l, r) = e.children().map(inputToSQLNoPredicate) + val Array(l, r) = e.children().map(inputToSQLNoBool) visitBinaryComparison(e.name(), l, r) case "CASE_WHEN" => // Since MsSqlServer cannot handle boolean expressions inside @@ -102,9 +102,9 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr // Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1 val stringArray = e.children().grouped(2).flatMap { case Array(whenExpression, thenExpression) => - Array(inputToSQL(whenExpression), inputToSQLNoPredicate(thenExpression)) + Array(inputToSQL(whenExpression), inputToSQLNoBool(thenExpression)) case Array(elseExpression) => - Array(inputToSQLNoPredicate(elseExpression)) + Array(inputToSQLNoBool(elseExpression)) }.toArray visitCaseWhen(stringArray) + " = 1" From 53a42200e0738074ec6d2327e8083e44d54dc9f0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 20 Nov 2024 11:15:15 +0800 Subject: [PATCH 24/28] Update connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala --- .../apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index b9436ccb6da2..5a04ff954f51 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException} import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{FilterExec, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker From cbef9a59135e7a4d7cb08f5b68c73b4aad0aec52 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 20 Nov 2024 16:11:29 +0800 Subject: [PATCH 25/28] Update MsSqlServerIntegrationSuite.scala --- .../apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 5a04ff954f51..9b6e54bd24e3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -40,7 +40,7 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { - + def getExternalEngineQuery(executedPlan: SparkPlan): String = { getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery } From c990ec680987204e2ab7a00c84a43d7f91754d83 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 20 Nov 2024 20:52:58 +0800 Subject: [PATCH 26/28] Update JdbcDialects.scala --- .../src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 7f9be3a59778..81ad1a6d38bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -386,7 +386,7 @@ abstract class JdbcDialect extends Serializable with Logging { case p: Predicate => predicateToIntSQL(inputToSQL(p)) case _ => inputToSQL(input) } - + protected def predicateToIntSQL(input: String): String = "CASE WHEN " + input + " THEN 1 ELSE 0 END" From ee4d4fb71c3345b0f55f94509578cb544536881d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 20 Nov 2024 23:39:43 +0800 Subject: [PATCH 27/28] Update sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala --- .../scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index d54bfd59bf86..7d339a90db8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -61,7 +61,6 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr class MsSqlServerSQLBuilder extends JDBCSQLBuilder { override protected def predicateToIntSQL(input: String): String = "IIF(" + input + ", 1, 0)" - override def visitSortOrder( sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = { (sortDirection, nullOrdering) match { From 6cff9f56f6f1b9d9d00d2bd95c862ea733deeb3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Gobelji=C4=87?= Date: Thu, 21 Nov 2024 09:15:36 +0100 Subject: [PATCH 28/28] Update MsSqlServerIntegrationSuite.scala --- .../jdbc/v2/MsSqlServerIntegrationSuite.scala | 39 ------------------- 1 file changed, 39 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 9b6e54bd24e3..fd7efb1efb76 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -162,45 +162,6 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD assert(df.collect().length == 2) } - test("SPARK-50087: SqlServer handle booleans in IF in SELECT test") { - // This doesn't compile on SqlServer unless result boolean expressions - // in IF / CASE WHEN are wrapped with a CASE WHEN(<>, 1, 0). - val df = sql( - s"""|WITH dummy AS ( - | SELECT - | DISTINCT name AS full_name, - | UPPER(name) AS test_type, - | name, - | IF( - | LOWER(name) = 'legolas' OR LOWER(name) = 'elrond', - | 'Elf', - | IF( - | LOWER(name) = 'gimli' OR LOWER(name) = 'thorin', - | 'Dwarf', - | IF( - | LOWER(name) = 'gandalf' OR LOWER(name) = 'radagast', - | 'Wizard', - | LOWER(name) - | ) - | ) - | ) AS test_type_name - | FROM $catalogName.employee - |), - |dummy_new AS ( - | SELECT * - | FROM dummy WHERE test_type_name = 'Wizard' - |) - |SELECT * FROM dummy_new limit 1""".stripMargin - ) - - // scalastyle:off - assert(getExternalEngineQuery(df.queryExecution.executedPlan) == - """SELECT TOP (1) "name" FROM "employee" WHERE (CASE WHEN ((LOWER("name") = 'legolas') OR (LOWER("name") = 'elrond')) THEN 0 ELSE IIF((CASE WHEN ((LOWER("name") = 'gimli') OR (LOWER("name") = 'thorin')) THEN 0 ELSE IIF((CASE WHEN ((LOWER("name") = 'gandalf') OR (LOWER("name") = 'radagast')) THEN 1 ELSE IIF((LOWER("name") = 'Wizard'), 1, 0) END = 1), 1, 0) END = 1), 1, 0) END = 1) """ - ) - // scalastyle:on - df.collect() - } - test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") { val df = sql( s"""|SELECT * FROM $catalogName.employee