Skip to content

Support for scalar subquery #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/enclave/Enclave/ExpressionEvaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,49 @@ class FlatbuffersExpressionEvaluator {
static_cast<const tuix::Literal *>(expr->expr())->value(), builder);
}

case tuix::ExprUnion_Decrypt:
{
auto decrypt_expr = static_cast<const tuix::Decrypt *>(expr->expr());
const tuix::Field *value =
flatbuffers::GetTemporaryPointer(builder, eval_helper(row, decrypt_expr->value()));

if (value->value_type() != tuix::FieldUnion_StringField) {
throw std::runtime_error(
std::string("tuix::Decrypt only accepts a string input, not ")
+ std::string(tuix::EnumNameFieldUnion(value->value_type())));
}

bool result_is_null = value->is_null();
if (!result_is_null) {
auto str_field = static_cast<const tuix::StringField *>(value->value());

std::vector<uint8_t> str_vec(
flatbuffers::VectorIterator<uint8_t, uint8_t>(str_field->value()->Data(),
static_cast<uint32_t>(0)),
flatbuffers::VectorIterator<uint8_t, uint8_t>(str_field->value()->Data(),
static_cast<uint32_t>(str_field->length())));

std::string ciphertext(str_vec.begin(), str_vec.end());
std::string ciphertext_decoded = ciphertext_base64_decode(ciphertext);

uint8_t *plaintext = new uint8_t[dec_size(ciphertext_decoded.size())];
decrypt(reinterpret_cast<const uint8_t *>(ciphertext_decoded.data()), ciphertext_decoded.size(), plaintext);

BufferRefView<tuix::Rows> buf(plaintext, ciphertext_decoded.size());
buf.verify();

const tuix::Rows *rows = buf.root();
const tuix::Field *field = rows->rows()->Get(0)->field_values()->Get(0);
auto ret = flatbuffers_copy<tuix::Field>(field, builder);

delete plaintext;
return ret;
} else {
throw std::runtime_error(std::string("tuix::Decrypt does not accept a NULL string\n"));
}

}

case tuix::ExprUnion_Cast:
{
auto cast = static_cast<const tuix::Cast *>(expr->expr());
Expand Down
76 changes: 76 additions & 0 deletions src/enclave/Enclave/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,79 @@ int secs_to_tm(long long t, struct tm *tm) {

return 0;
}

// Code adapted from https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c
/*
Copyright (C) 2004-2008 Rene Nyffenegger

This source code is provided 'as-is', without any express or implied
warranty. In no event will the author be held liable for any damages
arising from the use of this software.

Permission is granted to anyone to use this software for any purpose,
including commercial applications, and to alter it and redistribute it
freely, subject to the following restrictions:

1. The origin of this source code must not be misrepresented; you must not
claim that you wrote the original source code. If you use this source code
in a product, an acknowledgment in the product documentation would be
appreciated but is not required.

2. Altered source versions must be plainly marked as such, and must not be
misrepresented as being the original source code.

3. This notice may not be removed or altered from any source distribution.

Rene Nyffenegger [email protected]

*/

static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";

static inline bool is_base64(unsigned char c) {
return (isalnum(c) || (c == '+') || (c == '/'));
}

std::string ciphertext_base64_decode(const std::string &encoded_string) {
int in_len = encoded_string.size();
int i = 0;
int j = 0;
int in_ = 0;
uint8_t char_array_4[4], char_array_3[3];
std::string ret;

while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
char_array_4[i++] = encoded_string[in_]; in_++;
if (i ==4) {
for (i = 0; i <4; i++)
char_array_4[i] = base64_chars.find(char_array_4[i]);

char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];

for (i = 0; (i < 3); i++)
ret += char_array_3[i];
i = 0;
}
}

if (i) {
for (j = i; j <4; j++)
char_array_4[j] = 0;

for (j = 0; j <4; j++)
char_array_4[j] = base64_chars.find(char_array_4[j]);

char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];

for (j = 0; (j < i - 1); j++) ret += char_array_3[j];
}

return ret;
}
2 changes: 2 additions & 0 deletions src/enclave/Enclave/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ int pow_2(int value);

int secs_to_tm(long long t, struct tm *tm);

std::string ciphertext_base64_decode(const std::string &encoded_string);

#endif // UTIL_H
9 changes: 7 additions & 2 deletions src/flatbuffers/Expr.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ union ExprUnion {
CreateArray,
Upper,
DateAdd,
DateAddInterval
DateAddInterval,
Decrypt
}

table Expr {
Expand Down Expand Up @@ -221,4 +222,8 @@ table ClosestPoint {

table Upper {
child:Expr;
}
}

table Decrypt {
value:Expr;
}
87 changes: 87 additions & 0 deletions src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import java.io.File
import java.io.FileNotFoundException
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom
import java.util.Base64
import java.util.UUID

import javax.crypto._
Expand Down Expand Up @@ -92,6 +94,8 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.util.MapData
import org.apache.spark.sql.execution.SubqueryExec
import org.apache.spark.sql.execution.ScalarSubquery
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
Expand All @@ -102,6 +106,7 @@ import edu.berkeley.cs.rise.opaque.execution.Block
import edu.berkeley.cs.rise.opaque.execution.OpaqueOperatorExec
import edu.berkeley.cs.rise.opaque.execution.SGXEnclave
import edu.berkeley.cs.rise.opaque.expressions.ClosestPoint
import edu.berkeley.cs.rise.opaque.expressions.Decrypt
import edu.berkeley.cs.rise.opaque.expressions.DotProduct
import edu.berkeley.cs.rise.opaque.expressions.VectorAdd
import edu.berkeley.cs.rise.opaque.expressions.VectorMultiply
Expand Down Expand Up @@ -589,6 +594,7 @@ object Utils extends Logging {
tuix.StringField.createValueVector(builder, Array.empty),
0),
isNull)
case _ => throw new OpaqueException(s"FlatbuffersCreateField failed to match on ${value} of type {value.getClass.getName()}, ${dataType}")
}
}

Expand Down Expand Up @@ -663,6 +669,50 @@ object Utils extends Logging {

val MaxBlockSize = 1000

/**
* Encrypts/decrypts a given scalar value
**/
def encryptScalar(value: Any, dataType: DataType): String = {
// First serialize the scalar value
var builder = new FlatBufferBuilder
var rowOffsets = ArrayBuilder.make[Int]

val v = dataType match {
case StringType => UTF8String.fromString(value.asInstanceOf[String])
case _ => value
}

val isNull = (value == null)

// TODO: the NULL variable for field value could be set to true
builder.finish(
tuix.Rows.createRows(
builder,
tuix.Rows.createRowsVector(
builder,
Array(tuix.Row.createRow(
builder,
tuix.Row.createFieldValuesVector(
builder,
Array(flatbuffersCreateField(builder, v, dataType, false))),
isNull)))))

val plaintext = builder.sizedByteArray()
val ciphertext = encrypt(plaintext)
val ciphertext_str = Base64.getEncoder().encodeToString(ciphertext);
ciphertext_str
}

def decryptScalar(ciphertext: String): Any = {
val ciphertext_bytes = Base64.getDecoder().decode(ciphertext);
val plaintext = decrypt(ciphertext_bytes)
val rows = tuix.Rows.getRootAsRows(ByteBuffer.wrap(plaintext))
val row = rows.rows(0)
val field = row.fieldValues(0)
val value = flatbuffersExtractFieldValue(field)
value
}

/**
* Encrypts the given Spark SQL [[InternalRow]]s into a [[Block]] (a serialized
* tuix.EncryptedBlocks).
Expand Down Expand Up @@ -822,6 +872,13 @@ object Utils extends Logging {
tuix.ExprUnion.Literal,
tuix.Literal.createLiteral(builder, valueOffset))

// This expression should never be evaluated on the driver
case (Decrypt(child, dataType), Seq(childOffset)) =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe include a comment stating that this should never be evaluated on the driver?

tuix.Expr.createExpr(
builder,
tuix.ExprUnion.Decrypt,
tuix.Decrypt.createDecrypt(builder, childOffset))

case (Alias(child, _), Seq(childOffset)) =>
// TODO: Use an expression for aliases so we can refer to them elsewhere in the expression
// tree. For now we just ignore them when evaluating expressions.
Expand Down Expand Up @@ -1112,6 +1169,36 @@ object Utils extends Logging {
// TODO: Implement decimal serialization, followed by CheckOverflow
childOffset

case (ScalarSubquery(SubqueryExec(name, child), exprId), Seq()) =>
val output = child.output(0)
val dataType = output match {
case AttributeReference(name, dataType, _, _) => dataType
case _ => throw new OpaqueException("Scalar subquery cannot match to AttributeReference")
}
// Need to deserialize the encrypted blocks to get the encrypted block
val blockList = child.asInstanceOf[OpaqueOperatorExec].collectEncrypted()
val encryptedBlocksList = blockList.map { block =>
val buf = ByteBuffer.wrap(block.bytes)
tuix.EncryptedBlocks.getRootAsEncryptedBlocks(buf)
}
val encryptedBlocks = encryptedBlocksList.find(_.blocksLength > 0).getOrElse(encryptedBlocksList(0))
if (encryptedBlocks.blocksLength == 0) {
// If empty, the returned result is null
flatbuffersSerializeExpression(builder, Literal(null, dataType), input)
} else {
assert(encryptedBlocks.blocksLength == 1)
val encryptedBlock = encryptedBlocks.blocks(0)
val ciphertextBuf = encryptedBlock.encRowsAsByteBuffer
val ciphertext = new Array[Byte](ciphertextBuf.remaining)
ciphertextBuf.get(ciphertext)
val ciphertext_str = Base64.getEncoder().encodeToString(ciphertext)
flatbuffersSerializeExpression(
builder,
Decrypt(Literal(UTF8String.fromString(ciphertext_str), StringType), dataType),
input
)
}

case (_, Seq(childOffset)) =>
throw new OpaqueException("Expression not supported: " + expr.toString())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,19 @@ trait OpaqueOperatorExec extends SparkPlan {
* method and persist the resulting RDD. [[ConvertToOpaqueOperators]] later eliminates the dummy
* relation from the logical plan, but this only happens after InMemoryRelation has called this
* method. We therefore have to silently return an empty RDD here.
*/
*/

override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.emptyRDD
// throw new UnsupportedOperationException("use executeBlocked")
}

def collectEncrypted(): Array[Block] = {
executeBlocked().collect
}

override def executeCollect(): Array[InternalRow] = {
executeBlocked().collect().flatMap { block =>
collectEncrypted().flatMap { block =>
Utils.decryptBlockFlatbuffers(block)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ object ClosestPoint {
* point - list of coordinates representing a point
* centroids - list of lists of coordinates, each representing a point
""")
/**
*
*/
case class ClosestPoint(left: Expression, right: Expression)
extends BinaryExpression with NullIntolerant with CodegenFallback with ExpectsInputTypes {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package edu.berkeley.cs.rise.opaque.expressions

import edu.berkeley.cs.rise.opaque.Utils

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.ExpressionDescription
import org.apache.spark.sql.catalyst.expressions.NullIntolerant
import org.apache.spark.sql.catalyst.expressions.Nondeterministic
import org.apache.spark.sql.catalyst.expressions.UnaryExpression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String

object Decrypt {
def decrypt(v: Column, dataType: DataType): Column = new Column(Decrypt(v.expr, dataType))
}

@ExpressionDescription(
usage = """
_FUNC_(child, outputDataType) - Decrypt the input evaluated expression, which should always be a string
""",
arguments = """
Arguments:
* child - an encrypted literal of string type
* outputDataType - the decrypted data type
""")
case class Decrypt(child: Expression, outputDataType: DataType)
extends UnaryExpression with NullIntolerant with CodegenFallback with Nondeterministic {

override def dataType: DataType = outputDataType

protected def initializeInternal(partitionIndex: Int): Unit = { }

protected override def evalInternal(input: InternalRow): Any = {
val v = child.eval()
nullSafeEval(v)
}

protected override def nullSafeEval(input: Any): Any = {
// This function is implemented so that we can test against Spark;
// should never be used in production because we want to keep the literal encrypted
val v = input.asInstanceOf[UTF8String].toString
Utils.decryptScalar(v)
}
}
Loading