Skip to content

Commit 8cfa2a1

Browse files
eric-feng-2011Eric Fengwzheng
authored
Merge master into multiparty (#134)
* Interval (#116) * add date_add, interval sql still running into issues * Add Interval SQL support * uncomment out the other tests * resolve comments * change interval equality Co-authored-by: Eric Feng <[email protected]> * Fix NULL handling for aggregation (#130) * Modify COUNT and SUM to correctly handle NULL values * Change average to support NULL values * Fix * Changing operator matching from logical to physical (#129) * WIP * Fix * Unapply change * Aggregation rewrite (#132) Co-authored-by: Eric Feng <[email protected]> Co-authored-by: Wenting Zheng <[email protected]>
1 parent 9c87e8e commit 8cfa2a1

File tree

17 files changed

+687
-468
lines changed

17 files changed

+687
-468
lines changed

build.sbt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ concurrentRestrictions in Global := Seq(
2424
fork in Test := true
2525
fork in run := true
2626

27+
testOptions in Test += Tests.Argument("-oF")
2728
javaOptions in Test ++= Seq("-Xmx2048m", "-XX:ReservedCodeCacheSize=384m")
2829
javaOptions in run ++= Seq(
2930
"-Xmx2048m", "-XX:ReservedCodeCacheSize=384m", "-Dspark.master=local[1]")

src/enclave/App/App.cpp

Lines changed: 10 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,8 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
599599
}
600600

601601
JNIEXPORT jobject JNICALL
602-
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep1(
603-
JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows) {
602+
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
603+
JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows, jboolean isPartial) {
604604
(void)obj;
605605

606606
jboolean if_copy;
@@ -611,98 +611,21 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep1
611611
uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows);
612612
uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy);
613613

614-
uint8_t *first_row = nullptr;
615-
size_t first_row_length = 0;
616-
617-
uint8_t *last_group = nullptr;
618-
size_t last_group_length = 0;
619-
620-
uint8_t *last_row = nullptr;
621-
size_t last_row_length = 0;
622-
623-
if (input_rows_ptr == nullptr) {
624-
ocall_throw("NonObliviousAggregateStep1: JNI failed to get input byte array.");
625-
} else {
626-
oe_check_and_time("Non-Oblivious Aggregate Step 1",
627-
ecall_non_oblivious_aggregate_step1(
628-
(oe_enclave_t*)eid,
629-
agg_op_ptr, agg_op_length,
630-
input_rows_ptr, input_rows_length,
631-
&first_row, &first_row_length,
632-
&last_group, &last_group_length,
633-
&last_row, &last_row_length));
634-
}
635-
636-
jbyteArray first_row_array = env->NewByteArray(first_row_length);
637-
env->SetByteArrayRegion(first_row_array, 0, first_row_length, (jbyte *) first_row);
638-
free(first_row);
639-
640-
jbyteArray last_group_array = env->NewByteArray(last_group_length);
641-
env->SetByteArrayRegion(last_group_array, 0, last_group_length, (jbyte *) last_group);
642-
free(last_group);
643-
644-
jbyteArray last_row_array = env->NewByteArray(last_row_length);
645-
env->SetByteArrayRegion(last_row_array, 0, last_row_length, (jbyte *) last_row);
646-
free(last_row);
647-
648-
env->ReleaseByteArrayElements(agg_op, (jbyte *) agg_op_ptr, 0);
649-
env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0);
650-
651-
jclass tuple3_class = env->FindClass("scala/Tuple3");
652-
jobject ret = env->NewObject(
653-
tuple3_class,
654-
env->GetMethodID(tuple3_class, "<init>",
655-
"(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)V"),
656-
first_row_array, last_group_array, last_row_array);
657-
658-
return ret;
659-
}
660-
661-
JNIEXPORT jbyteArray JNICALL
662-
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep2(
663-
JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows,
664-
jbyteArray next_partition_first_row, jbyteArray prev_partition_last_group,
665-
jbyteArray prev_partition_last_row) {
666-
(void)obj;
667-
668-
jboolean if_copy;
669-
670-
uint32_t agg_op_length = (uint32_t) env->GetArrayLength(agg_op);
671-
uint8_t *agg_op_ptr = (uint8_t *) env->GetByteArrayElements(agg_op, &if_copy);
672-
673-
uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows);
674-
uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy);
675-
676-
uint32_t next_partition_first_row_length =
677-
(uint32_t) env->GetArrayLength(next_partition_first_row);
678-
uint8_t *next_partition_first_row_ptr =
679-
(uint8_t *) env->GetByteArrayElements(next_partition_first_row, &if_copy);
680-
681-
uint32_t prev_partition_last_group_length =
682-
(uint32_t) env->GetArrayLength(prev_partition_last_group);
683-
uint8_t *prev_partition_last_group_ptr =
684-
(uint8_t *) env->GetByteArrayElements(prev_partition_last_group, &if_copy);
685-
686-
uint32_t prev_partition_last_row_length =
687-
(uint32_t) env->GetArrayLength(prev_partition_last_row);
688-
uint8_t *prev_partition_last_row_ptr =
689-
(uint8_t *) env->GetByteArrayElements(prev_partition_last_row, &if_copy);
690-
691614
uint8_t *output_rows = nullptr;
692615
size_t output_rows_length = 0;
693616

617+
bool is_partial = (bool) isPartial;
618+
694619
if (input_rows_ptr == nullptr) {
695-
ocall_throw("NonObliviousAggregateStep2: JNI failed to get input byte array.");
620+
ocall_throw("NonObliviousAggregateStep: JNI failed to get input byte array.");
696621
} else {
697-
oe_check_and_time("Non-Oblivious Aggregate Step 2",
698-
ecall_non_oblivious_aggregate_step2(
622+
oe_check_and_time("Non-Oblivious Aggregate",
623+
ecall_non_oblivious_aggregate(
699624
(oe_enclave_t*)eid,
700625
agg_op_ptr, agg_op_length,
701626
input_rows_ptr, input_rows_length,
702-
next_partition_first_row_ptr, next_partition_first_row_length,
703-
prev_partition_last_group_ptr, prev_partition_last_group_length,
704-
prev_partition_last_row_ptr, prev_partition_last_row_length,
705-
&output_rows, &output_rows_length));
627+
&output_rows, &output_rows_length,
628+
is_partial));
706629
}
707630

708631
jbyteArray ret = env->NewByteArray(output_rows_length);
@@ -711,13 +634,7 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep2
711634

712635
env->ReleaseByteArrayElements(agg_op, (jbyte *) agg_op_ptr, 0);
713636
env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0);
714-
env->ReleaseByteArrayElements(
715-
next_partition_first_row, (jbyte *) next_partition_first_row_ptr, 0);
716-
env->ReleaseByteArrayElements(
717-
prev_partition_last_group, (jbyte *) prev_partition_last_group_ptr, 0);
718-
env->ReleaseByteArrayElements(
719-
prev_partition_last_row, (jbyte *) prev_partition_last_row_ptr, 0);
720-
637+
721638
return ret;
722639
}
723640

src/enclave/App/SGXEnclave.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,8 @@ extern "C" {
4646
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray);
4747

4848
JNIEXPORT jobject JNICALL
49-
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep1(
50-
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);
51-
52-
JNIEXPORT jbyteArray JNICALL
53-
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep2(
54-
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray, jbyteArray, jbyteArray);
49+
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
50+
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jboolean);
5551

5652
JNIEXPORT jbyteArray JNICALL
5753
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_CountRowsPerPartition(

src/enclave/Enclave/Aggregate.cpp

Lines changed: 15 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -5,116 +5,38 @@
55
#include "FlatbuffersWriters.h"
66
#include "common.h"
77

8-
void non_oblivious_aggregate_step1(
8+
void non_oblivious_aggregate(
99
uint8_t *agg_op, size_t agg_op_length,
1010
uint8_t *input_rows, size_t input_rows_length,
11-
uint8_t **first_row, size_t *first_row_length,
12-
uint8_t **last_group, size_t *last_group_length,
13-
uint8_t **last_row, size_t *last_row_length) {
11+
uint8_t **output_rows, size_t *output_rows_length,
12+
bool is_partial) {
1413

1514
FlatbuffersAggOpEvaluator agg_op_eval(agg_op, agg_op_length);
1615
RowReader r(BufferRefView<tuix::EncryptedBlocks>(input_rows, input_rows_length));
17-
RowWriter first_row_writer;
18-
RowWriter last_group_writer;
19-
RowWriter last_row_writer;
16+
RowWriter w;
2017

2118
FlatbuffersTemporaryRow prev, cur;
19+
size_t count = 0;
20+
2221
while (r.has_next()) {
2322
prev.set(cur.get());
2423
cur.set(r.next());
25-
26-
if (prev.get() == nullptr) {
27-
first_row_writer.append(cur.get());
28-
}
29-
30-
if (!r.has_next()) {
31-
last_row_writer.append(cur.get());
32-
}
33-
24+
3425
if (prev.get() != nullptr && !agg_op_eval.is_same_group(prev.get(), cur.get())) {
26+
w.append(agg_op_eval.evaluate());
3527
agg_op_eval.reset_group();
3628
}
3729
agg_op_eval.aggregate(cur.get());
30+
count += 1;
3831
}
39-
last_group_writer.append(agg_op_eval.get_partial_agg());
40-
41-
first_row_writer.output_buffer(first_row, first_row_length);
42-
last_group_writer.output_buffer(last_group, last_group_length);
43-
last_row_writer.output_buffer(last_row, last_row_length);
44-
}
45-
46-
void non_oblivious_aggregate_step2(
47-
uint8_t *agg_op, size_t agg_op_length,
48-
uint8_t *input_rows, size_t input_rows_length,
49-
uint8_t *next_partition_first_row, size_t next_partition_first_row_length,
50-
uint8_t *prev_partition_last_group, size_t prev_partition_last_group_length,
51-
uint8_t *prev_partition_last_row, size_t prev_partition_last_row_length,
52-
uint8_t **output_rows, size_t *output_rows_length) {
53-
54-
FlatbuffersAggOpEvaluator agg_op_eval(agg_op, agg_op_length);
55-
RowReader r(BufferRefView<tuix::EncryptedBlocks>(input_rows, input_rows_length));
56-
RowReader next_partition_first_row_reader(
57-
BufferRefView<tuix::EncryptedBlocks>(
58-
next_partition_first_row, next_partition_first_row_length));
59-
RowReader prev_partition_last_group_reader(
60-
BufferRefView<tuix::EncryptedBlocks>(
61-
prev_partition_last_group, prev_partition_last_group_length));
62-
RowReader prev_partition_last_row_reader(
63-
BufferRefView<tuix::EncryptedBlocks>(
64-
prev_partition_last_row, prev_partition_last_row_length));
65-
RowWriter w;
66-
67-
if (next_partition_first_row_reader.num_rows() > 1) {
68-
throw std::runtime_error(
69-
std::string("Incorrect number of starting rows from next partition passed: expected 0 or 1, got ")
70-
+ std::to_string(next_partition_first_row_reader.num_rows()));
71-
}
72-
if (prev_partition_last_group_reader.num_rows() > 1) {
73-
throw std::runtime_error(
74-
std::string("Incorrect number of ending groups from prev partition passed: expected 0 or 1, got ")
75-
+ std::to_string(prev_partition_last_group_reader.num_rows()));
76-
}
77-
if (prev_partition_last_row_reader.num_rows() > 1) {
78-
throw std::runtime_error(
79-
std::string("Incorrect number of ending rows from prev partition passed: expected 0 or 1, got ")
80-
+ std::to_string(prev_partition_last_row_reader.num_rows()));
81-
}
82-
83-
const tuix::Row *next_partition_first_row_ptr =
84-
next_partition_first_row_reader.has_next() ? next_partition_first_row_reader.next() : nullptr;
85-
agg_op_eval.set(prev_partition_last_group_reader.has_next() ?
86-
prev_partition_last_group_reader.next() : nullptr);
87-
const tuix::Row *prev_partition_last_row_ptr =
88-
prev_partition_last_row_reader.has_next() ? prev_partition_last_row_reader.next() : nullptr;
8932

90-
FlatbuffersTemporaryRow prev, cur(prev_partition_last_row_ptr), next;
91-
bool stop = false;
92-
if (r.has_next()) {
93-
next.set(r.next());
94-
} else {
95-
stop = true;
96-
}
97-
while (!stop) {
98-
// Populate prev, cur, next to enable lookbehind and lookahead
99-
prev.set(cur.get());
100-
cur.set(next.get());
101-
if (r.has_next()) {
102-
next.set(r.next());
103-
} else {
104-
next.set(next_partition_first_row_ptr);
105-
stop = true;
106-
}
107-
108-
if (prev.get() != nullptr && !agg_op_eval.is_same_group(prev.get(), cur.get())) {
109-
agg_op_eval.reset_group();
110-
}
111-
agg_op_eval.aggregate(cur.get());
112-
113-
// Output the current aggregate if it is the last aggregate for its run
114-
if (next.get() == nullptr || !agg_op_eval.is_same_group(cur.get(), next.get())) {
115-
w.append(agg_op_eval.evaluate());
116-
}
33+
// Skip outputting the final row if the number of input rows is 0 AND
34+
// 1. It's a grouping aggregation, OR
35+
// 2. It's a global aggregation, the mode is final
36+
if (!(count == 0 && (agg_op_eval.get_num_grouping_keys() > 0 || (agg_op_eval.get_num_grouping_keys() == 0 && !is_partial)))) {
37+
w.append(agg_op_eval.evaluate());
11738
}
11839

11940
w.output_buffer(output_rows, output_rows_length);
12041
}
42+

src/enclave/Enclave/Aggregate.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,10 @@
44
#ifndef AGGREGATE_H
55
#define AGGREGATE_H
66

7-
void non_oblivious_aggregate_step1(
7+
void non_oblivious_aggregate(
88
uint8_t *agg_op, size_t agg_op_length,
99
uint8_t *input_rows, size_t input_rows_length,
10-
uint8_t **first_row, size_t *first_row_length,
11-
uint8_t **last_group, size_t *last_group_length,
12-
uint8_t **last_row, size_t *last_row_length);
13-
14-
void non_oblivious_aggregate_step2(
15-
uint8_t *agg_op, size_t agg_op_length,
16-
uint8_t *input_rows, size_t input_rows_length,
17-
uint8_t *next_partition_first_row, size_t next_partition_first_row_length,
18-
uint8_t *prev_partition_last_group, size_t prev_partition_last_group_length,
19-
uint8_t *prev_partition_last_row, size_t prev_partition_last_row_length,
20-
uint8_t **output_rows, size_t *output_rows_length);
10+
uint8_t **output_rows, size_t *output_rows_length,
11+
bool is_partial);
2112

2213
#endif // AGGREGATE_H

src/enclave/Enclave/Enclave.cpp

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -190,50 +190,21 @@ void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_le
190190
}
191191
}
192192

193-
void ecall_non_oblivious_aggregate_step1(
193+
void ecall_non_oblivious_aggregate(
194194
uint8_t *agg_op, size_t agg_op_length,
195195
uint8_t *input_rows, size_t input_rows_length,
196-
uint8_t **first_row, size_t *first_row_length,
197-
uint8_t **last_group, size_t *last_group_length,
198-
uint8_t **last_row, size_t *last_row_length) {
196+
uint8_t **output_rows, size_t *output_rows_length,
197+
bool is_partial) {
199198
// Guard against operating on arbitrary enclave memory
200199
assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1);
201200
__builtin_ia32_lfence();
202201

203202
try {
204-
non_oblivious_aggregate_step1(
205-
agg_op, agg_op_length,
206-
input_rows, input_rows_length,
207-
first_row, first_row_length,
208-
last_group, last_group_length,
209-
last_row, last_row_length);
210-
} catch (const std::runtime_error &e) {
211-
ocall_throw(e.what());
212-
}
213-
}
214-
215-
void ecall_non_oblivious_aggregate_step2(
216-
uint8_t *agg_op, size_t agg_op_length,
217-
uint8_t *input_rows, size_t input_rows_length,
218-
uint8_t *next_partition_first_row, size_t next_partition_first_row_length,
219-
uint8_t *prev_partition_last_group, size_t prev_partition_last_group_length,
220-
uint8_t *prev_partition_last_row, size_t prev_partition_last_row_length,
221-
uint8_t **output_rows, size_t *output_rows_length) {
222-
// Guard against operating on arbitrary enclave memory
223-
assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1);
224-
assert(oe_is_outside_enclave(next_partition_first_row, next_partition_first_row_length) == 1);
225-
assert(oe_is_outside_enclave(prev_partition_last_group, prev_partition_last_group_length) == 1);
226-
assert(oe_is_outside_enclave(prev_partition_last_row, prev_partition_last_row_length) == 1);
227-
__builtin_ia32_lfence();
228-
229-
try {
230-
non_oblivious_aggregate_step2(
231-
agg_op, agg_op_length,
232-
input_rows, input_rows_length,
233-
next_partition_first_row, next_partition_first_row_length,
234-
prev_partition_last_group, prev_partition_last_group_length,
235-
prev_partition_last_row, prev_partition_last_row_length,
236-
output_rows, output_rows_length);
203+
non_oblivious_aggregate(agg_op, agg_op_length,
204+
input_rows, input_rows_length,
205+
output_rows, output_rows_length,
206+
is_partial);
207+
237208
} catch (const std::runtime_error &e) {
238209
ocall_throw(e.what());
239210
}

src/enclave/Enclave/Enclave.edl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,11 @@ enclave {
5454
[user_check] uint8_t *join_row, size_t join_row_length,
5555
[out] uint8_t **output_rows, [out] size_t *output_rows_length);
5656

57-
public void ecall_non_oblivious_aggregate_step1(
57+
public void ecall_non_oblivious_aggregate(
5858
[in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length,
5959
[user_check] uint8_t *input_rows, size_t input_rows_length,
60-
[out] uint8_t **first_row, [out] size_t *first_row_length,
61-
[out] uint8_t **last_group, [out] size_t *last_group_length,
62-
[out] uint8_t **last_row, [out] size_t *last_row_length);
63-
64-
public void ecall_non_oblivious_aggregate_step2(
65-
[in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length,
66-
[user_check] uint8_t *input_rows, size_t input_rows_length,
67-
[user_check] uint8_t *next_partition_first_row, size_t next_partition_first_row_length,
68-
[user_check] uint8_t *prev_partition_last_group, size_t prev_partition_last_group_length,
69-
[user_check] uint8_t *prev_partition_last_row, size_t prev_partition_last_row_length,
70-
[out] uint8_t **output_rows, [out] size_t *output_rows_length);
60+
[out] uint8_t **output_rows, [out] size_t *output_rows_length,
61+
bool is_partial);
7162

7263
public void ecall_count_rows_per_partition(
7364
[user_check] uint8_t *input_rows, size_t input_rows_length,

0 commit comments

Comments
 (0)