diff --git a/src/enclave/Enclave/ExpressionEvaluation.h b/src/enclave/Enclave/ExpressionEvaluation.h index 0f48c56d48..7b8dfe0b8b 100644 --- a/src/enclave/Enclave/ExpressionEvaluation.h +++ b/src/enclave/Enclave/ExpressionEvaluation.h @@ -1614,6 +1614,68 @@ class FlatbuffersExpressionEvaluator { result_is_null); } + case tuix::ExprUnion_NormalizeNaNAndZero: + { + auto normalize = static_cast(expr->expr()); + auto child_offset = eval_helper(row, normalize->child()); + + const tuix::Field *value = flatbuffers::GetTemporaryPointer(builder, child_offset); + + if (value->value_type() != tuix::FieldUnion_FloatField && value->value_type() != tuix::FieldUnion_DoubleField) { + throw std::runtime_error( + std::string("tuix::NormalizeNaNAndZero requires type Float or Double, not ") + + std::string(tuix::EnumNameFieldUnion(value->value_type()))); + } + + bool result_is_null = value->is_null(); + + if (value->value_type() == tuix::FieldUnion_FloatField) { + if (!result_is_null) { + float v = value->value_as_FloatField()->value(); + if (isnan(v)) { + v = std::numeric_limits::quiet_NaN(); + } else if (v == -0.0f) { + v = 0.0f; + } + + return tuix::CreateField( + builder, + tuix::FieldUnion_FloatField, + tuix::CreateFloatField(builder, v).Union(), + result_is_null); + } + + return tuix::CreateField( + builder, + tuix::FieldUnion_FloatField, + tuix::CreateFloatField(builder, 0).Union(), + result_is_null); + + } else { + + if (!result_is_null) { + double v = value->value_as_DoubleField()->value(); + if (isnan(v)) { + v = std::numeric_limits::quiet_NaN(); + } else if (v == -0.0d) { + v = 0.0d; + } + + return tuix::CreateField( + builder, + tuix::FieldUnion_DoubleField, + tuix::CreateDoubleField(builder, v).Union(), + result_is_null); + } + + return tuix::CreateField( + builder, + tuix::FieldUnion_DoubleField, + tuix::CreateDoubleField(builder, 0).Union(), + result_is_null); + } + } + default: throw std::runtime_error( std::string("Can't evaluate expression of type ") diff --git a/src/flatbuffers/Expr.fbs b/src/flatbuffers/Expr.fbs index 4acce5e53d..a1e4d92aeb 100644 --- a/src/flatbuffers/Expr.fbs +++ b/src/flatbuffers/Expr.fbs @@ -36,6 +36,7 @@ union ExprUnion { VectorMultiply, DotProduct, Exp, + NormalizeNaNAndZero, ClosestPoint, CreateArray, Upper, @@ -199,6 +200,10 @@ table CreateArray { children:[Expr]; } +table NormalizeNaNAndZero { + child:Expr; +} + // Opaque UDFs table VectorAdd { left:Expr; diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 4c6970e489..cbe2f944dc 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -61,6 +61,7 @@ import org.apache.spark.sql.catalyst.expressions.If import org.apache.spark.sql.catalyst.expressions.In import org.apache.spark.sql.catalyst.expressions.IsNotNull import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized import org.apache.spark.sql.catalyst.expressions.LessThan import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual import org.apache.spark.sql.catalyst.expressions.Literal @@ -91,6 +92,7 @@ import org.apache.spark.sql.catalyst.plans.NaturalJoin import org.apache.spark.sql.catalyst.plans.RightOuter import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.util.MapData @@ -1169,6 +1171,15 @@ object Utils extends Logging { // TODO: Implement decimal serialization, followed by CheckOverflow childOffset + case (NormalizeNaNAndZero(child), Seq(childOffset)) => + tuix.Expr.createExpr( + builder, + tuix.ExprUnion.NormalizeNaNAndZero, + tuix.NormalizeNaNAndZero.createNormalizeNaNAndZero(builder, childOffset)) + + case (KnownFloatingPointNormalized(NormalizeNaNAndZero(child)), Seq(childOffset)) => + flatbuffersSerializeExpression(builder, NormalizeNaNAndZero(child), input) + case (ScalarSubquery(SubqueryExec(name, child), exprId), Seq()) => val output = child.output(0) val dataType = output match { diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala index 14d71a1d0c..c235265624 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.SQLContext object TPCHBenchmark { // Add query numbers here once they are supported - val supportedQueries = Seq(1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22) + val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22) def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = { val sqlStr = tpch.getQuery(queryNumber) diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index a69894d13c..88a5550f17 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -344,6 +344,36 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.collect } + testAgainstSpark("join on floats") { securityLevel => + val p_data = for (i <- 0 to 16) yield (i, i.toFloat, i * 10) + val f_data = (0 until 256).map(x => { + if (x % 3 == 0) + (x, null.asInstanceOf[Float], x * 10) + else + (x, (x % 16).asInstanceOf[Float], x * 10) + }).toSeq + + val p = makeDF(p_data, securityLevel, "id", "pk", "x") + val f = makeDF(f_data, securityLevel, "id", "fk", "x") + val df = p.join(f, $"pk" === $"fk") + df.collect.toSet + } + + testAgainstSpark("join on doubles") { securityLevel => + val p_data = for (i <- 0 to 16) yield (i, i.toDouble, i * 10) + val f_data = (0 until 256).map(x => { + if (x % 3 == 0) + (x, null.asInstanceOf[Double], x * 10) + else + (x, (x % 16).asInstanceOf[Double], x * 10) + }).toSeq + + val p = makeDF(p_data, securityLevel, "id", "pk", "x") + val f = makeDF(f_data, securityLevel, "id", "fk", "x") + val df = p.join(f, $"pk" === $"fk") + df.collect.toSet + } + def abc(i: Int): String = (i % 3) match { case 0 => "A" case 1 => "B"