diff --git a/build.sbt b/build.sbt index c98816d9f3..9dfd59f16f 100644 --- a/build.sbt +++ b/build.sbt @@ -24,6 +24,7 @@ concurrentRestrictions in Global := Seq( fork in Test := true fork in run := true +testOptions in Test += Tests.Argument("-oF") javaOptions in Test ++= Seq("-Xmx2048m", "-XX:ReservedCodeCacheSize=384m") javaOptions in run ++= Seq( "-Xmx2048m", "-XX:ReservedCodeCacheSize=384m", "-Dspark.master=local[1]") diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 7da1a4e21a..1359dcb6f7 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -74,6 +74,7 @@ import org.apache.spark.sql.catalyst.expressions.Upper import org.apache.spark.sql.catalyst.expressions.Year import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.aggregate.Average +import org.apache.spark.sql.catalyst.expressions.aggregate.Complete import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.expressions.aggregate.Final import org.apache.spark.sql.catalyst.expressions.aggregate.First @@ -1136,17 +1137,15 @@ object Utils extends Logging { } def serializeAggOp( - groupingExpressions: Seq[Expression], - aggExpressions: Seq[NamedExpression], + groupingExpressions: Seq[NamedExpression], + aggExpressions: Seq[AggregateExpression], input: Seq[Attribute]): Array[Byte] = { - // aggExpressions contains both grouping expressions and AggregateExpressions. Transform the - // grouping expressions into AggregateExpressions that collect the first seen value. - val aggExpressionsWithFirst = aggExpressions.map { - case Alias(e: AggregateExpression, _) => e - case e: NamedExpression => AggregateExpression(First(e, Literal(false)), Final, false) + val aggGroupingExpressions = groupingExpressions.map { + case e: NamedExpression => AggregateExpression(First(e, Literal(false)), Complete, false) } + val aggregateExpressions = aggGroupingExpressions ++ aggExpressions - val aggSchema = aggExpressionsWithFirst.flatMap(_.aggregateFunction.aggBufferAttributes) + val aggSchema = aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) // For aggregation, we concatenate the current aggregate row with the new input row and run // the update expressions as a projection to obtain a new aggregate row. concatSchema // describes the schema of the temporary concatenated row. @@ -1161,7 +1160,7 @@ object Utils extends Logging { groupingExpressions.map(e => flatbuffersSerializeExpression(builder, e, input)).toArray), tuix.AggregateOp.createAggregateExpressionsVector( builder, - aggExpressionsWithFirst + aggregateExpressions .map(e => serializeAggExpression(builder, e, input, aggSchema, concatSchema)) .toArray))) builder.sizedByteArray() diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index fa50c23f7e..aa8a968c91 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan @@ -223,15 +224,16 @@ case class EncryptedFilterExec(condition: Expression, child: SparkPlan) } case class EncryptedAggregateExec( - groupingExpressions: Seq[Expression], - aggExpressions: Seq[NamedExpression], + groupingExpressions: Seq[NamedExpression], + aggExpressions: Seq[AggregateExpression], child: SparkPlan) extends UnaryExecNode with OpaqueOperatorExec { override def producedAttributes: AttributeSet = AttributeSet(aggExpressions) -- AttributeSet(groupingExpressions) - override def output: Seq[Attribute] = aggExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = + groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute) override def executeBlocked(): RDD[Block] = { val aggExprSer = Utils.serializeAggOp(groupingExpressions, aggExpressions, child.output) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/logical/rules.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/logical/rules.scala index b48f3f22d8..70257d8c6d 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/logical/rules.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/logical/rules.scala @@ -57,53 +57,6 @@ object ConvertToOpaqueOperators extends Rule[LogicalPlan] { case l @ LogicalRelation(baseRelation: EncryptedScan, _, _, false) => EncryptedBlockRDD(l.output, baseRelation.buildBlockedScan()) - case p @ Project(projectList, child) if isEncrypted(child) => - EncryptedProject(projectList, child.asInstanceOf[OpaqueOperator]) - - // We don't support null values yet, so there's no point in checking whether the output of an - // encrypted operator is null - case p @ Filter(And(IsNotNull(_), IsNotNull(_)), child) if isEncrypted(child) => - child - case p @ Filter(IsNotNull(_), child) if isEncrypted(child) => - child - - case p @ Filter(condition, child) if isEncrypted(child) => - EncryptedFilter(condition, child.asInstanceOf[OpaqueOperator]) - - case p @ Sort(order, true, child) if isEncrypted(child) => - EncryptedSort(order, child.asInstanceOf[OpaqueOperator]) - - case p @ Join(left, right, joinType, condition, _) if isEncrypted(p) => - EncryptedJoin( - left.asInstanceOf[OpaqueOperator], right.asInstanceOf[OpaqueOperator], joinType, condition) - - case p @ Aggregate(groupingExprs, aggExprs, child) if isEncrypted(p) => - UndoCollapseProject.separateProjectAndAgg(p) match { - case Some((projectExprs, aggExprs)) => - EncryptedProject( - projectExprs, - EncryptedAggregate( - groupingExprs, aggExprs, - EncryptedSort( - groupingExprs.map(e => SortOrder(e, Ascending)), - child.asInstanceOf[OpaqueOperator]))) - case None => - EncryptedAggregate( - groupingExprs, aggExprs, - EncryptedSort( - groupingExprs.map(e => SortOrder(e, Ascending)), - child.asInstanceOf[OpaqueOperator])) - } - - case p @ Union(Seq(left, right)) if isEncrypted(p) => - EncryptedUnion(left.asInstanceOf[OpaqueOperator], right.asInstanceOf[OpaqueOperator]) - - case p @ LocalLimit(limitExpr, child) if isEncrypted(p) => - EncryptedLocalLimit(limitExpr, child.asInstanceOf[OpaqueOperator]) - - case p @ GlobalLimit(limitExpr, child) if isEncrypted(p) => - EncryptedGlobalLimit(limitExpr, child.asInstanceOf[OpaqueOperator]) - case InMemoryRelationMatcher(output, storageLevel, child) if isEncrypted(child) => EncryptedBlockRDD( output, diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index b6d5ce4e72..0e1f3f3716 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -19,69 +19,120 @@ package edu.berkeley.cs.rise.opaque import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.And import org.apache.spark.sql.catalyst.expressions.Ascending import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.IntegerLiteral +import org.apache.spark.sql.catalyst.expressions.IsNotNull import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.catalyst.plans.logical.JoinHint -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.SparkPlan import edu.berkeley.cs.rise.opaque.execution._ import edu.berkeley.cs.rise.opaque.logical._ object OpaqueOperators extends Strategy { + + def isEncrypted(plan: LogicalPlan): Boolean = { + plan.find { + case _: OpaqueOperator => true + case _ => false + }.nonEmpty + } + + def isEncrypted(plan: SparkPlan): Boolean = { + plan.find { + case _: OpaqueOperatorExec => true + case _ => false + }.nonEmpty + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case EncryptedProject(projectList, child) => + case Project(projectList, child) if isEncrypted(child) => EncryptedProjectExec(projectList, planLater(child)) :: Nil - case EncryptedFilter(condition, child) => + // We don't support null values yet, so there's no point in checking whether the output of an + // encrypted operator is null + case p @ Filter(And(IsNotNull(_), IsNotNull(_)), child) if isEncrypted(child) => + planLater(child) :: Nil + case p @ Filter(IsNotNull(_), child) if isEncrypted(child) => + planLater(child) :: Nil + + case Filter(condition, child) if isEncrypted(child) => EncryptedFilterExec(condition, planLater(child)) :: Nil - case EncryptedSort(order, child) => - EncryptedSortExec(order, planLater(child)) :: Nil - - case EncryptedJoin(left, right, joinType, condition) => - Join(left, right, joinType, condition, JoinHint.NONE) match { - case ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _) => - val (leftProjSchema, leftKeysProj, tag) = tagForJoin(leftKeys, left.output, true) - val (rightProjSchema, rightKeysProj, _) = tagForJoin(rightKeys, right.output, false) - val leftProj = EncryptedProjectExec(leftProjSchema, planLater(left)) - val rightProj = EncryptedProjectExec(rightProjSchema, planLater(right)) - val unioned = EncryptedUnionExec(leftProj, rightProj) - val sorted = EncryptedSortExec(sortForJoin(leftKeysProj, tag, unioned.output), unioned) - val joined = EncryptedSortMergeJoinExec( - joinType, - leftKeysProj, - rightKeysProj, - leftProjSchema.map(_.toAttribute), - rightProjSchema.map(_.toAttribute), - (leftProjSchema ++ rightProjSchema).map(_.toAttribute), - sorted) - val tagsDropped = EncryptedProjectExec(dropTags(left.output, right.output), joined) - val filtered = condition match { - case Some(condition) => EncryptedFilterExec(condition, tagsDropped) - case None => tagsDropped - } - filtered :: Nil - case _ => Nil + case Sort(sortExprs, global, child) if isEncrypted(child) => + EncryptedSortExec(sortExprs, planLater(child)) :: Nil + + case p @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if isEncrypted(p) => + val (leftProjSchema, leftKeysProj, tag) = tagForJoin(leftKeys, left.output, true) + val (rightProjSchema, rightKeysProj, _) = tagForJoin(rightKeys, right.output, false) + val leftProj = EncryptedProjectExec(leftProjSchema, planLater(left)) + val rightProj = EncryptedProjectExec(rightProjSchema, planLater(right)) + val unioned = EncryptedUnionExec(leftProj, rightProj) + val sorted = EncryptedSortExec(sortForJoin(leftKeysProj, tag, unioned.output), unioned) + val joined = EncryptedSortMergeJoinExec( + joinType, + leftKeysProj, + rightKeysProj, + leftProjSchema.map(_.toAttribute), + rightProjSchema.map(_.toAttribute), + (leftProjSchema ++ rightProjSchema).map(_.toAttribute), + sorted) + val tagsDropped = EncryptedProjectExec(dropTags(left.output, right.output), joined) + val filtered = condition match { + case Some(condition) => EncryptedFilterExec(condition, tagsDropped) + case None => tagsDropped } + filtered :: Nil - case a @ EncryptedAggregate(groupingExpressions, aggExpressions, child) => - EncryptedAggregateExec(groupingExpressions, aggExpressions, planLater(child)) :: Nil + case a @ PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) => + val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression]).map(_.copy(mode = Complete)) - case EncryptedUnion(left, right) => + EncryptedProjectExec(resultExpressions, + EncryptedAggregateExec( + groupingExpressions, aggregateExpressions, + EncryptedSortExec( + groupingExpressions.map(e => SortOrder(e, Ascending)), planLater(child)))) :: Nil + + case p @ Union(Seq(left, right)) if isEncrypted(p) => EncryptedUnionExec(planLater(left), planLater(right)) :: Nil - case EncryptedLocalLimit(IntegerLiteral(limit), child) => + case ReturnAnswer(rootPlan) => rootPlan match { + case Limit(IntegerLiteral(limit), Sort(sortExprs, true, child)) if isEncrypted(child) => + EncryptedGlobalLimitExec(limit, + EncryptedLocalLimitExec(limit, + EncryptedSortExec(sortExprs, planLater(child)))) :: Nil + + case Limit(IntegerLiteral(limit), Project(projectList, child)) if isEncrypted(child) => + EncryptedGlobalLimitExec(limit, + EncryptedLocalLimitExec(limit, + EncryptedProjectExec(projectList, planLater(child)))) :: Nil + + case _ => Nil + } + + case Limit(IntegerLiteral(limit), Sort(sortExprs, true, child)) if isEncrypted(child) => + EncryptedGlobalLimitExec(limit, + EncryptedLocalLimitExec(limit, + EncryptedSortExec(sortExprs, planLater(child)))) :: Nil + + case Limit(IntegerLiteral(limit), Project(projectList, child)) if isEncrypted(child) => + EncryptedGlobalLimitExec(limit, + EncryptedLocalLimitExec(limit, + EncryptedProjectExec(projectList, planLater(child)))) :: Nil + + case LocalLimit(IntegerLiteral(limit), child) if isEncrypted(child) => EncryptedLocalLimitExec(limit, planLater(child)) :: Nil - case EncryptedGlobalLimit(IntegerLiteral(limit), child) => + case GlobalLimit(IntegerLiteral(limit), child) if isEncrypted(child) => EncryptedGlobalLimitExec(limit, planLater(child)) :: Nil case Encrypt(child) =>