Skip to content

Commit 697644b

Browse files
andrewlawhhAndrew Laweric-feng-2011Eric Fengchester-leung
authored
Merge move join (#191)
* Support for multiple branched CaseWhen * Interval (#116) * add date_add, interval sql still running into issues * Add Interval SQL support * uncomment out the other tests * resolve comments * change interval equality Co-authored-by: Eric Feng <[email protected]> * Remove partition ID argument from enclaves * Fix comments * updates * Modifications to integrate crumb, log-mac, and all-outputs_mac, wip * Store log mac after each output buffer, add all-outputs-mac to each encryptedblocks wip * Add all_outputs_mac to all EncryptedBlocks once all log_macs have been generated * Almost builds * cpp builds * Use ubyte for all_outputs_mac * use Mac for all_outputs_mac * Hopefully this works for flatbuffers all_outputs_mac mutation, cpp builds * Scala builds now too, running into error with union * Stuff builds, error with all outputs mac serialization. this commit uses all_outputs_mac as Mac table * Fixed bug, basic encryption / show works * All single partition tests pass, multiple partiton passes until tpch-9 * All tests pass except tpch-9 and skew join * comment tpch back in * Check same number of ecalls per partition - exception for scanCollectLastPrimary(?) * First attempt at constructing executed DAG * Fix typos * Rework graph * Add log macs to graph nodes * Construct expected DAG and refactor JobNode. Refactor construction of executed DAG. * Implement 'paths to sink' for a DAG * add crumb for last ecall * Fix NULL handling for aggregation (#130) * Modify COUNT and SUM to correctly handle NULL values * Change average to support NULL values * Fix * Changing operator matching from logical to physical (#129) * WIP * Fix * Unapply change * Aggregation rewrite (#132) * updated build/sbt file (#135) * Travis update (#137) * update breeze (#138) * TPC-H test suite added (#136) * added tpch sql files * functions updated to save temp view * main function skeleton done * load and clear done * fix clear * performQuery done * import cleanup, use OPAQUE_HOME * TPC-H 9 refactored to use SQL rather than DF operations * removed : Unit, unused imports * added TestUtils.scala * moved all common initialization to TestUtils * update name * begin rewriting TPCH.scala to store persistent tables * invalid table name error * TPCH conversion to class started * compiles * added second case, cleared up names * added TPC-H 6 to check that persistent state has no issues * added functions for the last two tables * addressed most logic changes * DataFrame only loaded once * apply method in companion object * full test suite added * added testFunc parameter to testAgainstSpark * ignore #18 * Separate IN PR (#124) * finishing the in expression. adding more tests and null support. need confirmation on null behavior and also I wonder why integer field is sufficient for string * adding additional test * adding additional test * saving concat implementation and it's passing basic functionality tests * adding type aware comparison and better error message for IN operator * adding null checking for the concat operator and adding one additional test * cleaning up IN&Concat PR * deleting concat and preping the in branch for in pr * fixing null bahavior now it's only null when there's no match and there's null input * Build failed Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Wenting Zheng <[email protected]> Co-authored-by: Wenting Zheng <[email protected]> * Merge new aggregate * Uncomment log_mac_lst clear * Clean up comments * Separate Concat PR (#125) Implementation of the CONCAT expression. Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Wenting Zheng <[email protected]> * Clean up comments in other files * Update pathsEqual to be less conservative * Remove print statements from unit tests * Removed calls to toSet in TPC-H tests (#140) * removed calls to toSet * added calls to toSet back where queries are unordered * Documentation update (#148) * Cluster Remote Attestation Fix (#146) The existing code only had RA working when run locally. This PR adds a sleep for 5 seconds to make sure that all executors are spun up successfully before attestation begins. Closes #147 * upgrade to 3.0.1 (#144) * Update two TPC-H queries (#149) Tests for TPC-H 12 and 19 pass. * TPC-H 20 Fix (#142) * string to stringtype error * tpch 20 passes * cleanup * implemented changes * decimal.tofloat Co-authored-by: Wenting Zheng <[email protected]> * Add expected operator DAG generation from executedPlan string * Rebase * Join update (#145) * Merge join update * Integrate new join * Add expected operator for sortexec * Merge comp-integrity with join update * Remove some print statements * Migrate from Travis CI to Github Actions (#156) * Upgrade to OE 0.12 (#153) * Update README.md * Support for scalar subquery (#157) This PR implements the scalar subquery expression, which is triggered whenever a subquery returns a scalar value. There were two main problems that needed to be solved. First, support for matching the scalar subquery expression is necessary. Spark implements this by wrapping a SparkPlan within the expression and calls executeCollect. Then it constructs a literal with that value. However, this is problematic for us because that value should not be decrypted by the driver and serialized into an expression, since it's an intermediate value. Therefore, the second issue to be addressed here is supporting an encrypted literal. This is implemented in this PR by serializing an encrypted ciphertext into a base64 encoded string, and wrapping a Decrypt expression on top of it. This expression is then evaluated in the enclave and returns a literal. Note that, in order to test our implementation, we also implement a Decrypt expression in Scala. However, this should never be evaluated on the driver side and serialized into a plaintext literal. This is because Decrypt is designated as a Nondeterministic expression, and therefore will always evaluate on the workers. * Add TPC-H Benchmarks (#139) * logic decoupling in TPCH.scala for easier benchmarking * added TPCHBenchmark.scala * Benchmark.scala rewrite * done adding all support TPC-H query benchmarks * changed commandline arguments that benchmark takes * TPCHBenchmark takes in parameters * fixed issue with spark conf * size error handling, --help flag * add Utils.force, break cluster mode * comment out logistic regression benchmark * ensureCached right before temp view created/replaced * upgrade to 3.0.1 * upgrade to 3.0.1 * 10 scale factor * persistData * almost done refactor * more cleanup * compiles * 9 passes * cleanup * collect instead of force, sf_none * remove sf_none * defaultParallelism * no removing trailing/leading whitespace * add sf_med * hdfs works in local case * cleanup, added new CLI argument * added newly supported tpch queries * function for running all supported tests * Construct expected DAG from dataframe physical plan * Refactor collect and add integrity checking helper function to OpaqueOperatorTest * Float expressions (#160) This PR adds float normalization expressions [implemented in Spark](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala#L170). TPC-H query 2 also passes. * Broadcast Nested Loop Join - Left Anti and Left Semi (#159) This PR is the first of two parts towards making TPC-H 16 work: the other will be implementing `is_distinct` for aggregate operations. `BroadcastNestedLoopJoin` is Spark's "catch all" for non-equi joins. It works by first picking a side to broadcast, then iterating through every possible row combination and checking the non-equi condition against the pair. * Move join condition handling for equi-joins into enclave code (#164) * Add in TPC-H 21 * Add condition processing in enclave code * Code clean up * Enable query 18 * WIP * Local tests pass * Apply suggestions from code review Co-authored-by: octaviansima <[email protected]> * WIP * Address comments * q21.sql Co-authored-by: octaviansima <[email protected]> * Remove addExpectedOperator from JobVerificationEngine, add comments * Implement expected DAG construction by doing graph manipulation on dataframe field instead of string parsing * Fix merge errors in the test cases Co-authored-by: Andrew Law <[email protected]> Co-authored-by: Eric Feng <[email protected]> Co-authored-by: Eric Feng <[email protected]> Co-authored-by: Chester Leung <[email protected]> Co-authored-by: Wenting Zheng <[email protected]> Co-authored-by: octaviansima <[email protected]> Co-authored-by: Chenyu Shi <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Wenting Zheng <[email protected]>
1 parent f98f344 commit 697644b

File tree

10 files changed

+209
-105
lines changed

10 files changed

+209
-105
lines changed

.vscode/settings.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"files.associations": {
3+
"*.tcc": "cpp"
4+
}
5+
}

src/enclave/Enclave/ExpressionEvaluation.h

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,7 @@ class FlatbuffersJoinExprEvaluator {
17901790
join_type = join_expr->join_type();
17911791

17921792
join_type = join_expr->join_type();
1793+
condition_eval = nullptr;
17931794
if (join_expr->condition() != NULL) {
17941795
condition_eval = std::unique_ptr<FlatbuffersExpressionEvaluator>(
17951796
new FlatbuffersExpressionEvaluator(join_expr->condition()));
@@ -1798,9 +1799,6 @@ class FlatbuffersJoinExprEvaluator {
17981799

17991800
if (join_expr->left_keys() != NULL && join_expr->right_keys() != NULL) {
18001801
is_equi_join = true;
1801-
if (join_expr->condition() != NULL) {
1802-
throw std::runtime_error("Equi join cannot have condition");
1803-
}
18041802
if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
18051803
throw std::runtime_error("Mismatched join key lengths");
18061804
}
@@ -1836,14 +1834,12 @@ class FlatbuffersJoinExprEvaluator {
18361834
return is_primary(row1) ? row1 : row2;
18371835
}
18381836

1839-
/** Return true if the two rows satisfy the join condition. */
1840-
bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) {
1837+
/** Return true if the two rows are from the same join group
1838+
* Since the function calls `is_primary`, the rows must have been tagged in Scala */
1839+
bool is_same_group(const tuix::Row *row1, const tuix::Row *row2) {
18411840
builder.Clear();
18421841
bool row1_equals_row2;
18431842

1844-
/** Check equality for equi joins. If it is a non-equi join,
1845-
* the key evaluators will be empty, so the code never enters the for loop.
1846-
*/
18471843
auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators;
18481844
auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators;
18491845
for (uint32_t i = 0; i < row1_evaluators.size(); i++) {
@@ -1856,9 +1852,8 @@ class FlatbuffersJoinExprEvaluator {
18561852
auto row2_field = flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row2_eval_offset);
18571853

18581854
flatbuffers::Offset<tuix::Field> comparison = eval_binary_comparison<tuix::EqualTo, std::equal_to>(
1859-
builder,
1860-
row1_field,
1861-
row2_field);
1855+
builder, row1_field, row2_field);
1856+
18621857
row1_equals_row2 =
18631858
static_cast<const tuix::BooleanField *>(
18641859
flatbuffers::GetTemporaryPointer<tuix::Field>(
@@ -1889,6 +1884,28 @@ class FlatbuffersJoinExprEvaluator {
18891884
return true;
18901885
}
18911886

1887+
/** Evaluate condition on the two input rows */
1888+
bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) {
1889+
if (condition_eval != nullptr) {
1890+
std::vector<flatbuffers::Offset<tuix::Field>> concat_fields;
1891+
for (auto field : *row1->field_values()) {
1892+
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
1893+
}
1894+
for (auto field : *row2->field_values()) {
1895+
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
1896+
}
1897+
flatbuffers::Offset<tuix::Row> concat = tuix::CreateRowDirect(builder, &concat_fields);
1898+
const tuix::Row *concat_ptr = flatbuffers::GetTemporaryPointer<tuix::Row>(builder, concat);
1899+
const tuix::Field *condition_result = condition_eval->eval(concat_ptr);
1900+
return static_cast<const tuix::BooleanField *>(condition_result->value())->value();
1901+
}
1902+
1903+
// The `condition_eval` can only be empty when it's an equi-join.
1904+
// Since `condition_eval` is an extra predicate used to filter out *matched* rows in an equi-join, an empty
1905+
// condition means the matched row should not be filtered out; hence the default return value of true
1906+
return true;
1907+
}
1908+
18921909
tuix::JoinType get_join_type() {
18931910
return join_type;
18941911
}

src/enclave/Enclave/NonObliviousSortMergeJoin.cpp

Lines changed: 98 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,53 @@
55
#include "FlatbuffersWriters.h"
66
#include "common.h"
77

8+
/**
9+
* C++ implementation of a non-oblivious sort merge join.
10+
* Rows MUST be tagged primary or secondary for this to work.
11+
*/
12+
13+
void test_rows_same_group(FlatbuffersJoinExprEvaluator &join_expr_eval,
14+
const tuix::Row *primary,
15+
const tuix::Row *current) {
16+
if (!join_expr_eval.is_same_group(primary, current)) {
17+
throw std::runtime_error(
18+
std::string("Invariant violation: rows of primary_group "
19+
"are not of the same group: ")
20+
+ to_string(primary)
21+
+ std::string(" vs ")
22+
+ to_string(current));
23+
}
24+
}
25+
26+
void write_output_rows(RowWriter &group, RowWriter &w) {
27+
auto group_buffer = group.output_buffer(std::string(""));
28+
RowReader group_reader(group_buffer.view());
29+
30+
while (group_reader.has_next()) {
31+
const tuix::Row *row = group_reader.next();
32+
w.append(row);
33+
}
34+
}
35+
36+
/**
37+
* Sort merge equi join algorithm
38+
* Input: the rows are unioned from both the primary (or left) table and the non-primary (or right) table
39+
*
40+
* Outer loop: iterate over all input rows
41+
*
42+
* If it's a row from the left table
43+
* - Add it to the current group
44+
* - Otherwise start a new group
45+
* - If it's a left semi/anti join, output the primary_matched_rows/primary_unmatched_rows
46+
*
47+
* If it's a row from the right table
48+
* - Inner join: iterate over current left group, output the joined row only if the condition is satisfied
49+
* - Left semi/anti join: iterate over `primary_unmatched_rows`, add a matched row to `primary_matched_rows`
50+
* and remove from `primary_unmatched_rows`
51+
*
52+
* After loop: output the last group left semi/anti join
53+
*/
54+
855
void non_oblivious_sort_merge_join(
956
uint8_t *join_expr, size_t join_expr_length,
1057
uint8_t *input_rows, size_t input_rows_length,
@@ -17,86 +64,85 @@ void non_oblivious_sort_merge_join(
1764

1865
RowWriter primary_group;
1966
FlatbuffersTemporaryRow last_primary_of_group;
67+
RowWriter primary_matched_rows, primary_unmatched_rows; // This is only used for left semi/anti join
2068

21-
bool pk_fk_match = false;
22-
69+
EnclaveContext::getInstance().set_append_mac(false);
2370
while (r.has_next()) {
2471
const tuix::Row *current = r.next();
72+
2573
if (join_expr_eval.is_primary(current)) {
26-
EnclaveContext::getInstance().set_append_mac(false);
27-
// If current row is from primary table
2874
if (last_primary_of_group.get()
29-
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
75+
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
76+
3077
// Add this primary row to the current group
78+
// If this is a left semi/anti join, also add the rows to primary_unmatched_rows
3179
primary_group.append(current);
80+
if (join_type == tuix::JoinType_LeftSemi || join_type == tuix::JoinType_LeftAnti) {
81+
primary_unmatched_rows.append(current);
82+
}
3283
last_primary_of_group.set(current);
84+
3385
} else {
3486
// If a new primary group is encountered
35-
if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
36-
auto primary_group_buffer = primary_group.output_buffer(std::string(""));
37-
RowReader primary_group_reader(primary_group_buffer.view());
38-
39-
while (primary_group_reader.has_next()) {
40-
const tuix::Row *primary = primary_group_reader.next();
41-
w.append(primary);
42-
}
87+
if (join_type == tuix::JoinType_LeftSemi) {
88+
write_output_rows(primary_matched_rows, w);
89+
} else if (join_type == tuix::JoinType_LeftAnti) {
90+
write_output_rows(primary_unmatched_rows, w);
4391
}
4492

4593
primary_group.clear();
94+
primary_unmatched_rows.clear();
95+
primary_matched_rows.clear();
96+
4697
primary_group.append(current);
98+
primary_unmatched_rows.append(current);
4799
last_primary_of_group.set(current);
48-
49-
pk_fk_match = false;
50100
}
51101
} else {
52-
// Current row isn't from primary table
53-
// Output the joined rows resulting from this foreign row
54102
if (last_primary_of_group.get()
55-
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
56-
EnclaveContext::getInstance().set_append_mac(false);
57-
auto primary_group_buffer = primary_group.output_buffer(std::string(""));
58-
RowReader primary_group_reader(primary_group_buffer.view());
59-
while (primary_group_reader.has_next()) {
60-
// For each foreign key row, join all primary key rows in same group with it
61-
const tuix::Row *primary = primary_group_reader.next();
103+
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
104+
if (join_type == tuix::JoinType_Inner) {
105+
auto primary_group_buffer = primary_group.output_buffer(std::string(""));
106+
RowReader primary_group_reader(primary_group_buffer.view());
107+
while (primary_group_reader.has_next()) {
108+
const tuix::Row *primary = primary_group_reader.next();
109+
test_rows_same_group(join_expr_eval, primary, current);
62110

63-
if (!join_expr_eval.eval_condition(primary, current)) {
64-
throw std::runtime_error(
65-
std::string("Invariant violation: rows of primary_group "
66-
"are not of the same group: ")
67-
+ to_string(primary)
68-
+ std::string(" vs ")
69-
+ to_string(current));
111+
if (join_expr_eval.eval_condition(primary, current)) {
112+
w.append(primary, current);
113+
}
70114
}
115+
} else if (join_type == tuix::JoinType_LeftSemi || join_type == tuix::JoinType_LeftAnti) {
116+
auto primary_unmatched_rows_buffer = primary_unmatched_rows.output_buffer(std::string(""));
117+
RowReader primary_unmatched_rows_reader(primary_unmatched_rows_buffer.view());
118+
RowWriter new_primary_unmatched_rows;
71119

72-
if (join_type == tuix::JoinType_Inner) {
73-
w.append(primary, current);
74-
} else if (join_type == tuix::JoinType_LeftSemi) {
75-
// Only output the pk group ONCE
76-
if (!pk_fk_match) {
77-
w.append(primary);
120+
while (primary_unmatched_rows_reader.has_next()) {
121+
const tuix::Row *primary = primary_unmatched_rows_reader.next();
122+
test_rows_same_group(join_expr_eval, primary, current);
123+
if (join_expr_eval.eval_condition(primary, current)) {
124+
primary_matched_rows.append(primary);
125+
} else {
126+
new_primary_unmatched_rows.append(primary);
78127
}
79128
}
129+
130+
// Reset primary_unmatched_rows
131+
primary_unmatched_rows.clear();
132+
auto new_primary_unmatched_rows_buffer = new_primary_unmatched_rows.output_buffer(std::string(""));
133+
RowReader new_primary_unmatched_rows_reader(new_primary_unmatched_rows_buffer.view());
134+
while (new_primary_unmatched_rows_reader.has_next()) {
135+
primary_unmatched_rows.append(new_primary_unmatched_rows_reader.next());
136+
}
80137
}
81-
82-
pk_fk_match = true;
83-
} else {
84-
// If pk_fk_match were true, and the code got to here, then that means the group match has not been "cleared" yet
85-
// It will be processed when the code advances to the next pk group
86-
pk_fk_match &= true;
87138
}
88139
}
89140
}
90141

91-
if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
92-
EnclaveContext::getInstance().set_append_mac(false);
93-
auto primary_group_buffer = primary_group.output_buffer(std::string(""));
94-
RowReader primary_group_reader(primary_group_buffer.view());
95-
96-
while (primary_group_reader.has_next()) {
97-
const tuix::Row *primary = primary_group_reader.next();
98-
w.append(primary);
99-
}
142+
if (join_type == tuix::JoinType_LeftSemi) {
143+
write_output_rows(primary_matched_rows, w);
144+
} else if (join_type == tuix::JoinType_LeftAnti) {
145+
write_output_rows(primary_unmatched_rows, w);
100146
}
101147

102148
EnclaveContext::getInstance().set_append_mac(true);

src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ object JobVerificationEngine {
544544
if (!arePathsEqual) {
545545
// println(executedPathsToSink.toString)
546546
// println(expectedPathsToSink.toString)
547-
println("===========DAGS NOT EQUAL===========")
547+
// println("===========DAGS NOT EQUAL===========")
548+
return false
548549
}
549550
return true
550551
}

src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,7 +1279,7 @@ object Utils extends Logging {
12791279
def serializeJoinExpression(
12801280
joinType: JoinType, leftKeys: Option[Seq[Expression]], rightKeys: Option[Seq[Expression]],
12811281
leftSchema: Seq[Attribute], rightSchema: Seq[Attribute],
1282-
condition: Option[Expression] = None): Array[Byte] = {
1282+
condition: Option[Expression]): Array[Byte] = {
12831283
val builder = new FlatBufferBuilder
12841284
builder.finish(
12851285
tuix.JoinExpr.createJoinExpr(
@@ -1318,7 +1318,7 @@ object Utils extends Logging {
13181318
condition match {
13191319
case Some(condition) =>
13201320
flatbuffersSerializeExpression(builder, condition, leftSchema ++ rightSchema)
1321-
case _ => 0
1321+
case None => 0
13221322
}))
13231323
builder.sizedByteArray()
13241324
}

src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.SQLContext
2424
object TPCHBenchmark {
2525

2626
// Add query numbers here once they are supported
27-
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22)
27+
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 21, 22)
2828

2929
def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = {
3030
val sqlStr = tpch.getQuery(queryNumber)

src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ case class EncryptedSortMergeJoinExec(
292292
rightKeys: Seq[Expression],
293293
leftSchema: Seq[Attribute],
294294
rightSchema: Seq[Attribute],
295+
condition: Option[Expression],
295296
child: SparkPlan)
296297
extends UnaryExecNode with OpaqueOperatorExec {
297298

@@ -304,7 +305,7 @@ case class EncryptedSortMergeJoinExec(
304305

305306
override def executeBlocked(): RDD[Block] = {
306307
val joinExprSer = Utils.serializeJoinExpression(
307-
joinType, Some(leftKeys), Some(rightKeys), leftSchema, rightSchema)
308+
joinType, Some(leftKeys), Some(rightKeys), leftSchema, rightSchema, condition)
308309

309310
timeOperator(
310311
child.asInstanceOf[OpaqueOperatorExec].executeBlocked(),

src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,35 @@ object OpaqueOperators extends Strategy {
9898
rightKeysProj,
9999
leftProjSchema.map(_.toAttribute),
100100
rightProjSchema.map(_.toAttribute),
101+
condition,
101102
sorted)
102103

103104
val tagsDropped = joinType match {
104105
case Inner => EncryptedProjectExec(dropTags(left.output, right.output), joined)
105106
case LeftSemi | LeftAnti => EncryptedProjectExec(left.output, joined)
106107
}
107108

108-
val filtered = condition match {
109-
case Some(condition) => EncryptedFilterExec(condition, tagsDropped)
110-
case None => tagsDropped
111-
}
109+
tagsDropped :: Nil
110+
111+
// Used to match non-equi joins
112+
case Join(left, right, joinType, condition, hint) if isEncrypted(left) && isEncrypted(right) =>
113+
// How to pick broadcast side: if left join, broadcast right. If right join, broadcast left.
114+
// This is the simplest and most performant method, but may be worth revisting if one side is
115+
// significantly smaller than the other. Otherwise, pick the smallest side to broadcast.
116+
// NOTE: the current implementation of BNLJ only works under the assumption that
117+
// left join <==> broadcast right AND right join <==> broadcast left.
118+
val desiredBuildSide = if (joinType.isInstanceOf[InnerLike] || joinType == FullOuter)
119+
getSmallerSide(left, right) else
120+
getBroadcastSideBNLJ(joinType)
121+
122+
val joined = EncryptedBroadcastNestedLoopJoinExec(
123+
planLater(left),
124+
planLater(right),
125+
desiredBuildSide,
126+
joinType,
127+
condition)
112128

113-
filtered :: Nil
129+
joined :: Nil
114130

115131
// Used to match non-equi joins
116132
case Join(left, right, joinType, condition, hint) if isEncrypted(left) && isEncrypted(right) =>

0 commit comments

Comments
 (0)