Skip to content

Commit 51b621b

Browse files
authored
Changing operator matching from logical to physical (#129)
* WIP * Fix * Unapply change
1 parent 29e3312 commit 51b621b

File tree

5 files changed

+102
-96
lines changed

5 files changed

+102
-96
lines changed

build.sbt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ concurrentRestrictions in Global := Seq(
2424
fork in Test := true
2525
fork in run := true
2626

27+
testOptions in Test += Tests.Argument("-oF")
2728
javaOptions in Test ++= Seq("-Xmx2048m", "-XX:ReservedCodeCacheSize=384m")
2829
javaOptions in run ++= Seq(
2930
"-Xmx2048m", "-XX:ReservedCodeCacheSize=384m", "-Dspark.master=local[1]")

src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ import org.apache.spark.sql.catalyst.expressions.Upper
7777
import org.apache.spark.sql.catalyst.expressions.Year
7878
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
7979
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
80+
import org.apache.spark.sql.catalyst.expressions.aggregate.Complete
8081
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
8182
import org.apache.spark.sql.catalyst.expressions.aggregate.Final
8283
import org.apache.spark.sql.catalyst.expressions.aggregate.First
@@ -1154,17 +1155,15 @@ object Utils extends Logging {
11541155
}
11551156

11561157
def serializeAggOp(
1157-
groupingExpressions: Seq[Expression],
1158-
aggExpressions: Seq[NamedExpression],
1158+
groupingExpressions: Seq[NamedExpression],
1159+
aggExpressions: Seq[AggregateExpression],
11591160
input: Seq[Attribute]): Array[Byte] = {
1160-
// aggExpressions contains both grouping expressions and AggregateExpressions. Transform the
1161-
// grouping expressions into AggregateExpressions that collect the first seen value.
1162-
val aggExpressionsWithFirst = aggExpressions.map {
1163-
case Alias(e: AggregateExpression, _) => e
1164-
case e: NamedExpression => AggregateExpression(First(e, Literal(false)), Final, false)
1161+
val aggGroupingExpressions = groupingExpressions.map {
1162+
case e: NamedExpression => AggregateExpression(First(e, Literal(false)), Complete, false)
11651163
}
1164+
val aggregateExpressions = aggGroupingExpressions ++ aggExpressions
11661165

1167-
val aggSchema = aggExpressionsWithFirst.flatMap(_.aggregateFunction.aggBufferAttributes)
1166+
val aggSchema = aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
11681167
// For aggregation, we concatenate the current aggregate row with the new input row and run
11691168
// the update expressions as a projection to obtain a new aggregate row. concatSchema
11701169
// describes the schema of the temporary concatenated row.
@@ -1179,7 +1178,7 @@ object Utils extends Logging {
11791178
groupingExpressions.map(e => flatbuffersSerializeExpression(builder, e, input)).toArray),
11801179
tuix.AggregateOp.createAggregateExpressionsVector(
11811180
builder,
1182-
aggExpressionsWithFirst
1181+
aggregateExpressions
11831182
.map(e => serializeAggExpression(builder, e, input, aggSchema, concatSchema))
11841183
.toArray)))
11851184
builder.sizedByteArray()

src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.AttributeSet
2626
import org.apache.spark.sql.catalyst.expressions._
27+
import org.apache.spark.sql.catalyst.expressions.aggregate._
2728
import org.apache.spark.sql.catalyst.plans.JoinType
2829
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2930
import org.apache.spark.sql.execution.SparkPlan
@@ -223,15 +224,16 @@ case class EncryptedFilterExec(condition: Expression, child: SparkPlan)
223224
}
224225

225226
case class EncryptedAggregateExec(
226-
groupingExpressions: Seq[Expression],
227-
aggExpressions: Seq[NamedExpression],
227+
groupingExpressions: Seq[NamedExpression],
228+
aggExpressions: Seq[AggregateExpression],
228229
child: SparkPlan)
229230
extends UnaryExecNode with OpaqueOperatorExec {
230231

231232
override def producedAttributes: AttributeSet =
232233
AttributeSet(aggExpressions) -- AttributeSet(groupingExpressions)
233234

234-
override def output: Seq[Attribute] = aggExpressions.map(_.toAttribute)
235+
override def output: Seq[Attribute] =
236+
groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute)
235237

236238
override def executeBlocked(): RDD[Block] = {
237239
val aggExprSer = Utils.serializeAggOp(groupingExpressions, aggExpressions, child.output)

src/main/scala/edu/berkeley/cs/rise/opaque/logical/rules.scala

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -57,53 +57,6 @@ object ConvertToOpaqueOperators extends Rule[LogicalPlan] {
5757
case l @ LogicalRelation(baseRelation: EncryptedScan, _, _, false) =>
5858
EncryptedBlockRDD(l.output, baseRelation.buildBlockedScan())
5959

60-
case p @ Project(projectList, child) if isEncrypted(child) =>
61-
EncryptedProject(projectList, child.asInstanceOf[OpaqueOperator])
62-
63-
// We don't support null values yet, so there's no point in checking whether the output of an
64-
// encrypted operator is null
65-
case p @ Filter(And(IsNotNull(_), IsNotNull(_)), child) if isEncrypted(child) =>
66-
child
67-
case p @ Filter(IsNotNull(_), child) if isEncrypted(child) =>
68-
child
69-
70-
case p @ Filter(condition, child) if isEncrypted(child) =>
71-
EncryptedFilter(condition, child.asInstanceOf[OpaqueOperator])
72-
73-
case p @ Sort(order, true, child) if isEncrypted(child) =>
74-
EncryptedSort(order, child.asInstanceOf[OpaqueOperator])
75-
76-
case p @ Join(left, right, joinType, condition, _) if isEncrypted(p) =>
77-
EncryptedJoin(
78-
left.asInstanceOf[OpaqueOperator], right.asInstanceOf[OpaqueOperator], joinType, condition)
79-
80-
case p @ Aggregate(groupingExprs, aggExprs, child) if isEncrypted(p) =>
81-
UndoCollapseProject.separateProjectAndAgg(p) match {
82-
case Some((projectExprs, aggExprs)) =>
83-
EncryptedProject(
84-
projectExprs,
85-
EncryptedAggregate(
86-
groupingExprs, aggExprs,
87-
EncryptedSort(
88-
groupingExprs.map(e => SortOrder(e, Ascending)),
89-
child.asInstanceOf[OpaqueOperator])))
90-
case None =>
91-
EncryptedAggregate(
92-
groupingExprs, aggExprs,
93-
EncryptedSort(
94-
groupingExprs.map(e => SortOrder(e, Ascending)),
95-
child.asInstanceOf[OpaqueOperator]))
96-
}
97-
98-
case p @ Union(Seq(left, right)) if isEncrypted(p) =>
99-
EncryptedUnion(left.asInstanceOf[OpaqueOperator], right.asInstanceOf[OpaqueOperator])
100-
101-
case p @ LocalLimit(limitExpr, child) if isEncrypted(p) =>
102-
EncryptedLocalLimit(limitExpr, child.asInstanceOf[OpaqueOperator])
103-
104-
case p @ GlobalLimit(limitExpr, child) if isEncrypted(p) =>
105-
EncryptedGlobalLimit(limitExpr, child.asInstanceOf[OpaqueOperator])
106-
10760
case InMemoryRelationMatcher(output, storageLevel, child) if isEncrypted(child) =>
10861
EncryptedBlockRDD(
10962
output,

src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,69 +19,120 @@ package edu.berkeley.cs.rise.opaque
1919

2020
import org.apache.spark.sql.Strategy
2121
import org.apache.spark.sql.catalyst.expressions.Alias
22+
import org.apache.spark.sql.catalyst.expressions.And
2223
import org.apache.spark.sql.catalyst.expressions.Ascending
2324
import org.apache.spark.sql.catalyst.expressions.Attribute
2425
import org.apache.spark.sql.catalyst.expressions.Expression
2526
import org.apache.spark.sql.catalyst.expressions.IntegerLiteral
27+
import org.apache.spark.sql.catalyst.expressions.IsNotNull
2628
import org.apache.spark.sql.catalyst.expressions.Literal
2729
import org.apache.spark.sql.catalyst.expressions.NamedExpression
2830
import org.apache.spark.sql.catalyst.expressions.SortOrder
31+
import org.apache.spark.sql.catalyst.expressions.aggregate._
2932
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
30-
import org.apache.spark.sql.catalyst.plans.logical.Join
31-
import org.apache.spark.sql.catalyst.plans.logical.JoinHint
32-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
33+
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
34+
import org.apache.spark.sql.catalyst.plans.logical._
3335
import org.apache.spark.sql.execution.SparkPlan
3436

3537
import edu.berkeley.cs.rise.opaque.execution._
3638
import edu.berkeley.cs.rise.opaque.logical._
3739

3840
object OpaqueOperators extends Strategy {
41+
42+
def isEncrypted(plan: LogicalPlan): Boolean = {
43+
plan.find {
44+
case _: OpaqueOperator => true
45+
case _ => false
46+
}.nonEmpty
47+
}
48+
49+
def isEncrypted(plan: SparkPlan): Boolean = {
50+
plan.find {
51+
case _: OpaqueOperatorExec => true
52+
case _ => false
53+
}.nonEmpty
54+
}
55+
3956
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
40-
case EncryptedProject(projectList, child) =>
57+
case Project(projectList, child) if isEncrypted(child) =>
4158
EncryptedProjectExec(projectList, planLater(child)) :: Nil
4259

43-
case EncryptedFilter(condition, child) =>
60+
// We don't support null values yet, so there's no point in checking whether the output of an
61+
// encrypted operator is null
62+
case p @ Filter(And(IsNotNull(_), IsNotNull(_)), child) if isEncrypted(child) =>
63+
planLater(child) :: Nil
64+
case p @ Filter(IsNotNull(_), child) if isEncrypted(child) =>
65+
planLater(child) :: Nil
66+
67+
case Filter(condition, child) if isEncrypted(child) =>
4468
EncryptedFilterExec(condition, planLater(child)) :: Nil
4569

46-
case EncryptedSort(order, child) =>
47-
EncryptedSortExec(order, planLater(child)) :: Nil
48-
49-
case EncryptedJoin(left, right, joinType, condition) =>
50-
Join(left, right, joinType, condition, JoinHint.NONE) match {
51-
case ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _) =>
52-
val (leftProjSchema, leftKeysProj, tag) = tagForJoin(leftKeys, left.output, true)
53-
val (rightProjSchema, rightKeysProj, _) = tagForJoin(rightKeys, right.output, false)
54-
val leftProj = EncryptedProjectExec(leftProjSchema, planLater(left))
55-
val rightProj = EncryptedProjectExec(rightProjSchema, planLater(right))
56-
val unioned = EncryptedUnionExec(leftProj, rightProj)
57-
val sorted = EncryptedSortExec(sortForJoin(leftKeysProj, tag, unioned.output), unioned)
58-
val joined = EncryptedSortMergeJoinExec(
59-
joinType,
60-
leftKeysProj,
61-
rightKeysProj,
62-
leftProjSchema.map(_.toAttribute),
63-
rightProjSchema.map(_.toAttribute),
64-
(leftProjSchema ++ rightProjSchema).map(_.toAttribute),
65-
sorted)
66-
val tagsDropped = EncryptedProjectExec(dropTags(left.output, right.output), joined)
67-
val filtered = condition match {
68-
case Some(condition) => EncryptedFilterExec(condition, tagsDropped)
69-
case None => tagsDropped
70-
}
71-
filtered :: Nil
72-
case _ => Nil
70+
case Sort(sortExprs, global, child) if isEncrypted(child) =>
71+
EncryptedSortExec(sortExprs, planLater(child)) :: Nil
72+
73+
case p @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if isEncrypted(p) =>
74+
val (leftProjSchema, leftKeysProj, tag) = tagForJoin(leftKeys, left.output, true)
75+
val (rightProjSchema, rightKeysProj, _) = tagForJoin(rightKeys, right.output, false)
76+
val leftProj = EncryptedProjectExec(leftProjSchema, planLater(left))
77+
val rightProj = EncryptedProjectExec(rightProjSchema, planLater(right))
78+
val unioned = EncryptedUnionExec(leftProj, rightProj)
79+
val sorted = EncryptedSortExec(sortForJoin(leftKeysProj, tag, unioned.output), unioned)
80+
val joined = EncryptedSortMergeJoinExec(
81+
joinType,
82+
leftKeysProj,
83+
rightKeysProj,
84+
leftProjSchema.map(_.toAttribute),
85+
rightProjSchema.map(_.toAttribute),
86+
(leftProjSchema ++ rightProjSchema).map(_.toAttribute),
87+
sorted)
88+
val tagsDropped = EncryptedProjectExec(dropTags(left.output, right.output), joined)
89+
val filtered = condition match {
90+
case Some(condition) => EncryptedFilterExec(condition, tagsDropped)
91+
case None => tagsDropped
7392
}
93+
filtered :: Nil
7494

75-
case a @ EncryptedAggregate(groupingExpressions, aggExpressions, child) =>
76-
EncryptedAggregateExec(groupingExpressions, aggExpressions, planLater(child)) :: Nil
95+
case a @ PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child)
96+
if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) =>
97+
val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression]).map(_.copy(mode = Complete))
7798

78-
case EncryptedUnion(left, right) =>
99+
EncryptedProjectExec(resultExpressions,
100+
EncryptedAggregateExec(
101+
groupingExpressions, aggregateExpressions,
102+
EncryptedSortExec(
103+
groupingExpressions.map(e => SortOrder(e, Ascending)), planLater(child)))) :: Nil
104+
105+
case p @ Union(Seq(left, right)) if isEncrypted(p) =>
79106
EncryptedUnionExec(planLater(left), planLater(right)) :: Nil
80107

81-
case EncryptedLocalLimit(IntegerLiteral(limit), child) =>
108+
case ReturnAnswer(rootPlan) => rootPlan match {
109+
case Limit(IntegerLiteral(limit), Sort(sortExprs, true, child)) if isEncrypted(child) =>
110+
EncryptedGlobalLimitExec(limit,
111+
EncryptedLocalLimitExec(limit,
112+
EncryptedSortExec(sortExprs, planLater(child)))) :: Nil
113+
114+
case Limit(IntegerLiteral(limit), Project(projectList, child)) if isEncrypted(child) =>
115+
EncryptedGlobalLimitExec(limit,
116+
EncryptedLocalLimitExec(limit,
117+
EncryptedProjectExec(projectList, planLater(child)))) :: Nil
118+
119+
case _ => Nil
120+
}
121+
122+
case Limit(IntegerLiteral(limit), Sort(sortExprs, true, child)) if isEncrypted(child) =>
123+
EncryptedGlobalLimitExec(limit,
124+
EncryptedLocalLimitExec(limit,
125+
EncryptedSortExec(sortExprs, planLater(child)))) :: Nil
126+
127+
case Limit(IntegerLiteral(limit), Project(projectList, child)) if isEncrypted(child) =>
128+
EncryptedGlobalLimitExec(limit,
129+
EncryptedLocalLimitExec(limit,
130+
EncryptedProjectExec(projectList, planLater(child)))) :: Nil
131+
132+
case LocalLimit(IntegerLiteral(limit), child) if isEncrypted(child) =>
82133
EncryptedLocalLimitExec(limit, planLater(child)) :: Nil
83134

84-
case EncryptedGlobalLimit(IntegerLiteral(limit), child) =>
135+
case GlobalLimit(IntegerLiteral(limit), child) if isEncrypted(child) =>
85136
EncryptedGlobalLimitExec(limit, planLater(child)) :: Nil
86137

87138
case Encrypt(child) =>

0 commit comments

Comments
 (0)