Skip to content

Join update #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 1 addition & 44 deletions src/enclave/App/App.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,47 +518,9 @@ JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEncla
return ret;
}

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ScanCollectLastPrimary(
JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows) {
(void)obj;

jboolean if_copy;

uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr);
uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy);

uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows);
uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy);

uint8_t *output_rows = nullptr;
size_t output_rows_length = 0;

if (input_rows_ptr == nullptr) {
ocall_throw("ScanCollectLastPrimary: JNI failed to get input byte array.");
} else {
oe_check_and_time("Scan Collect Last Primary",
ecall_scan_collect_last_primary(
(oe_enclave_t*)eid,
join_expr_ptr, join_expr_length,
input_rows_ptr, input_rows_length,
&output_rows, &output_rows_length));
}

jbyteArray ret = env->NewByteArray(output_rows_length);
env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows);
free(output_rows);

env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0);
env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0);

return ret;
}

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows,
jbyteArray join_row) {
JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows) {
(void)obj;

jboolean if_copy;
Expand All @@ -569,9 +531,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows);
uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy);

uint32_t join_row_length = (uint32_t) env->GetArrayLength(join_row);
uint8_t *join_row_ptr = (uint8_t *) env->GetByteArrayElements(join_row, &if_copy);

uint8_t *output_rows = nullptr;
size_t output_rows_length = 0;

Expand All @@ -583,7 +542,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
(oe_enclave_t*)eid,
join_expr_ptr, join_expr_length,
input_rows_ptr, input_rows_length,
join_row_ptr, join_row_length,
&output_rows, &output_rows_length));
}

Expand All @@ -593,7 +551,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(

env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0);
env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0);
env->ReleaseByteArrayElements(join_row, (jbyte *) join_row_ptr, 0);

return ret;
}
Expand Down
6 changes: 1 addition & 5 deletions src/enclave/App/SGXEnclave.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,9 @@ extern "C" {
JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ExternalSort(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ScanCollectLastPrimary(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray);
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);

JNIEXPORT jobject JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
Expand Down
19 changes: 0 additions & 19 deletions src/enclave/Enclave/Enclave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,35 +145,16 @@ void ecall_external_sort(uint8_t *sort_order, size_t sort_order_length,
}
}

void ecall_scan_collect_last_primary(uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t **output_rows, size_t *output_rows_length) {
// Guard against operating on arbitrary enclave memory
assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1);
__builtin_ia32_lfence();

try {
scan_collect_last_primary(join_expr, join_expr_length,
input_rows, input_rows_length,
output_rows, output_rows_length);
} catch (const std::runtime_error &e) {
ocall_throw(e.what());
}
}

void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t *join_row, size_t join_row_length,
uint8_t **output_rows, size_t *output_rows_length) {
// Guard against operating on arbitrary enclave memory
assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1);
assert(oe_is_outside_enclave(join_row, join_row_length) == 1);
__builtin_ia32_lfence();

try {
non_oblivious_sort_merge_join(join_expr, join_expr_length,
input_rows, input_rows_length,
join_row, join_row_length,
output_rows, output_rows_length);
} catch (const std::runtime_error &e) {
ocall_throw(e.what());
Expand Down
6 changes: 0 additions & 6 deletions src/enclave/Enclave/Enclave.edl
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,9 @@ enclave {
[user_check] uint8_t *input_rows, size_t input_rows_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);

public void ecall_scan_collect_last_primary(
[in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length,
[user_check] uint8_t *input_rows, size_t input_rows_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);

public void ecall_non_oblivious_sort_merge_join(
[in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length,
[user_check] uint8_t *input_rows, size_t input_rows_length,
[user_check] uint8_t *join_row, size_t join_row_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);

public void ecall_non_oblivious_aggregate(
Expand Down
6 changes: 6 additions & 0 deletions src/enclave/Enclave/ExpressionEvaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,7 @@ class FlatbuffersJoinExprEvaluator {
}

const tuix::JoinExpr* join_expr = flatbuffers::GetRoot<tuix::JoinExpr>(buf);
join_type = join_expr->join_type();

if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) {
throw std::runtime_error("Mismatched join key lengths");
Expand Down Expand Up @@ -1738,8 +1739,13 @@ class FlatbuffersJoinExprEvaluator {
return true;
}

tuix::JoinType get_join_type() {
return join_type;
}

private:
flatbuffers::FlatBufferBuilder builder;
tuix::JoinType join_type;
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> left_key_evaluators;
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> right_key_evaluators;
};
Expand Down
84 changes: 40 additions & 44 deletions src/enclave/Enclave/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,59 +5,20 @@
#include "FlatbuffersWriters.h"
#include "common.h"

void scan_collect_last_primary(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t **output_rows, size_t *output_rows_length) {

FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length);
RowReader r(BufferRefView<tuix::EncryptedBlocks>(input_rows, input_rows_length));
RowWriter w;

FlatbuffersTemporaryRow last_primary;

// Accumulate all primary table rows from the same group as the last primary row into `w`.
//
// Because our distributed sorting algorithm uses range partitioning over the join keys, all
// primary rows belonging to the same group will be colocated in the same partition. (The
// corresponding foreign rows may be in the same partition or the next partition.) Therefore it is
// sufficient to send primary rows at most one partition forward.
while (r.has_next()) {
const tuix::Row *row = r.next();
if (join_expr_eval.is_primary(row)) {
if (!last_primary.get() || !join_expr_eval.is_same_group(last_primary.get(), row)) {
w.clear();
last_primary.set(row);
}

w.append(row);
} else {
w.clear();
last_primary.set(nullptr);
}
}

w.output_buffer(output_rows, output_rows_length);
}

void non_oblivious_sort_merge_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t *join_row, size_t join_row_length,
uint8_t **output_rows, size_t *output_rows_length) {

FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length);
tuix::JoinType join_type = join_expr_eval.get_join_type();
RowReader r(BufferRefView<tuix::EncryptedBlocks>(input_rows, input_rows_length));
RowReader j(BufferRefView<tuix::EncryptedBlocks>(join_row, join_row_length));
RowWriter w;

RowWriter primary_group;
FlatbuffersTemporaryRow last_primary_of_group;
while (j.has_next()) {
const tuix::Row *row = j.next();
primary_group.append(row);
last_primary_of_group.set(row);
}

bool pk_fk_match = false;

while (r.has_next()) {
const tuix::Row *current = r.next();
Expand All @@ -69,10 +30,22 @@ void non_oblivious_sort_merge_join(
primary_group.append(current);
last_primary_of_group.set(current);
} else {
// Advance to a new group
// If a new primary group is encountered
if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
auto primary_group_buffer = primary_group.output_buffer();
RowReader primary_group_reader(primary_group_buffer.view());

while (primary_group_reader.has_next()) {
const tuix::Row *primary = primary_group_reader.next();
w.append(primary);
}
}

primary_group.clear();
primary_group.append(current);
last_primary_of_group.set(current);

pk_fk_match = false;
}
} else {
// Output the joined rows resulting from this foreign row
Expand All @@ -92,11 +65,34 @@ void non_oblivious_sort_merge_join(
+ to_string(current));
}

w.append(primary, current);
if (join_type == tuix::JoinType_Inner) {
w.append(primary, current);
} else if (join_type == tuix::JoinType_LeftSemi) {
// Only output the pk group ONCE
if (!pk_fk_match) {
w.append(primary);
}
}
}

pk_fk_match = true;
} else {
// If pk_fk_match were true, and the code got to here, then that means the group match has not been "cleared" yet
// It will be processed when the code advances to the next pk group
pk_fk_match &= true;
}
}
}

if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) {
auto primary_group_buffer = primary_group.output_buffer();
RowReader primary_group_reader(primary_group_buffer.view());

while (primary_group_reader.has_next()) {
const tuix::Row *primary = primary_group_reader.next();
w.append(primary);
}
}

w.output_buffer(output_rows, output_rows_length);
}
6 changes: 0 additions & 6 deletions src/enclave/Enclave/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,9 @@
#ifndef JOIN_H
#define JOIN_H

void scan_collect_last_primary(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t **output_rows, size_t *output_rows_length);

void non_oblivious_sort_merge_join(
uint8_t *join_expr, size_t join_expr_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t *join_row, size_t join_row_length,
uint8_t **output_rows, size_t *output_rows_length);

#endif
Loading