diff --git a/src/enclave/Enclave/ExpressionEvaluation.h b/src/enclave/Enclave/ExpressionEvaluation.h index 9405ddd34f..0f48c56d48 100644 --- a/src/enclave/Enclave/ExpressionEvaluation.h +++ b/src/enclave/Enclave/ExpressionEvaluation.h @@ -288,6 +288,49 @@ class FlatbuffersExpressionEvaluator { static_cast(expr->expr())->value(), builder); } + case tuix::ExprUnion_Decrypt: + { + auto decrypt_expr = static_cast(expr->expr()); + const tuix::Field *value = + flatbuffers::GetTemporaryPointer(builder, eval_helper(row, decrypt_expr->value())); + + if (value->value_type() != tuix::FieldUnion_StringField) { + throw std::runtime_error( + std::string("tuix::Decrypt only accepts a string input, not ") + + std::string(tuix::EnumNameFieldUnion(value->value_type()))); + } + + bool result_is_null = value->is_null(); + if (!result_is_null) { + auto str_field = static_cast(value->value()); + + std::vector str_vec( + flatbuffers::VectorIterator(str_field->value()->Data(), + static_cast(0)), + flatbuffers::VectorIterator(str_field->value()->Data(), + static_cast(str_field->length()))); + + std::string ciphertext(str_vec.begin(), str_vec.end()); + std::string ciphertext_decoded = ciphertext_base64_decode(ciphertext); + + uint8_t *plaintext = new uint8_t[dec_size(ciphertext_decoded.size())]; + decrypt(reinterpret_cast(ciphertext_decoded.data()), ciphertext_decoded.size(), plaintext); + + BufferRefView buf(plaintext, ciphertext_decoded.size()); + buf.verify(); + + const tuix::Rows *rows = buf.root(); + const tuix::Field *field = rows->rows()->Get(0)->field_values()->Get(0); + auto ret = flatbuffers_copy(field, builder); + + delete plaintext; + return ret; + } else { + throw std::runtime_error(std::string("tuix::Decrypt does not accept a NULL string\n")); + } + + } + case tuix::ExprUnion_Cast: { auto cast = static_cast(expr->expr()); diff --git a/src/enclave/Enclave/util.cpp b/src/enclave/Enclave/util.cpp index 0f13e6af49..6cd2a898b0 100644 --- a/src/enclave/Enclave/util.cpp +++ b/src/enclave/Enclave/util.cpp @@ -142,3 +142,79 @@ int secs_to_tm(long long t, struct tm *tm) { return 0; } + +// Code adapted from https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +/* + Copyright (C) 2004-2008 Rene Nyffenegger + + This source code is provided 'as-is', without any express or implied + warranty. In no event will the author be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this source code must not be misrepresented; you must not + claim that you wrote the original source code. If you use this source code + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original source code. + + 3. This notice may not be removed or altered from any source distribution. + + Rene Nyffenegger rene.nyffenegger@adp-gmbh.ch + +*/ + +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static inline bool is_base64(unsigned char c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +std::string ciphertext_base64_decode(const std::string &encoded_string) { + int in_len = encoded_string.size(); + int i = 0; + int j = 0; + int in_ = 0; + uint8_t char_array_4[4], char_array_3[3]; + std::string ret; + + while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i ==4) { + for (i = 0; i <4; i++) + char_array_4[i] = base64_chars.find(char_array_4[i]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) + ret += char_array_3[i]; + i = 0; + } + } + + if (i) { + for (j = i; j <4; j++) + char_array_4[j] = 0; + + for (j = 0; j <4; j++) + char_array_4[j] = base64_chars.find(char_array_4[j]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) ret += char_array_3[j]; + } + + return ret; +} diff --git a/src/enclave/Enclave/util.h b/src/enclave/Enclave/util.h index b4e0b52327..df80ba7cd0 100644 --- a/src/enclave/Enclave/util.h +++ b/src/enclave/Enclave/util.h @@ -41,4 +41,6 @@ int pow_2(int value); int secs_to_tm(long long t, struct tm *tm); +std::string ciphertext_base64_decode(const std::string &encoded_string); + #endif // UTIL_H diff --git a/src/flatbuffers/Expr.fbs b/src/flatbuffers/Expr.fbs index a96215b5a2..4acce5e53d 100644 --- a/src/flatbuffers/Expr.fbs +++ b/src/flatbuffers/Expr.fbs @@ -40,7 +40,8 @@ union ExprUnion { CreateArray, Upper, DateAdd, - DateAddInterval + DateAddInterval, + Decrypt } table Expr { @@ -221,4 +222,8 @@ table ClosestPoint { table Upper { child:Expr; -} \ No newline at end of file +} + +table Decrypt { + value:Expr; +} diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index cb054a3d36..4c6970e489 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -21,7 +21,9 @@ import java.io.File import java.io.FileNotFoundException import java.nio.ByteBuffer import java.nio.ByteOrder +import java.nio.charset.StandardCharsets; import java.security.SecureRandom +import java.util.Base64 import java.util.UUID import javax.crypto._ @@ -92,6 +94,8 @@ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.util.MapData +import org.apache.spark.sql.execution.SubqueryExec +import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -102,6 +106,7 @@ import edu.berkeley.cs.rise.opaque.execution.Block import edu.berkeley.cs.rise.opaque.execution.OpaqueOperatorExec import edu.berkeley.cs.rise.opaque.execution.SGXEnclave import edu.berkeley.cs.rise.opaque.expressions.ClosestPoint +import edu.berkeley.cs.rise.opaque.expressions.Decrypt import edu.berkeley.cs.rise.opaque.expressions.DotProduct import edu.berkeley.cs.rise.opaque.expressions.VectorAdd import edu.berkeley.cs.rise.opaque.expressions.VectorMultiply @@ -589,6 +594,7 @@ object Utils extends Logging { tuix.StringField.createValueVector(builder, Array.empty), 0), isNull) + case _ => throw new OpaqueException(s"FlatbuffersCreateField failed to match on ${value} of type {value.getClass.getName()}, ${dataType}") } } @@ -663,6 +669,50 @@ object Utils extends Logging { val MaxBlockSize = 1000 + /** + * Encrypts/decrypts a given scalar value + **/ + def encryptScalar(value: Any, dataType: DataType): String = { + // First serialize the scalar value + var builder = new FlatBufferBuilder + var rowOffsets = ArrayBuilder.make[Int] + + val v = dataType match { + case StringType => UTF8String.fromString(value.asInstanceOf[String]) + case _ => value + } + + val isNull = (value == null) + + // TODO: the NULL variable for field value could be set to true + builder.finish( + tuix.Rows.createRows( + builder, + tuix.Rows.createRowsVector( + builder, + Array(tuix.Row.createRow( + builder, + tuix.Row.createFieldValuesVector( + builder, + Array(flatbuffersCreateField(builder, v, dataType, false))), + isNull))))) + + val plaintext = builder.sizedByteArray() + val ciphertext = encrypt(plaintext) + val ciphertext_str = Base64.getEncoder().encodeToString(ciphertext); + ciphertext_str + } + + def decryptScalar(ciphertext: String): Any = { + val ciphertext_bytes = Base64.getDecoder().decode(ciphertext); + val plaintext = decrypt(ciphertext_bytes) + val rows = tuix.Rows.getRootAsRows(ByteBuffer.wrap(plaintext)) + val row = rows.rows(0) + val field = row.fieldValues(0) + val value = flatbuffersExtractFieldValue(field) + value + } + /** * Encrypts the given Spark SQL [[InternalRow]]s into a [[Block]] (a serialized * tuix.EncryptedBlocks). @@ -822,6 +872,13 @@ object Utils extends Logging { tuix.ExprUnion.Literal, tuix.Literal.createLiteral(builder, valueOffset)) + // This expression should never be evaluated on the driver + case (Decrypt(child, dataType), Seq(childOffset)) => + tuix.Expr.createExpr( + builder, + tuix.ExprUnion.Decrypt, + tuix.Decrypt.createDecrypt(builder, childOffset)) + case (Alias(child, _), Seq(childOffset)) => // TODO: Use an expression for aliases so we can refer to them elsewhere in the expression // tree. For now we just ignore them when evaluating expressions. @@ -1112,6 +1169,36 @@ object Utils extends Logging { // TODO: Implement decimal serialization, followed by CheckOverflow childOffset + case (ScalarSubquery(SubqueryExec(name, child), exprId), Seq()) => + val output = child.output(0) + val dataType = output match { + case AttributeReference(name, dataType, _, _) => dataType + case _ => throw new OpaqueException("Scalar subquery cannot match to AttributeReference") + } + // Need to deserialize the encrypted blocks to get the encrypted block + val blockList = child.asInstanceOf[OpaqueOperatorExec].collectEncrypted() + val encryptedBlocksList = blockList.map { block => + val buf = ByteBuffer.wrap(block.bytes) + tuix.EncryptedBlocks.getRootAsEncryptedBlocks(buf) + } + val encryptedBlocks = encryptedBlocksList.find(_.blocksLength > 0).getOrElse(encryptedBlocksList(0)) + if (encryptedBlocks.blocksLength == 0) { + // If empty, the returned result is null + flatbuffersSerializeExpression(builder, Literal(null, dataType), input) + } else { + assert(encryptedBlocks.blocksLength == 1) + val encryptedBlock = encryptedBlocks.blocks(0) + val ciphertextBuf = encryptedBlock.encRowsAsByteBuffer + val ciphertext = new Array[Byte](ciphertextBuf.remaining) + ciphertextBuf.get(ciphertext) + val ciphertext_str = Base64.getEncoder().encodeToString(ciphertext) + flatbuffersSerializeExpression( + builder, + Decrypt(Literal(UTF8String.fromString(ciphertext_str), StringType), dataType), + input + ) + } + case (_, Seq(childOffset)) => throw new OpaqueException("Expression not supported: " + expr.toString()) } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index 7ed6862b6b..4eb941157e 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -134,14 +134,19 @@ trait OpaqueOperatorExec extends SparkPlan { * method and persist the resulting RDD. [[ConvertToOpaqueOperators]] later eliminates the dummy * relation from the logical plan, but this only happens after InMemoryRelation has called this * method. We therefore have to silently return an empty RDD here. - */ + */ + override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.emptyRDD // throw new UnsupportedOperationException("use executeBlocked") } + def collectEncrypted(): Array[Block] = { + executeBlocked().collect + } + override def executeCollect(): Array[InternalRow] = { - executeBlocked().collect().flatMap { block => + collectEncrypted().flatMap { block => Utils.decryptBlockFlatbuffers(block) } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/ClosestPoint.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/ClosestPoint.scala index b4f1e27200..7eac3c990c 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/ClosestPoint.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/ClosestPoint.scala @@ -29,9 +29,6 @@ object ClosestPoint { * point - list of coordinates representing a point * centroids - list of lists of coordinates, each representing a point """) -/** - * - */ case class ClosestPoint(left: Expression, right: Expression) extends BinaryExpression with NullIntolerant with CodegenFallback with ExpectsInputTypes { diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/Decrypt.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/Decrypt.scala new file mode 100644 index 0000000000..a52ecb113e --- /dev/null +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/expressions/Decrypt.scala @@ -0,0 +1,49 @@ +package edu.berkeley.cs.rise.opaque.expressions + +import edu.berkeley.cs.rise.opaque.Utils + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExpressionDescription +import org.apache.spark.sql.catalyst.expressions.NullIntolerant +import org.apache.spark.sql.catalyst.expressions.Nondeterministic +import org.apache.spark.sql.catalyst.expressions.UnaryExpression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.StringType +import org.apache.spark.unsafe.types.UTF8String + +object Decrypt { + def decrypt(v: Column, dataType: DataType): Column = new Column(Decrypt(v.expr, dataType)) +} + +@ExpressionDescription( + usage = """ + _FUNC_(child, outputDataType) - Decrypt the input evaluated expression, which should always be a string + """, + arguments = """ + Arguments: + * child - an encrypted literal of string type + * outputDataType - the decrypted data type + """) +case class Decrypt(child: Expression, outputDataType: DataType) + extends UnaryExpression with NullIntolerant with CodegenFallback with Nondeterministic { + + override def dataType: DataType = outputDataType + + protected def initializeInternal(partitionIndex: Int): Unit = { } + + protected override def evalInternal(input: InternalRow): Any = { + val v = child.eval() + nullSafeEval(v) + } + + protected override def nullSafeEval(input: Any): Any = { + // This function is implemented so that we can test against Spark; + // should never be used in production because we want to keep the literal encrypted + val v = input.asInstanceOf[UTF8String].toString + Utils.decryptScalar(v) + } +} diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index 16c8082fbd..a69894d13c 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -35,6 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval import edu.berkeley.cs.rise.opaque.benchmark._ import edu.berkeley.cs.rise.opaque.execution.EncryptedBlockRDDScanExec +import edu.berkeley.cs.rise.opaque.expressions.Decrypt.decrypt import edu.berkeley.cs.rise.opaque.expressions.DotProduct.dot import edu.berkeley.cs.rise.opaque.expressions.VectorMultiply.vectormultiply import edu.berkeley.cs.rise.opaque.expressions.VectorSum @@ -879,6 +880,30 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => KMeans.train(spark, securityLevel, numPartitions, 10, 2, 3, 0.01).map(_.toSeq).sorted } + testAgainstSpark("encrypted literal") { securityLevel => + val input = 10 + val enc_str = Utils.encryptScalar(input, IntegerType) + + val data = for (i <- 0 until 256) yield (i, abc(i), 1) + val words = makeDF(data, securityLevel, "id", "word", "count") + val df = words.filter($"id" < decrypt(lit(enc_str), IntegerType)).sort($"id") + df.collect + } + + testAgainstSpark("scalar subquery") { securityLevel => + // Example taken from https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/2728434780191932/1483312212640900/6987336228780374/latest.html + val data = for (i <- 0 until 256) yield (i, abc(i), i) + val words = makeDF(data, securityLevel, "id", "word", "count") + words.createTempView("words") + + try { + val df = spark.sql("""SELECT id, word, (SELECT MAX(count) FROM words) max_age FROM words ORDER BY id, word""") + df.collect + } finally { + spark.catalog.dropTempView("words") + } + } + testAgainstSpark("pagerank") { securityLevel => PageRank.run(spark, securityLevel, "256", numPartitions).collect.toSet } diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala index 8117fb8de1..54ded162bc 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala @@ -68,7 +68,7 @@ trait OpaqueTestsBase extends FunSuite with BeforeAndAfterAll { self => testFunc(name + " - encrypted") { // The === operator uses implicitly[Equality[A]], which compares Double and Array[Double] // using the numeric tolerance specified above - assert(f(Encrypted) === f(Insecure)) + assert(f(Insecure) === f(Encrypted)) } } @@ -102,4 +102,4 @@ trait OpaqueTestsBase extends FunSuite with BeforeAndAfterAll { self => } } } -} \ No newline at end of file +} diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala index ed8da375c5..8d60dfa550 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala @@ -40,7 +40,7 @@ trait TPCHTests extends OpaqueTestsBase { self => tpch.query(3, securityLevel, spark.sqlContext, numPartitions).collect } - testAgainstSpark("TPC-H 4", ignore) { securityLevel => + testAgainstSpark("TPC-H 4") { securityLevel => tpch.query(4, securityLevel, spark.sqlContext, numPartitions).collect } @@ -68,7 +68,7 @@ trait TPCHTests extends OpaqueTestsBase { self => tpch.query(10, securityLevel, spark.sqlContext, numPartitions).collect } - testAgainstSpark("TPC-H 11", ignore) { securityLevel => + testAgainstSpark("TPC-H 11") { securityLevel => tpch.query(11, securityLevel, spark.sqlContext, numPartitions).collect } @@ -84,7 +84,7 @@ trait TPCHTests extends OpaqueTestsBase { self => tpch.query(14, securityLevel, spark.sqlContext, numPartitions).collect.toSet } - testAgainstSpark("TPC-H 15", ignore) { securityLevel => + testAgainstSpark("TPC-H 15") { securityLevel => tpch.query(15, securityLevel, spark.sqlContext, numPartitions).collect } @@ -112,7 +112,7 @@ trait TPCHTests extends OpaqueTestsBase { self => tpch.query(21, securityLevel, spark.sqlContext, numPartitions).collect } - testAgainstSpark("TPC-H 22", ignore) { securityLevel => + testAgainstSpark("TPC-H 22") { securityLevel => tpch.query(22, securityLevel, spark.sqlContext, numPartitions).collect } }