diff --git a/src/enclave/Enclave/ExpressionEvaluation.h b/src/enclave/Enclave/ExpressionEvaluation.h index 58bcb773f2..80475b877f 100644 --- a/src/enclave/Enclave/ExpressionEvaluation.h +++ b/src/enclave/Enclave/ExpressionEvaluation.h @@ -743,6 +743,50 @@ class FlatbuffersExpressionEvaluator { } + case tuix::ExprUnion_Concat: + { + //implementing this like string concat since each argument in already serialized + auto c = static_cast(expr->expr()); + size_t num_children = c->children()->size(); + + size_t total = 0; + + std::vector result; + + for (size_t i =0; i< num_children; i++){ + auto offset = eval_helper(row, (*c->children())[i]); + const tuix::Field *str = flatbuffers::GetTemporaryPointer(builder, offset); + if (str->value_type() != tuix::FieldUnion_StringField) { + throw std::runtime_error( + std::string("tuix::Concat requires serializable data types, not ") + + std::string(tuix::EnumNameFieldUnion(str->value_type())) + + std::string(". You do not need to provide the data as string but the data should be serialized into string before sent to concat")); + } + if (!str->is_null()){ + // skipping over the null input + auto str_field = static_cast(str->value()); + uint32_t start = 0; + uint32_t end = str_field ->length(); + total += end; + std::vector stringtoadd( + flatbuffers::VectorIterator(str_field->value()->Data(), + start), + flatbuffers::VectorIterator(str_field->value()->Data(), + end)); + result.insert(result.end(), stringtoadd.begin(), stringtoadd.end()); + } + + } + + return tuix::CreateField( + builder, + tuix::FieldUnion_StringField, + tuix::CreateStringFieldDirect( + builder, &result, static_cast(total)).Union(), + total==0); + + } + case tuix::ExprUnion_In: { auto c = static_cast(expr->expr()); diff --git a/src/flatbuffers/Expr.fbs b/src/flatbuffers/Expr.fbs index a9c0a09168..a96215b5a2 100644 --- a/src/flatbuffers/Expr.fbs +++ b/src/flatbuffers/Expr.fbs @@ -12,6 +12,7 @@ union ExprUnion { GreaterThanOrEqual, EqualTo, Contains, + Concat, In, Col, Literal, @@ -126,8 +127,12 @@ table Contains { right:Expr; } +table Concat { + children:[Expr]; +} + // Array expressions -table In{ +table In { children:[Expr]; } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 46c5325a8b..5a85154253 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -44,8 +44,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.expressions.Contains - -import org.apache.spark.sql.catalyst.expressions.In +import org.apache.spark.sql.catalyst.expressions.Concat import org.apache.spark.sql.catalyst.expressions.DateAdd import org.apache.spark.sql.catalyst.expressions.DateAddInterval import org.apache.spark.sql.catalyst.expressions.Descending @@ -57,6 +56,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.GreaterThan import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual import org.apache.spark.sql.catalyst.expressions.If +import org.apache.spark.sql.catalyst.expressions.In import org.apache.spark.sql.catalyst.expressions.IsNotNull import org.apache.spark.sql.catalyst.expressions.IsNull import org.apache.spark.sql.catalyst.expressions.LessThan @@ -997,6 +997,12 @@ object Utils extends Logging { tuix.Contains.createContains( builder, leftOffset, rightOffset)) + case (Concat(child), childrenOffsets) => + tuix.Expr.createExpr( + builder, + tuix.ExprUnion.Concat, + tuix.Concat.createConcat( + builder, tuix.Concat.createChildrenVector(builder, childrenOffsets.toArray))) case (In(left, right), childrenOffsets) => tuix.Expr.createExpr( @@ -1004,8 +1010,8 @@ object Utils extends Logging { tuix.ExprUnion.In, tuix.In.createIn( builder, tuix.In.createChildrenVector(builder, childrenOffsets.toArray))) - // Time expressions + // Time expressions case (Year(child), Seq(childOffset)) => tuix.Expr.createExpr( builder, diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index ef394d95b6..c8926c3df7 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -440,6 +440,22 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.filter($"word".contains(lit("1"))).collect } + testAgainstSpark("concat with string") { securityLevel => + val data = for (i <- 0 until 256) yield ("%03d".format(i) * 3, i.toString) + val df = makeDF(data, securityLevel, "str", "x") + df.select(concat(col("str"),lit(","),col("x"))).collect + } + + testAgainstSpark("concat with other datatype") { securityLevel => + // float causes a formating issue where opaque outputs 1.000000 and spark produces 1.0 so the following line is commented out + // val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, 1.0f) + // you can't serialize date so that's not supported as well + // opaque doesn't support byte + val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, null.asInstanceOf[Int],"") + val df = makeDF(data, securityLevel, "str", "int","null","emptystring") + df.select(concat(col("str"),lit(","),col("int"),col("null"),col("emptystring"))).collect + } + testAgainstSpark("isin1") { securityLevel => val ids = Seq((1, 2, 2), (2, 3, 1)) val df = makeDF(ids, securityLevel, "x", "y", "id")