Skip to content

Move join condition handling for equi-joins into enclave code #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/enclave/Enclave/ExpressionEvaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,7 @@ class FlatbuffersJoinExprEvaluator {
const tuix::JoinExpr* join_expr = flatbuffers::GetRoot<tuix::JoinExpr>(buf);

join_type = join_expr->join_type();
condition_eval = nullptr;
if (join_expr->condition() != NULL) {
condition_eval = std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(join_expr->condition()));
Expand All @@ -1797,9 +1798,6 @@ class FlatbuffersJoinExprEvaluator {

if (join_expr->left_keys() != NULL && join_expr->right_keys() != NULL) {
is_equi_join = true;
if (join_expr->condition() != NULL) {
throw std::runtime_error("Equi join cannot have condition");
}
if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
throw std::runtime_error("Mismatched join key lengths");
}
Expand Down Expand Up @@ -1835,14 +1833,12 @@ class FlatbuffersJoinExprEvaluator {
return is_primary(row1) ? row1 : row2;
}

/** Return true if the two rows satisfy the join condition. */
bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) {
/** Return true if the two rows are from the same join group
* Since the function calls `is_primary`, the rows must have been tagged in Scala */
bool is_same_group(const tuix::Row *row1, const tuix::Row *row2) {
builder.Clear();
bool row1_equals_row2;

/** Check equality for equi joins. If it is a non-equi join,
* the key evaluators will be empty, so the code never enters the for loop.
*/
auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators;
auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators;
for (uint32_t i = 0; i < row1_evaluators.size(); i++) {
Expand All @@ -1855,9 +1851,8 @@ class FlatbuffersJoinExprEvaluator {
auto row2_field = flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row2_eval_offset);

flatbuffers::Offset<tuix::Field> comparison = eval_binary_comparison<tuix::EqualTo, std::equal_to>(
builder,
row1_field,
row2_field);
builder, row1_field, row2_field);

row1_equals_row2 =
static_cast<const tuix::BooleanField *>(
flatbuffers::GetTemporaryPointer<tuix::Field>(
Expand All @@ -1868,9 +1863,12 @@ class FlatbuffersJoinExprEvaluator {
return false;
}
}
return true;
}

/* Check condition for non-equi joins */
if (!is_equi_join) {
/** Evaluate condition on the two input rows */
bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) {
if (condition_eval != nullptr) {
std::vector<flatbuffers::Offset<tuix::Field>> concat_fields;
for (auto field : *row1->field_values()) {
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
Expand All @@ -1880,11 +1878,13 @@ class FlatbuffersJoinExprEvaluator {
}
flatbuffers::Offset<tuix::Row> concat = tuix::CreateRowDirect(builder, &concat_fields);
const tuix::Row *concat_ptr = flatbuffers::GetTemporaryPointer<tuix::Row>(builder, concat);

const tuix::Field *condition_result = condition_eval->eval(concat_ptr);

return static_cast<const tuix::BooleanField *>(condition_result->value())->value();
}

// The `condition_eval` can only be empty when it's an equi-join.
// Since `condition_eval` is an extra predicate used to filter out *matched* rows in an equi-join, an empty
// condition means the matched row should not be filtered out; hence the default return value of true
return true;
}

Expand Down
141 changes: 94 additions & 47 deletions src/enclave/Enclave/NonObliviousSortMergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,53 @@
#include "FlatbuffersWriters.h"
#include "common.h"

/** C++ implementation of a non-oblivious sort merge join.
/**
* C++ implementation of a non-oblivious sort merge join.
* Rows MUST be tagged primary or secondary for this to work.
*/

void test_rows_same_group(FlatbuffersJoinExprEvaluator &join_expr_eval,
const tuix::Row *primary,
const tuix::Row *current) {
if (!join_expr_eval.is_same_group(primary, current)) {
throw std::runtime_error(
std::string("Invariant violation: rows of primary_group "
"are not of the same group: ")
+ to_string(primary)
+ std::string(" vs ")
+ to_string(current));
}
}

void write_output_rows(RowWriter &group, RowWriter &w) {
auto group_buffer = group.output_buffer();
RowReader group_reader(group_buffer.view());

while (group_reader.has_next()) {
const tuix::Row *row = group_reader.next();
w.append(row);
}
}

/**
* Sort merge equi join algorithm
* Input: the rows are unioned from both the primary (or left) table and the non-primary (or right) table
*
* Outer loop: iterate over all input rows
*
* If it's a row from the left table
* - Add it to the current group
* - Otherwise start a new group
* - If it's a left semi/anti join, output the primary_matched_rows/primary_unmatched_rows
*
* If it's a row from the right table
* - Inner join: iterate over current left group, output the joined row only if the condition is satisfied
* - Left semi/anti join: iterate over `primary_unmatched_rows`, add a matched row to `primary_matched_rows`
* and remove from `primary_unmatched_rows`
*
* After loop: output the last group left semi/anti join
*/

void non_oblivious_sort_merge_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
Expand All @@ -20,81 +64,84 @@ void non_oblivious_sort_merge_join(

RowWriter primary_group;
FlatbuffersTemporaryRow last_primary_of_group;

bool pk_fk_match = false;
RowWriter primary_matched_rows, primary_unmatched_rows; // This is only used for left semi/anti join

while (r.has_next()) {
const tuix::Row *current = r.next();

if (join_expr_eval.is_primary(current)) {
if (last_primary_of_group.get()
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {

// Add this primary row to the current group
// If this is a left semi/anti join, also add the rows to primary_unmatched_rows
primary_group.append(current);
if (join_type == tuix::JoinType_LeftSemi || join_type == tuix::JoinType_LeftAnti) {
primary_unmatched_rows.append(current);
}
last_primary_of_group.set(current);

} else {
// If a new primary group is encountered
if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
auto primary_group_buffer = primary_group.output_buffer();
RowReader primary_group_reader(primary_group_buffer.view());

while (primary_group_reader.has_next()) {
const tuix::Row *primary = primary_group_reader.next();
w.append(primary);
}
if (join_type == tuix::JoinType_LeftSemi) {
write_output_rows(primary_matched_rows, w);
} else if (join_type == tuix::JoinType_LeftAnti) {
write_output_rows(primary_unmatched_rows, w);
}

primary_group.clear();
primary_unmatched_rows.clear();
primary_matched_rows.clear();

primary_group.append(current);
primary_unmatched_rows.append(current);
last_primary_of_group.set(current);

pk_fk_match = false;
}
} else {
// Output the joined rows resulting from this foreign row
if (last_primary_of_group.get()
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
auto primary_group_buffer = primary_group.output_buffer();
RowReader primary_group_reader(primary_group_buffer.view());
while (primary_group_reader.has_next()) {
const tuix::Row *primary = primary_group_reader.next();
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
if (join_type == tuix::JoinType_Inner) {
auto primary_group_buffer = primary_group.output_buffer();
RowReader primary_group_reader(primary_group_buffer.view());
while (primary_group_reader.has_next()) {
const tuix::Row *primary = primary_group_reader.next();
test_rows_same_group(join_expr_eval, primary, current);

if (!join_expr_eval.eval_condition(primary, current)) {
throw std::runtime_error(
std::string("Invariant violation: rows of primary_group "
"are not of the same group: ")
+ to_string(primary)
+ std::string(" vs ")
+ to_string(current));
if (join_expr_eval.eval_condition(primary, current)) {
w.append(primary, current);
}
}
} else if (join_type == tuix::JoinType_LeftSemi || join_type == tuix::JoinType_LeftAnti) {
auto primary_unmatched_rows_buffer = primary_unmatched_rows.output_buffer();
RowReader primary_unmatched_rows_reader(primary_unmatched_rows_buffer.view());
RowWriter new_primary_unmatched_rows;

if (join_type == tuix::JoinType_Inner) {
w.append(primary, current);
} else if (join_type == tuix::JoinType_LeftSemi) {
// Only output the pk group ONCE
if (!pk_fk_match) {
w.append(primary);
while (primary_unmatched_rows_reader.has_next()) {
const tuix::Row *primary = primary_unmatched_rows_reader.next();
test_rows_same_group(join_expr_eval, primary, current);
if (join_expr_eval.eval_condition(primary, current)) {
primary_matched_rows.append(primary);
} else {
new_primary_unmatched_rows.append(primary);
}
}

// Reset primary_unmatched_rows
primary_unmatched_rows.clear();
auto new_primary_unmatched_rows_buffer = new_primary_unmatched_rows.output_buffer();
RowReader new_primary_unmatched_rows_reader(new_primary_unmatched_rows_buffer.view());
while (new_primary_unmatched_rows_reader.has_next()) {
primary_unmatched_rows.append(new_primary_unmatched_rows_reader.next());
}
}

pk_fk_match = true;
} else {
// If pk_fk_match were true, and the code got to here, then that means the group match has not been "cleared" yet
// It will be processed when the code advances to the next pk group
pk_fk_match &= true;
}
}
}

if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
auto primary_group_buffer = primary_group.output_buffer();
RowReader primary_group_reader(primary_group_buffer.view());

while (primary_group_reader.has_next()) {
const tuix::Row *primary = primary_group_reader.next();
w.append(primary);
}
if (join_type == tuix::JoinType_LeftSemi) {
write_output_rows(primary_matched_rows, w);
} else if (join_type == tuix::JoinType_LeftAnti) {
write_output_rows(primary_unmatched_rows, w);
}

w.output_buffer(output_rows, output_rows_length);
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ object Utils extends Logging {
def serializeJoinExpression(
joinType: JoinType, leftKeys: Option[Seq[Expression]], rightKeys: Option[Seq[Expression]],
leftSchema: Seq[Attribute], rightSchema: Seq[Attribute],
condition: Option[Expression] = None): Array[Byte] = {
condition: Option[Expression]): Array[Byte] = {
val builder = new FlatBufferBuilder
builder.finish(
tuix.JoinExpr.createJoinExpr(
Expand Down Expand Up @@ -1298,7 +1298,7 @@ object Utils extends Logging {
condition match {
case Some(condition) =>
flatbuffersSerializeExpression(builder, condition, leftSchema ++ rightSchema)
case _ => 0
case None => 0
}))
builder.sizedByteArray()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.SQLContext
object TPCHBenchmark {

// Add query numbers here once they are supported
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22)
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 21, 22)

def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = {
val sqlStr = tpch.getQuery(queryNumber)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ case class EncryptedSortMergeJoinExec(
rightKeys: Seq[Expression],
leftSchema: Seq[Attribute],
rightSchema: Seq[Attribute],
condition: Option[Expression],
child: SparkPlan)
extends UnaryExecNode with OpaqueOperatorExec {

Expand All @@ -293,7 +294,7 @@ case class EncryptedSortMergeJoinExec(

override def executeBlocked(): RDD[Block] = {
val joinExprSer = Utils.serializeJoinExpression(
joinType, Some(leftKeys), Some(rightKeys), leftSchema, rightSchema)
joinType, Some(leftKeys), Some(rightKeys), leftSchema, rightSchema, condition)

timeOperator(
child.asInstanceOf[OpaqueOperatorExec].executeBlocked(),
Expand Down
8 changes: 2 additions & 6 deletions src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,15 @@ object OpaqueOperators extends Strategy {
rightKeysProj,
leftProjSchema.map(_.toAttribute),
rightProjSchema.map(_.toAttribute),
condition,
sorted)

val tagsDropped = joinType match {
case Inner => EncryptedProjectExec(dropTags(left.output, right.output), joined)
case LeftSemi | LeftAnti => EncryptedProjectExec(left.output, joined)
}

val filtered = condition match {
case Some(condition) => EncryptedFilterExec(condition, tagsDropped)
case None => tagsDropped
}

filtered :: Nil
tagsDropped :: Nil

// Used to match non-equi joins
case Join(left, right, joinType, condition, hint) if isEncrypted(left) && isEncrypted(right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,20 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10)
val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10)
val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x")
val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x")
val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "y")
val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1", $"id1")
df.collect
}

testAgainstSpark("left semi join with condition") { securityLevel =>
val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10)
val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10)
val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x")
val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "y")
val df = p.join(f, $"join_col_1" === $"join_col_2" && $"x" > $"y", "left_semi").sort($"join_col_1", $"id1")
df.collect
}

testAgainstSpark("non-equi left semi join") { securityLevel =>
val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10)
val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10)
Expand All @@ -344,7 +353,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
df.collect
}

testAgainstSpark("left anti join 1") { securityLevel =>
testAgainstSpark("left anti join") { securityLevel =>
val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10)
val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10)
val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x")
Expand All @@ -353,6 +362,15 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
df.collect
}

testAgainstSpark("left anti join with condition") { securityLevel =>
val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10)
val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10)
val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x")
val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "y")
val df = p.join(f, $"join_col_1" === $"join_col_2" && $"x" > $"y", "left_anti").sort($"join_col_1", $"id1")
df.collect
}

testAgainstSpark("non-equi left anti join 1") { securityLevel =>
val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10)
val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10)
Expand Down