diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 329edc5d2f..59090a5bd3 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -287,8 +287,6 @@ impl PhysicalPlanner { ) } ExprStruct::IntegralDivide(expr) => { - // TODO respect eval mode - // https://github.com/apache/datafusion-comet/issues/533 let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr_with_options( expr.left.as_ref().unwrap(), @@ -987,11 +985,12 @@ impl PhysicalPlanner { } else { "decimal_div" }; - let fun_expr = create_comet_physical_fun( + let fun_expr = create_comet_physical_fun_with_eval_mode( func_name, data_type.clone(), &self.session_ctx.state(), None, + eval_mode, )?; Ok(Arc::new(ScalarFunctionExpr::new( func_name, diff --git a/native/spark-expr/benches/decimal_div.rs b/native/spark-expr/benches/decimal_div.rs index 4262e81238..3ca3e42eb5 100644 --- a/native/spark-expr/benches/decimal_div.rs +++ b/native/spark-expr/benches/decimal_div.rs @@ -20,7 +20,7 @@ use arrow::compute::cast; use arrow::datatypes::DataType; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::physical_plan::ColumnarValue; -use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div}; +use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div, EvalMode}; use std::hint::black_box; use std::sync::Arc; @@ -48,6 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box(spark_decimal_div( black_box(&args), black_box(&DataType::Decimal128(10, 4)), + EvalMode::Legacy, )) }) }); @@ -57,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box(spark_decimal_integral_div( black_box(&args), black_box(&DataType::Decimal128(10, 4)), + EvalMode::Legacy, )) }) }); diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 393f57662e..19fa11e641 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -137,13 +137,14 @@ pub fn create_comet_physical_fun_with_eval_mode( make_comet_scalar_udf!("unhex", func, without data_type) } "decimal_div" => { - make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) + make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type, eval_mode) } "decimal_integral_div" => { make_comet_scalar_udf!( "decimal_integral_div", spark_decimal_integral_div, - data_type + data_type, + eval_mode ) } "checked_add" => { diff --git a/native/spark-expr/src/math_funcs/div.rs b/native/spark-expr/src/math_funcs/div.rs index 9fc6692c03..933b28c094 100644 --- a/native/spark-expr/src/math_funcs/div.rs +++ b/native/spark-expr/src/math_funcs/div.rs @@ -16,29 +16,33 @@ // under the License. use crate::math_funcs::utils::get_precision_scale; +use crate::{divide_by_zero_error, EvalMode}; use arrow::array::{Array, Decimal128Array}; use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION}; +use arrow::error::ArrowError; use arrow::{ array::{ArrayRef, AsArray}, datatypes::Decimal128Type, }; use datafusion::common::DataFusionError; use datafusion::physical_plan::ColumnarValue; -use num::{BigInt, Signed, ToPrimitive}; +use num::{BigInt, Signed, ToPrimitive, Zero}; use std::sync::Arc; pub fn spark_decimal_div( args: &[ColumnarValue], data_type: &DataType, + eval_mode: EvalMode, ) -> Result { - spark_decimal_div_internal(args, data_type, false) + spark_decimal_div_internal(args, data_type, false, eval_mode) } pub fn spark_decimal_integral_div( args: &[ColumnarValue], data_type: &DataType, + eval_mode: EvalMode, ) -> Result { - spark_decimal_div_internal(args, data_type, true) + spark_decimal_div_internal(args, data_type, true, eval_mode) } // Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). @@ -50,6 +54,7 @@ fn spark_decimal_div_internal( args: &[ColumnarValue], data_type: &DataType, is_integral_div: bool, + eval_mode: EvalMode, ) -> Result { let left = &args[0]; let right = &args[1]; @@ -80,9 +85,12 @@ fn spark_decimal_div_internal( let r_mul = ten.pow(r_exp); let five = BigInt::from(5); let zero = BigInt::from(0); - arrow::compute::kernels::arity::binary(left, right, |l, r| { + arrow::compute::kernels::arity::try_binary(left, right, |l, r| { let l = BigInt::from(l) * &l_mul; let r = BigInt::from(r) * &r_mul; + if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() { + return Err(ArrowError::ComputeError(divide_by_zero_error().to_string())); + } let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; let res = if is_integral_div { div @@ -91,14 +99,17 @@ fn spark_decimal_div_internal( } else { div + &five } / &ten; - res.to_i128().unwrap_or(i128::MAX) + Ok(res.to_i128().unwrap_or(i128::MAX)) })? } else { let l_mul = 10_i128.pow(l_exp); let r_mul = 10_i128.pow(r_exp); - arrow::compute::kernels::arity::binary(left, right, |l, r| { + arrow::compute::kernels::arity::try_binary(left, right, |l, r| { let l = l * l_mul; let r = r * r_mul; + if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() { + return Err(ArrowError::ComputeError(divide_by_zero_error().to_string())); + } let div = if r == 0 { 0 } else { l / r }; let res = if is_integral_div { div @@ -107,7 +118,7 @@ fn spark_decimal_div_internal( } else { div + 5 } / 10; - res.to_i128().unwrap_or(i128::MAX) + Ok(res.to_i128().unwrap_or(i128::MAX)) })? }; let result = result.with_data_type(DataType::Decimal128(p3, s3)); diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala index 4507dc1073..eed4d7b16a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala @@ -180,14 +180,6 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase { object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with MathBase { - override def getSupportLevel(expr: IntegralDivide): SupportLevel = { - if (expr.evalMode == EvalMode.ANSI) { - Incompatible(Some("ANSI mode is not supported")) - } else { - Compatible(None) - } - } - override def convert( expr: IntegralDivide, inputs: Seq[Attribute], @@ -206,9 +198,9 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat if (expr.right.dataType.isInstanceOf[DecimalType]) expr.right else Cast(expr.right, DecimalType(19, 0)) - val rightExpr = nullIfWhenPrimitive(right) + val rightExpr = if (expr.evalMode != EvalMode.ANSI) nullIfWhenPrimitive(right) else right - val dataType = (left.dataType, right.dataType) match { + val dataType = (left.dataType, rightExpr.dataType) match { case (l: DecimalType, r: DecimalType) => // copy from IntegralDivide.resultDecimalType val intDig = l.precision - l.scale + r.scale diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f391d52f78..f86a934e7b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -58,7 +58,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val ARITHMETIC_OVERFLOW_EXCEPTION_MSG = """org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error""" val DIVIDE_BY_ZERO_EXCEPTION_MSG = - """org.apache.comet.CometNativeException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead""" + """Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead""" test("compare true/false to negative zero") { Seq(false, true).foreach { dictionary => @@ -2948,7 +2948,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("ANSI support for divide (division by zero)") { - // TODO : Support ANSI mode in Integral divide - val data = Seq((Integer.MIN_VALUE, 0)) withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { withParquetTable(data, "tbl") { @@ -2969,7 +2968,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("ANSI support for divide (division by zero) float division") { - // TODO : Support ANSI mode in Integral divide - val data = Seq((Float.MinPositiveValue, 0.0)) withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { withParquetTable(data, "tbl") { @@ -2989,6 +2987,35 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for integral divide (division by zero)") { + val data = Seq((Integer.MAX_VALUE, 0)) + Seq("true", "false").foreach { p => + withSQLConf(SQLConf.ANSI_ENABLED.key -> p) { + withParquetTable(data, "tbl") { + val res = spark.sql(""" + |SELECT + | _1 div _2 + | from tbl + | """.stripMargin) + + checkSparkMaybeThrows(res) match { + case (Some(sparkException), Some(cometException)) => + assert(sparkException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG)) + assert(cometException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG)) + case (None, None) => checkSparkAnswerAndOperator(res) + case (None, Some(ex)) => + fail( + "Comet threw an exception but Spark did not. Comet exception: " + ex.getMessage) + case (Some(sparkException), None) => + fail( + "Spark threw an exception but Comet did not. Spark exception: " + + sparkException.getMessage) + } + } + } + } + } + test("test integral divide overflow for decimal") { if (isSpark40Plus) { Seq(true, false)