diff --git a/.github/scripts/build.sh b/.github/scripts/build.sh new file mode 100755 index 0000000000..4662f1ec2d --- /dev/null +++ b/.github/scripts/build.sh @@ -0,0 +1,25 @@ +# Install OpenEnclave 0.9.0 +echo 'deb [arch=amd64] https://download.01.org/intel-sgx/sgx_repo/ubuntu bionic main' | sudo tee /etc/apt/sources.list.d/intel-sgx.list +wget -qO - https://download.01.org/intel-sgx/sgx_repo/ubuntu/intel-sgx-deb.key | sudo apt-key add - +echo "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-7 main" | sudo tee /etc/apt/sources.list.d/llvm-toolchain-bionic-7.list +wget -qO - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - +echo "deb [arch=amd64] https://packages.microsoft.com/ubuntu/18.04/prod bionic main" | sudo tee /etc/apt/sources.list.d/msprod.list +wget -qO - https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add - + +sudo apt update +sudo apt -y install clang-7 libssl-dev gdb libsgx-enclave-common libsgx-enclave-common-dev libprotobuf10 libsgx-dcap-ql libsgx-dcap-ql-dev az-dcap-client open-enclave=0.12.0 + +# Install Opaque Dependencies +sudo apt -y install wget build-essential openjdk-8-jdk python libssl-dev + +wget https://github.com/Kitware/CMake/releases/download/v3.15.6/cmake-3.15.6-Linux-x86_64.sh +sudo bash cmake-3.15.6-Linux-x86_64.sh --skip-license --prefix=/usr/local + +# Generate keypair for attestation +openssl genrsa -out ./private_key.pem -3 3072 + +source opaqueenv +source /opt/openenclave/share/openenclave/openenclaverc +export MODE=SIMULATE + +build/sbt test diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000000..f4695ac8b8 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,40 @@ +name: CI + +# Controls when the action will run. +on: + # Triggers the workflow on push or pull request events but only for the master branch + push: + branches: [ master ] + pull_request: + branches: [ master ] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# A workflow run is made up of one or more jobs that can run sequentially or in parallel +jobs: + build: + # Define the OS to run on + runs-on: ubuntu-18.04 + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + # Specify the version of Java that is installed + - uses: actions/setup-java@v1 + with: + java-version: '8' + # Caching (from https://www.scala-sbt.org/1.x/docs/GitHub-Actions-with-sbt.html) + - uses: coursier/cache-action@v5 + # Run the test + - name: Install dependencies, set environment variables, and run sbt tests + run: | + ./.github/scripts/build.sh + + rm -rf "$HOME/.ivy2/local" || true + find $HOME/Library/Caches/Coursier/v1 -name "ivydata-*.properties" -delete || true + find $HOME/.ivy2/cache -name "ivydata-*.properties" -delete || true + find $HOME/.cache/coursier/v1 -name "ivydata-*.properties" -delete || true + find $HOME/.sbt -name "*.lock" -delete || true + shell: bash + diff --git a/.travis.yml b/.travis.yml index f3e91c6831..4f1ee055ac 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,7 @@ before_install: - sudo apt update - sudo apt -y install clang-7 libssl-dev gdb libsgx-enclave-common libsgx-enclave-common-dev libprotobuf10 libsgx-dcap-ql libsgx-dcap-ql-dev - sudo apt-get -y install wget build-essential openjdk-8-jdk python libssl-dev - - sudo apt-get -y install open-enclave=0.9.0 + - sudo apt-get -y install open-enclave=0.12.0 - wget https://github.com/Kitware/CMake/releases/download/v3.15.6/cmake-3.15.6-Linux-x86_64.sh - sudo bash cmake-3.15.6-Linux-x86_64.sh --skip-license --prefix=/usr/local - export PATH=/usr/local/bin:"$PATH" diff --git a/README.md b/README.md index a5e606e134..2cc4a41f3d 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ UDFs must be [implemented in C++](#user-defined-functions-udfs). After downloading the Opaque codebase, build and test it as follows. -1. Install dependencies and the [OpenEnclave SDK](https://github.com/openenclave/openenclave/blob/v0.9.x/docs/GettingStartedDocs/install_oe_sdk-Ubuntu_18.04.md). We currently support OE version 0.9.0 (so please install with `open-enclave=0.9.0`) and Ubuntu 18.04. +1. Install dependencies and the [OpenEnclave SDK](https://github.com/openenclave/openenclave/blob/v0.12.0/docs/GettingStartedDocs/install_oe_sdk-Ubuntu_18.04.md). We currently support OE version 0.12.0 (so please install with `open-enclave=0.12.0`) and Ubuntu 18.04. ```sh # For Ubuntu 18.04: @@ -206,7 +206,3 @@ Now we can port this UDF to Opaque as follows: ``` 3. Finally, implement the UDF in C++. In [`FlatbuffersExpressionEvaluator#eval_helper`](src/enclave/Enclave/ExpressionEvaluation.h), add a case for `tuix::ExprUnion_DotProduct`. Within that case, cast the expression to a `tuix::DotProduct`, recursively evaluate the left and right children, perform the dot product computation on them, and construct a `DoubleField` containing the result. - -## Contact - -If you want to know more about our project or have questions, please contact Wenting (wzheng13@gmail.com) and/or Ankur (ankurdave@gmail.com). diff --git a/src/enclave/App/App.cpp b/src/enclave/App/App.cpp index 99c9a23965..74c9da868a 100644 --- a/src/enclave/App/App.cpp +++ b/src/enclave/App/App.cpp @@ -555,6 +555,50 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( return ret; } +JNIEXPORT jbyteArray JNICALL +Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin( + JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray outer_rows, jbyteArray inner_rows) { + (void)obj; + + jboolean if_copy; + + uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr); + uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy); + + uint32_t outer_rows_length = (uint32_t) env->GetArrayLength(outer_rows); + uint8_t *outer_rows_ptr = (uint8_t *) env->GetByteArrayElements(outer_rows, &if_copy); + + uint32_t inner_rows_length = (uint32_t) env->GetArrayLength(inner_rows); + uint8_t *inner_rows_ptr = (uint8_t *) env->GetByteArrayElements(inner_rows, &if_copy); + + uint8_t *output_rows = nullptr; + size_t output_rows_length = 0; + + if (outer_rows_ptr == nullptr) { + ocall_throw("BroadcastNestedLoopJoin: JNI failed to get inner byte array."); + } else if (inner_rows_ptr == nullptr) { + ocall_throw("BroadcastNestedLoopJoin: JNI failed to get outer byte array."); + } else { + oe_check_and_time("Broadcast Nested Loop Join", + ecall_broadcast_nested_loop_join( + (oe_enclave_t*)eid, + join_expr_ptr, join_expr_length, + outer_rows_ptr, outer_rows_length, + inner_rows_ptr, inner_rows_length, + &output_rows, &output_rows_length)); + } + + jbyteArray ret = env->NewByteArray(output_rows_length); + env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows); + free(output_rows); + + env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0); + env->ReleaseByteArrayElements(outer_rows, (jbyte *) outer_rows_ptr, 0); + env->ReleaseByteArrayElements(inner_rows, (jbyte *) inner_rows_ptr, 0); + + return ret; +} + JNIEXPORT jobject JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate( JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows, jboolean isPartial) { diff --git a/src/enclave/App/CMakeLists.txt b/src/enclave/App/CMakeLists.txt index e2f6cf6f60..44c0ae648e 100644 --- a/src/enclave/App/CMakeLists.txt +++ b/src/enclave/App/CMakeLists.txt @@ -7,7 +7,10 @@ set(SOURCES ${CMAKE_CURRENT_BINARY_DIR}/Enclave_u.c) add_custom_command( - COMMAND oeedger8r --untrusted ${CMAKE_SOURCE_DIR}/Enclave/Enclave.edl --search-path ${CMAKE_SOURCE_DIR}/Enclave + COMMAND oeedger8r --untrusted ${CMAKE_SOURCE_DIR}/Enclave/Enclave.edl + --search-path ${CMAKE_SOURCE_DIR}/Enclave + --search-path ${OE_INCLUDEDIR} + --search-path ${OE_INCLUDEDIR}/openenclave/edl/sgx DEPENDS ${CMAKE_SOURCE_DIR}/Enclave/Enclave.edl OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/Enclave_u.h ${CMAKE_CURRENT_BINARY_DIR}/Enclave_u.c ${CMAKE_CURRENT_BINARY_DIR}/Enclave_args.h) @@ -22,6 +25,6 @@ if ("$ENV{MODE}" STREQUAL "SIMULATE") target_compile_definitions(enclave_jni PUBLIC -DSIMULATE) endif() -target_link_libraries(enclave_jni openenclave::oehost openenclave::oehostverify) +target_link_libraries(enclave_jni openenclave::oehost) install(TARGETS enclave_jni DESTINATION lib) diff --git a/src/enclave/App/SGXEnclave.h b/src/enclave/App/SGXEnclave.h index 2b74c42763..1ddd0d8497 100644 --- a/src/enclave/App/SGXEnclave.h +++ b/src/enclave/App/SGXEnclave.h @@ -41,6 +41,10 @@ extern "C" { Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( JNIEnv *, jobject, jlong, jbyteArray, jbyteArray); + JNIEXPORT jbyteArray JNICALL + Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin( + JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray); + JNIEXPORT jobject JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate( JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jboolean); diff --git a/src/enclave/CMakeLists.txt b/src/enclave/CMakeLists.txt index e29a67be65..d2ca34aa46 100644 --- a/src/enclave/CMakeLists.txt +++ b/src/enclave/CMakeLists.txt @@ -1,13 +1,17 @@ cmake_minimum_required(VERSION 3.13) project(OpaqueEnclave) - enable_language(ASM) option(FLATBUFFERS_LIB_DIR "Location of Flatbuffers library headers.") option(FLATBUFFERS_GEN_CPP_DIR "Location of Flatbuffers generated C++ files.") -find_package(OpenEnclave CONFIG REQUIRED) +set(OE_MIN_VERSION 0.12.0) +find_package(OpenEnclave ${OE_MIN_VERSION} CONFIG REQUIRED) + +set(OE_CRYPTO_LIB + mbed + CACHE STRING "Crypto library used by enclaves.") include_directories(App) include_directories(${CMAKE_BINARY_DIR}/App) @@ -18,7 +22,7 @@ include_directories(${CMAKE_BINARY_DIR}/Enclave) include_directories(ServiceProvider) include_directories(${FLATBUFFERS_LIB_DIR}) include_directories(${FLATBUFFERS_GEN_CPP_DIR}) -include_directories("/opt/openenclave/include") +include_directories(${OE_INCLUDEDIR}) if(CMAKE_SIZEOF_VOID_P EQUAL 4) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -m32") @@ -31,14 +35,11 @@ set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -DDEBUG -UNDEBUG -UED set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O2 -DNDEBUG -DEDEBUG -UDEBUG") set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_PROFILE} -O2 -DNDEBUG -DEDEBUG -UDEBUG -DPERF") -message("openssl rsa -in $ENV{OPAQUE_HOME}/private_key.pem -pubout -out $ENV{OPAQUE_HOME}/public_key.pub") -message("$ENV{OPAQUE_HOME}/public_key.pub") - add_custom_target(run ALL DEPENDS $ENV{OPAQUE_HOME}/public_key.pub) add_custom_command( - COMMAND openssl rsa -in $ENV{OPAQUE_HOME}/private_key.pem -pubout -out $ENV{OPAQUE_HOME}/public_key.pub + COMMAND openssl rsa -in $ENV{PRIVATE_KEY_PATH} -pubout -out $ENV{OPAQUE_HOME}/public_key.pub OUTPUT $ENV{OPAQUE_HOME}/public_key.pub) add_subdirectory(App) diff --git a/src/enclave/Enclave/BroadcastNestedLoopJoin.cpp b/src/enclave/Enclave/BroadcastNestedLoopJoin.cpp new file mode 100644 index 0000000000..cd6bfabbd2 --- /dev/null +++ b/src/enclave/Enclave/BroadcastNestedLoopJoin.cpp @@ -0,0 +1,54 @@ +#include "BroadcastNestedLoopJoin.h" + +#include "ExpressionEvaluation.h" +#include "FlatbuffersReaders.h" +#include "FlatbuffersWriters.h" +#include "common.h" + +/** C++ implementation of a broadcast nested loop join. + * Assumes outer_rows is streamed and inner_rows is broadcast. + * DOES NOT rely on rows to be tagged primary or secondary, and that + * assumption will break the implementation. + */ +void broadcast_nested_loop_join( + uint8_t *join_expr, size_t join_expr_length, + uint8_t *outer_rows, size_t outer_rows_length, + uint8_t *inner_rows, size_t inner_rows_length, + uint8_t **output_rows, size_t *output_rows_length) { + + FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length); + const tuix::JoinType join_type = join_expr_eval.get_join_type(); + + RowReader outer_r(BufferRefView(outer_rows, outer_rows_length)); + RowWriter w; + + while (outer_r.has_next()) { + const tuix::Row *outer = outer_r.next(); + bool o_i_match = false; + + RowReader inner_r(BufferRefView(inner_rows, inner_rows_length)); + const tuix::Row *inner; + while (inner_r.has_next()) { + inner = inner_r.next(); + o_i_match |= join_expr_eval.eval_condition(outer, inner); + } + + switch(join_type) { + case tuix::JoinType_LeftAnti: + if (!o_i_match) { + w.append(outer); + } + break; + case tuix::JoinType_LeftSemi: + if (o_i_match) { + w.append(outer); + } + break; + default: + throw std::runtime_error( + std::string("Join type not supported: ") + + std::string(to_string(join_type))); + } + } + w.output_buffer(output_rows, output_rows_length, std::string("broadcastNestedLoopJoin")); +} diff --git a/src/enclave/Enclave/BroadcastNestedLoopJoin.h b/src/enclave/Enclave/BroadcastNestedLoopJoin.h new file mode 100644 index 0000000000..55c934067b --- /dev/null +++ b/src/enclave/Enclave/BroadcastNestedLoopJoin.h @@ -0,0 +1,8 @@ +#include +#include + +void broadcast_nested_loop_join( + uint8_t *join_expr, size_t join_expr_length, + uint8_t *outer_rows, size_t outer_rows_length, + uint8_t *inner_rows, size_t inner_rows_length, + uint8_t **output_rows, size_t *output_rows_length); diff --git a/src/enclave/Enclave/CMakeLists.txt b/src/enclave/Enclave/CMakeLists.txt index 996a0d1742..ab3c1bb856 100644 --- a/src/enclave/Enclave/CMakeLists.txt +++ b/src/enclave/Enclave/CMakeLists.txt @@ -11,7 +11,8 @@ set(SOURCES FlatbuffersReaders.cpp FlatbuffersWriters.cpp IntegrityUtils.cpp - Join.cpp + NonObliviousSortMergeJoin.cpp + BroadcastNestedLoopJoin.cpp Limit.cpp Project.cpp Sort.cpp @@ -23,7 +24,10 @@ set(SOURCES ${CMAKE_CURRENT_BINARY_DIR}/Enclave_t.c) add_custom_command( - COMMAND oeedger8r --trusted ${CMAKE_SOURCE_DIR}/Enclave/Enclave.edl --search-path ${CMAKE_SOURCE_DIR}/Enclave + COMMAND oeedger8r --trusted ${CMAKE_SOURCE_DIR}/Enclave/Enclave.edl + --search-path ${CMAKE_SOURCE_DIR}/Enclave + --search-path ${OE_INCLUDEDIR} + --search-path ${OE_INCLUDEDIR}/openenclave/edl/sgx DEPENDS ${CMAKE_SOURCE_DIR}/Enclave/Enclave.edl OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/Enclave_t.h ${CMAKE_CURRENT_BINARY_DIR}/Enclave_t.c ${CMAKE_CURRENT_BINARY_DIR}/Enclave_args.h) @@ -42,22 +46,21 @@ endif() target_compile_definitions(enclave_trusted PUBLIC OE_API_VERSION=2) # Need for the generated file Enclave_t.h -target_include_directories(enclave_trusted PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +target_include_directories(enclave_trusted PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${OE_INCLUDEDIR}/openenclave/3rdparty) -target_link_libraries(enclave_trusted - openenclave::oeenclave - openenclave::oelibc +link_directories(${OE_LIBDIR} ${OE_LIBDIR}/openenclave/enclave) +target_link_libraries(enclave_trusted + openenclave::oeenclave + openenclave::oecrypto${OE_CRYPTO_LIB} + openenclave::oelibc openenclave::oelibcxx - openenclave::oehostsock - openenclave::oehostresolver) + openenclave::oecore) add_custom_command( - COMMAND oesign sign -e $ -c ${CMAKE_CURRENT_SOURCE_DIR}/Enclave.conf -k $ENV{PRIVATE_KEY_PATH} + COMMAND openenclave::oesign sign -e $ -c ${CMAKE_CURRENT_SOURCE_DIR}/Enclave.conf -k $ENV{PRIVATE_KEY_PATH} DEPENDS enclave_trusted ${CMAKE_CURRENT_SOURCE_DIR}/Enclave.conf OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/enclave_trusted.signed) -# TODO: Use the user-generated private key to sign the enclave code. -# Currently we use the sample private key from the Intel SGX SDK. add_custom_command( COMMAND mv ${CMAKE_CURRENT_BINARY_DIR}/libenclave_trusted.so.signed ${CMAKE_CURRENT_BINARY_DIR}/libenclave_trusted_signed.so DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/enclave_trusted.signed diff --git a/src/enclave/Enclave/Enclave.cpp b/src/enclave/Enclave/Enclave.cpp index b4a4e32680..48537b194e 100644 --- a/src/enclave/Enclave/Enclave.cpp +++ b/src/enclave/Enclave/Enclave.cpp @@ -6,7 +6,8 @@ #include "Aggregate.h" #include "Crypto.h" #include "Filter.h" -#include "Join.h" +#include "NonObliviousSortMergeJoin.h" +#include "BroadcastNestedLoopJoin.h" #include "Limit.h" #include "Project.h" #include "Sort.h" @@ -196,7 +197,6 @@ void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_le __builtin_ia32_lfence(); try { - debug("Ecall: NonObliviousSortMergJoin\n"); non_oblivious_sort_merge_join(join_expr, join_expr_length, input_rows, input_rows_length, output_rows, output_rows_length); @@ -208,6 +208,28 @@ void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_le } } +void ecall_broadcast_nested_loop_join(uint8_t *join_expr, size_t join_expr_length, + uint8_t *outer_rows, size_t outer_rows_length, + uint8_t *inner_rows, size_t inner_rows_length, + uint8_t **output_rows, size_t *output_rows_length) { + // Guard against operating on arbitrary enclave memory + assert(oe_is_outside_enclave(outer_rows, outer_rows_length) == 1); + assert(oe_is_outside_enclave(inner_rows, inner_rows_length) == 1); + __builtin_ia32_lfence(); + + try { + broadcast_nested_loop_join(join_expr, join_expr_length, + outer_rows, outer_rows_length, + inner_rows, inner_rows_length, + output_rows, output_rows_length); + complete_encrypted_blocks(*output_rows); + EnclaveContext::getInstance().finish_ecall(); + } catch (const std::runtime_error &e) { + EnclaveContext::getInstance().finish_ecall(); + ocall_throw(e.what()); + } +} + void ecall_non_oblivious_aggregate( uint8_t *agg_op, size_t agg_op_length, uint8_t *input_rows, size_t input_rows_length, diff --git a/src/enclave/Enclave/Enclave.edl b/src/enclave/Enclave/Enclave.edl index 0225c64efa..1789ff2b64 100644 --- a/src/enclave/Enclave/Enclave.edl +++ b/src/enclave/Enclave/Enclave.edl @@ -3,6 +3,9 @@ enclave { + from "openenclave/edl/syscall.edl" import *; + from "platform.edl" import *; + include "stdbool.h" trusted { @@ -48,6 +51,12 @@ enclave { [user_check] uint8_t *input_rows, size_t input_rows_length, [out] uint8_t **output_rows, [out] size_t *output_rows_length); + public void ecall_broadcast_nested_loop_join( + [in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length, + [user_check] uint8_t *outer_rows, size_t outer_rows_length, + [user_check] uint8_t *inner_rows, size_t inner_rows_length, + [out] uint8_t **output_rows, [out] size_t *output_rows_length); + public void ecall_non_oblivious_aggregate( [in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length, [user_check] uint8_t *input_rows, size_t input_rows_length, diff --git a/src/enclave/Enclave/EnclaveContext.h b/src/enclave/Enclave/EnclaveContext.h index 195b1d878b..a4cfc30868 100644 --- a/src/enclave/Enclave/EnclaveContext.h +++ b/src/enclave/Enclave/EnclaveContext.h @@ -199,7 +199,8 @@ class EnclaveContext { {"countRowsPerPartition", 10}, {"computeNumRowsPerPartition", 11}, {"localLimit", 12}, - {"limitReturnRows", 13} + {"limitReturnRows", 13}, + {"broadcastNestedLoopJoin", 14} }; return ecall_id[ecall]; } diff --git a/src/enclave/Enclave/ExpressionEvaluation.h b/src/enclave/Enclave/ExpressionEvaluation.h index 9405ddd34f..06a2100fc1 100644 --- a/src/enclave/Enclave/ExpressionEvaluation.h +++ b/src/enclave/Enclave/ExpressionEvaluation.h @@ -288,6 +288,49 @@ class FlatbuffersExpressionEvaluator { static_cast(expr->expr())->value(), builder); } + case tuix::ExprUnion_Decrypt: + { + auto decrypt_expr = static_cast(expr->expr()); + const tuix::Field *value = + flatbuffers::GetTemporaryPointer(builder, eval_helper(row, decrypt_expr->value())); + + if (value->value_type() != tuix::FieldUnion_StringField) { + throw std::runtime_error( + std::string("tuix::Decrypt only accepts a string input, not ") + + std::string(tuix::EnumNameFieldUnion(value->value_type()))); + } + + bool result_is_null = value->is_null(); + if (!result_is_null) { + auto str_field = static_cast(value->value()); + + std::vector str_vec( + flatbuffers::VectorIterator(str_field->value()->Data(), + static_cast(0)), + flatbuffers::VectorIterator(str_field->value()->Data(), + static_cast(str_field->length()))); + + std::string ciphertext(str_vec.begin(), str_vec.end()); + std::string ciphertext_decoded = ciphertext_base64_decode(ciphertext); + + uint8_t *plaintext = new uint8_t[dec_size(ciphertext_decoded.size())]; + decrypt(reinterpret_cast(ciphertext_decoded.data()), ciphertext_decoded.size(), plaintext); + + BufferRefView buf(plaintext, ciphertext_decoded.size()); + buf.verify(); + + const tuix::Rows *rows = buf.root(); + const tuix::Field *field = rows->rows()->Get(0)->field_values()->Get(0); + auto ret = flatbuffers_copy(field, builder); + + delete plaintext; + return ret; + } else { + throw std::runtime_error(std::string("tuix::Decrypt does not accept a NULL string\n")); + } + + } + case tuix::ExprUnion_Cast: { auto cast = static_cast(expr->expr()); @@ -1571,6 +1614,68 @@ class FlatbuffersExpressionEvaluator { result_is_null); } + case tuix::ExprUnion_NormalizeNaNAndZero: + { + auto normalize = static_cast(expr->expr()); + auto child_offset = eval_helper(row, normalize->child()); + + const tuix::Field *value = flatbuffers::GetTemporaryPointer(builder, child_offset); + + if (value->value_type() != tuix::FieldUnion_FloatField && value->value_type() != tuix::FieldUnion_DoubleField) { + throw std::runtime_error( + std::string("tuix::NormalizeNaNAndZero requires type Float or Double, not ") + + std::string(tuix::EnumNameFieldUnion(value->value_type()))); + } + + bool result_is_null = value->is_null(); + + if (value->value_type() == tuix::FieldUnion_FloatField) { + if (!result_is_null) { + float v = value->value_as_FloatField()->value(); + if (isnan(v)) { + v = std::numeric_limits::quiet_NaN(); + } else if (v == -0.0f) { + v = 0.0f; + } + + return tuix::CreateField( + builder, + tuix::FieldUnion_FloatField, + tuix::CreateFloatField(builder, v).Union(), + result_is_null); + } + + return tuix::CreateField( + builder, + tuix::FieldUnion_FloatField, + tuix::CreateFloatField(builder, 0).Union(), + result_is_null); + + } else { + + if (!result_is_null) { + double v = value->value_as_DoubleField()->value(); + if (isnan(v)) { + v = std::numeric_limits::quiet_NaN(); + } else if (v == -0.0d) { + v = 0.0d; + } + + return tuix::CreateField( + builder, + tuix::FieldUnion_DoubleField, + tuix::CreateDoubleField(builder, v).Union(), + result_is_null); + } + + return tuix::CreateField( + builder, + tuix::FieldUnion_DoubleField, + tuix::CreateDoubleField(builder, 0).Union(), + result_is_null); + } + } + default: throw std::runtime_error( std::string("Can't evaluate expression of type ") @@ -1684,58 +1789,103 @@ class FlatbuffersJoinExprEvaluator { const tuix::JoinExpr* join_expr = flatbuffers::GetRoot(buf); join_type = join_expr->join_type(); - if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) { - throw std::runtime_error("Mismatched join key lengths"); - } - for (auto key_it = join_expr->left_keys()->begin(); - key_it != join_expr->left_keys()->end(); ++key_it) { - left_key_evaluators.emplace_back( - std::unique_ptr( - new FlatbuffersExpressionEvaluator(*key_it))); + join_type = join_expr->join_type(); + if (join_expr->condition() != NULL) { + condition_eval = std::unique_ptr( + new FlatbuffersExpressionEvaluator(join_expr->condition())); } - for (auto key_it = join_expr->right_keys()->begin(); - key_it != join_expr->right_keys()->end(); ++key_it) { - right_key_evaluators.emplace_back( - std::unique_ptr( - new FlatbuffersExpressionEvaluator(*key_it))); + is_equi_join = false; + + if (join_expr->left_keys() != NULL && join_expr->right_keys() != NULL) { + is_equi_join = true; + if (join_expr->condition() != NULL) { + throw std::runtime_error("Equi join cannot have condition"); + } + if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) { + throw std::runtime_error("Mismatched join key lengths"); + } + for (auto key_it = join_expr->left_keys()->begin(); + key_it != join_expr->left_keys()->end(); ++key_it) { + left_key_evaluators.emplace_back( + std::unique_ptr( + new FlatbuffersExpressionEvaluator(*key_it))); + } + for (auto key_it = join_expr->right_keys()->begin(); + key_it != join_expr->right_keys()->end(); ++key_it) { + right_key_evaluators.emplace_back( + std::unique_ptr( + new FlatbuffersExpressionEvaluator(*key_it))); + } } } - /** - * Return true if the given row is from the primary table, indicated by its first field, which - * must be an IntegerField. + /** Return true if the given row is from the primary table, indicated by its first field, which + * must be an IntegerField. + * Rows MUST have been tagged in Scala. */ bool is_primary(const tuix::Row *row) { return static_cast( row->field_values()->Get(0)->value())->value() == 0; } - /** Return true if the two rows are from the same join group. */ - bool is_same_group(const tuix::Row *row1, const tuix::Row *row2) { - auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators; - auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators; + /** Returns the row evaluator corresponding to the primary row + * Rows MUST have been tagged in Scala. + */ + const tuix::Row *get_primary_row( + const tuix::Row *row1, const tuix::Row *row2) { + return is_primary(row1) ? row1 : row2; + } + /** Return true if the two rows satisfy the join condition. */ + bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) { builder.Clear(); + bool row1_equals_row2; + + /** Check equality for equi joins. If it is a non-equi join, + * the key evaluators will be empty, so the code never enters the for loop. + */ + auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators; + auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators; for (uint32_t i = 0; i < row1_evaluators.size(); i++) { const tuix::Field *row1_eval_tmp = row1_evaluators[i]->eval(row1); auto row1_eval_offset = flatbuffers_copy(row1_eval_tmp, builder); + auto row1_field = flatbuffers::GetTemporaryPointer(builder, row1_eval_offset); + const tuix::Field *row2_eval_tmp = row2_evaluators[i]->eval(row2); auto row2_eval_offset = flatbuffers_copy(row2_eval_tmp, builder); + auto row2_field = flatbuffers::GetTemporaryPointer(builder, row2_eval_offset); - bool row1_equals_row2 = + flatbuffers::Offset comparison = eval_binary_comparison( + builder, + row1_field, + row2_field); + row1_equals_row2 = static_cast( flatbuffers::GetTemporaryPointer( builder, - eval_binary_comparison( - builder, - flatbuffers::GetTemporaryPointer(builder, row1_eval_offset), - flatbuffers::GetTemporaryPointer(builder, row2_eval_offset))) - ->value())->value(); + comparison)->value())->value(); if (!row1_equals_row2) { return false; } } + + /* Check condition for non-equi joins */ + if (!is_equi_join) { + std::vector> concat_fields; + for (auto field : *row1->field_values()) { + concat_fields.push_back(flatbuffers_copy(field, builder)); + } + for (auto field : *row2->field_values()) { + concat_fields.push_back(flatbuffers_copy(field, builder)); + } + flatbuffers::Offset concat = tuix::CreateRowDirect(builder, &concat_fields); + const tuix::Row *concat_ptr = flatbuffers::GetTemporaryPointer(builder, concat); + + const tuix::Field *condition_result = condition_eval->eval(concat_ptr); + + return static_cast(condition_result->value())->value(); + } return true; } @@ -1748,6 +1898,8 @@ class FlatbuffersJoinExprEvaluator { tuix::JoinType join_type; std::vector> left_key_evaluators; std::vector> right_key_evaluators; + bool is_equi_join; + std::unique_ptr condition_eval; }; class AggregateExpressionEvaluator { diff --git a/src/enclave/Enclave/Join.cpp b/src/enclave/Enclave/NonObliviousSortMergeJoin.cpp similarity index 93% rename from src/enclave/Enclave/Join.cpp rename to src/enclave/Enclave/NonObliviousSortMergeJoin.cpp index 53d9814f00..db9e707718 100644 --- a/src/enclave/Enclave/Join.cpp +++ b/src/enclave/Enclave/NonObliviousSortMergeJoin.cpp @@ -1,4 +1,4 @@ -#include "Join.h" +#include "NonObliviousSortMergeJoin.h" #include "ExpressionEvaluation.h" #include "FlatbuffersReaders.h" @@ -26,7 +26,7 @@ void non_oblivious_sort_merge_join( EnclaveContext::getInstance().set_append_mac(false); // If current row is from primary table if (last_primary_of_group.get() - && join_expr_eval.is_same_group(last_primary_of_group.get(), current)) { + && join_expr_eval.eval_condition(last_primary_of_group.get(), current)) { // Add this primary row to the current group primary_group.append(current); last_primary_of_group.set(current); @@ -52,7 +52,7 @@ void non_oblivious_sort_merge_join( // Current row isn't from primary table // Output the joined rows resulting from this foreign row if (last_primary_of_group.get() - && join_expr_eval.is_same_group(last_primary_of_group.get(), current)) { + && join_expr_eval.eval_condition(last_primary_of_group.get(), current)) { EnclaveContext::getInstance().set_append_mac(false); auto primary_group_buffer = primary_group.output_buffer(std::string("")); RowReader primary_group_reader(primary_group_buffer.view()); @@ -60,7 +60,7 @@ void non_oblivious_sort_merge_join( // For each foreign key row, join all primary key rows in same group with it const tuix::Row *primary = primary_group_reader.next(); - if (!join_expr_eval.is_same_group(primary, current)) { + if (!join_expr_eval.eval_condition(primary, current)) { throw std::runtime_error( std::string("Invariant violation: rows of primary_group " "are not of the same group: ") diff --git a/src/enclave/Enclave/Join.h b/src/enclave/Enclave/NonObliviousSortMergeJoin.h similarity index 85% rename from src/enclave/Enclave/Join.h rename to src/enclave/Enclave/NonObliviousSortMergeJoin.h index b380909027..ef60c38437 100644 --- a/src/enclave/Enclave/Join.h +++ b/src/enclave/Enclave/NonObliviousSortMergeJoin.h @@ -1,12 +1,7 @@ #include #include -#ifndef JOIN_H -#define JOIN_H - void non_oblivious_sort_merge_join( uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, uint8_t **output_rows, size_t *output_rows_length); - -#endif diff --git a/src/enclave/Enclave/util.cpp b/src/enclave/Enclave/util.cpp index 0f13e6af49..6cd2a898b0 100644 --- a/src/enclave/Enclave/util.cpp +++ b/src/enclave/Enclave/util.cpp @@ -142,3 +142,79 @@ int secs_to_tm(long long t, struct tm *tm) { return 0; } + +// Code adapted from https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +/* + Copyright (C) 2004-2008 Rene Nyffenegger + + This source code is provided 'as-is', without any express or implied + warranty. In no event will the author be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this source code must not be misrepresented; you must not + claim that you wrote the original source code. If you use this source code + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original source code. + + 3. This notice may not be removed or altered from any source distribution. + + Rene Nyffenegger rene.nyffenegger@adp-gmbh.ch + +*/ + +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static inline bool is_base64(unsigned char c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +std::string ciphertext_base64_decode(const std::string &encoded_string) { + int in_len = encoded_string.size(); + int i = 0; + int j = 0; + int in_ = 0; + uint8_t char_array_4[4], char_array_3[3]; + std::string ret; + + while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i ==4) { + for (i = 0; i <4; i++) + char_array_4[i] = base64_chars.find(char_array_4[i]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) + ret += char_array_3[i]; + i = 0; + } + } + + if (i) { + for (j = i; j <4; j++) + char_array_4[j] = 0; + + for (j = 0; j <4; j++) + char_array_4[j] = base64_chars.find(char_array_4[j]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) ret += char_array_3[j]; + } + + return ret; +} diff --git a/src/enclave/Enclave/util.h b/src/enclave/Enclave/util.h index b4e0b52327..df80ba7cd0 100644 --- a/src/enclave/Enclave/util.h +++ b/src/enclave/Enclave/util.h @@ -41,4 +41,6 @@ int pow_2(int value); int secs_to_tm(long long t, struct tm *tm); +std::string ciphertext_base64_decode(const std::string &encoded_string); + #endif // UTIL_H diff --git a/src/enclave/ServiceProvider/CMakeLists.txt b/src/enclave/ServiceProvider/CMakeLists.txt index aed31320d6..2047dc15f2 100644 --- a/src/enclave/ServiceProvider/CMakeLists.txt +++ b/src/enclave/ServiceProvider/CMakeLists.txt @@ -12,9 +12,10 @@ set(SOURCES iasrequest.cpp sp_crypto.cpp) -link_directories("$ENV{OE_SDK_PATH}/lib/openenclave/enclave") -include_directories("$ENV{OE_SDK_PATH}/include") -include_directories("$ENV{OE_SDK_PATH}/include/openenclave/3rdparty") +link_directories(${OE_LIBDIR}) +link_directories(${OE_LIBDIR}/openenclave/enclave) +include_directories(${OE_INCLUDEDIR}) +include_directories(${OE_INCLUDEDIR}/openenclave/3rdparty) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -Wno-attributes") set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} ${CMAKE_CXX_FLAGS}") @@ -27,6 +28,6 @@ endif() find_library(CRYPTO_LIB crypto) find_library(SSL_LIB ssl) -target_link_libraries(ra_jni "${CRYPTO_LIB}" "${SSL_LIB}" mbedcrypto mbedtls openenclave::oehost openenclave::oehostverify) +target_link_libraries(ra_jni ${CRYPTO_LIB} ${SSL_LIB} mbedcrypto mbedtls openenclave::oehost) install(TARGETS ra_jni DESTINATION lib) diff --git a/src/enclave/ServiceProvider/sp_crypto.h b/src/enclave/ServiceProvider/sp_crypto.h index 5cf9c1479b..d5323af4ed 100644 --- a/src/enclave/ServiceProvider/sp_crypto.h +++ b/src/enclave/ServiceProvider/sp_crypto.h @@ -42,7 +42,7 @@ #include #include -#include +#include #include "openssl/evp.h" #include "openssl/pem.h" diff --git a/src/flatbuffers/Expr.fbs b/src/flatbuffers/Expr.fbs index a96215b5a2..a1e4d92aeb 100644 --- a/src/flatbuffers/Expr.fbs +++ b/src/flatbuffers/Expr.fbs @@ -36,11 +36,13 @@ union ExprUnion { VectorMultiply, DotProduct, Exp, + NormalizeNaNAndZero, ClosestPoint, CreateArray, Upper, DateAdd, - DateAddInterval + DateAddInterval, + Decrypt } table Expr { @@ -198,6 +200,10 @@ table CreateArray { children:[Expr]; } +table NormalizeNaNAndZero { + child:Expr; +} + // Opaque UDFs table VectorAdd { left:Expr; @@ -221,4 +227,8 @@ table ClosestPoint { table Upper { child:Expr; -} \ No newline at end of file +} + +table Decrypt { + value:Expr; +} diff --git a/src/flatbuffers/operators.fbs b/src/flatbuffers/operators.fbs index 1ebd06c971..9fa82b6cab 100644 --- a/src/flatbuffers/operators.fbs +++ b/src/flatbuffers/operators.fbs @@ -54,10 +54,11 @@ enum JoinType : ubyte { } table JoinExpr { join_type:JoinType; - // Currently only cross joins and equijoins are supported, so we store - // parallel arrays of key expressions and the join outputs pairs of rows - // where each expression from the left is equal to the matching expression - // from the right. + // In the case of equi joins, we store parallel arrays of key expressions and have the join output + // pairs of rows where each expression from the left is equal to the matching expression from the right. left_keys:[Expr]; right_keys:[Expr]; + // In the case of non-equi joins, we pass in a condition as an expression and evaluate that on each pair of rows. + // TODO: have equi joins use this condition rather than an additional filter operation. + condition:Expr; } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala index dc1cbae97b..991aa99c1e 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala @@ -161,7 +161,8 @@ object JobVerificationEngine { "EncryptedFilter", "EncryptedAggregate", "EncryptedGlobalLimit", - "EncryptedLocalLimit") + "EncryptedLocalLimit", + "EncryptedBroadcastNestedLoopJoin") def addLogEntryChain(logEntryChain: tuix.LogEntryChain): Unit = { logEntryChains += logEntryChain @@ -340,6 +341,10 @@ object JobVerificationEngine { for (i <- 0 until numPartitions) { parentEcalls(i).addOutgoingNeighbor(childEcalls(i)) } + } else if (ecall == 14) { + for (i <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(i)) + } } else { throw new Exception("Job Verification Error creating expected DAG: " + "ecall not supported - " + ecall) @@ -374,6 +379,9 @@ object JobVerificationEngine { } else if (operatorName == "EncryptedGlobalLimit") { // ("countRowsPerPartition", "computeNumRowsPerPartition", "limitReturnRows") expectedEcalls.append(10, 11, 13) + } else if (operatorName == "EncryptedBroadcastNestedLoopJoin") { + // ("broadcastNestedLoopJoin") + expectedEcalls.append(14) } else { throw new Exception("Executed unknown operator: " + operatorName) } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index fd0796ac5b..43a126213f 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -21,7 +21,9 @@ import java.io.File import java.io.FileNotFoundException import java.nio.ByteBuffer import java.nio.ByteOrder +import java.nio.charset.StandardCharsets; import java.security.SecureRandom +import java.util.Base64 import java.util.UUID import javax.crypto._ @@ -60,6 +62,7 @@ import org.apache.spark.sql.catalyst.expressions.If import org.apache.spark.sql.catalyst.expressions.In import org.apache.spark.sql.catalyst.expressions.IsNotNull import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized import org.apache.spark.sql.catalyst.expressions.LessThan import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual import org.apache.spark.sql.catalyst.expressions.Literal @@ -90,9 +93,12 @@ import org.apache.spark.sql.catalyst.plans.NaturalJoin import org.apache.spark.sql.catalyst.plans.RightOuter import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.util.MapData +import org.apache.spark.sql.execution.SubqueryExec +import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -103,6 +109,7 @@ import edu.berkeley.cs.rise.opaque.execution.Block import edu.berkeley.cs.rise.opaque.execution.OpaqueOperatorExec import edu.berkeley.cs.rise.opaque.execution.SGXEnclave import edu.berkeley.cs.rise.opaque.expressions.ClosestPoint +import edu.berkeley.cs.rise.opaque.expressions.Decrypt import edu.berkeley.cs.rise.opaque.expressions.DotProduct import edu.berkeley.cs.rise.opaque.expressions.VectorAdd import edu.berkeley.cs.rise.opaque.expressions.VectorMultiply @@ -590,6 +597,7 @@ object Utils extends Logging { tuix.StringField.createValueVector(builder, Array.empty), 0), isNull) + case _ => throw new OpaqueException(s"FlatbuffersCreateField failed to match on ${value} of type {value.getClass.getName()}, ${dataType}") } } @@ -664,6 +672,50 @@ object Utils extends Logging { val MaxBlockSize = 1000 + /** + * Encrypts/decrypts a given scalar value + **/ + def encryptScalar(value: Any, dataType: DataType): String = { + // First serialize the scalar value + var builder = new FlatBufferBuilder + var rowOffsets = ArrayBuilder.make[Int] + + val v = dataType match { + case StringType => UTF8String.fromString(value.asInstanceOf[String]) + case _ => value + } + + val isNull = (value == null) + + // TODO: the NULL variable for field value could be set to true + builder.finish( + tuix.Rows.createRows( + builder, + tuix.Rows.createRowsVector( + builder, + Array(tuix.Row.createRow( + builder, + tuix.Row.createFieldValuesVector( + builder, + Array(flatbuffersCreateField(builder, v, dataType, false))), + isNull))))) + + val plaintext = builder.sizedByteArray() + val ciphertext = encrypt(plaintext) + val ciphertext_str = Base64.getEncoder().encodeToString(ciphertext); + ciphertext_str + } + + def decryptScalar(ciphertext: String): Any = { + val ciphertext_bytes = Base64.getDecoder().decode(ciphertext); + val plaintext = decrypt(ciphertext_bytes) + val rows = tuix.Rows.getRootAsRows(ByteBuffer.wrap(plaintext)) + val row = rows.rows(0) + val field = row.fieldValues(0) + val value = flatbuffersExtractFieldValue(field) + value + } + /** * Encrypts the given Spark SQL [[InternalRow]]s into a [[Block]] (a serialized * tuix.EncryptedBlocks). @@ -842,6 +894,13 @@ object Utils extends Logging { tuix.ExprUnion.Literal, tuix.Literal.createLiteral(builder, valueOffset)) + // This expression should never be evaluated on the driver + case (Decrypt(child, dataType), Seq(childOffset)) => + tuix.Expr.createExpr( + builder, + tuix.ExprUnion.Decrypt, + tuix.Decrypt.createDecrypt(builder, childOffset)) + case (Alias(child, _), Seq(childOffset)) => // TODO: Use an expression for aliases so we can refer to them elsewhere in the expression // tree. For now we just ignore them when evaluating expressions. @@ -1132,6 +1191,45 @@ object Utils extends Logging { // TODO: Implement decimal serialization, followed by CheckOverflow childOffset + case (NormalizeNaNAndZero(child), Seq(childOffset)) => + tuix.Expr.createExpr( + builder, + tuix.ExprUnion.NormalizeNaNAndZero, + tuix.NormalizeNaNAndZero.createNormalizeNaNAndZero(builder, childOffset)) + + case (KnownFloatingPointNormalized(NormalizeNaNAndZero(child)), Seq(childOffset)) => + flatbuffersSerializeExpression(builder, NormalizeNaNAndZero(child), input) + + case (ScalarSubquery(SubqueryExec(name, child), exprId), Seq()) => + val output = child.output(0) + val dataType = output match { + case AttributeReference(name, dataType, _, _) => dataType + case _ => throw new OpaqueException("Scalar subquery cannot match to AttributeReference") + } + // Need to deserialize the encrypted blocks to get the encrypted block + val blockList = child.asInstanceOf[OpaqueOperatorExec].collectEncrypted() + val encryptedBlocksList = blockList.map { block => + val buf = ByteBuffer.wrap(block.bytes) + tuix.EncryptedBlocks.getRootAsEncryptedBlocks(buf) + } + val encryptedBlocks = encryptedBlocksList.find(_.blocksLength > 0).getOrElse(encryptedBlocksList(0)) + if (encryptedBlocks.blocksLength == 0) { + // If empty, the returned result is null + flatbuffersSerializeExpression(builder, Literal(null, dataType), input) + } else { + assert(encryptedBlocks.blocksLength == 1) + val encryptedBlock = encryptedBlocks.blocks(0) + val ciphertextBuf = encryptedBlock.encRowsAsByteBuffer + val ciphertext = new Array[Byte](ciphertextBuf.remaining) + ciphertextBuf.get(ciphertext) + val ciphertext_str = Base64.getEncoder().encodeToString(ciphertext) + flatbuffersSerializeExpression( + builder, + Decrypt(Literal(UTF8String.fromString(ciphertext_str), StringType), dataType), + input + ) + } + case (_, Seq(childOffset)) => throw new OpaqueException("Expression not supported: " + expr.toString()) } @@ -1179,8 +1277,9 @@ object Utils extends Logging { } def serializeJoinExpression( - joinType: JoinType, leftKeys: Seq[Expression], rightKeys: Seq[Expression], - leftSchema: Seq[Attribute], rightSchema: Seq[Attribute]): Array[Byte] = { + joinType: JoinType, leftKeys: Option[Seq[Expression]], rightKeys: Option[Seq[Expression]], + leftSchema: Seq[Attribute], rightSchema: Seq[Attribute], + condition: Option[Expression] = None): Array[Byte] = { val builder = new FlatBufferBuilder builder.finish( tuix.JoinExpr.createJoinExpr( @@ -1199,12 +1298,28 @@ object Utils extends Logging { case UsingJoin(_, _) => ??? // scalastyle:on }, - tuix.JoinExpr.createLeftKeysVector( - builder, - leftKeys.map(e => flatbuffersSerializeExpression(builder, e, leftSchema)).toArray), - tuix.JoinExpr.createRightKeysVector( - builder, - rightKeys.map(e => flatbuffersSerializeExpression(builder, e, rightSchema)).toArray))) + // Non-zero when equi join + leftKeys match { + case Some(leftKeys) => + tuix.JoinExpr.createLeftKeysVector( + builder, + leftKeys.map(e => flatbuffersSerializeExpression(builder, e, leftSchema)).toArray) + case None => 0 + }, + // Non-zero when equi join + rightKeys match { + case Some(rightKeys) => + tuix.JoinExpr.createRightKeysVector( + builder, + rightKeys.map(e => flatbuffersSerializeExpression(builder, e, rightSchema)).toArray) + case None => 0 + }, + // Non-zero when non-equi join + condition match { + case Some(condition) => + flatbuffersSerializeExpression(builder, condition, leftSchema ++ rightSchema) + case _ => 0 + })) builder.sizedByteArray() } @@ -1304,8 +1419,7 @@ object Utils extends Logging { updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray), tuix.AggregateExpr.createEvaluateExprsVector( builder, - evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray) - ) + evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)) case c @ Count(children) => val count = c.aggBufferAttributes(0) @@ -1343,8 +1457,7 @@ object Utils extends Logging { updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray), tuix.AggregateExpr.createEvaluateExprsVector( builder, - evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray) - ) + evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)) case f @ First(child, false) => val first = f.aggBufferAttributes(0) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/Benchmark.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/Benchmark.scala index b46a94d00c..13c4d288a3 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/Benchmark.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/Benchmark.scala @@ -24,11 +24,33 @@ import org.apache.spark.sql.SparkSession * Convenient runner for benchmarks. * * To run locally, use - * `$OPAQUE_HOME/build/sbt 'run edu.berkeley.cs.rise.opaque.benchmark.Benchmark'`. + * `$OPAQUE_HOME/build/sbt 'run edu.berkeley.cs.rise.opaque.benchmark.Benchmark '`. + * Available flags: + * --num-partitions: specify the number of partitions the data should be split into. + * Default: 2 * number of executors if exists, 4 otherwise + * --size: specify the size of the dataset that should be loaded into Spark. + * Default: sf_small + * --operations: select the different operations that should be benchmarked. + * Default: all + * Available operations: logistic-regression, tpc-h + * Syntax: --operations "logistic-regression,tpc-h" + * --run-local: boolean whether to use HDFS or the local filesystem + * Default: HDFS + * Leave --operations flag blank to run all benchmarks * * To run on a cluster, use `$SPARK_HOME/bin/spark-submit` with appropriate arguments. */ object Benchmark { + + val spark = SparkSession.builder() + .appName("Benchmark") + .getOrCreate() + var numPartitions = spark.sparkContext.defaultParallelism + var size = "sf_med" + + // Configure your HDFS namenode url here + var fileUrl = "hdfs://10.0.3.4:8020" + def dataDir: String = { if (System.getenv("SPARKSGX_DATA_DIR") == null) { throw new Exception("Set SPARKSGX_DATA_DIR") @@ -36,15 +58,9 @@ object Benchmark { System.getenv("SPARKSGX_DATA_DIR") } - def main(args: Array[String]): Unit = { - val spark = SparkSession.builder() - .appName("QEDBenchmark") - .getOrCreate() - Utils.initSQLContext(spark.sqlContext) - - // val numPartitions = - // if (spark.sparkContext.isLocal) 1 else spark.sparkContext.defaultParallelism - + def logisticRegression() = { + // TODO: this fails when Spark is ran on a cluster + /* // Warmup LogisticRegression.train(spark, Encrypted, 1000, 1) LogisticRegression.train(spark, Encrypted, 1000, 1) @@ -52,7 +68,73 @@ object Benchmark { // Run LogisticRegression.train(spark, Insecure, 100000, 1) LogisticRegression.train(spark, Encrypted, 100000, 1) + */ + } + def runAll() = { + logisticRegression() + TPCHBenchmark.run(spark.sqlContext, numPartitions, size, fileUrl) + } + + def main(args: Array[String]): Unit = { + Utils.initSQLContext(spark.sqlContext) + + if (args.length >= 2 && args(1) == "--help") { + println( +"""Available flags: + --num-partitions: specify the number of partitions the data should be split into. + Default: 2 * number of executors if exists, 4 otherwise + --size: specify the size of the dataset that should be loaded into Spark. + Default: sf_small + --operations: select the different operations that should be benchmarked. + Default: all + Available operations: logistic-regression, tpc-h + Syntax: --operations "logistic-regression,tpc-h" + Leave --operations flag blank to run all benchmarks + --run-local: boolean whether to use HDFS or the local filesystem + Default: HDFS""" + ) + } + + var runAll = true + args.slice(1, args.length).sliding(2, 2).toList.collect { + case Array("--num-partitions", numPartitions: String) => { + this.numPartitions = numPartitions.toInt + } + case Array("--size", size: String) => { + val supportedSizes = Set("sf_small, sf_med") + if (supportedSizes.contains(size)) { + this.size = size + } else { + println("Given size is not supported: available values are " + supportedSizes.toString()) + } + } + case Array("--run-local", runLocal: String) => { + runLocal match { + case "true" => { + fileUrl = "file://" + } + case _ => {} + } + } + case Array("--operations", operations: String) => { + runAll = false + val operationsArr = operations.split(",").map(_.trim) + for (operation <- operationsArr) { + operation match { + case "logistic-regression" => { + logisticRegression() + } + case "tpc-h" => { + TPCHBenchmark.run(spark.sqlContext, numPartitions, size, fileUrl) + } + } + } + } + } + if (runAll) { + this.runAll(); + } spark.stop() } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala index e0bb4d4caf..ee905026c8 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala @@ -17,6 +17,7 @@ package edu.berkeley.cs.rise.opaque.benchmark +import java.io.File import scala.io.Source import org.apache.spark.sql.DataFrame @@ -162,7 +163,7 @@ object TPCH { .option("delimiter", "|") .load(s"${Benchmark.dataDir}/tpch/$size/customer.tbl") - def generateMap( + def generateDFs( sqlContext: SQLContext, size: String) : Map[String, DataFrame] = { Map("part" -> part(sqlContext, size), @@ -175,42 +176,73 @@ object TPCH { "customer" -> customer(sqlContext, size) ), } - - def apply(sqlContext: SQLContext, size: String) : TPCH = { - val tpch = new TPCH(sqlContext, size) - tpch.tableNames = tableNames - tpch.nameToDF = generateMap(sqlContext, size) - tpch.ensureCached() - tpch - } } -class TPCH(val sqlContext: SQLContext, val size: String) { +class TPCH(val sqlContext: SQLContext, val size: String, val fileUrl: String) { - var tableNames : Seq[String] = Seq() - var nameToDF : Map[String, DataFrame] = Map() + val tableNames = TPCH.tableNames + val nameToDF = TPCH.generateDFs(sqlContext, size) - def ensureCached() = { - for (name <- tableNames) { - nameToDF.get(name).foreach(df => { - Utils.ensureCached(df) - Utils.ensureCached(Encrypted.applyTo(df)) - }) - } + private var numPartitions: Int = -1 + private var nameToPath = Map[String, File]() + private var nameToEncryptedPath = Map[String, File]() + + def getQuery(queryNumber: Int) : String = { + val queryLocation = sys.env.getOrElse("OPAQUE_HOME", ".") + "/src/test/resources/tpch/" + Source.fromFile(queryLocation + s"q$queryNumber.sql").getLines().mkString("\n") } - def setupViews(securityLevel: SecurityLevel, numPartitions: Int) = { - for ((name, df) <- nameToDF) { - securityLevel.applyTo(df.repartition(numPartitions)).createOrReplaceTempView(name) + def generateFiles(numPartitions: Int) = { + if (numPartitions != this.numPartitions) { + this.numPartitions = numPartitions + for ((name, df) <- nameToDF) { + nameToPath.get(name).foreach{ path => Utils.deleteRecursively(path) } + + nameToPath += (name -> createPath(df, Insecure, numPartitions)) + nameToEncryptedPath += (name -> createPath(df, Encrypted, numPartitions)) + } } } - def query(queryNumber: Int, securityLevel: SecurityLevel, sqlContext: SQLContext, numPartitions: Int) : DataFrame = { - setupViews(securityLevel, numPartitions) + private def createPath(df: DataFrame, securityLevel: SecurityLevel, numPartitions: Int): File = { + val partitionedDF = securityLevel.applyTo(df.repartition(numPartitions)) + val path = Utils.createTempDir() + path.delete() + securityLevel match { + case Insecure => { + partitionedDF.write.format("com.databricks.spark.csv") + .option("ignoreLeadingWhiteSpace", false) + .option("ignoreTrailingWhiteSpace", false) + .save(fileUrl + path.toString) + } + case Encrypted => { + partitionedDF.write.format("edu.berkeley.cs.rise.opaque.EncryptedSource").save(fileUrl + path.toString) + } + } + path + } - val queryLocation = sys.env.getOrElse("OPAQUE_HOME", ".") + "/src/test/resources/tpch/" - val sqlStr = Source.fromFile(queryLocation + s"q$queryNumber.sql").getLines().mkString("\n") + private def loadViews(securityLevel: SecurityLevel) = { + val (map, formatStr) = if (securityLevel == Insecure) + (nameToPath, "com.databricks.spark.csv") else + (nameToEncryptedPath, "edu.berkeley.cs.rise.opaque.EncryptedSource") + for ((name, path) <- map) { + val df = sqlContext.sparkSession.read + .format(formatStr) + .schema(nameToDF.get(name).get.schema) + .load(fileUrl + path.toString) + df.createOrReplaceTempView(name) + } + } + def performQuery(sqlStr: String, securityLevel: SecurityLevel): DataFrame = { + loadViews(securityLevel) sqlContext.sparkSession.sql(sqlStr) } + + def query(queryNumber: Int, securityLevel: SecurityLevel, numPartitions: Int): DataFrame = { + val sqlStr = getQuery(queryNumber) + generateFiles(numPartitions) + performQuery(sqlStr, securityLevel) + } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala new file mode 100644 index 0000000000..c235265624 --- /dev/null +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package edu.berkeley.cs.rise.opaque.benchmark + +import edu.berkeley.cs.rise.opaque.Utils + +import org.apache.spark.sql.SQLContext + +object TPCHBenchmark { + + // Add query numbers here once they are supported + val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19, 20, 22) + + def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = { + val sqlStr = tpch.getQuery(queryNumber) + tpch.generateFiles(numPartitions) + + Utils.timeBenchmark( + "distributed" -> (numPartitions > 1), + "query" -> s"TPC-H $queryNumber", + "system" -> Insecure.name) { + + tpch.performQuery(sqlStr, Insecure).collect + } + + Utils.timeBenchmark( + "distributed" -> (numPartitions > 1), + "query" -> s"TPC-H $queryNumber", + "system" -> Encrypted.name) { + + tpch.performQuery(sqlStr, Encrypted).collect + } + } + + def run(sqlContext: SQLContext, numPartitions: Int, size: String, fileUrl: String) = { + val tpch = new TPCH(sqlContext, size, fileUrl) + + for (queryNumber <- supportedQueries) { + query(queryNumber, tpch, sqlContext, numPartitions) + } + } +} diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala index b49090ced1..e1f1d31261 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala @@ -42,6 +42,9 @@ class SGXEnclave extends java.io.Serializable { @native def NonObliviousSortMergeJoin( eid: Long, joinExpr: Array[Byte], input: Array[Byte]): Array[Byte] + @native def BroadcastNestedLoopJoin( + eid: Long, joinExpr: Array[Byte], outerBlock: Array[Byte], innerBlock: Array[Byte]): Array[Byte] + @native def NonObliviousAggregate( eid: Long, aggOp: Array[Byte], inputRows: Array[Byte], isPartial: Boolean): (Array[Byte]) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index 0497b3cf2a..5aa1173c2c 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -26,12 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftAnti -import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.SparkPlan +import edu.berkeley.cs.rise.opaque.OpaqueException trait LeafExecNode extends SparkPlan { override final def children: Seq[SparkPlan] = Nil @@ -134,15 +133,20 @@ trait OpaqueOperatorExec extends SparkPlan { * method and persist the resulting RDD. [[ConvertToOpaqueOperators]] later eliminates the dummy * relation from the logical plan, but this only happens after InMemoryRelation has called this * method. We therefore have to silently return an empty RDD here. - */ + */ + override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.emptyRDD // throw new UnsupportedOperationException("use executeBlocked") } + def collectEncrypted(): Array[Block] = { + executeBlocked().collect + } + override def executeCollect(): Array[InternalRow] = { - val collectedRDD = executeBlocked().collect() + val collectedRDD = collectEncrypted() collectedRDD.map { block => Utils.addBlockForVerification(block) } @@ -300,7 +304,7 @@ case class EncryptedSortMergeJoinExec( override def executeBlocked(): RDD[Block] = { val joinExprSer = Utils.serializeJoinExpression( - joinType, leftKeys, rightKeys, leftSchema, rightSchema) + joinType, Some(leftKeys), Some(rightKeys), leftSchema, rightSchema) timeOperator( child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), @@ -314,6 +318,69 @@ case class EncryptedSortMergeJoinExec( } } +case class EncryptedBroadcastNestedLoopJoinExec( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) + extends BinaryExecNode with OpaqueOperatorExec { + + override def output: Seq[Attribute] = { + joinType match { + case _: InnerLike => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case j: ExistenceJoin => + left.output :+ j.exists + case LeftExistence(_) => + left.output + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") + } + } + + override def executeBlocked(): RDD[Block] = { + val joinExprSer = Utils.serializeJoinExpression( + joinType, None, None, left.output, right.output, condition) + + val leftRDD = left.asInstanceOf[OpaqueOperatorExec].executeBlocked() + val rightRDD = right.asInstanceOf[OpaqueOperatorExec].executeBlocked() + + joinType match { + case LeftExistence(_) => { + join(leftRDD, rightRDD, joinExprSer) + } + case _ => + throw new OpaqueException(s"$joinType JoinType is not yet supported") + } + } + + def join(leftRDD: RDD[Block], rightRDD: RDD[Block], + joinExprSer: Array[Byte]): RDD[Block] = { + // We pick which side to broadcast/stream according to buildSide. + // BuildRight means the right relation <=> the broadcast relation. + // NOTE: outer_rows and inner_rows in C++ correspond to stream and broadcast side respectively. + var (streamRDD, broadcastRDD) = buildSide match { + case BuildRight => + (leftRDD, rightRDD) + case BuildLeft => + (rightRDD, leftRDD) + } + val broadcast = Utils.concatEncryptedBlocks(broadcastRDD.collect) + streamRDD.map { block => + val (enclave, eid) = Utils.initEnclave() + Block(enclave.BroadcastNestedLoopJoin(eid, joinExprSer, block.bytes, broadcast.bytes)) + } + } +} + case class EncryptedUnionExec( left: SparkPlan, right: SparkPlan) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/ClosestPoint.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/ClosestPoint.scala index b4f1e27200..7eac3c990c 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/ClosestPoint.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/ClosestPoint.scala @@ -29,9 +29,6 @@ object ClosestPoint { * point - list of coordinates representing a point * centroids - list of lists of coordinates, each representing a point """) -/** - * - */ case class ClosestPoint(left: Expression, right: Expression) extends BinaryExpression with NullIntolerant with CodegenFallback with ExpectsInputTypes { diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/Decrypt.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/Decrypt.scala new file mode 100644 index 0000000000..a52ecb113e --- /dev/null +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/Decrypt.scala @@ -0,0 +1,49 @@ +package edu.berkeley.cs.rise.opaque.expressions + +import edu.berkeley.cs.rise.opaque.Utils + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExpressionDescription +import org.apache.spark.sql.catalyst.expressions.NullIntolerant +import org.apache.spark.sql.catalyst.expressions.Nondeterministic +import org.apache.spark.sql.catalyst.expressions.UnaryExpression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.StringType +import org.apache.spark.unsafe.types.UTF8String + +object Decrypt { + def decrypt(v: Column, dataType: DataType): Column = new Column(Decrypt(v.expr, dataType)) +} + +@ExpressionDescription( + usage = """ + _FUNC_(child, outputDataType) - Decrypt the input evaluated expression, which should always be a string + """, + arguments = """ + Arguments: + * child - an encrypted literal of string type + * outputDataType - the decrypted data type + """) +case class Decrypt(child: Expression, outputDataType: DataType) + extends UnaryExpression with NullIntolerant with CodegenFallback with Nondeterministic { + + override def dataType: DataType = outputDataType + + protected def initializeInternal(partitionIndex: Int): Unit = { } + + protected override def evalInternal(input: InternalRow): Any = { + val v = child.eval() + nullSafeEval(v) + } + + protected override def nullSafeEval(input: Any): Any = { + // This function is implemented so that we can test against Spark; + // should never be used in production because we want to keep the literal encrypted + val v = input.asInstanceOf[UTF8String].toString + Utils.decryptScalar(v) + } +} diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index 0c8f188369..dd104d2ad2 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -32,13 +32,19 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.InnerLike import org.apache.spark.sql.catalyst.plans.LeftAnti import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.SparkPlan import edu.berkeley.cs.rise.opaque.execution._ import edu.berkeley.cs.rise.opaque.logical._ +import org.apache.spark.sql.catalyst.plans.LeftExistence object OpaqueOperators extends Strategy { @@ -73,6 +79,7 @@ object OpaqueOperators extends Strategy { case Sort(sortExprs, global, child) if isEncrypted(child) => EncryptedSortExec(sortExprs, global, planLater(child)) :: Nil + // Used to match equi joins case p @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if isEncrypted(p) => val (leftProjSchema, leftKeysProj, tag) = tagForJoin(leftKeys, left.output, true) val (rightProjSchema, rightKeysProj, _) = tagForJoin(rightKeys, right.output, false) @@ -105,6 +112,26 @@ object OpaqueOperators extends Strategy { filtered :: Nil + // Used to match non-equi joins + case Join(left, right, joinType, condition, hint) if isEncrypted(left) && isEncrypted(right) => + // How to pick broadcast side: if left join, broadcast right. If right join, broadcast left. + // This is the simplest and most performant method, but may be worth revisting if one side is + // significantly smaller than the other. Otherwise, pick the smallest side to broadcast. + // NOTE: the current implementation of BNLJ only works under the assumption that + // left join <==> broadcast right AND right join <==> broadcast left. + val desiredBuildSide = if (joinType.isInstanceOf[InnerLike] || joinType == FullOuter) + getSmallerSide(left, right) else + getBroadcastSideBNLJ(joinType) + + val joined = EncryptedBroadcastNestedLoopJoinExec( + planLater(left), + planLater(right), + desiredBuildSide, + joinType, + condition) + + joined :: Nil + case a @ PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) => @@ -183,17 +210,29 @@ object OpaqueOperators extends Strategy { (Seq(tag) ++ keysProj ++ input, keysProj.map(_.toAttribute), tag.toAttribute) } - private def sortForJoin( - leftKeys: Seq[Expression], tag: Expression, input: Seq[Attribute]): Seq[SortOrder] = - leftKeys.map(k => SortOrder(k, Ascending)) :+ SortOrder(tag, Ascending) - private def dropTags( leftOutput: Seq[Attribute], rightOutput: Seq[Attribute]): Seq[NamedExpression] = leftOutput ++ rightOutput + private def sortForJoin( + leftKeys: Seq[Expression], tag: Expression, input: Seq[Attribute]): Seq[SortOrder] = + leftKeys.map(k => SortOrder(k, Ascending)) :+ SortOrder(tag, Ascending) + private def tagForGlobalAggregate(input: Seq[Attribute]) : (Seq[NamedExpression], NamedExpression) = { val tag = Alias(Literal(0), "_tag")() (Seq(tag) ++ input, tag.toAttribute) } + + private def getBroadcastSideBNLJ(joinType: JoinType): BuildSide = { + joinType match { + case LeftExistence(_) => BuildRight + case _ => BuildLeft + } + } + + // Everything below is a private method in SparkStrategies.scala + private def getSmallerSide(left: LogicalPlan, right: LogicalPlan): BuildSide = { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + } } diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index 0aa55d3138..4f5119d50e 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -35,6 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval import edu.berkeley.cs.rise.opaque.benchmark._ import edu.berkeley.cs.rise.opaque.execution.EncryptedBlockRDDScanExec +import edu.berkeley.cs.rise.opaque.expressions.Decrypt.decrypt import edu.berkeley.cs.rise.opaque.expressions.DotProduct.dot import edu.berkeley.cs.rise.opaque.expressions.VectorMultiply.vectormultiply import edu.berkeley.cs.rise.opaque.expressions.VectorSum @@ -332,7 +333,25 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1", $"id1") - integrityCollect(df) + df.collect + } + + testAgainstSpark("non-equi left semi join") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) + val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") + val df = p.join(f, $"join_col_1" >= $"join_col_2", "left_semi").sort($"join_col_1", $"id1") + df.collect + } + + testAgainstSpark("non-equi left semi join negated") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) + val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") + val df = p.join(f, $"join_col_1" < $"join_col_2", "left_semi").sort($"join_col_1", $"id1") + df.collect } testAgainstSpark("left anti join 1") { securityLevel => @@ -341,7 +360,25 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") val df = p.join(f, $"join_col_1" === $"join_col_2", "left_anti").sort($"join_col_1", $"id") - integrityCollect(df) + df.collect + } + + testAgainstSpark("non-equi left anti join 1") { securityLevel => + val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) + val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" >= $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + + testAgainstSpark("non-equi left anti join 1 negated") { securityLevel => + val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) + val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" < $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect } testAgainstSpark("left anti join 2") { securityLevel => @@ -350,7 +387,55 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") val df = p.join(f, $"join_col_1" === $"join_col_2", "left_anti").sort($"join_col_1", $"id") - integrityCollect(df) + df.collect + } + + testAgainstSpark("non-equi left anti join 2") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" >= $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + + testAgainstSpark("non-equi left anti join 2 negated") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" < $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + + testAgainstSpark("join on floats") { securityLevel => + val p_data = for (i <- 0 to 16) yield (i, i.toFloat, i * 10) + val f_data = (0 until 256).map(x => { + if (x % 3 == 0) + (x, null.asInstanceOf[Float], x * 10) + else + (x, (x % 16).asInstanceOf[Float], x * 10) + }).toSeq + + val p = makeDF(p_data, securityLevel, "id", "pk", "x") + val f = makeDF(f_data, securityLevel, "id", "fk", "x") + val df = p.join(f, $"pk" === $"fk") + df.collect.toSet + } + + testAgainstSpark("join on doubles") { securityLevel => + val p_data = for (i <- 0 to 16) yield (i, i.toDouble, i * 10) + val f_data = (0 until 256).map(x => { + if (x % 3 == 0) + (x, null.asInstanceOf[Double], x * 10) + else + (x, (x % 16).asInstanceOf[Double], x * 10) + }).toSeq + + val p = makeDF(p_data, securityLevel, "id", "pk", "x") + val f = makeDF(f_data, securityLevel, "id", "fk", "x") + val df = p.join(f, $"pk" === $"fk") + df.collect.toSet } def abc(i: Int): String = (i % 3) match { @@ -525,54 +610,6 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => integrityCollect(result) } - testAgainstSpark("concat with string") { securityLevel => - val data = for (i <- 0 until 256) yield ("%03d".format(i) * 3, i.toString) - val df = makeDF(data, securityLevel, "str", "x") - df.select(concat(col("str"),lit(","),col("x"))).collect - } - - testAgainstSpark("concat with other datatype") { securityLevel => - // float causes a formating issue where opaque outputs 1.000000 and spark produces 1.0 so the following line is commented out - // val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, 1.0f) - // you can't serialize date so that's not supported as well - // opaque doesn't support byte - val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, null.asInstanceOf[Int],"") - val df = makeDF(data, securityLevel, "str", "int","null","emptystring") - df.select(concat(col("str"),lit(","),col("int"),col("null"),col("emptystring"))).collect - } - - testAgainstSpark("isin1") { securityLevel => - val ids = Seq((1, 2, 2), (2, 3, 1)) - val df = makeDF(ids, securityLevel, "x", "y", "id") - val c = $"id" isin ($"x", $"y") - val result = df.filter(c) - result.collect - } - - testAgainstSpark("isin2") { securityLevel => - val ids2 = Seq((1, 1, 1), (2, 2, 2), (3,3,3), (4,4,4)) - val df2 = makeDF(ids2, securityLevel, "x", "y", "id") - val c2 = $"id" isin (1 ,2, 4, 5, 6) - val result = df2.filter(c2) - result.collect - } - - testAgainstSpark("isin with string") { securityLevel => - val ids3 = Seq(("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), ("b", "b", "b"), ("c","c","c"), ("d","d","d")) - val df3 = makeDF(ids3, securityLevel, "x", "y", "id") - val c3 = $"id" isin ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" ,"b", "c", "d", "e") - val result = df3.filter(c3) - result.collect - } - - testAgainstSpark("isin with null") { securityLevel => - val ids4 = Seq((1, 1, 1), (2, 2, 2), (3,3,null.asInstanceOf[Int]), (4,4,4)) - val df4 = makeDF(ids4, securityLevel, "x", "y", "id") - val c4 = $"id" isin (null.asInstanceOf[Int]) - val result = df4.filter(c4) - result.collect - } - testAgainstSpark("between") { securityLevel => val data = for (i <- 0 until 256) yield(i.toString, i) val df = makeDF(data, securityLevel, "word", "count") @@ -936,6 +973,30 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => KMeans.train(spark, securityLevel, numPartitions, 10, 2, 3, 0.01).map(_.toSeq).sorted } + testAgainstSpark("encrypted literal") { securityLevel => + val input = 10 + val enc_str = Utils.encryptScalar(input, IntegerType) + + val data = for (i <- 0 until 256) yield (i, abc(i), 1) + val words = makeDF(data, securityLevel, "id", "word", "count") + val df = words.filter($"id" < decrypt(lit(enc_str), IntegerType)).sort($"id") + df.collect + } + + testAgainstSpark("scalar subquery") { securityLevel => + // Example taken from https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/2728434780191932/1483312212640900/6987336228780374/latest.html + val data = for (i <- 0 until 256) yield (i, abc(i), i) + val words = makeDF(data, securityLevel, "id", "word", "count") + words.createTempView("words") + + try { + val df = spark.sql("""SELECT id, word, (SELECT MAX(count) FROM words) max_age FROM words ORDER BY id, word""") + df.collect + } finally { + spark.catalog.dropTempView("words") + } + } + testAgainstSpark("pagerank") { securityLevel => integrityCollect(PageRank.run(spark, securityLevel, "256", numPartitions)).toSet } diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala index 8117fb8de1..54ded162bc 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala @@ -68,7 +68,7 @@ trait OpaqueTestsBase extends FunSuite with BeforeAndAfterAll { self => testFunc(name + " - encrypted") { // The === operator uses implicitly[Equality[A]], which compares Double and Array[Double] // using the numeric tolerance specified above - assert(f(Encrypted) === f(Insecure)) + assert(f(Insecure) === f(Encrypted)) } } @@ -102,4 +102,4 @@ trait OpaqueTestsBase extends FunSuite with BeforeAndAfterAll { self => } } } -} \ No newline at end of file +} diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala index 8b68e69be2..fec76426cb 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala @@ -21,99 +21,19 @@ package edu.berkeley.cs.rise.opaque import org.apache.spark.sql.SparkSession import edu.berkeley.cs.rise.opaque.benchmark._ -import edu.berkeley.cs.rise.opaque.benchmark.TPCH trait TPCHTests extends OpaqueTestsBase { self => def size = "sf_small" - def tpch = TPCH(spark.sqlContext, size) - - testAgainstSpark("TPC-H 1") { securityLevel => - tpch.query(1, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 2", ignore) { securityLevel => - tpch.query(2, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 3") { securityLevel => - tpch.query(3, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 4", ignore) { securityLevel => - tpch.query(4, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 5") { securityLevel => - tpch.query(5, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 6") { securityLevel => - tpch.query(6, securityLevel, spark.sqlContext, numPartitions).collect.toSet - } - - testAgainstSpark("TPC-H 7") { securityLevel => - tpch.query(7, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 8") { securityLevel => - tpch.query(8, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 9") { securityLevel => - tpch.query(9, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 10") { securityLevel => - tpch.query(10, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 11", ignore) { securityLevel => - tpch.query(11, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 12") { securityLevel => - tpch.query(12, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 13", ignore) { securityLevel => - tpch.query(13, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 14") { securityLevel => - tpch.query(14, securityLevel, spark.sqlContext, numPartitions).collect.toSet - } - - testAgainstSpark("TPC-H 15", ignore) { securityLevel => - tpch.query(15, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 16", ignore) { securityLevel => - tpch.query(16, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 17") { securityLevel => - tpch.query(17, securityLevel, spark.sqlContext, numPartitions).collect.toSet - } - - testAgainstSpark("TPC-H 18", ignore) { securityLevel => - tpch.query(18, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 19") { securityLevel => - tpch.query(19, securityLevel, spark.sqlContext, numPartitions).collect.toSet - } - - testAgainstSpark("TPC-H 20") { securityLevel => - tpch.query(20, securityLevel, spark.sqlContext, numPartitions).collect.toSet - } - - testAgainstSpark("TPC-H 21", ignore) { securityLevel => - tpch.query(21, securityLevel, spark.sqlContext, numPartitions).collect - } - - testAgainstSpark("TPC-H 22", ignore) { securityLevel => - tpch.query(22, securityLevel, spark.sqlContext, numPartitions).collect + def tpch = new TPCH(spark.sqlContext, size, "file://") + + def runTests() = { + for (queryNum <- TPCHBenchmark.supportedQueries) { + val testStr = s"TPC-H $queryNum" + testAgainstSpark(testStr) { securityLevel => + tpch.query(queryNum, securityLevel, numPartitions).collect + } + } } }