Skip to content

Commit a4a6ff9

Browse files
octaviansimawzheng
authored andcommitted
Broadcast Nested Loop Join - Left Anti and Left Semi (#159)
This PR is the first of two parts towards making TPC-H 16 work: the other will be implementing `is_distinct` for aggregate operations. `BroadcastNestedLoopJoin` is Spark's "catch all" for non-equi joins. It works by first picking a side to broadcast, then iterating through every possible row combination and checking the non-equi condition against the pair.
1 parent 3c28b5f commit a4a6ff9

16 files changed

+418
-63
lines changed

src/enclave/App/App.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,50 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
555555
return ret;
556556
}
557557

558+
JNIEXPORT jbyteArray JNICALL
559+
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin(
560+
JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray outer_rows, jbyteArray inner_rows) {
561+
(void)obj;
562+
563+
jboolean if_copy;
564+
565+
uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr);
566+
uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy);
567+
568+
uint32_t outer_rows_length = (uint32_t) env->GetArrayLength(outer_rows);
569+
uint8_t *outer_rows_ptr = (uint8_t *) env->GetByteArrayElements(outer_rows, &if_copy);
570+
571+
uint32_t inner_rows_length = (uint32_t) env->GetArrayLength(inner_rows);
572+
uint8_t *inner_rows_ptr = (uint8_t *) env->GetByteArrayElements(inner_rows, &if_copy);
573+
574+
uint8_t *output_rows = nullptr;
575+
size_t output_rows_length = 0;
576+
577+
if (outer_rows_ptr == nullptr) {
578+
ocall_throw("BroadcastNestedLoopJoin: JNI failed to get inner byte array.");
579+
} else if (inner_rows_ptr == nullptr) {
580+
ocall_throw("BroadcastNestedLoopJoin: JNI failed to get outer byte array.");
581+
} else {
582+
oe_check_and_time("Broadcast Nested Loop Join",
583+
ecall_broadcast_nested_loop_join(
584+
(oe_enclave_t*)eid,
585+
join_expr_ptr, join_expr_length,
586+
outer_rows_ptr, outer_rows_length,
587+
inner_rows_ptr, inner_rows_length,
588+
&output_rows, &output_rows_length));
589+
}
590+
591+
jbyteArray ret = env->NewByteArray(output_rows_length);
592+
env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows);
593+
free(output_rows);
594+
595+
env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0);
596+
env->ReleaseByteArrayElements(outer_rows, (jbyte *) outer_rows_ptr, 0);
597+
env->ReleaseByteArrayElements(inner_rows, (jbyte *) inner_rows_ptr, 0);
598+
599+
return ret;
600+
}
601+
558602
JNIEXPORT jobject JNICALL
559603
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
560604
JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows, jboolean isPartial) {

src/enclave/App/SGXEnclave.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ extern "C" {
4141
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
4242
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);
4343

44+
JNIEXPORT jbyteArray JNICALL
45+
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin(
46+
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray);
47+
4448
JNIEXPORT jobject JNICALL
4549
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
4650
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jboolean);
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include "BroadcastNestedLoopJoin.h"
2+
3+
#include "ExpressionEvaluation.h"
4+
#include "FlatbuffersReaders.h"
5+
#include "FlatbuffersWriters.h"
6+
#include "common.h"
7+
8+
/** C++ implementation of a broadcast nested loop join.
9+
* Assumes outer_rows is streamed and inner_rows is broadcast.
10+
* DOES NOT rely on rows to be tagged primary or secondary, and that
11+
* assumption will break the implementation.
12+
*/
13+
void broadcast_nested_loop_join(
14+
uint8_t *join_expr, size_t join_expr_length,
15+
uint8_t *outer_rows, size_t outer_rows_length,
16+
uint8_t *inner_rows, size_t inner_rows_length,
17+
uint8_t **output_rows, size_t *output_rows_length) {
18+
19+
FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length);
20+
const tuix::JoinType join_type = join_expr_eval.get_join_type();
21+
22+
RowReader outer_r(BufferRefView<tuix::EncryptedBlocks>(outer_rows, outer_rows_length));
23+
RowWriter w;
24+
25+
while (outer_r.has_next()) {
26+
const tuix::Row *outer = outer_r.next();
27+
bool o_i_match = false;
28+
29+
RowReader inner_r(BufferRefView<tuix::EncryptedBlocks>(inner_rows, inner_rows_length));
30+
const tuix::Row *inner;
31+
while (inner_r.has_next()) {
32+
inner = inner_r.next();
33+
o_i_match |= join_expr_eval.eval_condition(outer, inner);
34+
}
35+
36+
switch(join_type) {
37+
case tuix::JoinType_LeftAnti:
38+
if (!o_i_match) {
39+
w.append(outer);
40+
}
41+
break;
42+
case tuix::JoinType_LeftSemi:
43+
if (o_i_match) {
44+
w.append(outer);
45+
}
46+
break;
47+
default:
48+
throw std::runtime_error(
49+
std::string("Join type not supported: ")
50+
+ std::string(to_string(join_type)));
51+
}
52+
}
53+
w.output_buffer(output_rows, output_rows_length);
54+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include <cstddef>
2+
#include <cstdint>
3+
4+
void broadcast_nested_loop_join(
5+
uint8_t *join_expr, size_t join_expr_length,
6+
uint8_t *outer_rows, size_t outer_rows_length,
7+
uint8_t *inner_rows, size_t inner_rows_length,
8+
uint8_t **output_rows, size_t *output_rows_length);

src/enclave/Enclave/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ set(SOURCES
1010
Flatbuffers.cpp
1111
FlatbuffersReaders.cpp
1212
FlatbuffersWriters.cpp
13-
Join.cpp
13+
NonObliviousSortMergeJoin.cpp
14+
BroadcastNestedLoopJoin.cpp
1415
Limit.cpp
1516
Project.cpp
1617
Sort.cpp

src/enclave/Enclave/Enclave.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
#include "Aggregate.h"
77
#include "Crypto.h"
88
#include "Filter.h"
9-
#include "Join.h"
9+
#include "NonObliviousSortMergeJoin.h"
10+
#include "BroadcastNestedLoopJoin.h"
1011
#include "Limit.h"
1112
#include "Project.h"
1213
#include "Sort.h"
@@ -161,6 +162,25 @@ void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_le
161162
}
162163
}
163164

165+
void ecall_broadcast_nested_loop_join(uint8_t *join_expr, size_t join_expr_length,
166+
uint8_t *outer_rows, size_t outer_rows_length,
167+
uint8_t *inner_rows, size_t inner_rows_length,
168+
uint8_t **output_rows, size_t *output_rows_length) {
169+
// Guard against operating on arbitrary enclave memory
170+
assert(oe_is_outside_enclave(outer_rows, outer_rows_length) == 1);
171+
assert(oe_is_outside_enclave(inner_rows, inner_rows_length) == 1);
172+
__builtin_ia32_lfence();
173+
174+
try {
175+
broadcast_nested_loop_join(join_expr, join_expr_length,
176+
outer_rows, outer_rows_length,
177+
inner_rows, inner_rows_length,
178+
output_rows, output_rows_length);
179+
} catch (const std::runtime_error &e) {
180+
ocall_throw(e.what());
181+
}
182+
}
183+
164184
void ecall_non_oblivious_aggregate(
165185
uint8_t *agg_op, size_t agg_op_length,
166186
uint8_t *input_rows, size_t input_rows_length,

src/enclave/Enclave/Enclave.edl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ enclave {
5151
[user_check] uint8_t *input_rows, size_t input_rows_length,
5252
[out] uint8_t **output_rows, [out] size_t *output_rows_length);
5353

54+
public void ecall_broadcast_nested_loop_join(
55+
[in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length,
56+
[user_check] uint8_t *outer_rows, size_t outer_rows_length,
57+
[user_check] uint8_t *inner_rows, size_t inner_rows_length,
58+
[out] uint8_t **output_rows, [out] size_t *output_rows_length);
59+
5460
public void ecall_non_oblivious_aggregate(
5561
[in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length,
5662
[user_check] uint8_t *input_rows, size_t input_rows_length,

src/enclave/Enclave/ExpressionEvaluation.h

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,60 +1787,104 @@ class FlatbuffersJoinExprEvaluator {
17871787
}
17881788

17891789
const tuix::JoinExpr* join_expr = flatbuffers::GetRoot<tuix::JoinExpr>(buf);
1790-
join_type = join_expr->join_type();
17911790

1792-
if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
1793-
throw std::runtime_error("Mismatched join key lengths");
1794-
}
1795-
for (auto key_it = join_expr->left_keys()->begin();
1796-
key_it != join_expr->left_keys()->end(); ++key_it) {
1797-
left_key_evaluators.emplace_back(
1798-
std::unique_ptr<FlatbuffersExpressionEvaluator>(
1799-
new FlatbuffersExpressionEvaluator(*key_it)));
1791+
join_type = join_expr->join_type();
1792+
if (join_expr->condition() != NULL) {
1793+
condition_eval = std::unique_ptr<FlatbuffersExpressionEvaluator>(
1794+
new FlatbuffersExpressionEvaluator(join_expr->condition()));
18001795
}
1801-
for (auto key_it = join_expr->right_keys()->begin();
1802-
key_it != join_expr->right_keys()->end(); ++key_it) {
1803-
right_key_evaluators.emplace_back(
1804-
std::unique_ptr<FlatbuffersExpressionEvaluator>(
1805-
new FlatbuffersExpressionEvaluator(*key_it)));
1796+
is_equi_join = false;
1797+
1798+
if (join_expr->left_keys() != NULL && join_expr->right_keys() != NULL) {
1799+
is_equi_join = true;
1800+
if (join_expr->condition() != NULL) {
1801+
throw std::runtime_error("Equi join cannot have condition");
1802+
}
1803+
if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
1804+
throw std::runtime_error("Mismatched join key lengths");
1805+
}
1806+
for (auto key_it = join_expr->left_keys()->begin();
1807+
key_it != join_expr->left_keys()->end(); ++key_it) {
1808+
left_key_evaluators.emplace_back(
1809+
std::unique_ptr<FlatbuffersExpressionEvaluator>(
1810+
new FlatbuffersExpressionEvaluator(*key_it)));
1811+
}
1812+
for (auto key_it = join_expr->right_keys()->begin();
1813+
key_it != join_expr->right_keys()->end(); ++key_it) {
1814+
right_key_evaluators.emplace_back(
1815+
std::unique_ptr<FlatbuffersExpressionEvaluator>(
1816+
new FlatbuffersExpressionEvaluator(*key_it)));
1817+
}
18061818
}
18071819
}
18081820

1809-
/**
1810-
* Return true if the given row is from the primary table, indicated by its first field, which
1811-
* must be an IntegerField.
1821+
/** Return true if the given row is from the primary table, indicated by its first field, which
1822+
* must be an IntegerField.
1823+
* Rows MUST have been tagged in Scala.
18121824
*/
18131825
bool is_primary(const tuix::Row *row) {
18141826
return static_cast<const tuix::IntegerField *>(
18151827
row->field_values()->Get(0)->value())->value() == 0;
18161828
}
18171829

1818-
/** Return true if the two rows are from the same join group. */
1819-
bool is_same_group(const tuix::Row *row1, const tuix::Row *row2) {
1820-
auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators;
1821-
auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators;
1830+
/** Returns the row evaluator corresponding to the primary row
1831+
* Rows MUST have been tagged in Scala.
1832+
*/
1833+
const tuix::Row *get_primary_row(
1834+
const tuix::Row *row1, const tuix::Row *row2) {
1835+
return is_primary(row1) ? row1 : row2;
1836+
}
18221837

1838+
/** Return true if the two rows satisfy the join condition. */
1839+
bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) {
18231840
builder.Clear();
1841+
bool row1_equals_row2;
1842+
1843+
/** Check equality for equi joins. If it is a non-equi join,
1844+
* the key evaluators will be empty, so the code never enters the for loop.
1845+
*/
1846+
auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators;
1847+
auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators;
18241848
for (uint32_t i = 0; i < row1_evaluators.size(); i++) {
18251849
const tuix::Field *row1_eval_tmp = row1_evaluators[i]->eval(row1);
18261850
auto row1_eval_offset = flatbuffers_copy(row1_eval_tmp, builder);
1851+
auto row1_field = flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row1_eval_offset);
1852+
18271853
const tuix::Field *row2_eval_tmp = row2_evaluators[i]->eval(row2);
18281854
auto row2_eval_offset = flatbuffers_copy(row2_eval_tmp, builder);
1855+
auto row2_field = flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row2_eval_offset);
18291856

1830-
bool row1_equals_row2 =
1857+
flatbuffers::Offset<tuix::Field> comparison = eval_binary_comparison<tuix::EqualTo, std::equal_to>(
1858+
builder,
1859+
row1_field,
1860+
row2_field);
1861+
row1_equals_row2 =
18311862
static_cast<const tuix::BooleanField *>(
18321863
flatbuffers::GetTemporaryPointer<tuix::Field>(
18331864
builder,
1834-
eval_binary_comparison<tuix::EqualTo, std::equal_to>(
1835-
builder,
1836-
flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row1_eval_offset),
1837-
flatbuffers::GetTemporaryPointer<tuix::Field>(builder, row2_eval_offset)))
1838-
->value())->value();
1865+
comparison)->value())->value();
18391866

18401867
if (!row1_equals_row2) {
18411868
return false;
18421869
}
18431870
}
1871+
1872+
/* Check condition for non-equi joins */
1873+
if (!is_equi_join) {
1874+
std::vector<flatbuffers::Offset<tuix::Field>> concat_fields;
1875+
for (auto field : *row1->field_values()) {
1876+
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
1877+
}
1878+
for (auto field : *row2->field_values()) {
1879+
concat_fields.push_back(flatbuffers_copy<tuix::Field>(field, builder));
1880+
}
1881+
flatbuffers::Offset<tuix::Row> concat = tuix::CreateRowDirect(builder, &concat_fields);
1882+
const tuix::Row *concat_ptr = flatbuffers::GetTemporaryPointer<tuix::Row>(builder, concat);
1883+
1884+
const tuix::Field *condition_result = condition_eval->eval(concat_ptr);
1885+
1886+
return static_cast<const tuix::BooleanField *>(condition_result->value())->value();
1887+
}
18441888
return true;
18451889
}
18461890

@@ -1853,6 +1897,8 @@ class FlatbuffersJoinExprEvaluator {
18531897
tuix::JoinType join_type;
18541898
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> left_key_evaluators;
18551899
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> right_key_evaluators;
1900+
bool is_equi_join;
1901+
std::unique_ptr<FlatbuffersExpressionEvaluator> condition_eval;
18561902
};
18571903

18581904
class AggregateExpressionEvaluator {

src/enclave/Enclave/Join.cpp renamed to src/enclave/Enclave/NonObliviousSortMergeJoin.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
#include "Join.h"
1+
#include "NonObliviousSortMergeJoin.h"
22

33
#include "ExpressionEvaluation.h"
44
#include "FlatbuffersReaders.h"
55
#include "FlatbuffersWriters.h"
66
#include "common.h"
77

8+
/** C++ implementation of a non-oblivious sort merge join.
9+
* Rows MUST be tagged primary or secondary for this to work.
10+
*/
811
void non_oblivious_sort_merge_join(
912
uint8_t *join_expr, size_t join_expr_length,
1013
uint8_t *input_rows, size_t input_rows_length,
@@ -25,7 +28,7 @@ void non_oblivious_sort_merge_join(
2528

2629
if (join_expr_eval.is_primary(current)) {
2730
if (last_primary_of_group.get()
28-
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
31+
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
2932
// Add this primary row to the current group
3033
primary_group.append(current);
3134
last_primary_of_group.set(current);
@@ -50,13 +53,13 @@ void non_oblivious_sort_merge_join(
5053
} else {
5154
// Output the joined rows resulting from this foreign row
5255
if (last_primary_of_group.get()
53-
&& join_expr_eval.is_same_group(last_primary_of_group.get(), current)) {
56+
&& join_expr_eval.eval_condition(last_primary_of_group.get(), current)) {
5457
auto primary_group_buffer = primary_group.output_buffer();
5558
RowReader primary_group_reader(primary_group_buffer.view());
5659
while (primary_group_reader.has_next()) {
5760
const tuix::Row *primary = primary_group_reader.next();
5861

59-
if (!join_expr_eval.is_same_group(primary, current)) {
62+
if (!join_expr_eval.eval_condition(primary, current)) {
6063
throw std::runtime_error(
6164
std::string("Invariant violation: rows of primary_group "
6265
"are not of the same group: ")
Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
#include <cstddef>
22
#include <cstdint>
33

4-
#ifndef JOIN_H
5-
#define JOIN_H
6-
74
void non_oblivious_sort_merge_join(
85
uint8_t *join_expr, size_t join_expr_length,
96
uint8_t *input_rows, size_t input_rows_length,
107
uint8_t **output_rows, size_t *output_rows_length);
11-
12-
#endif

src/flatbuffers/operators.fbs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ enum JoinType : ubyte {
5454
}
5555
table JoinExpr {
5656
join_type:JoinType;
57-
// Currently only cross joins and equijoins are supported, so we store
58-
// parallel arrays of key expressions and the join outputs pairs of rows
59-
// where each expression from the left is equal to the matching expression
60-
// from the right.
57+
// In the case of equi joins, we store parallel arrays of key expressions and have the join output
58+
// pairs of rows where each expression from the left is equal to the matching expression from the right.
6159
left_keys:[Expr];
6260
right_keys:[Expr];
61+
// In the case of non-equi joins, we pass in a condition as an expression and evaluate that on each pair of rows.
62+
// TODO: have equi joins use this condition rather than an additional filter operation.
63+
condition:Expr;
6364
}

0 commit comments

Comments
 (0)