Skip to content
44 changes: 44 additions & 0 deletions src/enclave/Enclave/ExpressionEvaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,50 @@ class FlatbuffersExpressionEvaluator {
}


case tuix::ExprUnion_Concat:
{
//implementing this like string concat since each argument in already serialized
auto c = static_cast<const tuix::Concat *>(expr->expr());
size_t num_children = c->children()->size();

size_t total = 0;

std::vector<uint8_t> result;

for (size_t i =0; i< num_children; i++){
auto offset = eval_helper(row, (*c->children())[i]);
const tuix::Field *str = flatbuffers::GetTemporaryPointer(builder, offset);
if (str->value_type() != tuix::FieldUnion_StringField) {
throw std::runtime_error(
std::string("tuix::Concat requires serializable data types, not ")
+ std::string(tuix::EnumNameFieldUnion(str->value_type()))
+ std::string(". You do not need to provide the data as string but the data should be serialized into string before sent to concat"));
}
if (!str->is_null()){
// skipping over the null input
auto str_field = static_cast<const tuix::StringField *>(str->value());
uint32_t start = 0;
uint32_t end = str_field ->length();
total += end;
std::vector<uint8_t> stringtoadd(
flatbuffers::VectorIterator<uint8_t, uint8_t>(str_field->value()->Data(),
start),
flatbuffers::VectorIterator<uint8_t, uint8_t>(str_field->value()->Data(),
end));
result.insert(result.end(), stringtoadd.begin(), stringtoadd.end());
}

}

return tuix::CreateField(
builder,
tuix::FieldUnion_StringField,
tuix::CreateStringFieldDirect(
builder, &result, static_cast<uint32_t>(total)).Union(),
total==0);

}

case tuix::ExprUnion_In:
{
auto c = static_cast<const tuix::In *>(expr->expr());
Expand Down
7 changes: 6 additions & 1 deletion src/flatbuffers/Expr.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ union ExprUnion {
GreaterThanOrEqual,
EqualTo,
Contains,
Concat,
In,
Col,
Literal,
Expand Down Expand Up @@ -126,8 +127,12 @@ table Contains {
right:Expr;
}

table Concat {
children:[Expr];
}

// Array expressions
table In{
table In {
children:[Expr];
}

Expand Down
12 changes: 9 additions & 3 deletions src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.expressions.Contains

import org.apache.spark.sql.catalyst.expressions.In
import org.apache.spark.sql.catalyst.expressions.Concat
import org.apache.spark.sql.catalyst.expressions.DateAdd
import org.apache.spark.sql.catalyst.expressions.DateAddInterval
import org.apache.spark.sql.catalyst.expressions.Descending
Expand All @@ -57,6 +56,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.GreaterThan
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual
import org.apache.spark.sql.catalyst.expressions.If
import org.apache.spark.sql.catalyst.expressions.In
import org.apache.spark.sql.catalyst.expressions.IsNotNull
import org.apache.spark.sql.catalyst.expressions.IsNull
import org.apache.spark.sql.catalyst.expressions.LessThan
Expand Down Expand Up @@ -997,15 +997,21 @@ object Utils extends Logging {
tuix.Contains.createContains(
builder, leftOffset, rightOffset))

case (Concat(child), childrenOffsets) =>
tuix.Expr.createExpr(
builder,
tuix.ExprUnion.Concat,
tuix.Concat.createConcat(
builder, tuix.Concat.createChildrenVector(builder, childrenOffsets.toArray)))

case (In(left, right), childrenOffsets) =>
tuix.Expr.createExpr(
builder,
tuix.ExprUnion.In,
tuix.In.createIn(
builder, tuix.In.createChildrenVector(builder, childrenOffsets.toArray)))
// Time expressions

// Time expressions
case (Year(child), Seq(childOffset)) =>
tuix.Expr.createExpr(
builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,22 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
df.filter($"word".contains(lit("1"))).collect
}

testAgainstSpark("concat with string") { securityLevel =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case for empty string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combined with "concat with other datatype test"

val data = for (i <- 0 until 256) yield ("%03d".format(i) * 3, i.toString)
val df = makeDF(data, securityLevel, "str", "x")
df.select(concat(col("str"),lit(","),col("x"))).collect
}

testAgainstSpark("concat with other datatype") { securityLevel =>
// float causes a formating issue where opaque outputs 1.000000 and spark produces 1.0 so the following line is commented out
// val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, 1.0f)
// you can't serialize date so that's not supported as well
// opaque doesn't support byte
val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, null.asInstanceOf[Int],"")
val df = makeDF(data, securityLevel, "str", "int","null","emptystring")
df.select(concat(col("str"),lit(","),col("int"),col("null"),col("emptystring"))).collect
}

testAgainstSpark("isin1") { securityLevel =>
val ids = Seq((1, 2, 2), (2, 3, 1))
val df = makeDF(ids, securityLevel, "x", "y", "id")
Expand Down