diff --git a/src/enclave/App/App.cpp b/src/enclave/App/App.cpp index 64013d2ab7..596e593d52 100644 --- a/src/enclave/App/App.cpp +++ b/src/enclave/App/App.cpp @@ -555,6 +555,50 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( return ret; } +JNIEXPORT jbyteArray JNICALL +Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin( + JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray outer_rows, jbyteArray inner_rows) { + (void)obj; + + jboolean if_copy; + + uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr); + uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy); + + uint32_t outer_rows_length = (uint32_t) env->GetArrayLength(outer_rows); + uint8_t *outer_rows_ptr = (uint8_t *) env->GetByteArrayElements(outer_rows, &if_copy); + + uint32_t inner_rows_length = (uint32_t) env->GetArrayLength(inner_rows); + uint8_t *inner_rows_ptr = (uint8_t *) env->GetByteArrayElements(inner_rows, &if_copy); + + uint8_t *output_rows = nullptr; + size_t output_rows_length = 0; + + if (outer_rows_ptr == nullptr) { + ocall_throw("BroadcastNestedLoopJoin: JNI failed to get inner byte array."); + } else if (inner_rows_ptr == nullptr) { + ocall_throw("BroadcastNestedLoopJoin: JNI failed to get outer byte array."); + } else { + oe_check_and_time("Broadcast Nested Loop Join", + ecall_broadcast_nested_loop_join( + (oe_enclave_t*)eid, + join_expr_ptr, join_expr_length, + outer_rows_ptr, outer_rows_length, + inner_rows_ptr, inner_rows_length, + &output_rows, &output_rows_length)); + } + + jbyteArray ret = env->NewByteArray(output_rows_length); + env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows); + free(output_rows); + + env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0); + env->ReleaseByteArrayElements(outer_rows, (jbyte *) outer_rows_ptr, 0); + env->ReleaseByteArrayElements(inner_rows, (jbyte *) inner_rows_ptr, 0); + + return ret; +} + JNIEXPORT jobject JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate( JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows, jboolean isPartial) { diff --git a/src/enclave/App/SGXEnclave.h b/src/enclave/App/SGXEnclave.h index 2b74c42763..1ddd0d8497 100644 --- a/src/enclave/App/SGXEnclave.h +++ b/src/enclave/App/SGXEnclave.h @@ -41,6 +41,10 @@ extern "C" { Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( JNIEnv *, jobject, jlong, jbyteArray, jbyteArray); + JNIEXPORT jbyteArray JNICALL + Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin( + JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray); + JNIEXPORT jobject JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate( JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jboolean); diff --git a/src/enclave/Enclave/BroadcastNestedLoopJoin.cpp b/src/enclave/Enclave/BroadcastNestedLoopJoin.cpp new file mode 100644 index 0000000000..c99297ebf5 --- /dev/null +++ b/src/enclave/Enclave/BroadcastNestedLoopJoin.cpp @@ -0,0 +1,54 @@ +#include "BroadcastNestedLoopJoin.h" + +#include "ExpressionEvaluation.h" +#include "FlatbuffersReaders.h" +#include "FlatbuffersWriters.h" +#include "common.h" + +/** C++ implementation of a broadcast nested loop join. + * Assumes outer_rows is streamed and inner_rows is broadcast. + * DOES NOT rely on rows to be tagged primary or secondary, and that + * assumption will break the implementation. + */ +void broadcast_nested_loop_join( + uint8_t *join_expr, size_t join_expr_length, + uint8_t *outer_rows, size_t outer_rows_length, + uint8_t *inner_rows, size_t inner_rows_length, + uint8_t **output_rows, size_t *output_rows_length) { + + FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length); + const tuix::JoinType join_type = join_expr_eval.get_join_type(); + + RowReader outer_r(BufferRefView(outer_rows, outer_rows_length)); + RowWriter w; + + while (outer_r.has_next()) { + const tuix::Row *outer = outer_r.next(); + bool o_i_match = false; + + RowReader inner_r(BufferRefView(inner_rows, inner_rows_length)); + const tuix::Row *inner; + while (inner_r.has_next()) { + inner = inner_r.next(); + o_i_match |= join_expr_eval.eval_condition(outer, inner); + } + + switch(join_type) { + case tuix::JoinType_LeftAnti: + if (!o_i_match) { + w.append(outer); + } + break; + case tuix::JoinType_LeftSemi: + if (o_i_match) { + w.append(outer); + } + break; + default: + throw std::runtime_error( + std::string("Join type not supported: ") + + std::string(to_string(join_type))); + } + } + w.output_buffer(output_rows, output_rows_length); +} diff --git a/src/enclave/Enclave/BroadcastNestedLoopJoin.h b/src/enclave/Enclave/BroadcastNestedLoopJoin.h new file mode 100644 index 0000000000..55c934067b --- /dev/null +++ b/src/enclave/Enclave/BroadcastNestedLoopJoin.h @@ -0,0 +1,8 @@ +#include +#include + +void broadcast_nested_loop_join( + uint8_t *join_expr, size_t join_expr_length, + uint8_t *outer_rows, size_t outer_rows_length, + uint8_t *inner_rows, size_t inner_rows_length, + uint8_t **output_rows, size_t *output_rows_length); diff --git a/src/enclave/Enclave/CMakeLists.txt b/src/enclave/Enclave/CMakeLists.txt index 6a72e76dfd..07e6130d80 100644 --- a/src/enclave/Enclave/CMakeLists.txt +++ b/src/enclave/Enclave/CMakeLists.txt @@ -10,7 +10,8 @@ set(SOURCES Flatbuffers.cpp FlatbuffersReaders.cpp FlatbuffersWriters.cpp - Join.cpp + NonObliviousSortMergeJoin.cpp + BroadcastNestedLoopJoin.cpp Limit.cpp Project.cpp Sort.cpp diff --git a/src/enclave/Enclave/Enclave.cpp b/src/enclave/Enclave/Enclave.cpp index e9342875b2..fde1806a97 100644 --- a/src/enclave/Enclave/Enclave.cpp +++ b/src/enclave/Enclave/Enclave.cpp @@ -6,7 +6,8 @@ #include "Aggregate.h" #include "Crypto.h" #include "Filter.h" -#include "Join.h" +#include "NonObliviousSortMergeJoin.h" +#include "BroadcastNestedLoopJoin.h" #include "Limit.h" #include "Project.h" #include "Sort.h" @@ -161,6 +162,25 @@ void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_le } } +void ecall_broadcast_nested_loop_join(uint8_t *join_expr, size_t join_expr_length, + uint8_t *outer_rows, size_t outer_rows_length, + uint8_t *inner_rows, size_t inner_rows_length, + uint8_t **output_rows, size_t *output_rows_length) { + // Guard against operating on arbitrary enclave memory + assert(oe_is_outside_enclave(outer_rows, outer_rows_length) == 1); + assert(oe_is_outside_enclave(inner_rows, inner_rows_length) == 1); + __builtin_ia32_lfence(); + + try { + broadcast_nested_loop_join(join_expr, join_expr_length, + outer_rows, outer_rows_length, + inner_rows, inner_rows_length, + output_rows, output_rows_length); + } catch (const std::runtime_error &e) { + ocall_throw(e.what()); + } +} + void ecall_non_oblivious_aggregate( uint8_t *agg_op, size_t agg_op_length, uint8_t *input_rows, size_t input_rows_length, diff --git a/src/enclave/Enclave/Enclave.edl b/src/enclave/Enclave/Enclave.edl index 44eccc7a76..1789ff2b64 100644 --- a/src/enclave/Enclave/Enclave.edl +++ b/src/enclave/Enclave/Enclave.edl @@ -51,6 +51,12 @@ enclave { [user_check] uint8_t *input_rows, size_t input_rows_length, [out] uint8_t **output_rows, [out] size_t *output_rows_length); + public void ecall_broadcast_nested_loop_join( + [in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length, + [user_check] uint8_t *outer_rows, size_t outer_rows_length, + [user_check] uint8_t *inner_rows, size_t inner_rows_length, + [out] uint8_t **output_rows, [out] size_t *output_rows_length); + public void ecall_non_oblivious_aggregate( [in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length, [user_check] uint8_t *input_rows, size_t input_rows_length, diff --git a/src/enclave/Enclave/ExpressionEvaluation.h b/src/enclave/Enclave/ExpressionEvaluation.h index 7b8dfe0b8b..e3c26f0b87 100644 --- a/src/enclave/Enclave/ExpressionEvaluation.h +++ b/src/enclave/Enclave/ExpressionEvaluation.h @@ -1787,60 +1787,104 @@ class FlatbuffersJoinExprEvaluator { } const tuix::JoinExpr* join_expr = flatbuffers::GetRoot(buf); - join_type = join_expr->join_type(); - if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) { - throw std::runtime_error("Mismatched join key lengths"); - } - for (auto key_it = join_expr->left_keys()->begin(); - key_it != join_expr->left_keys()->end(); ++key_it) { - left_key_evaluators.emplace_back( - std::unique_ptr( - new FlatbuffersExpressionEvaluator(*key_it))); + join_type = join_expr->join_type(); + if (join_expr->condition() != NULL) { + condition_eval = std::unique_ptr( + new FlatbuffersExpressionEvaluator(join_expr->condition())); } - for (auto key_it = join_expr->right_keys()->begin(); - key_it != join_expr->right_keys()->end(); ++key_it) { - right_key_evaluators.emplace_back( - std::unique_ptr( - new FlatbuffersExpressionEvaluator(*key_it))); + is_equi_join = false; + + if (join_expr->left_keys() != NULL && join_expr->right_keys() != NULL) { + is_equi_join = true; + if (join_expr->condition() != NULL) { + throw std::runtime_error("Equi join cannot have condition"); + } + if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) { + throw std::runtime_error("Mismatched join key lengths"); + } + for (auto key_it = join_expr->left_keys()->begin(); + key_it != join_expr->left_keys()->end(); ++key_it) { + left_key_evaluators.emplace_back( + std::unique_ptr( + new FlatbuffersExpressionEvaluator(*key_it))); + } + for (auto key_it = join_expr->right_keys()->begin(); + key_it != join_expr->right_keys()->end(); ++key_it) { + right_key_evaluators.emplace_back( + std::unique_ptr( + new FlatbuffersExpressionEvaluator(*key_it))); + } } } - /** - * Return true if the given row is from the primary table, indicated by its first field, which - * must be an IntegerField. + /** Return true if the given row is from the primary table, indicated by its first field, which + * must be an IntegerField. + * Rows MUST have been tagged in Scala. */ bool is_primary(const tuix::Row *row) { return static_cast( row->field_values()->Get(0)->value())->value() == 0; } - /** Return true if the two rows are from the same join group. */ - bool is_same_group(const tuix::Row *row1, const tuix::Row *row2) { - auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators; - auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators; + /** Returns the row evaluator corresponding to the primary row + * Rows MUST have been tagged in Scala. + */ + const tuix::Row *get_primary_row( + const tuix::Row *row1, const tuix::Row *row2) { + return is_primary(row1) ? row1 : row2; + } + /** Return true if the two rows satisfy the join condition. */ + bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) { builder.Clear(); + bool row1_equals_row2; + + /** Check equality for equi joins. If it is a non-equi join, + * the key evaluators will be empty, so the code never enters the for loop. + */ + auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators; + auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators; for (uint32_t i = 0; i < row1_evaluators.size(); i++) { const tuix::Field *row1_eval_tmp = row1_evaluators[i]->eval(row1); auto row1_eval_offset = flatbuffers_copy(row1_eval_tmp, builder); + auto row1_field = flatbuffers::GetTemporaryPointer(builder, row1_eval_offset); + const tuix::Field *row2_eval_tmp = row2_evaluators[i]->eval(row2); auto row2_eval_offset = flatbuffers_copy(row2_eval_tmp, builder); + auto row2_field = flatbuffers::GetTemporaryPointer(builder, row2_eval_offset); - bool row1_equals_row2 = + flatbuffers::Offset comparison = eval_binary_comparison( + builder, + row1_field, + row2_field); + row1_equals_row2 = static_cast( flatbuffers::GetTemporaryPointer( builder, - eval_binary_comparison( - builder, - flatbuffers::GetTemporaryPointer(builder, row1_eval_offset), - flatbuffers::GetTemporaryPointer(builder, row2_eval_offset))) - ->value())->value(); + comparison)->value())->value(); if (!row1_equals_row2) { return false; } } + + /* Check condition for non-equi joins */ + if (!is_equi_join) { + std::vector> concat_fields; + for (auto field : *row1->field_values()) { + concat_fields.push_back(flatbuffers_copy(field, builder)); + } + for (auto field : *row2->field_values()) { + concat_fields.push_back(flatbuffers_copy(field, builder)); + } + flatbuffers::Offset concat = tuix::CreateRowDirect(builder, &concat_fields); + const tuix::Row *concat_ptr = flatbuffers::GetTemporaryPointer(builder, concat); + + const tuix::Field *condition_result = condition_eval->eval(concat_ptr); + + return static_cast(condition_result->value())->value(); + } return true; } @@ -1853,6 +1897,8 @@ class FlatbuffersJoinExprEvaluator { tuix::JoinType join_type; std::vector> left_key_evaluators; std::vector> right_key_evaluators; + bool is_equi_join; + std::unique_ptr condition_eval; }; class AggregateExpressionEvaluator { diff --git a/src/enclave/Enclave/Join.cpp b/src/enclave/Enclave/NonObliviousSortMergeJoin.cpp similarity index 88% rename from src/enclave/Enclave/Join.cpp rename to src/enclave/Enclave/NonObliviousSortMergeJoin.cpp index 828c963d40..67bc546c0f 100644 --- a/src/enclave/Enclave/Join.cpp +++ b/src/enclave/Enclave/NonObliviousSortMergeJoin.cpp @@ -1,10 +1,13 @@ -#include "Join.h" +#include "NonObliviousSortMergeJoin.h" #include "ExpressionEvaluation.h" #include "FlatbuffersReaders.h" #include "FlatbuffersWriters.h" #include "common.h" +/** C++ implementation of a non-oblivious sort merge join. + * Rows MUST be tagged primary or secondary for this to work. + */ void non_oblivious_sort_merge_join( uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, @@ -25,7 +28,7 @@ void non_oblivious_sort_merge_join( if (join_expr_eval.is_primary(current)) { if (last_primary_of_group.get() - && join_expr_eval.is_same_group(last_primary_of_group.get(), current)) { + && join_expr_eval.eval_condition(last_primary_of_group.get(), current)) { // Add this primary row to the current group primary_group.append(current); last_primary_of_group.set(current); @@ -50,13 +53,13 @@ void non_oblivious_sort_merge_join( } else { // Output the joined rows resulting from this foreign row if (last_primary_of_group.get() - && join_expr_eval.is_same_group(last_primary_of_group.get(), current)) { + && join_expr_eval.eval_condition(last_primary_of_group.get(), current)) { auto primary_group_buffer = primary_group.output_buffer(); RowReader primary_group_reader(primary_group_buffer.view()); while (primary_group_reader.has_next()) { const tuix::Row *primary = primary_group_reader.next(); - if (!join_expr_eval.is_same_group(primary, current)) { + if (!join_expr_eval.eval_condition(primary, current)) { throw std::runtime_error( std::string("Invariant violation: rows of primary_group " "are not of the same group: ") diff --git a/src/enclave/Enclave/Join.h b/src/enclave/Enclave/NonObliviousSortMergeJoin.h similarity index 85% rename from src/enclave/Enclave/Join.h rename to src/enclave/Enclave/NonObliviousSortMergeJoin.h index b380909027..ef60c38437 100644 --- a/src/enclave/Enclave/Join.h +++ b/src/enclave/Enclave/NonObliviousSortMergeJoin.h @@ -1,12 +1,7 @@ #include #include -#ifndef JOIN_H -#define JOIN_H - void non_oblivious_sort_merge_join( uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, uint8_t **output_rows, size_t *output_rows_length); - -#endif diff --git a/src/flatbuffers/operators.fbs b/src/flatbuffers/operators.fbs index 1ebd06c971..9fa82b6cab 100644 --- a/src/flatbuffers/operators.fbs +++ b/src/flatbuffers/operators.fbs @@ -54,10 +54,11 @@ enum JoinType : ubyte { } table JoinExpr { join_type:JoinType; - // Currently only cross joins and equijoins are supported, so we store - // parallel arrays of key expressions and the join outputs pairs of rows - // where each expression from the left is equal to the matching expression - // from the right. + // In the case of equi joins, we store parallel arrays of key expressions and have the join output + // pairs of rows where each expression from the left is equal to the matching expression from the right. left_keys:[Expr]; right_keys:[Expr]; + // In the case of non-equi joins, we pass in a condition as an expression and evaluate that on each pair of rows. + // TODO: have equi joins use this condition rather than an additional filter operation. + condition:Expr; } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index cbe2f944dc..7845e9ea89 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -1257,8 +1257,9 @@ object Utils extends Logging { } def serializeJoinExpression( - joinType: JoinType, leftKeys: Seq[Expression], rightKeys: Seq[Expression], - leftSchema: Seq[Attribute], rightSchema: Seq[Attribute]): Array[Byte] = { + joinType: JoinType, leftKeys: Option[Seq[Expression]], rightKeys: Option[Seq[Expression]], + leftSchema: Seq[Attribute], rightSchema: Seq[Attribute], + condition: Option[Expression] = None): Array[Byte] = { val builder = new FlatBufferBuilder builder.finish( tuix.JoinExpr.createJoinExpr( @@ -1277,12 +1278,28 @@ object Utils extends Logging { case UsingJoin(_, _) => ??? // scalastyle:on }, - tuix.JoinExpr.createLeftKeysVector( - builder, - leftKeys.map(e => flatbuffersSerializeExpression(builder, e, leftSchema)).toArray), - tuix.JoinExpr.createRightKeysVector( - builder, - rightKeys.map(e => flatbuffersSerializeExpression(builder, e, rightSchema)).toArray))) + // Non-zero when equi join + leftKeys match { + case Some(leftKeys) => + tuix.JoinExpr.createLeftKeysVector( + builder, + leftKeys.map(e => flatbuffersSerializeExpression(builder, e, leftSchema)).toArray) + case None => 0 + }, + // Non-zero when equi join + rightKeys match { + case Some(rightKeys) => + tuix.JoinExpr.createRightKeysVector( + builder, + rightKeys.map(e => flatbuffersSerializeExpression(builder, e, rightSchema)).toArray) + case None => 0 + }, + // Non-zero when non-equi join + condition match { + case Some(condition) => + flatbuffersSerializeExpression(builder, condition, leftSchema ++ rightSchema) + case _ => 0 + })) builder.sizedByteArray() } @@ -1382,8 +1399,7 @@ object Utils extends Logging { updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray), tuix.AggregateExpr.createEvaluateExprsVector( builder, - evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray) - ) + evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)) case c @ Count(children) => val count = c.aggBufferAttributes(0) @@ -1421,8 +1437,7 @@ object Utils extends Logging { updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray), tuix.AggregateExpr.createEvaluateExprsVector( builder, - evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray) - ) + evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)) case f @ First(child, false) => val first = f.aggBufferAttributes(0) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala index b49090ced1..e1f1d31261 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala @@ -42,6 +42,9 @@ class SGXEnclave extends java.io.Serializable { @native def NonObliviousSortMergeJoin( eid: Long, joinExpr: Array[Byte], input: Array[Byte]): Array[Byte] + @native def BroadcastNestedLoopJoin( + eid: Long, joinExpr: Array[Byte], outerBlock: Array[Byte], innerBlock: Array[Byte]): Array[Byte] + @native def NonObliviousAggregate( eid: Long, aggOp: Array[Byte], inputRows: Array[Byte], isPartial: Boolean): (Array[Byte]) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index 4eb941157e..6983df047b 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -26,12 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftAnti -import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.SparkPlan +import edu.berkeley.cs.rise.opaque.OpaqueException trait LeafExecNode extends SparkPlan { override final def children: Seq[SparkPlan] = Nil @@ -294,7 +293,7 @@ case class EncryptedSortMergeJoinExec( override def executeBlocked(): RDD[Block] = { val joinExprSer = Utils.serializeJoinExpression( - joinType, leftKeys, rightKeys, leftSchema, rightSchema) + joinType, Some(leftKeys), Some(rightKeys), leftSchema, rightSchema) timeOperator( child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), @@ -308,6 +307,69 @@ case class EncryptedSortMergeJoinExec( } } +case class EncryptedBroadcastNestedLoopJoinExec( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) + extends BinaryExecNode with OpaqueOperatorExec { + + override def output: Seq[Attribute] = { + joinType match { + case _: InnerLike => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case j: ExistenceJoin => + left.output :+ j.exists + case LeftExistence(_) => + left.output + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") + } + } + + override def executeBlocked(): RDD[Block] = { + val joinExprSer = Utils.serializeJoinExpression( + joinType, None, None, left.output, right.output, condition) + + val leftRDD = left.asInstanceOf[OpaqueOperatorExec].executeBlocked() + val rightRDD = right.asInstanceOf[OpaqueOperatorExec].executeBlocked() + + joinType match { + case LeftExistence(_) => { + join(leftRDD, rightRDD, joinExprSer) + } + case _ => + throw new OpaqueException(s"$joinType JoinType is not yet supported") + } + } + + def join(leftRDD: RDD[Block], rightRDD: RDD[Block], + joinExprSer: Array[Byte]): RDD[Block] = { + // We pick which side to broadcast/stream according to buildSide. + // BuildRight means the right relation <=> the broadcast relation. + // NOTE: outer_rows and inner_rows in C++ correspond to stream and broadcast side respectively. + var (streamRDD, broadcastRDD) = buildSide match { + case BuildRight => + (leftRDD, rightRDD) + case BuildLeft => + (rightRDD, leftRDD) + } + val broadcast = Utils.concatEncryptedBlocks(broadcastRDD.collect) + streamRDD.map { block => + val (enclave, eid) = Utils.initEnclave() + Block(enclave.BroadcastNestedLoopJoin(eid, joinExprSer, block.bytes, broadcast.bytes)) + } + } +} + case class EncryptedUnionExec( left: SparkPlan, right: SparkPlan) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index 0c8f188369..dd104d2ad2 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -32,13 +32,19 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.InnerLike import org.apache.spark.sql.catalyst.plans.LeftAnti import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.SparkPlan import edu.berkeley.cs.rise.opaque.execution._ import edu.berkeley.cs.rise.opaque.logical._ +import org.apache.spark.sql.catalyst.plans.LeftExistence object OpaqueOperators extends Strategy { @@ -73,6 +79,7 @@ object OpaqueOperators extends Strategy { case Sort(sortExprs, global, child) if isEncrypted(child) => EncryptedSortExec(sortExprs, global, planLater(child)) :: Nil + // Used to match equi joins case p @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if isEncrypted(p) => val (leftProjSchema, leftKeysProj, tag) = tagForJoin(leftKeys, left.output, true) val (rightProjSchema, rightKeysProj, _) = tagForJoin(rightKeys, right.output, false) @@ -105,6 +112,26 @@ object OpaqueOperators extends Strategy { filtered :: Nil + // Used to match non-equi joins + case Join(left, right, joinType, condition, hint) if isEncrypted(left) && isEncrypted(right) => + // How to pick broadcast side: if left join, broadcast right. If right join, broadcast left. + // This is the simplest and most performant method, but may be worth revisting if one side is + // significantly smaller than the other. Otherwise, pick the smallest side to broadcast. + // NOTE: the current implementation of BNLJ only works under the assumption that + // left join <==> broadcast right AND right join <==> broadcast left. + val desiredBuildSide = if (joinType.isInstanceOf[InnerLike] || joinType == FullOuter) + getSmallerSide(left, right) else + getBroadcastSideBNLJ(joinType) + + val joined = EncryptedBroadcastNestedLoopJoinExec( + planLater(left), + planLater(right), + desiredBuildSide, + joinType, + condition) + + joined :: Nil + case a @ PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) => @@ -183,17 +210,29 @@ object OpaqueOperators extends Strategy { (Seq(tag) ++ keysProj ++ input, keysProj.map(_.toAttribute), tag.toAttribute) } - private def sortForJoin( - leftKeys: Seq[Expression], tag: Expression, input: Seq[Attribute]): Seq[SortOrder] = - leftKeys.map(k => SortOrder(k, Ascending)) :+ SortOrder(tag, Ascending) - private def dropTags( leftOutput: Seq[Attribute], rightOutput: Seq[Attribute]): Seq[NamedExpression] = leftOutput ++ rightOutput + private def sortForJoin( + leftKeys: Seq[Expression], tag: Expression, input: Seq[Attribute]): Seq[SortOrder] = + leftKeys.map(k => SortOrder(k, Ascending)) :+ SortOrder(tag, Ascending) + private def tagForGlobalAggregate(input: Seq[Attribute]) : (Seq[NamedExpression], NamedExpression) = { val tag = Alias(Literal(0), "_tag")() (Seq(tag) ++ input, tag.toAttribute) } + + private def getBroadcastSideBNLJ(joinType: JoinType): BuildSide = { + joinType match { + case LeftExistence(_) => BuildRight + case _ => BuildLeft + } + } + + // Everything below is a private method in SparkStrategies.scala + private def getSmallerSide(left: LogicalPlan, right: LogicalPlan): BuildSide = { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + } } diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index 88a5550f17..859b3bdde4 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -326,6 +326,24 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.collect } + testAgainstSpark("non-equi left semi join") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) + val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") + val df = p.join(f, $"join_col_1" >= $"join_col_2", "left_semi").sort($"join_col_1", $"id1") + df.collect + } + + testAgainstSpark("non-equi left semi join negated") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) + val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") + val df = p.join(f, $"join_col_1" < $"join_col_2", "left_semi").sort($"join_col_1", $"id1") + df.collect + } + testAgainstSpark("left anti join 1") { securityLevel => val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) @@ -335,6 +353,24 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.collect } + testAgainstSpark("non-equi left anti join 1") { securityLevel => + val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) + val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" >= $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + + testAgainstSpark("non-equi left anti join 1 negated") { securityLevel => + val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) + val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" < $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + testAgainstSpark("left anti join 2") { securityLevel => val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) @@ -344,6 +380,24 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.collect } + testAgainstSpark("non-equi left anti join 2") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" >= $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + + testAgainstSpark("non-equi left anti join 2 negated") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" < $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + testAgainstSpark("join on floats") { securityLevel => val p_data = for (i <- 0 to 16) yield (i, i.toFloat, i * 10) val f_data = (0 until 256).map(x => {