Skip to content

Commit 7c89dd1

Browse files
committed
matching in strategies.scala
set up class thing cleanup added test cases for non-equi left anti join rename to serializeEquiJoinExpression added isEncrypted condition set up keys JoinExpr now has condition rename serialization does not throw compile error for BNLJ split up added condition in ExpressionEvaluation.h zipPartitions cpp put in place typo added func to header two loops in place update tests condition fixed scala loop interchange rows added tags ensure cached == match working comparison decoupling in ExpressionEvalulation save compiles and condition works is printing fix swap outer/inner o_i_match show() has the same result tests pass test cleanup added test cases for different condition BuildLeft works optional keys in scala started C++ passes the operator tests comments, cleanup attemping to do it the ~right~ way comments to distinguish between primary/secondary, operator tests pass cleanup comments, about to begin implementation for distinct agg ops is_distinct added test case serializing with isDistinct is_distinct in ExpressionEvaluation.h removed unused code from join implementation remove RowWriter/Reader in condition evaluation (join) easier test serialization done correct checking in Scala set is set up spaghetti but it finally works function for clearing values condition_eval isntead of condition goto comment remove explain from test, need to fix distinct aggregation for >1 partitions started impl of multiple partitions fix added rangepartitionexec that runs partitioning cleanup serialization properly comments, generalization for > 1 distinct function comments about to refactor into logical.Aggregation the new case has distinct in result expressions need to match on distinct removed new case (doesn't make difference?) works Upgrade to OE 0.12 (mc2-project#153) Update README.md Support for scalar subquery (mc2-project#157) This PR implements the scalar subquery expression, which is triggered whenever a subquery returns a scalar value. There were two main problems that needed to be solved. First, support for matching the scalar subquery expression is necessary. Spark implements this by wrapping a SparkPlan within the expression and calls executeCollect. Then it constructs a literal with that value. However, this is problematic for us because that value should not be decrypted by the driver and serialized into an expression, since it's an intermediate value. Therefore, the second issue to be addressed here is supporting an encrypted literal. This is implemented in this PR by serializing an encrypted ciphertext into a base64 encoded string, and wrapping a Decrypt expression on top of it. This expression is then evaluated in the enclave and returns a literal. Note that, in order to test our implementation, we also implement a Decrypt expression in Scala. However, this should never be evaluated on the driver side and serialized into a plaintext literal. This is because Decrypt is designated as a Nondeterministic expression, and therefore will always evaluate on the workers. match remove RangePartitionExec inefficient implementation refined Add TPC-H Benchmarks (mc2-project#139) * logic decoupling in TPCH.scala for easier benchmarking * added TPCHBenchmark.scala * Benchmark.scala rewrite * done adding all support TPC-H query benchmarks * changed commandline arguments that benchmark takes * TPCHBenchmark takes in parameters * fixed issue with spark conf * size error handling, --help flag * add Utils.force, break cluster mode * comment out logistic regression benchmark * ensureCached right before temp view created/replaced * upgrade to 3.0.1 * upgrade to 3.0.1 * 10 scale factor * persistData * almost done refactor * more cleanup * compiles * 9 passes * cleanup * collect instead of force, sf_none * remove sf_none * defaultParallelism * no removing trailing/leading whitespace * add sf_med * hdfs works in local case * cleanup, added new CLI argument * added newly supported tpch queries * function for running all supported tests complete instead of partial -> final removed traces of join cleanup
1 parent 96e6285 commit 7c89dd1

File tree

6 files changed

+124
-30
lines changed

6 files changed

+124
-30
lines changed

src/enclave/Enclave/Aggregate.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ void non_oblivious_aggregate(
3030
count += 1;
3131
}
3232

33-
// Skip outputting the final row if the number of input rows is 0 AND
34-
// 1. It's a grouping aggregation, OR
33+
// Skip outputting the final row if:
34+
// 1. The number of input rows is 0 AND it's a grouping aggregation, OR
3535
// 2. It's a global aggregation, the mode is final
3636
if (!(count == 0 && (agg_op_eval.get_num_grouping_keys() > 0 || (agg_op_eval.get_num_grouping_keys() == 0 && !is_partial)))) {
3737
w.append(agg_op_eval.evaluate());

src/enclave/Enclave/ExpressionEvaluation.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,6 +1811,9 @@ class AggregateExpressionEvaluator {
18111811
std::unique_ptr<FlatbuffersExpressionEvaluator>(
18121812
new FlatbuffersExpressionEvaluator(eval_expr)));
18131813
}
1814+
is_distinct = expr->is_distinct();
1815+
value_selector = std::unique_ptr<FlatbuffersExpressionEvaluator>(
1816+
new FlatbuffersExpressionEvaluator(expr->value_selector()));
18141817
}
18151818

18161819
std::vector<const tuix::Field *> initial_values(const tuix::Row *unused) {
@@ -1824,6 +1827,15 @@ class AggregateExpressionEvaluator {
18241827
std::vector<const tuix::Field *> update(const tuix::Row *concat) {
18251828
std::vector<const tuix::Field *> result;
18261829
for (auto&& e : update_evaluators) {
1830+
if (is_distinct) {
1831+
std::string value = to_string(value_selector->eval(concat));
1832+
/* Check to see if this distinct value has already been counted */
1833+
if (observed_values.count(value)) {
1834+
std::vector<const tuix::Field *> vect(1, nullptr);
1835+
return vect;
1836+
}
1837+
observed_values.insert(value);
1838+
}
18271839
result.push_back(e->eval(concat));
18281840
}
18291841
return result;
@@ -1837,11 +1849,18 @@ class AggregateExpressionEvaluator {
18371849
return result;
18381850
}
18391851

1852+
void clear_observed_values() {
1853+
observed_values.clear();
1854+
}
1855+
18401856
private:
18411857
flatbuffers::FlatBufferBuilder builder;
18421858
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> initial_value_evaluators;
18431859
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> update_evaluators;
18441860
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> evaluate_evaluators;
1861+
bool is_distinct;
1862+
std::unique_ptr<FlatbuffersExpressionEvaluator> value_selector;
1863+
std::set<std::string> observed_values;
18451864
};
18461865

18471866
class FlatbuffersAggOpEvaluator {
@@ -1880,6 +1899,7 @@ class FlatbuffersAggOpEvaluator {
18801899
// Write initial values to a
18811900
std::vector<flatbuffers::Offset<tuix::Field>> init_fields;
18821901
for (auto&& e : aggregate_evaluators) {
1902+
e->clear_observed_values();
18831903
for (auto f : e->initial_values(nullptr)) {
18841904
init_fields.push_back(flatbuffers_copy<tuix::Field>(f, builder2));
18851905
}
@@ -1901,6 +1921,7 @@ class FlatbuffersAggOpEvaluator {
19011921
void aggregate(const tuix::Row *row) {
19021922
builder.Clear();
19031923
flatbuffers::Offset<tuix::Row> concat;
1924+
int a_length = a->field_values()->size();
19041925

19051926
std::vector<flatbuffers::Offset<tuix::Field>> concat_fields;
19061927
// concat row to a
@@ -1918,9 +1939,18 @@ class FlatbuffersAggOpEvaluator {
19181939
std::vector<flatbuffers::Offset<tuix::Field>> output_fields;
19191940
for (auto&& e : aggregate_evaluators) {
19201941
for (auto f : e->update(concat_ptr)) {
1942+
if (f == nullptr) { // Only triggered on EXPR(distinct expr ...)
1943+
output_fields.clear();
1944+
for (int i = 0; i < a_length; i++) {
1945+
auto f = concat_ptr->field_values()->Get(i);
1946+
output_fields.push_back(flatbuffers_copy<tuix::Field>(f, builder2));
1947+
}
1948+
goto save_a;
1949+
}
19211950
output_fields.push_back(flatbuffers_copy<tuix::Field>(f, builder2));
19221951
}
19231952
}
1953+
save_a:
19241954
a = flatbuffers::GetTemporaryPointer<tuix::Row>(
19251955
builder2, tuix::CreateRowDirect(builder2, &output_fields));
19261956
}

src/flatbuffers/operators.fbs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ table AggregateExpr {
3838
initial_values: [Expr];
3939
update_exprs: [Expr];
4040
evaluate_exprs: [Expr];
41+
// Items below are used for EXPR(distinct col_name ...)
42+
is_distinct: bool;
43+
value_selector: Expr;
4144
}
4245
// Supported: Average, Count, First, Last, Max, Min, Sum
4346

src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,13 +1371,25 @@ object Utils extends Logging {
13711371
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
13721372
tuix.AggregateExpr.createEvaluateExprsVector(
13731373
builder,
1374-
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)
1374+
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
1375+
false,
1376+
0
13751377
)
13761378

13771379
case c @ Count(children) =>
13781380
val count = c.aggBufferAttributes(0)
13791381
// COUNT(*) should count NULL values
13801382
// COUNT(expr) should return the number or rows for which the supplied expressions are non-NULL
1383+
// COUNT(distinct expr ...) should return the number of rows that contain UNIQUE values of expr
1384+
1385+
val ar = e.aggregateFunction.children(0)
1386+
val colNum = concatSchema.indexWhere(_.semanticEquals(ar))
1387+
val (isDistinct, valueSelector) = (e.isDistinct, colNum) match {
1388+
case (true, x) if x >= 0 => // If colNum < 0, then the given schema does not contain the attribute
1389+
(true, flatbuffersSerializeExpression(builder, ar, concatSchema))
1390+
case _ =>
1391+
(false, 0)
1392+
}
13811393

13821394
val (updateExprs: Seq[Expression], evaluateExprs: Seq[Expression]) = e.mode match {
13831395
case Partial => {
@@ -1396,7 +1408,7 @@ object Utils extends Logging {
13961408
val countUpdateExpr = Add(count, Literal(1L))
13971409
(Seq(countUpdateExpr), Seq(count))
13981410
}
1399-
case _ =>
1411+
case _ =>
14001412
}
14011413

14021414
tuix.AggregateExpr.createAggregateExpr(
@@ -1410,7 +1422,9 @@ object Utils extends Logging {
14101422
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
14111423
tuix.AggregateExpr.createEvaluateExprsVector(
14121424
builder,
1413-
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)
1425+
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
1426+
isDistinct,
1427+
valueSelector
14141428
)
14151429

14161430
case f @ First(child, false) =>
@@ -1449,7 +1463,10 @@ object Utils extends Logging {
14491463
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
14501464
tuix.AggregateExpr.createEvaluateExprsVector(
14511465
builder,
1452-
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
1466+
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
1467+
false,
1468+
0
1469+
)
14531470

14541471
case l @ Last(child, false) =>
14551472
val last = l.aggBufferAttributes(0)
@@ -1487,7 +1504,10 @@ object Utils extends Logging {
14871504
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
14881505
tuix.AggregateExpr.createEvaluateExprsVector(
14891506
builder,
1490-
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
1507+
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
1508+
false,
1509+
0
1510+
)
14911511

14921512
case m @ Max(child) =>
14931513
val max = m.aggBufferAttributes(0)
@@ -1520,7 +1540,10 @@ object Utils extends Logging {
15201540
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
15211541
tuix.AggregateExpr.createEvaluateExprsVector(
15221542
builder,
1523-
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
1543+
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
1544+
false,
1545+
0
1546+
)
15241547

15251548
case m @ Min(child) =>
15261549
val min = m.aggBufferAttributes(0)
@@ -1553,7 +1576,10 @@ object Utils extends Logging {
15531576
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
15541577
tuix.AggregateExpr.createEvaluateExprsVector(
15551578
builder,
1556-
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
1579+
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
1580+
false,
1581+
0
1582+
)
15571583

15581584
case s @ Sum(child) =>
15591585
val sum = s.aggBufferAttributes(0)
@@ -1591,7 +1617,10 @@ object Utils extends Logging {
15911617
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
15921618
tuix.AggregateExpr.createEvaluateExprsVector(
15931619
builder,
1594-
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
1620+
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
1621+
false,
1622+
0
1623+
)
15951624

15961625
case vs @ ScalaUDAF(Seq(child), _: VectorSum, _, _) =>
15971626
val sum = vs.aggBufferAttributes(0)
@@ -1626,7 +1655,10 @@ object Utils extends Logging {
16261655
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
16271656
tuix.AggregateExpr.createEvaluateExprsVector(
16281657
builder,
1629-
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
1658+
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
1659+
false,
1660+
0
1661+
)
16301662
}
16311663
}
16321664

src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,25 +109,47 @@ object OpaqueOperators extends Strategy {
109109
if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) =>
110110

111111
val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression])
112-
113-
if (groupingExpressions.size == 0) {
114-
// Global aggregation
115-
val partialAggregate = EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, planLater(child))
116-
val partialOutput = partialAggregate.output
117-
val (projSchema, tag) = tagForGlobalAggregate(partialOutput)
118-
119-
EncryptedProjectExec(resultExpressions,
120-
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
121-
EncryptedProjectExec(partialOutput,
122-
EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true,
123-
EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil
124-
} else {
125-
// Grouping aggregation
126-
EncryptedProjectExec(resultExpressions,
127-
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
128-
EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true,
129-
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial,
130-
EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil
112+
val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct)
113+
114+
functionsWithDistinct.size match {
115+
case size if size == 0 => // No distinct aggregate operations
116+
if (groupingExpressions.size == 0) {
117+
// Global aggregation
118+
val partialAggregate = EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, planLater(child))
119+
val partialOutput = partialAggregate.output
120+
val (projSchema, tag) = tagForGlobalAggregate(partialOutput)
121+
122+
EncryptedProjectExec(resultExpressions,
123+
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
124+
EncryptedProjectExec(partialOutput,
125+
EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true,
126+
EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil
127+
} else {
128+
// Grouping aggregation
129+
EncryptedProjectExec(resultExpressions,
130+
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
131+
EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true,
132+
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial,
133+
EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil
134+
}
135+
case size if size == 1 => // One distinct aggregate operation
136+
if (groupingExpressions.size == 0) {
137+
// Global aggregation
138+
val partialAggregate = EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, planLater(child))
139+
val partialOutput = partialAggregate.output
140+
val (projSchema, tag) = tagForGlobalAggregate(partialOutput)
141+
142+
EncryptedProjectExec(resultExpressions,
143+
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
144+
EncryptedProjectExec(partialOutput,
145+
EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true,
146+
EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil
147+
} else {
148+
// Grouping aggregation
149+
EncryptedProjectExec(resultExpressions,
150+
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Complete,
151+
EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), true, planLater(child)))) :: Nil
152+
}
131153
}
132154

133155
case p @ Union(Seq(left, right)) if isEncrypted(p) =>

src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,13 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
377377
.collect.sortBy { case Row(category: String, _) => category }
378378
}
379379

380+
testAgainstSpark("aggregate count - distinct") { securityLevel =>
381+
val data = (0 until 32).map{ i => (abc(i), i % 8)}.toSeq
382+
val words = makeDF(data, securityLevel, "category", "price")
383+
words.groupBy("category").agg(countDistinct("price").as("distinctPrices"))
384+
.collect.sortBy { case Row(category: String, _) => category }
385+
}
386+
380387
testAgainstSpark("aggregate first") { securityLevel =>
381388
val data = for (i <- 0 until 256) yield (i, abc(i), 1)
382389
val words = makeDF(data, securityLevel, "id", "category", "price")

0 commit comments

Comments
 (0)