Skip to content

Commit 3c28b5f

Browse files
authored
Float expressions (#160)
This PR adds float normalization expressions [implemented in Spark](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala#L170). TPC-H query 2 also passes.
1 parent 96e6285 commit 3c28b5f

File tree

5 files changed

+109
-1
lines changed

5 files changed

+109
-1
lines changed

src/enclave/Enclave/ExpressionEvaluation.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,6 +1614,68 @@ class FlatbuffersExpressionEvaluator {
16141614
result_is_null);
16151615
}
16161616

1617+
case tuix::ExprUnion_NormalizeNaNAndZero:
1618+
{
1619+
auto normalize = static_cast<const tuix::NormalizeNaNAndZero *>(expr->expr());
1620+
auto child_offset = eval_helper(row, normalize->child());
1621+
1622+
const tuix::Field *value = flatbuffers::GetTemporaryPointer(builder, child_offset);
1623+
1624+
if (value->value_type() != tuix::FieldUnion_FloatField && value->value_type() != tuix::FieldUnion_DoubleField) {
1625+
throw std::runtime_error(
1626+
std::string("tuix::NormalizeNaNAndZero requires type Float or Double, not ")
1627+
+ std::string(tuix::EnumNameFieldUnion(value->value_type())));
1628+
}
1629+
1630+
bool result_is_null = value->is_null();
1631+
1632+
if (value->value_type() == tuix::FieldUnion_FloatField) {
1633+
if (!result_is_null) {
1634+
float v = value->value_as_FloatField()->value();
1635+
if (isnan(v)) {
1636+
v = std::numeric_limits<float>::quiet_NaN();
1637+
} else if (v == -0.0f) {
1638+
v = 0.0f;
1639+
}
1640+
1641+
return tuix::CreateField(
1642+
builder,
1643+
tuix::FieldUnion_FloatField,
1644+
tuix::CreateFloatField(builder, v).Union(),
1645+
result_is_null);
1646+
}
1647+
1648+
return tuix::CreateField(
1649+
builder,
1650+
tuix::FieldUnion_FloatField,
1651+
tuix::CreateFloatField(builder, 0).Union(),
1652+
result_is_null);
1653+
1654+
} else {
1655+
1656+
if (!result_is_null) {
1657+
double v = value->value_as_DoubleField()->value();
1658+
if (isnan(v)) {
1659+
v = std::numeric_limits<double>::quiet_NaN();
1660+
} else if (v == -0.0d) {
1661+
v = 0.0d;
1662+
}
1663+
1664+
return tuix::CreateField(
1665+
builder,
1666+
tuix::FieldUnion_DoubleField,
1667+
tuix::CreateDoubleField(builder, v).Union(),
1668+
result_is_null);
1669+
}
1670+
1671+
return tuix::CreateField(
1672+
builder,
1673+
tuix::FieldUnion_DoubleField,
1674+
tuix::CreateDoubleField(builder, 0).Union(),
1675+
result_is_null);
1676+
}
1677+
}
1678+
16171679
default:
16181680
throw std::runtime_error(
16191681
std::string("Can't evaluate expression of type ")

src/flatbuffers/Expr.fbs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ union ExprUnion {
3636
VectorMultiply,
3737
DotProduct,
3838
Exp,
39+
NormalizeNaNAndZero,
3940
ClosestPoint,
4041
CreateArray,
4142
Upper,
@@ -199,6 +200,10 @@ table CreateArray {
199200
children:[Expr];
200201
}
201202

203+
table NormalizeNaNAndZero {
204+
child:Expr;
205+
}
206+
202207
// Opaque UDFs
203208
table VectorAdd {
204209
left:Expr;

src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ import org.apache.spark.sql.catalyst.expressions.If
6161
import org.apache.spark.sql.catalyst.expressions.In
6262
import org.apache.spark.sql.catalyst.expressions.IsNotNull
6363
import org.apache.spark.sql.catalyst.expressions.IsNull
64+
import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized
6465
import org.apache.spark.sql.catalyst.expressions.LessThan
6566
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual
6667
import org.apache.spark.sql.catalyst.expressions.Literal
@@ -91,6 +92,7 @@ import org.apache.spark.sql.catalyst.plans.NaturalJoin
9192
import org.apache.spark.sql.catalyst.plans.RightOuter
9293
import org.apache.spark.sql.catalyst.plans.UsingJoin
9394
import org.apache.spark.sql.catalyst.trees.TreeNode
95+
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
9496
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
9597
import org.apache.spark.sql.catalyst.util.ArrayData
9698
import org.apache.spark.sql.catalyst.util.MapData
@@ -1169,6 +1171,15 @@ object Utils extends Logging {
11691171
// TODO: Implement decimal serialization, followed by CheckOverflow
11701172
childOffset
11711173

1174+
case (NormalizeNaNAndZero(child), Seq(childOffset)) =>
1175+
tuix.Expr.createExpr(
1176+
builder,
1177+
tuix.ExprUnion.NormalizeNaNAndZero,
1178+
tuix.NormalizeNaNAndZero.createNormalizeNaNAndZero(builder, childOffset))
1179+
1180+
case (KnownFloatingPointNormalized(NormalizeNaNAndZero(child)), Seq(childOffset)) =>
1181+
flatbuffersSerializeExpression(builder, NormalizeNaNAndZero(child), input)
1182+
11721183
case (ScalarSubquery(SubqueryExec(name, child), exprId), Seq()) =>
11731184
val output = child.output(0)
11741185
val dataType = output match {

src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.SQLContext
2424
object TPCHBenchmark {
2525

2626
// Add query numbers here once they are supported
27-
val supportedQueries = Seq(1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22)
27+
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22)
2828

2929
def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = {
3030
val sqlStr = tpch.getQuery(queryNumber)

src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,36 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
344344
df.collect
345345
}
346346

347+
testAgainstSpark("join on floats") { securityLevel =>
348+
val p_data = for (i <- 0 to 16) yield (i, i.toFloat, i * 10)
349+
val f_data = (0 until 256).map(x => {
350+
if (x % 3 == 0)
351+
(x, null.asInstanceOf[Float], x * 10)
352+
else
353+
(x, (x % 16).asInstanceOf[Float], x * 10)
354+
}).toSeq
355+
356+
val p = makeDF(p_data, securityLevel, "id", "pk", "x")
357+
val f = makeDF(f_data, securityLevel, "id", "fk", "x")
358+
val df = p.join(f, $"pk" === $"fk")
359+
df.collect.toSet
360+
}
361+
362+
testAgainstSpark("join on doubles") { securityLevel =>
363+
val p_data = for (i <- 0 to 16) yield (i, i.toDouble, i * 10)
364+
val f_data = (0 until 256).map(x => {
365+
if (x % 3 == 0)
366+
(x, null.asInstanceOf[Double], x * 10)
367+
else
368+
(x, (x % 16).asInstanceOf[Double], x * 10)
369+
}).toSeq
370+
371+
val p = makeDF(p_data, securityLevel, "id", "pk", "x")
372+
val f = makeDF(f_data, securityLevel, "id", "fk", "x")
373+
val df = p.join(f, $"pk" === $"fk")
374+
df.collect.toSet
375+
}
376+
347377
def abc(i: Int): String = (i % 3) match {
348378
case 0 => "A"
349379
case 1 => "B"

0 commit comments

Comments
 (0)