Skip to content

Commit 4d89ecb

Browse files
authored
Support for scalar subquery (#157)
This PR implements the scalar subquery expression, which is triggered whenever a subquery returns a scalar value. There were two main problems that needed to be solved. First, support for matching the scalar subquery expression is necessary. Spark implements this by wrapping a SparkPlan within the expression and calls executeCollect. Then it constructs a literal with that value. However, this is problematic for us because that value should not be decrypted by the driver and serialized into an expression, since it's an intermediate value. Therefore, the second issue to be addressed here is supporting an encrypted literal. This is implemented in this PR by serializing an encrypted ciphertext into a base64 encoded string, and wrapping a Decrypt expression on top of it. This expression is then evaluated in the enclave and returns a literal. Note that, in order to test our implementation, we also implement a Decrypt expression in Scala. However, this should never be evaluated on the driver side and serialized into a plaintext literal. This is because Decrypt is designated as a Nondeterministic expression, and therefore will always evaluate on the workers.
1 parent 29da474 commit 4d89ecb

File tree

11 files changed

+302
-13
lines changed

11 files changed

+302
-13
lines changed

src/enclave/Enclave/ExpressionEvaluation.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,49 @@ class FlatbuffersExpressionEvaluator {
288288
static_cast<const tuix::Literal *>(expr->expr())->value(), builder);
289289
}
290290

291+
case tuix::ExprUnion_Decrypt:
292+
{
293+
auto decrypt_expr = static_cast<const tuix::Decrypt *>(expr->expr());
294+
const tuix::Field *value =
295+
flatbuffers::GetTemporaryPointer(builder, eval_helper(row, decrypt_expr->value()));
296+
297+
if (value->value_type() != tuix::FieldUnion_StringField) {
298+
throw std::runtime_error(
299+
std::string("tuix::Decrypt only accepts a string input, not ")
300+
+ std::string(tuix::EnumNameFieldUnion(value->value_type())));
301+
}
302+
303+
bool result_is_null = value->is_null();
304+
if (!result_is_null) {
305+
auto str_field = static_cast<const tuix::StringField *>(value->value());
306+
307+
std::vector<uint8_t> str_vec(
308+
flatbuffers::VectorIterator<uint8_t, uint8_t>(str_field->value()->Data(),
309+
static_cast<uint32_t>(0)),
310+
flatbuffers::VectorIterator<uint8_t, uint8_t>(str_field->value()->Data(),
311+
static_cast<uint32_t>(str_field->length())));
312+
313+
std::string ciphertext(str_vec.begin(), str_vec.end());
314+
std::string ciphertext_decoded = ciphertext_base64_decode(ciphertext);
315+
316+
uint8_t *plaintext = new uint8_t[dec_size(ciphertext_decoded.size())];
317+
decrypt(reinterpret_cast<const uint8_t *>(ciphertext_decoded.data()), ciphertext_decoded.size(), plaintext);
318+
319+
BufferRefView<tuix::Rows> buf(plaintext, ciphertext_decoded.size());
320+
buf.verify();
321+
322+
const tuix::Rows *rows = buf.root();
323+
const tuix::Field *field = rows->rows()->Get(0)->field_values()->Get(0);
324+
auto ret = flatbuffers_copy<tuix::Field>(field, builder);
325+
326+
delete plaintext;
327+
return ret;
328+
} else {
329+
throw std::runtime_error(std::string("tuix::Decrypt does not accept a NULL string\n"));
330+
}
331+
332+
}
333+
291334
case tuix::ExprUnion_Cast:
292335
{
293336
auto cast = static_cast<const tuix::Cast *>(expr->expr());

src/enclave/Enclave/util.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,79 @@ int secs_to_tm(long long t, struct tm *tm) {
142142

143143
return 0;
144144
}
145+
146+
// Code adapted from https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c
147+
/*
148+
Copyright (C) 2004-2008 Rene Nyffenegger
149+
150+
This source code is provided 'as-is', without any express or implied
151+
warranty. In no event will the author be held liable for any damages
152+
arising from the use of this software.
153+
154+
Permission is granted to anyone to use this software for any purpose,
155+
including commercial applications, and to alter it and redistribute it
156+
freely, subject to the following restrictions:
157+
158+
1. The origin of this source code must not be misrepresented; you must not
159+
claim that you wrote the original source code. If you use this source code
160+
in a product, an acknowledgment in the product documentation would be
161+
appreciated but is not required.
162+
163+
2. Altered source versions must be plainly marked as such, and must not be
164+
misrepresented as being the original source code.
165+
166+
3. This notice may not be removed or altered from any source distribution.
167+
168+
Rene Nyffenegger [email protected]
169+
170+
*/
171+
172+
static const std::string base64_chars =
173+
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
174+
"abcdefghijklmnopqrstuvwxyz"
175+
"0123456789+/";
176+
177+
static inline bool is_base64(unsigned char c) {
178+
return (isalnum(c) || (c == '+') || (c == '/'));
179+
}
180+
181+
std::string ciphertext_base64_decode(const std::string &encoded_string) {
182+
int in_len = encoded_string.size();
183+
int i = 0;
184+
int j = 0;
185+
int in_ = 0;
186+
uint8_t char_array_4[4], char_array_3[3];
187+
std::string ret;
188+
189+
while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
190+
char_array_4[i++] = encoded_string[in_]; in_++;
191+
if (i ==4) {
192+
for (i = 0; i <4; i++)
193+
char_array_4[i] = base64_chars.find(char_array_4[i]);
194+
195+
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
196+
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
197+
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
198+
199+
for (i = 0; (i < 3); i++)
200+
ret += char_array_3[i];
201+
i = 0;
202+
}
203+
}
204+
205+
if (i) {
206+
for (j = i; j <4; j++)
207+
char_array_4[j] = 0;
208+
209+
for (j = 0; j <4; j++)
210+
char_array_4[j] = base64_chars.find(char_array_4[j]);
211+
212+
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
213+
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
214+
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
215+
216+
for (j = 0; (j < i - 1); j++) ret += char_array_3[j];
217+
}
218+
219+
return ret;
220+
}

src/enclave/Enclave/util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,6 @@ int pow_2(int value);
4141

4242
int secs_to_tm(long long t, struct tm *tm);
4343

44+
std::string ciphertext_base64_decode(const std::string &encoded_string);
45+
4446
#endif // UTIL_H

src/flatbuffers/Expr.fbs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ union ExprUnion {
4040
CreateArray,
4141
Upper,
4242
DateAdd,
43-
DateAddInterval
43+
DateAddInterval,
44+
Decrypt
4445
}
4546

4647
table Expr {
@@ -221,4 +222,8 @@ table ClosestPoint {
221222

222223
table Upper {
223224
child:Expr;
224-
}
225+
}
226+
227+
table Decrypt {
228+
value:Expr;
229+
}

src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ import java.io.File
2121
import java.io.FileNotFoundException
2222
import java.nio.ByteBuffer
2323
import java.nio.ByteOrder
24+
import java.nio.charset.StandardCharsets;
2425
import java.security.SecureRandom
26+
import java.util.Base64
2527
import java.util.UUID
2628

2729
import javax.crypto._
@@ -92,6 +94,8 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
9294
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
9395
import org.apache.spark.sql.catalyst.util.ArrayData
9496
import org.apache.spark.sql.catalyst.util.MapData
97+
import org.apache.spark.sql.execution.SubqueryExec
98+
import org.apache.spark.sql.execution.ScalarSubquery
9599
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
96100
import org.apache.spark.sql.types._
97101
import org.apache.spark.storage.StorageLevel
@@ -102,6 +106,7 @@ import edu.berkeley.cs.rise.opaque.execution.Block
102106
import edu.berkeley.cs.rise.opaque.execution.OpaqueOperatorExec
103107
import edu.berkeley.cs.rise.opaque.execution.SGXEnclave
104108
import edu.berkeley.cs.rise.opaque.expressions.ClosestPoint
109+
import edu.berkeley.cs.rise.opaque.expressions.Decrypt
105110
import edu.berkeley.cs.rise.opaque.expressions.DotProduct
106111
import edu.berkeley.cs.rise.opaque.expressions.VectorAdd
107112
import edu.berkeley.cs.rise.opaque.expressions.VectorMultiply
@@ -589,6 +594,7 @@ object Utils extends Logging {
589594
tuix.StringField.createValueVector(builder, Array.empty),
590595
0),
591596
isNull)
597+
case _ => throw new OpaqueException(s"FlatbuffersCreateField failed to match on ${value} of type {value.getClass.getName()}, ${dataType}")
592598
}
593599
}
594600

@@ -663,6 +669,50 @@ object Utils extends Logging {
663669

664670
val MaxBlockSize = 1000
665671

672+
/**
673+
* Encrypts/decrypts a given scalar value
674+
**/
675+
def encryptScalar(value: Any, dataType: DataType): String = {
676+
// First serialize the scalar value
677+
var builder = new FlatBufferBuilder
678+
var rowOffsets = ArrayBuilder.make[Int]
679+
680+
val v = dataType match {
681+
case StringType => UTF8String.fromString(value.asInstanceOf[String])
682+
case _ => value
683+
}
684+
685+
val isNull = (value == null)
686+
687+
// TODO: the NULL variable for field value could be set to true
688+
builder.finish(
689+
tuix.Rows.createRows(
690+
builder,
691+
tuix.Rows.createRowsVector(
692+
builder,
693+
Array(tuix.Row.createRow(
694+
builder,
695+
tuix.Row.createFieldValuesVector(
696+
builder,
697+
Array(flatbuffersCreateField(builder, v, dataType, false))),
698+
isNull)))))
699+
700+
val plaintext = builder.sizedByteArray()
701+
val ciphertext = encrypt(plaintext)
702+
val ciphertext_str = Base64.getEncoder().encodeToString(ciphertext);
703+
ciphertext_str
704+
}
705+
706+
def decryptScalar(ciphertext: String): Any = {
707+
val ciphertext_bytes = Base64.getDecoder().decode(ciphertext);
708+
val plaintext = decrypt(ciphertext_bytes)
709+
val rows = tuix.Rows.getRootAsRows(ByteBuffer.wrap(plaintext))
710+
val row = rows.rows(0)
711+
val field = row.fieldValues(0)
712+
val value = flatbuffersExtractFieldValue(field)
713+
value
714+
}
715+
666716
/**
667717
* Encrypts the given Spark SQL [[InternalRow]]s into a [[Block]] (a serialized
668718
* tuix.EncryptedBlocks).
@@ -822,6 +872,13 @@ object Utils extends Logging {
822872
tuix.ExprUnion.Literal,
823873
tuix.Literal.createLiteral(builder, valueOffset))
824874

875+
// This expression should never be evaluated on the driver
876+
case (Decrypt(child, dataType), Seq(childOffset)) =>
877+
tuix.Expr.createExpr(
878+
builder,
879+
tuix.ExprUnion.Decrypt,
880+
tuix.Decrypt.createDecrypt(builder, childOffset))
881+
825882
case (Alias(child, _), Seq(childOffset)) =>
826883
// TODO: Use an expression for aliases so we can refer to them elsewhere in the expression
827884
// tree. For now we just ignore them when evaluating expressions.
@@ -1112,6 +1169,36 @@ object Utils extends Logging {
11121169
// TODO: Implement decimal serialization, followed by CheckOverflow
11131170
childOffset
11141171

1172+
case (ScalarSubquery(SubqueryExec(name, child), exprId), Seq()) =>
1173+
val output = child.output(0)
1174+
val dataType = output match {
1175+
case AttributeReference(name, dataType, _, _) => dataType
1176+
case _ => throw new OpaqueException("Scalar subquery cannot match to AttributeReference")
1177+
}
1178+
// Need to deserialize the encrypted blocks to get the encrypted block
1179+
val blockList = child.asInstanceOf[OpaqueOperatorExec].collectEncrypted()
1180+
val encryptedBlocksList = blockList.map { block =>
1181+
val buf = ByteBuffer.wrap(block.bytes)
1182+
tuix.EncryptedBlocks.getRootAsEncryptedBlocks(buf)
1183+
}
1184+
val encryptedBlocks = encryptedBlocksList.find(_.blocksLength > 0).getOrElse(encryptedBlocksList(0))
1185+
if (encryptedBlocks.blocksLength == 0) {
1186+
// If empty, the returned result is null
1187+
flatbuffersSerializeExpression(builder, Literal(null, dataType), input)
1188+
} else {
1189+
assert(encryptedBlocks.blocksLength == 1)
1190+
val encryptedBlock = encryptedBlocks.blocks(0)
1191+
val ciphertextBuf = encryptedBlock.encRowsAsByteBuffer
1192+
val ciphertext = new Array[Byte](ciphertextBuf.remaining)
1193+
ciphertextBuf.get(ciphertext)
1194+
val ciphertext_str = Base64.getEncoder().encodeToString(ciphertext)
1195+
flatbuffersSerializeExpression(
1196+
builder,
1197+
Decrypt(Literal(UTF8String.fromString(ciphertext_str), StringType), dataType),
1198+
input
1199+
)
1200+
}
1201+
11151202
case (_, Seq(childOffset)) =>
11161203
throw new OpaqueException("Expression not supported: " + expr.toString())
11171204
}

src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,19 @@ trait OpaqueOperatorExec extends SparkPlan {
134134
* method and persist the resulting RDD. [[ConvertToOpaqueOperators]] later eliminates the dummy
135135
* relation from the logical plan, but this only happens after InMemoryRelation has called this
136136
* method. We therefore have to silently return an empty RDD here.
137-
*/
137+
*/
138+
138139
override def doExecute(): RDD[InternalRow] = {
139140
sqlContext.sparkContext.emptyRDD
140141
// throw new UnsupportedOperationException("use executeBlocked")
141142
}
142143

144+
def collectEncrypted(): Array[Block] = {
145+
executeBlocked().collect
146+
}
147+
143148
override def executeCollect(): Array[InternalRow] = {
144-
executeBlocked().collect().flatMap { block =>
149+
collectEncrypted().flatMap { block =>
145150
Utils.decryptBlockFlatbuffers(block)
146151
}
147152
}

src/main/scala/edu/berkeley/cs/rise/opaque/expressions/ClosestPoint.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ object ClosestPoint {
2929
* point - list of coordinates representing a point
3030
* centroids - list of lists of coordinates, each representing a point
3131
""")
32-
/**
33-
*
34-
*/
3532
case class ClosestPoint(left: Expression, right: Expression)
3633
extends BinaryExpression with NullIntolerant with CodegenFallback with ExpectsInputTypes {
3734

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package edu.berkeley.cs.rise.opaque.expressions
2+
3+
import edu.berkeley.cs.rise.opaque.Utils
4+
5+
import org.apache.spark.sql.Column
6+
import org.apache.spark.sql.catalyst.InternalRow
7+
import org.apache.spark.sql.catalyst.expressions.Expression
8+
import org.apache.spark.sql.catalyst.expressions.ExpressionDescription
9+
import org.apache.spark.sql.catalyst.expressions.NullIntolerant
10+
import org.apache.spark.sql.catalyst.expressions.Nondeterministic
11+
import org.apache.spark.sql.catalyst.expressions.UnaryExpression
12+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
13+
import org.apache.spark.sql.types.DataType
14+
import org.apache.spark.sql.types.DataTypes
15+
import org.apache.spark.sql.types.StringType
16+
import org.apache.spark.unsafe.types.UTF8String
17+
18+
object Decrypt {
19+
def decrypt(v: Column, dataType: DataType): Column = new Column(Decrypt(v.expr, dataType))
20+
}
21+
22+
@ExpressionDescription(
23+
usage = """
24+
_FUNC_(child, outputDataType) - Decrypt the input evaluated expression, which should always be a string
25+
""",
26+
arguments = """
27+
Arguments:
28+
* child - an encrypted literal of string type
29+
* outputDataType - the decrypted data type
30+
""")
31+
case class Decrypt(child: Expression, outputDataType: DataType)
32+
extends UnaryExpression with NullIntolerant with CodegenFallback with Nondeterministic {
33+
34+
override def dataType: DataType = outputDataType
35+
36+
protected def initializeInternal(partitionIndex: Int): Unit = { }
37+
38+
protected override def evalInternal(input: InternalRow): Any = {
39+
val v = child.eval()
40+
nullSafeEval(v)
41+
}
42+
43+
protected override def nullSafeEval(input: Any): Any = {
44+
// This function is implemented so that we can test against Spark;
45+
// should never be used in production because we want to keep the literal encrypted
46+
val v = input.asInstanceOf[UTF8String].toString
47+
Utils.decryptScalar(v)
48+
}
49+
}

0 commit comments

Comments
 (0)