Skip to content

Float expressions #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/enclave/Enclave/ExpressionEvaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,68 @@ class FlatbuffersExpressionEvaluator {
result_is_null);
}

case tuix::ExprUnion_NormalizeNaNAndZero:
{
auto normalize = static_cast<const tuix::NormalizeNaNAndZero *>(expr->expr());
auto child_offset = eval_helper(row, normalize->child());

const tuix::Field *value = flatbuffers::GetTemporaryPointer(builder, child_offset);

if (value->value_type() != tuix::FieldUnion_FloatField && value->value_type() != tuix::FieldUnion_DoubleField) {
throw std::runtime_error(
std::string("tuix::NormalizeNaNAndZero requires type Float or Double, not ")
+ std::string(tuix::EnumNameFieldUnion(value->value_type())));
}

bool result_is_null = value->is_null();

if (value->value_type() == tuix::FieldUnion_FloatField) {
if (!result_is_null) {
float v = value->value_as_FloatField()->value();
if (isnan(v)) {
v = std::numeric_limits<float>::quiet_NaN();
} else if (v == -0.0f) {
v = 0.0f;
}

return tuix::CreateField(
builder,
tuix::FieldUnion_FloatField,
tuix::CreateFloatField(builder, v).Union(),
result_is_null);
}

return tuix::CreateField(
builder,
tuix::FieldUnion_FloatField,
tuix::CreateFloatField(builder, 0).Union(),
result_is_null);

} else {

if (!result_is_null) {
double v = value->value_as_DoubleField()->value();
if (isnan(v)) {
v = std::numeric_limits<double>::quiet_NaN();
} else if (v == -0.0d) {
v = 0.0d;
}

return tuix::CreateField(
builder,
tuix::FieldUnion_DoubleField,
tuix::CreateDoubleField(builder, v).Union(),
result_is_null);
}

return tuix::CreateField(
builder,
tuix::FieldUnion_DoubleField,
tuix::CreateDoubleField(builder, 0).Union(),
result_is_null);
}
}

default:
throw std::runtime_error(
std::string("Can't evaluate expression of type ")
Expand Down
5 changes: 5 additions & 0 deletions src/flatbuffers/Expr.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ union ExprUnion {
VectorMultiply,
DotProduct,
Exp,
NormalizeNaNAndZero,
ClosestPoint,
CreateArray,
Upper,
Expand Down Expand Up @@ -199,6 +200,10 @@ table CreateArray {
children:[Expr];
}

table NormalizeNaNAndZero {
child:Expr;
}

// Opaque UDFs
table VectorAdd {
left:Expr;
Expand Down
11 changes: 11 additions & 0 deletions src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import org.apache.spark.sql.catalyst.expressions.If
import org.apache.spark.sql.catalyst.expressions.In
import org.apache.spark.sql.catalyst.expressions.IsNotNull
import org.apache.spark.sql.catalyst.expressions.IsNull
import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized
import org.apache.spark.sql.catalyst.expressions.LessThan
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual
import org.apache.spark.sql.catalyst.expressions.Literal
Expand Down Expand Up @@ -91,6 +92,7 @@ import org.apache.spark.sql.catalyst.plans.NaturalJoin
import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.plans.UsingJoin
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.util.MapData
Expand Down Expand Up @@ -1169,6 +1171,15 @@ object Utils extends Logging {
// TODO: Implement decimal serialization, followed by CheckOverflow
childOffset

case (NormalizeNaNAndZero(child), Seq(childOffset)) =>
tuix.Expr.createExpr(
builder,
tuix.ExprUnion.NormalizeNaNAndZero,
tuix.NormalizeNaNAndZero.createNormalizeNaNAndZero(builder, childOffset))

case (KnownFloatingPointNormalized(NormalizeNaNAndZero(child)), Seq(childOffset)) =>
flatbuffersSerializeExpression(builder, NormalizeNaNAndZero(child), input)

case (ScalarSubquery(SubqueryExec(name, child), exprId), Seq()) =>
val output = child.output(0)
val dataType = output match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.SQLContext
object TPCHBenchmark {

// Add query numbers here once they are supported
val supportedQueries = Seq(1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22)
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22)

def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = {
val sqlStr = tpch.getQuery(queryNumber)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,36 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
df.collect
}

testAgainstSpark("join on floats") { securityLevel =>
val p_data = for (i <- 0 to 16) yield (i, i.toFloat, i * 10)
val f_data = (0 until 256).map(x => {
if (x % 3 == 0)
(x, null.asInstanceOf[Float], x * 10)
else
(x, (x % 16).asInstanceOf[Float], x * 10)
}).toSeq

val p = makeDF(p_data, securityLevel, "id", "pk", "x")
val f = makeDF(f_data, securityLevel, "id", "fk", "x")
val df = p.join(f, $"pk" === $"fk")
df.collect.toSet
}

testAgainstSpark("join on doubles") { securityLevel =>
val p_data = for (i <- 0 to 16) yield (i, i.toDouble, i * 10)
val f_data = (0 until 256).map(x => {
if (x % 3 == 0)
(x, null.asInstanceOf[Double], x * 10)
else
(x, (x % 16).asInstanceOf[Double], x * 10)
}).toSeq

val p = makeDF(p_data, securityLevel, "id", "pk", "x")
val f = makeDF(f_data, securityLevel, "id", "fk", "x")
val df = p.join(f, $"pk" === $"fk")
df.collect.toSet
}

def abc(i: Int): String = (i % 3) match {
case 0 => "A"
case 1 => "B"
Expand Down