diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 58c95c94ba19..d2e30af079ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -113,7 +113,7 @@ private[codegen] case class NewFunctionSpec( * A context for codegen, tracking a list of objects that could be passed into generated Java * function. */ -class CodegenContext extends Logging { +class CodegenContext(val disallowSwitchStatement: Boolean = false) extends Logging { import CodeGenerator._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bd190c3e5abc..c1959cf7faac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -454,7 +454,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (canBeComputedUsingSwitch && hset.size <= SQLConf.get.optimizerInSetSwitchThreshold) { + if (canBeComputedUsingSwitch && hset.size <= SQLConf.get.optimizerInSetSwitchThreshold && + !ctx.disallowSwitchStatement) { genCodeWithSwitch(ctx, ev) } else { genCodeWithSet(ctx, ev) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 0615324b8430..4e5cdb74aa8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -167,8 +167,9 @@ case class ExpandExec( } } - // Part 2: switch/case statements - val cases = projections.zipWithIndex.map { case (exprs, row) => + // Part 2: switch/case statements(, or if ~ else if statements when needed) + + val updates = projections.map { exprs => var updateCode = "" for (col <- exprs.indices) { if (!sameOutput(col)) { @@ -178,27 +179,48 @@ case class ExpandExec( |${ev.code} |${outputColumns(col).isNull} = ${ev.isNull}; |${outputColumns(col).value} = ${ev.value}; - """.stripMargin + """.stripMargin } } + updateCode.trim + } + + // the name needs to be known to build conditions + val i = ctx.freshName("i") + val loopContent = if (!ctx.disallowSwitchStatement) { + val cases = updates.zipWithIndex.map { case (updateCode, row) => + s""" + |case $row: + | ${updateCode.trim} + | break; + """.stripMargin + } s""" - |case $row: - | ${updateCode.trim} - | break; + |switch ($i) { + | ${cases.mkString("\n").trim} + |} """.stripMargin + } else { + val conditions = updates.zipWithIndex.map { case (updateCode, row) => + (if (row > 0) "else " else "") + + s""" + |if ($i == $row) { + | ${updateCode.trim} + |} + """.stripMargin + } + + conditions.mkString("\n").trim } val numOutput = metricTerm(ctx, "numOutputRows") - val i = ctx.freshName("i") // these column have to declared before the loop. val evaluate = evaluateVariables(outputColumns) s""" |$evaluate |for (int $i = 0; $i < ${projections.length}; $i ++) { - | switch ($i) { - | ${cases.mkString("\n").trim} - | } + | $loopContent | $numOutput.add(1); | ${consume(ctx, outputColumns)} |} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 10fe0f252322..ade3fd1b21ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -629,6 +629,10 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) */ def doCodeGen(): (CodegenContext, CodeAndComment) = { val ctx = new CodegenContext + (ctx, doCodeGen(ctx)) + } + + private def doCodeGen(ctx: CodegenContext): CodeAndComment = { val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) // main next function. @@ -647,7 +651,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) } ${ctx.registerComment( - s"""Codegend pipeline for stage (id=$codegenStageId) + s"""Codegen pipeline for stage (id=$codegenStageId) |${this.treeString.trim}""".stripMargin, "wsc_codegenPipeline")} ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)} @@ -679,7 +683,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) new CodeAndComment(CodeFormatter.stripExtraNewLines(source), ctx.getPlaceHolderToComments())) logDebug(s"\n${CodeFormatter.format(cleanedSource)}") - (ctx, cleanedSource) + cleanedSource } override def doExecuteColumnar(): RDD[ColumnarBatch] = { @@ -688,11 +692,56 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) child.executeColumnar() } - override def doExecute(): RDD[InternalRow] = { + private type CompileResult = (CodegenContext, CodeAndComment, GeneratedClass, ByteCodeStats) + + /** + * NOTE: This method handles the known Janino bug: + * - https://github.com/janino-compiler/janino/issues/113 + * + * It tries to generate code and compile in normal path. If the compilation fails and the reason + * is due to the known bug, it generates workaround code via touching flag in CodegenContext and + * compile again. + */ + private def doGenCodeAndCompile(): CompileResult = { + def containsMsg(exception: Throwable, msg: String): Boolean = { + def contain(msg1: String, msg2: String): Boolean = { + msg1.toLowerCase(Locale.ROOT).contains(msg2.toLowerCase(Locale.ROOT)) + } + + var e = exception + var contains = contain(e.getMessage, msg) + while (e.getCause != null && !contains) { + e = e.getCause + contains = contain(e.getMessage, msg) + } + contains + } + val (ctx, cleanedSource) = doCodeGen() + try { + val (genClass, maxCodeSize) = CodeGenerator.compile(cleanedSource) + (ctx, cleanedSource, genClass, maxCodeSize) + } catch { + case NonFatal(e) if cleanedSource.body.contains("switch") && + containsMsg(e, "Operand stack inconsistent at offset") => + // It might hit known Janino bug (https://github.com/janino-compiler/janino/issues/113) + // Try to disallow "switch" statement during codegen, and compile again. + // The log level is matched with the log level for compilation error log message in + // Codegenerator.compile() to ensure the log message is shown if end users see the log + // for compilation error. + logError("Generated code hits known Janino bug - applying workaround and recompiling...") + + val newCtx = new CodegenContext(disallowSwitchStatement = true) + val newCleanedSource = doCodeGen(newCtx) + val (genClass, maxCodeSize) = CodeGenerator.compile(newCleanedSource) + (newCtx, newCleanedSource, genClass, maxCodeSize) + } + } + + override def doExecute(): RDD[InternalRow] = { // try to compile and fallback if it failed - val (_, compiledCodeStats) = try { - CodeGenerator.compile(cleanedSource) + val (ctx, cleanedSource, _, compiledCodeStats) = try { + doGenCodeAndCompile() } catch { case NonFatal(_) if !Utils.isTesting && sqlContext.conf.codegenFallback => // We should already saw the error message diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 288f3dac3662..2049652c8f88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData.DecimalData import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -957,4 +956,60 @@ class DataFrameAggregateSuite extends QueryTest assert(error.message.contains("function count_if requires boolean type")) } } + + /** + * NOTE: The test code tries to control the size of for/switch statement in expand_doConsume, + * as well as the overall size of expand_doConsume, so that the query triggers known Janino + * bug - https://github.com/janino-compiler/janino/issues/113. + * + * The expected exception message from Janino when we use switch statement for "ExpandExec": + * - "Operand stack inconsistent at offset xxx: Previous size 1, now 0" + * which will not happen when we use if-else-if statement for "ExpandExec". + * + * "The number of fields" and "The number of distinct aggregation functions" are the major + * factors to increase the size of generated code: while these values should be large enough + * to trigger the Janino bug, these values should not also too big; otherwise one of below + * exceptions might be thrown: + * - "expand_doConsume would be beyond 64KB" + * - "java.lang.ClassFormatError: Too many arguments in method signature in class file" + */ + test("SPARK-31115 Lots of columns and distinct aggregations shouldn't break code generation") { + withSQLConf( + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true"), + (SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key, "10000"), + (SQLConf.CODEGEN_FALLBACK.key, "false"), + (SQLConf.CODEGEN_LOGGING_MAX_LINES.key, "-1") + ) { + var df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + + // The value is tested under commit "e807118eef9e0214170ff62c828524d237bd58e3": + // the query fails with switch statement, whereas it passes with if-else statement. + // Note that the value depends on the Spark logic as well - different Spark versions may + // require different value to ensure the test failing with switch statement. + val numNewFields = 100 + + df = df.withColumns( + (1 to numNewFields).map { idx => s"a$idx" }, + (1 to numNewFields).map { idx => + when(col("c").mod(lit(2)).===(lit(0)), lit(idx)).otherwise(col("c")) + } + ) + + val aggExprs: Array[Column] = Range(1, numNewFields).map { idx => + if (idx % 2 == 0) { + coalesce(countDistinct(s"a$idx"), lit(0)) + } else { + coalesce(count(s"a$idx"), lit(0)) + } + }.toArray + + val aggDf = df + .groupBy("a", "b") + .agg(aggExprs.head, aggExprs.tail: _*) + + // We are only interested in whether the code compilation fails or not, so skipping + // verification on outputs. + aggDf.collect() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 888772c35d0e..d453896cbb57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -107,7 +107,7 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { .queryExecution.executedPlan) assert(res.length == 2) assert(res.forall { case (_, code, _) => - (code.contains("* Codegend pipeline") == flag) && + (code.contains("* Codegen pipeline") == flag) && (code.contains("// input[") == flag) }) }