Skip to content

Commit b4ba2db

Browse files
authored
Join update (#145)
1 parent 823d95d commit b4ba2db

File tree

12 files changed

+166
-194
lines changed

12 files changed

+166
-194
lines changed

src/enclave/App/App.cpp

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -518,47 +518,9 @@ JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEncla
518518
return ret;
519519
}
520520

521-
JNIEXPORT jbyteArray JNICALL
522-
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ScanCollectLastPrimary(
523-
JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows) {
524-
(void)obj;
525-
526-
jboolean if_copy;
527-
528-
uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr);
529-
uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy);
530-
531-
uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows);
532-
uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy);
533-
534-
uint8_t *output_rows = nullptr;
535-
size_t output_rows_length = 0;
536-
537-
if (input_rows_ptr == nullptr) {
538-
ocall_throw("ScanCollectLastPrimary: JNI failed to get input byte array.");
539-
} else {
540-
oe_check_and_time("Scan Collect Last Primary",
541-
ecall_scan_collect_last_primary(
542-
(oe_enclave_t*)eid,
543-
join_expr_ptr, join_expr_length,
544-
input_rows_ptr, input_rows_length,
545-
&output_rows, &output_rows_length));
546-
}
547-
548-
jbyteArray ret = env->NewByteArray(output_rows_length);
549-
env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows);
550-
free(output_rows);
551-
552-
env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0);
553-
env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0);
554-
555-
return ret;
556-
}
557-
558521
JNIEXPORT jbyteArray JNICALL
559522
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
560-
JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows,
561-
jbyteArray join_row) {
523+
JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows) {
562524
(void)obj;
563525

564526
jboolean if_copy;
@@ -569,9 +531,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
569531
uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows);
570532
uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy);
571533

572-
uint32_t join_row_length = (uint32_t) env->GetArrayLength(join_row);
573-
uint8_t *join_row_ptr = (uint8_t *) env->GetByteArrayElements(join_row, &if_copy);
574-
575534
uint8_t *output_rows = nullptr;
576535
size_t output_rows_length = 0;
577536

@@ -583,7 +542,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
583542
(oe_enclave_t*)eid,
584543
join_expr_ptr, join_expr_length,
585544
input_rows_ptr, input_rows_length,
586-
join_row_ptr, join_row_length,
587545
&output_rows, &output_rows_length));
588546
}
589547

@@ -593,7 +551,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
593551

594552
env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0);
595553
env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0);
596-
env->ReleaseByteArrayElements(join_row, (jbyte *) join_row_ptr, 0);
597554

598555
return ret;
599556
}

src/enclave/App/SGXEnclave.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,9 @@ extern "C" {
3737
JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ExternalSort(
3838
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);
3939

40-
JNIEXPORT jbyteArray JNICALL
41-
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ScanCollectLastPrimary(
42-
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);
43-
4440
JNIEXPORT jbyteArray JNICALL
4541
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
46-
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray);
42+
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);
4743

4844
JNIEXPORT jobject JNICALL
4945
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(

src/enclave/Enclave/Enclave.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -145,35 +145,16 @@ void ecall_external_sort(uint8_t *sort_order, size_t sort_order_length,
145145
}
146146
}
147147

148-
void ecall_scan_collect_last_primary(uint8_t *join_expr, size_t join_expr_length,
149-
uint8_t *input_rows, size_t input_rows_length,
150-
uint8_t **output_rows, size_t *output_rows_length) {
151-
// Guard against operating on arbitrary enclave memory
152-
assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1);
153-
__builtin_ia32_lfence();
154-
155-
try {
156-
scan_collect_last_primary(join_expr, join_expr_length,
157-
input_rows, input_rows_length,
158-
output_rows, output_rows_length);
159-
} catch (const std::runtime_error &e) {
160-
ocall_throw(e.what());
161-
}
162-
}
163-
164148
void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_length,
165149
uint8_t *input_rows, size_t input_rows_length,
166-
uint8_t *join_row, size_t join_row_length,
167150
uint8_t **output_rows, size_t *output_rows_length) {
168151
// Guard against operating on arbitrary enclave memory
169152
assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1);
170-
assert(oe_is_outside_enclave(join_row, join_row_length) == 1);
171153
__builtin_ia32_lfence();
172154

173155
try {
174156
non_oblivious_sort_merge_join(join_expr, join_expr_length,
175157
input_rows, input_rows_length,
176-
join_row, join_row_length,
177158
output_rows, output_rows_length);
178159
} catch (const std::runtime_error &e) {
179160
ocall_throw(e.what());

src/enclave/Enclave/Enclave.edl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,9 @@ enclave {
4343
[user_check] uint8_t *input_rows, size_t input_rows_length,
4444
[out] uint8_t **output_rows, [out] size_t *output_rows_length);
4545

46-
public void ecall_scan_collect_last_primary(
47-
[in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length,
48-
[user_check] uint8_t *input_rows, size_t input_rows_length,
49-
[out] uint8_t **output_rows, [out] size_t *output_rows_length);
50-
5146
public void ecall_non_oblivious_sort_merge_join(
5247
[in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length,
5348
[user_check] uint8_t *input_rows, size_t input_rows_length,
54-
[user_check] uint8_t *join_row, size_t join_row_length,
5549
[out] uint8_t **output_rows, [out] size_t *output_rows_length);
5650

5751
public void ecall_non_oblivious_aggregate(

src/enclave/Enclave/ExpressionEvaluation.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,7 @@ class FlatbuffersJoinExprEvaluator {
16821682
}
16831683

16841684
const tuix::JoinExpr* join_expr = flatbuffers::GetRoot<tuix::JoinExpr>(buf);
1685+
join_type = join_expr->join_type();
16851686

16861687
if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
16871688
throw std::runtime_error("Mismatched join key lengths");
@@ -1738,8 +1739,13 @@ class FlatbuffersJoinExprEvaluator {
17381739
return true;
17391740
}
17401741

1742+
tuix::JoinType get_join_type() {
1743+
return join_type;
1744+
}
1745+
17411746
private:
17421747
flatbuffers::FlatBufferBuilder builder;
1748+
tuix::JoinType join_type;
17431749
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> left_key_evaluators;
17441750
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> right_key_evaluators;
17451751
};

src/enclave/Enclave/Join.cpp

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,59 +5,20 @@
55
#include "FlatbuffersWriters.h"
66
#include "common.h"
77

8-
void scan_collect_last_primary(
9-
uint8_t *join_expr, size_t join_expr_length,
10-
uint8_t *input_rows, size_t input_rows_length,
11-
uint8_t **output_rows, size_t *output_rows_length) {
12-
13-
FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length);
14-
RowReader r(BufferRefView<tuix::EncryptedBlocks>(input_rows, input_rows_length));
15-
RowWriter w;
16-
17-
FlatbuffersTemporaryRow last_primary;
18-
19-
// Accumulate all primary table rows from the same group as the last primary row into `w`.
20-
//
21-
// Because our distributed sorting algorithm uses range partitioning over the join keys, all
22-
// primary rows belonging to the same group will be colocated in the same partition. (The
23-
// corresponding foreign rows may be in the same partition or the next partition.) Therefore it is
24-
// sufficient to send primary rows at most one partition forward.
25-
while (r.has_next()) {
26-
const tuix::Row *row = r.next();
27-
if (join_expr_eval.is_primary(row)) {
28-
if (!last_primary.get() || !join_expr_eval.is_same_group(last_primary.get(), row)) {
29-
w.clear();
30-
last_primary.set(row);
31-
}
32-
33-
w.append(row);
34-
} else {
35-
w.clear();
36-
last_primary.set(nullptr);
37-
}
38-
}
39-
40-
w.output_buffer(output_rows, output_rows_length);
41-
}
42-
438
void non_oblivious_sort_merge_join(
449
uint8_t *join_expr, size_t join_expr_length,
4510
uint8_t *input_rows, size_t input_rows_length,
46-
uint8_t *join_row, size_t join_row_length,
4711
uint8_t **output_rows, size_t *output_rows_length) {
4812

4913
FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length);
14+
tuix::JoinType join_type = join_expr_eval.get_join_type();
5015
RowReader r(BufferRefView<tuix::EncryptedBlocks>(input_rows, input_rows_length));
51-
RowReader j(BufferRefView<tuix::EncryptedBlocks>(join_row, join_row_length));
5216
RowWriter w;
5317

5418
RowWriter primary_group;
5519
FlatbuffersTemporaryRow last_primary_of_group;
56-
while (j.has_next()) {
57-
const tuix::Row *row = j.next();
58-
primary_group.append(row);
59-
last_primary_of_group.set(row);
60-
}
20+
21+
bool pk_fk_match = false;
6122

6223
while (r.has_next()) {
6324
const tuix::Row *current = r.next();
@@ -69,10 +30,22 @@ void non_oblivious_sort_merge_join(
6930
primary_group.append(current);
7031
last_primary_of_group.set(current);
7132
} else {
72-
// Advance to a new group
33+
// If a new primary group is encountered
34+
if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
35+
auto primary_group_buffer = primary_group.output_buffer();
36+
RowReader primary_group_reader(primary_group_buffer.view());
37+
38+
while (primary_group_reader.has_next()) {
39+
const tuix::Row *primary = primary_group_reader.next();
40+
w.append(primary);
41+
}
42+
}
43+
7344
primary_group.clear();
7445
primary_group.append(current);
7546
last_primary_of_group.set(current);
47+
48+
pk_fk_match = false;
7649
}
7750
} else {
7851
// Output the joined rows resulting from this foreign row
@@ -92,11 +65,34 @@ void non_oblivious_sort_merge_join(
9265
+ to_string(current));
9366
}
9467

95-
w.append(primary, current);
68+
if (join_type == tuix::JoinType_Inner) {
69+
w.append(primary, current);
70+
} else if (join_type == tuix::JoinType_LeftSemi) {
71+
// Only output the pk group ONCE
72+
if (!pk_fk_match) {
73+
w.append(primary);
74+
}
75+
}
9676
}
77+
78+
pk_fk_match = true;
79+
} else {
80+
// If pk_fk_match were true, and the code got to here, then that means the group match has not been "cleared" yet
81+
// It will be processed when the code advances to the next pk group
82+
pk_fk_match &= true;
9783
}
9884
}
9985
}
10086

87+
if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
88+
auto primary_group_buffer = primary_group.output_buffer();
89+
RowReader primary_group_reader(primary_group_buffer.view());
90+
91+
while (primary_group_reader.has_next()) {
92+
const tuix::Row *primary = primary_group_reader.next();
93+
w.append(primary);
94+
}
95+
}
96+
10197
w.output_buffer(output_rows, output_rows_length);
10298
}

src/enclave/Enclave/Join.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,9 @@
44
#ifndef JOIN_H
55
#define JOIN_H
66

7-
void scan_collect_last_primary(
8-
uint8_t *join_expr, size_t join_expr_length,
9-
uint8_t *input_rows, size_t input_rows_length,
10-
uint8_t **output_rows, size_t *output_rows_length);
11-
127
void non_oblivious_sort_merge_join(
138
uint8_t *join_expr, size_t join_expr_length,
149
uint8_t *input_rows, size_t input_rows_length,
15-
uint8_t *join_row, size_t join_row_length,
1610
uint8_t **output_rows, size_t *output_rows_length);
1711

1812
#endif

0 commit comments

Comments
 (0)