Skip to content

Commit a96abc5

Browse files
wzhengoctaviansima
andauthored
Move join condition handling for equi-joins into enclave code (#164)
* Add in TPC-H 21 * Add condition processing in enclave code * Code clean up * Enable query 18 * WIP * Local tests pass * Apply suggestions from code review Co-authored-by: octaviansima <[email protected]> * WIP * Address comments * q21.sql Co-authored-by: octaviansima <[email protected]>
1 parent a4a6ff9 commit a96abc5

File tree

7 files changed

+136
-74
lines changed

7 files changed

+136
-74
lines changed

src/enclave/Enclave/ExpressionEvaluation.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,6 +1789,7 @@ class FlatbuffersJoinExprEvaluator {
17891789
const tuix::JoinExpr* join_expr = flatbuffers::GetRoot<tuix::JoinExpr>(buf);
17901790

17911791
join_type = join_expr->join_type();
1792+
condition_eval = nullptr;
17921793
if (join_expr->condition() != NULL) {
17931794
condition_eval = std::unique_ptr<FlatbuffersExpressionEvaluator>(
17941795
new FlatbuffersExpressionEvaluator(join_expr->condition()));
@@ -1797,9 +1798,6 @@ class FlatbuffersJoinExprEvaluator {
17971798

17981799
if (join_expr->left_keys() != NULL && join_expr->right_keys() != NULL) {
17991800
is_equi_join = true;
1800-
if (join_expr->condition() != NULL) {
1801-
throw std::runtime_error("Equi join cannot have condition");
1802-
}
18031801
if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
18041802
throw std::runtime_error("Mismatched join key lengths");
18051803
}
@@ -1835,14 +1833,12 @@ class FlatbuffersJoinExprEvaluator {
18351833
return is_primary(row1) ? row1 : row2;
18361834
}
18371835

1838-
/** Return true if the two rows satisfy the join condition. */
1839-
bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) {
1836+
/** Return true if the two rows are from the same join group
1837+
* Since the function calls `is_primary`, the rows must have been tagged in Scala */
1838+
bool is_same_group(const tuix::Row *row1, const tuix::Row *row2) {
18401839
builder.Clear();
18411840
bool row1_equals_row2;
18421841

1843-
/** Check equality for equi joins. If it is a non-equi join,
1844-
* the key evaluators will be empty, so the code never enters the for loop.
1845-
*/
18461842
auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators;
18471843
auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators;
18481844
for (uint32_t i = 0; i < row1_evaluators.size(); i++) {
@@ -1855,9 +1851,8 @@ class FlatbuffersJoinExprEvaluator {
18551851
auto row2_field = flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row2_eval_offset);
18561852

18571853
flatbuffers::Offset<tuix::Field> comparison = eval_binary_comparison<tuix::EqualTo, std::equal_to>(
1858-
builder,
1859-
row1_field,
1860-
row2_field);
1854+
builder, row1_field, row2_field);
1855+
18611856
row1_equals_row2 =
18621857
static_cast<const tuix::BooleanField *>(
18631858
flatbuffers::GetTemporaryPointer<tuix::Field>(
@@ -1868,9 +1863,12 @@ class FlatbuffersJoinExprEvaluator {
18681863
return false;
18691864
}
18701865
}
1866+
return true;
1867+
}
18711868

1872-
/* Check condition for non-equi joins */
1873-
if (!is_equi_join) {
1869+
/** Evaluate condition on the two input rows */
1870+
bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) {
1871+
if (condition_eval != nullptr) {
18741872
std::vector<flatbuffers::Offset<tuix::Field>> concat_fields;
18751873
for (auto field : *row1->field_values()) {
18761874
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
@@ -1880,11 +1878,13 @@ class FlatbuffersJoinExprEvaluator {
18801878
}
18811879
flatbuffers::Offset<tuix::Row> concat = tuix::CreateRowDirect(builder, &concat_fields);
18821880
const tuix::Row *concat_ptr = flatbuffers::GetTemporaryPointer<tuix::Row>(builder, concat);
1883-
18841881
const tuix::Field *condition_result = condition_eval->eval(concat_ptr);
1885-
18861882
return static_cast<const tuix::BooleanField *>(condition_result->value())->value();
18871883
}
1884+
1885+
// The `condition_eval` can only be empty when it's an equi-join.
1886+
// Since `condition_eval` is an extra predicate used to filter out *matched* rows in an equi-join, an empty
1887+
// condition means the matched row should not be filtered out; hence the default return value of true
18881888
return true;
18891889
}
18901890

src/enclave/Enclave/NonObliviousSortMergeJoin.cpp

Lines changed: 94 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,53 @@
55
#include "FlatbuffersWriters.h"
66
#include "common.h"
77

8-
/** C++ implementation of a non-oblivious sort merge join.
8+
/**
9+
* C++ implementation of a non-oblivious sort merge join.
910
* Rows MUST be tagged primary or secondary for this to work.
1011
*/
12+
13+
void test_rows_same_group(FlatbuffersJoinExprEvaluator &join_expr_eval,
14+
const tuix::Row *primary,
15+
const tuix::Row *current) {
16+
if (!join_expr_eval.is_same_group(primary, current)) {
17+
throw std::runtime_error(
18+
std::string("Invariant violation: rows of primary_group "
19+
"are not of the same group: ")
20+
+ to_string(primary)
21+
+ std::string(" vs ")
22+
+ to_string(current));
23+
}
24+
}
25+
26+
void write_output_rows(RowWriter &group, RowWriter &w) {
27+
auto group_buffer = group.output_buffer();
28+
RowReader group_reader(group_buffer.view());
29+
30+
while (group_reader.has_next()) {
31+
const tuix::Row *row = group_reader.next();
32+
w.append(row);
33+
}
34+
}
35+
36+
/**
37+
* Sort merge equi join algorithm
38+
* Input: the rows are unioned from both the primary (or left) table and the non-primary (or right) table
39+
*
40+
* Outer loop: iterate over all input rows
41+
*
42+
* If it's a row from the left table
43+
* - Add it to the current group
44+
* - Otherwise start a new group
45+
* - If it's a left semi/anti join, output the primary_matched_rows/primary_unmatched_rows
46+
*
47+
* If it's a row from the right table
48+
* - Inner join: iterate over current left group, output the joined row only if the condition is satisfied
49+
* - Left semi/anti join: iterate over `primary_unmatched_rows`, add a matched row to `primary_matched_rows`
50+
* and remove from `primary_unmatched_rows`
51+
*
52+
* After loop: output the last group left semi/anti join
53+
*/
54+
1155
void non_oblivious_sort_merge_join(
1256
uint8_t *join_expr, size_t join_expr_length,
1357
uint8_t *input_rows, size_t input_rows_length,
@@ -20,81 +64,84 @@ void non_oblivious_sort_merge_join(
2064

2165
RowWriter primary_group;
2266
FlatbuffersTemporaryRow last_primary_of_group;
23-
24-
bool pk_fk_match = false;
67+
RowWriter primary_matched_rows, primary_unmatched_rows; // This is only used for left semi/anti join
2568

2669
while (r.has_next()) {
2770
const tuix::Row *current = r.next();
2871

2972
if (join_expr_eval.is_primary(current)) {
3073
if (last_primary_of_group.get()
31-
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
74+
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
75+
3276
// Add this primary row to the current group
77+
// If this is a left semi/anti join, also add the rows to primary_unmatched_rows
3378
primary_group.append(current);
79+
if (join_type == tuix::JoinType_LeftSemi || join_type == tuix::JoinType_LeftAnti) {
80+
primary_unmatched_rows.append(current);
81+
}
3482
last_primary_of_group.set(current);
83+
3584
} else {
3685
// If a new primary group is encountered
37-
if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
38-
auto primary_group_buffer = primary_group.output_buffer();
39-
RowReader primary_group_reader(primary_group_buffer.view());
40-
41-
while (primary_group_reader.has_next()) {
42-
const tuix::Row *primary = primary_group_reader.next();
43-
w.append(primary);
44-
}
86+
if (join_type == tuix::JoinType_LeftSemi) {
87+
write_output_rows(primary_matched_rows, w);
88+
} else if (join_type == tuix::JoinType_LeftAnti) {
89+
write_output_rows(primary_unmatched_rows, w);
4590
}
4691

4792
primary_group.clear();
93+
primary_unmatched_rows.clear();
94+
primary_matched_rows.clear();
95+
4896
primary_group.append(current);
97+
primary_unmatched_rows.append(current);
4998
last_primary_of_group.set(current);
50-
51-
pk_fk_match = false;
5299
}
53100
} else {
54-
// Output the joined rows resulting from this foreign row
55101
if (last_primary_of_group.get()
56-
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
57-
auto primary_group_buffer = primary_group.output_buffer();
58-
RowReader primary_group_reader(primary_group_buffer.view());
59-
while (primary_group_reader.has_next()) {
60-
const tuix::Row *primary = primary_group_reader.next();
102+
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
103+
if (join_type == tuix::JoinType_Inner) {
104+
auto primary_group_buffer = primary_group.output_buffer();
105+
RowReader primary_group_reader(primary_group_buffer.view());
106+
while (primary_group_reader.has_next()) {
107+
const tuix::Row *primary = primary_group_reader.next();
108+
test_rows_same_group(join_expr_eval, primary, current);
61109

62-
if (!join_expr_eval.eval_condition(primary, current)) {
63-
throw std::runtime_error(
64-
std::string("Invariant violation: rows of primary_group "
65-
"are not of the same group: ")
66-
+ to_string(primary)
67-
+ std::string(" vs ")
68-
+ to_string(current));
110+
if (join_expr_eval.eval_condition(primary, current)) {
111+
w.append(primary, current);
112+
}
69113
}
114+
} else if (join_type == tuix::JoinType_LeftSemi || join_type == tuix::JoinType_LeftAnti) {
115+
auto primary_unmatched_rows_buffer = primary_unmatched_rows.output_buffer();
116+
RowReader primary_unmatched_rows_reader(primary_unmatched_rows_buffer.view());
117+
RowWriter new_primary_unmatched_rows;
70118

71-
if (join_type == tuix::JoinType_Inner) {
72-
w.append(primary, current);
73-
} else if (join_type == tuix::JoinType_LeftSemi) {
74-
// Only output the pk group ONCE
75-
if (!pk_fk_match) {
76-
w.append(primary);
119+
while (primary_unmatched_rows_reader.has_next()) {
120+
const tuix::Row *primary = primary_unmatched_rows_reader.next();
121+
test_rows_same_group(join_expr_eval, primary, current);
122+
if (join_expr_eval.eval_condition(primary, current)) {
123+
primary_matched_rows.append(primary);
124+
} else {
125+
new_primary_unmatched_rows.append(primary);
77126
}
78127
}
128+
129+
// Reset primary_unmatched_rows
130+
primary_unmatched_rows.clear();
131+
auto new_primary_unmatched_rows_buffer = new_primary_unmatched_rows.output_buffer();
132+
RowReader new_primary_unmatched_rows_reader(new_primary_unmatched_rows_buffer.view());
133+
while (new_primary_unmatched_rows_reader.has_next()) {
134+
primary_unmatched_rows.append(new_primary_unmatched_rows_reader.next());
135+
}
79136
}
80-
81-
pk_fk_match = true;
82-
} else {
83-
// If pk_fk_match were true, and the code got to here, then that means the group match has not been "cleared" yet
84-
// It will be processed when the code advances to the next pk group
85-
pk_fk_match &= true;
86137
}
87138
}
88139
}
89140

90-
if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
91-
auto primary_group_buffer = primary_group.output_buffer();
92-
RowReader primary_group_reader(primary_group_buffer.view());
93-
94-
while (primary_group_reader.has_next()) {
95-
const tuix::Row *primary = primary_group_reader.next();
96-
w.append(primary);
97-
}
141+
if (join_type == tuix::JoinType_LeftSemi) {
142+
write_output_rows(primary_matched_rows, w);
143+
} else if (join_type == tuix::JoinType_LeftAnti) {
144+
write_output_rows(primary_unmatched_rows, w);
98145
}
99146

100147
w.output_buffer(output_rows, output_rows_length);

src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ object Utils extends Logging {
12591259
def serializeJoinExpression(
12601260
joinType: JoinType, leftKeys: Option[Seq[Expression]], rightKeys: Option[Seq[Expression]],
12611261
leftSchema: Seq[Attribute], rightSchema: Seq[Attribute],
1262-
condition: Option[Expression] = None): Array[Byte] = {
1262+
condition: Option[Expression]): Array[Byte] = {
12631263
val builder = new FlatBufferBuilder
12641264
builder.finish(
12651265
tuix.JoinExpr.createJoinExpr(
@@ -1298,7 +1298,7 @@ object Utils extends Logging {
12981298
condition match {
12991299
case Some(condition) =>
13001300
flatbuffersSerializeExpression(builder, condition, leftSchema ++ rightSchema)
1301-
case _ => 0
1301+
case None => 0
13021302
}))
13031303
builder.sizedByteArray()
13041304
}

src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.SQLContext
2424
object TPCHBenchmark {
2525

2626
// Add query numbers here once they are supported
27-
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22)
27+
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 21, 22)
2828

2929
def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = {
3030
val sqlStr = tpch.getQuery(queryNumber)

src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ case class EncryptedSortMergeJoinExec(
281281
rightKeys: Seq[Expression],
282282
leftSchema: Seq[Attribute],
283283
rightSchema: Seq[Attribute],
284+
condition: Option[Expression],
284285
child: SparkPlan)
285286
extends UnaryExecNode with OpaqueOperatorExec {
286287

@@ -293,7 +294,7 @@ case class EncryptedSortMergeJoinExec(
293294

294295
override def executeBlocked(): RDD[Block] = {
295296
val joinExprSer = Utils.serializeJoinExpression(
296-
joinType, Some(leftKeys), Some(rightKeys), leftSchema, rightSchema)
297+
joinType, Some(leftKeys), Some(rightKeys), leftSchema, rightSchema, condition)
297298

298299
timeOperator(
299300
child.asInstanceOf[OpaqueOperatorExec].executeBlocked(),

src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,15 @@ object OpaqueOperators extends Strategy {
9898
rightKeysProj,
9999
leftProjSchema.map(_.toAttribute),
100100
rightProjSchema.map(_.toAttribute),
101+
condition,
101102
sorted)
102103

103104
val tagsDropped = joinType match {
104105
case Inner => EncryptedProjectExec(dropTags(left.output, right.output), joined)
105106
case LeftSemi | LeftAnti => EncryptedProjectExec(left.output, joined)
106107
}
107108

108-
val filtered = condition match {
109-
case Some(condition) => EncryptedFilterExec(condition, tagsDropped)
110-
case None => tagsDropped
111-
}
112-
113-
filtered :: Nil
109+
tagsDropped :: Nil
114110

115111
// Used to match non-equi joins
116112
case Join(left, right, joinType, condition, hint) if isEncrypted(left) && isEncrypted(right) =>

src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,20 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
321321
val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10)
322322
val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10)
323323
val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x")
324-
val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x")
324+
val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "y")
325325
val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1", $"id1")
326326
df.collect
327327
}
328328

329+
testAgainstSpark("left semi join with condition") { securityLevel =>
330+
val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10)
331+
val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10)
332+
val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x")
333+
val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "y")
334+
val df = p.join(f, $"join_col_1" === $"join_col_2" && $"x" > $"y", "left_semi").sort($"join_col_1", $"id1")
335+
df.collect
336+
}
337+
329338
testAgainstSpark("non-equi left semi join") { securityLevel =>
330339
val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10)
331340
val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10)
@@ -344,7 +353,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
344353
df.collect
345354
}
346355

347-
testAgainstSpark("left anti join 1") { securityLevel =>
356+
testAgainstSpark("left anti join") { securityLevel =>
348357
val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10)
349358
val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10)
350359
val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x")
@@ -353,6 +362,15 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
353362
df.collect
354363
}
355364

365+
testAgainstSpark("left anti join with condition") { securityLevel =>
366+
val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10)
367+
val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10)
368+
val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x")
369+
val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "y")
370+
val df = p.join(f, $"join_col_1" === $"join_col_2" && $"x" > $"y", "left_anti").sort($"join_col_1", $"id1")
371+
df.collect
372+
}
373+
356374
testAgainstSpark("non-equi left anti join 1") { securityLevel =>
357375
val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10)
358376
val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10)

0 commit comments

Comments
 (0)