Skip to content

Commit b78b4a4

Browse files
Chenyu-ShiUbuntuwzheng
authored
Separate Concat PR (#125)
Implementation of the CONCAT expression. Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Wenting Zheng <[email protected]>
1 parent 2fec4ad commit b78b4a4

File tree

4 files changed

+75
-4
lines changed

4 files changed

+75
-4
lines changed

src/enclave/Enclave/ExpressionEvaluation.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,50 @@ class FlatbuffersExpressionEvaluator {
743743
}
744744

745745

746+
case tuix::ExprUnion_Concat:
747+
{
748+
//implementing this like string concat since each argument in already serialized
749+
auto c = static_cast<const tuix::Concat *>(expr->expr());
750+
size_t num_children = c->children()->size();
751+
752+
size_t total = 0;
753+
754+
std::vector<uint8_t> result;
755+
756+
for (size_t i =0; i< num_children; i++){
757+
auto offset = eval_helper(row, (*c->children())[i]);
758+
const tuix::Field *str = flatbuffers::GetTemporaryPointer(builder, offset);
759+
if (str->value_type() != tuix::FieldUnion_StringField) {
760+
throw std::runtime_error(
761+
std::string("tuix::Concat requires serializable data types, not ")
762+
+ std::string(tuix::EnumNameFieldUnion(str->value_type()))
763+
+ std::string(". You do not need to provide the data as string but the data should be serialized into string before sent to concat"));
764+
}
765+
if (!str->is_null()){
766+
// skipping over the null input
767+
auto str_field = static_cast<const tuix::StringField *>(str->value());
768+
uint32_t start = 0;
769+
uint32_t end = str_field ->length();
770+
total += end;
771+
std::vector<uint8_t> stringtoadd(
772+
flatbuffers::VectorIterator<uint8_t, uint8_t>(str_field->value()->Data(),
773+
start),
774+
flatbuffers::VectorIterator<uint8_t, uint8_t>(str_field->value()->Data(),
775+
end));
776+
result.insert(result.end(), stringtoadd.begin(), stringtoadd.end());
777+
}
778+
779+
}
780+
781+
return tuix::CreateField(
782+
builder,
783+
tuix::FieldUnion_StringField,
784+
tuix::CreateStringFieldDirect(
785+
builder, &result, static_cast<uint32_t>(total)).Union(),
786+
total==0);
787+
788+
}
789+
746790
case tuix::ExprUnion_In:
747791
{
748792
auto c = static_cast<const tuix::In *>(expr->expr());

src/flatbuffers/Expr.fbs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ union ExprUnion {
1212
GreaterThanOrEqual,
1313
EqualTo,
1414
Contains,
15+
Concat,
1516
In,
1617
Col,
1718
Literal,
@@ -126,8 +127,12 @@ table Contains {
126127
right:Expr;
127128
}
128129

130+
table Concat {
131+
children:[Expr];
132+
}
133+
129134
// Array expressions
130-
table In{
135+
table In {
131136
children:[Expr];
132137
}
133138

src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
4444
import org.apache.spark.sql.catalyst.expressions.AttributeReference
4545
import org.apache.spark.sql.catalyst.expressions.Cast
4646
import org.apache.spark.sql.catalyst.expressions.Contains
47-
48-
import org.apache.spark.sql.catalyst.expressions.In
47+
import org.apache.spark.sql.catalyst.expressions.Concat
4948
import org.apache.spark.sql.catalyst.expressions.DateAdd
5049
import org.apache.spark.sql.catalyst.expressions.DateAddInterval
5150
import org.apache.spark.sql.catalyst.expressions.Descending
@@ -57,6 +56,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
5756
import org.apache.spark.sql.catalyst.expressions.GreaterThan
5857
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual
5958
import org.apache.spark.sql.catalyst.expressions.If
59+
import org.apache.spark.sql.catalyst.expressions.In
6060
import org.apache.spark.sql.catalyst.expressions.IsNotNull
6161
import org.apache.spark.sql.catalyst.expressions.IsNull
6262
import org.apache.spark.sql.catalyst.expressions.LessThan
@@ -997,15 +997,21 @@ object Utils extends Logging {
997997
tuix.Contains.createContains(
998998
builder, leftOffset, rightOffset))
999999

1000+
case (Concat(child), childrenOffsets) =>
1001+
tuix.Expr.createExpr(
1002+
builder,
1003+
tuix.ExprUnion.Concat,
1004+
tuix.Concat.createConcat(
1005+
builder, tuix.Concat.createChildrenVector(builder, childrenOffsets.toArray)))
10001006

10011007
case (In(left, right), childrenOffsets) =>
10021008
tuix.Expr.createExpr(
10031009
builder,
10041010
tuix.ExprUnion.In,
10051011
tuix.In.createIn(
10061012
builder, tuix.In.createChildrenVector(builder, childrenOffsets.toArray)))
1007-
// Time expressions
10081013

1014+
// Time expressions
10091015
case (Year(child), Seq(childOffset)) =>
10101016
tuix.Expr.createExpr(
10111017
builder,

src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,22 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
440440
df.filter($"word".contains(lit("1"))).collect
441441
}
442442

443+
testAgainstSpark("concat with string") { securityLevel =>
444+
val data = for (i <- 0 until 256) yield ("%03d".format(i) * 3, i.toString)
445+
val df = makeDF(data, securityLevel, "str", "x")
446+
df.select(concat(col("str"),lit(","),col("x"))).collect
447+
}
448+
449+
testAgainstSpark("concat with other datatype") { securityLevel =>
450+
// float causes a formating issue where opaque outputs 1.000000 and spark produces 1.0 so the following line is commented out
451+
// val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, 1.0f)
452+
// you can't serialize date so that's not supported as well
453+
// opaque doesn't support byte
454+
val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, null.asInstanceOf[Int],"")
455+
val df = makeDF(data, securityLevel, "str", "int","null","emptystring")
456+
df.select(concat(col("str"),lit(","),col("int"),col("null"),col("emptystring"))).collect
457+
}
458+
443459
testAgainstSpark("isin1") { securityLevel =>
444460
val ids = Seq((1, 2, 2), (2, 3, 1))
445461
val df = makeDF(ids, securityLevel, "x", "y", "id")

0 commit comments

Comments
 (0)