Skip to content

Distinct aggregation support #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7c89dd1
matching in strategies.scala
octaviansima Feb 10, 2021
59d228e
added test case for one distinct one non, reverted comment
octaviansima Feb 23, 2021
aa4c127
removed C++ level implementation of is_distinct
octaviansima Feb 24, 2021
4635081
PartialMerge in operators.scala
octaviansima Feb 24, 2021
c6a4750
stage 1: grouping with distinct expressions
octaviansima Feb 24, 2021
3afb949
stage 2: WIP
octaviansima Feb 24, 2021
db409a1
saving, sorting by group expressions ++ name distinct expressions worked
octaviansima Feb 24, 2021
cb2220f
stage 1 & 2 printing the expected results
octaviansima Feb 24, 2021
e84ffa8
removed extraneous call to sorted, #3 in place but not working
octaviansima Feb 24, 2021
400ffc9
stage 3 has the final, correct result: refactoring the Aggregate code…
octaviansima Feb 24, 2021
c4e84a2
refactor done, C++ still printing the correct values
octaviansima Feb 24, 2021
37c9d69
need to formalize None case in EncryptedAggregateExec.output, but sta…
octaviansima Feb 24, 2021
3e629aa
distinct and indistinct passes (git add -u)
octaviansima Feb 24, 2021
4abe290
general cleanup, None case looks nicer
octaviansima Feb 24, 2021
f4e6019
throw error with >1 distinct, add test case for global distinct
octaviansima Feb 24, 2021
39acc1a
no need for global aggregation case
octaviansima Feb 24, 2021
33b0d5a
single partition passes all aggregate tests, multiple partition doesn't
octaviansima Feb 24, 2021
1c02769
works with global sort first
octaviansima Feb 24, 2021
dbcb18c
works with non-global sort first
octaviansima Feb 24, 2021
3fa7577
cleanup
octaviansima Feb 24, 2021
84a79b6
cleanup tests
octaviansima Feb 24, 2021
39384d6
removed iostream, other nit
octaviansima Feb 24, 2021
6388aea
Merge branch 'master' into distinct-aggregate
octaviansima Feb 24, 2021
00dfe76
added test case for 13
octaviansima Feb 24, 2021
db978ae
None case in isPartial match done properly
octaviansima Feb 25, 2021
a6f6e37
added test cases for sumDistinct
octaviansima Feb 25, 2021
11b746d
case-specific namedDistinctExpressions working
octaviansima Feb 25, 2021
8ba80b6
distinct sum is done
octaviansima Feb 25, 2021
f9bb4c4
removed comments
octaviansima Feb 25, 2021
5aefa65
got rid of mode argument
octaviansima Feb 26, 2021
c07843b
tests include null values
octaviansima Feb 26, 2021
2fff873
partition followed by local sort instead of first global sort
octaviansima Feb 26, 2021
8531c02
Merge branch 'master' into distinct-aggregate
octaviansima Mar 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,10 @@ object Utils extends Logging {
}
(Seq(countUpdateExpr), Seq(count))
}
case PartialMerge => {
val countUpdateExpr = Add(count, c.inputAggBufferAttributes(0))
(Seq(countUpdateExpr), Seq(count))
}
case Final => {
val countUpdateExpr = Add(count, c.inputAggBufferAttributes(0))
(Seq(countUpdateExpr), Seq(count))
Expand All @@ -1423,7 +1427,7 @@ object Utils extends Logging {
val countUpdateExpr = Add(count, Literal(1L))
(Seq(countUpdateExpr), Seq(count))
}
case _ =>
case _ =>
}

tuix.AggregateExpr.createAggregateExpr(
Expand Down Expand Up @@ -1594,6 +1598,11 @@ object Utils extends Logging {
val sumUpdateExpr = If(IsNull(partialSum), sum, partialSum)
(Seq(sumUpdateExpr), Seq(sum))
}
case PartialMerge => {
val partialSum = Add(If(IsNull(sum), Literal.default(sumDataType), sum), s.inputAggBufferAttributes(0))
val sumUpdateExpr = If(IsNull(partialSum), sum, partialSum)
(Seq(sumUpdateExpr), Seq(sum))
}
case Final => {
val partialSum = Add(If(IsNull(sum), Literal.default(sumDataType), sum), s.inputAggBufferAttributes(0))
val sumUpdateExpr = If(IsNull(partialSum), sum, partialSum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.SQLContext
object TPCHBenchmark {

// Add query numbers here once they are supported
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 21, 22)
val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22)

def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = {
val sqlStr = tpch.getQuery(queryNumber)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,43 +233,34 @@ case class EncryptedFilterExec(condition: Expression, child: SparkPlan)

case class EncryptedAggregateExec(
groupingExpressions: Seq[NamedExpression],
aggExpressions: Seq[AggregateExpression],
mode: AggregateMode,
aggregateExpressions: Seq[AggregateExpression],
child: SparkPlan)
extends UnaryExecNode with OpaqueOperatorExec {

override def producedAttributes: AttributeSet =
AttributeSet(aggExpressions) -- AttributeSet(groupingExpressions)

override def output: Seq[Attribute] = mode match {
case Partial => groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.copy(mode = Partial)).flatMap(_.aggregateFunction.inputAggBufferAttributes)
case Final => groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute)
case Complete => groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute)
}
AttributeSet(aggregateExpressions) -- AttributeSet(groupingExpressions)

override def output: Seq[Attribute] = groupingExpressions.map(_.toAttribute) ++
aggregateExpressions.flatMap(expr => {
expr.mode match {
case Partial | PartialMerge =>
expr.aggregateFunction.inputAggBufferAttributes
case _ =>
Seq(expr.resultAttribute)
}
})

override def executeBlocked(): RDD[Block] = {

val (groupingExprs, aggExprs) = mode match {
case Partial => {
val partialAggExpressions = aggExpressions.map(_.copy(mode = Partial))
(groupingExpressions, partialAggExpressions)
}
case Final => {
val finalGroupingExpressions = groupingExpressions.map(_.toAttribute)
val finalAggExpressions = aggExpressions.map(_.copy(mode = Final))
(finalGroupingExpressions, finalAggExpressions)
}
case Complete => {
(groupingExpressions, aggExpressions.map(_.copy(mode = Complete)))
}
}
val aggExprSer = Utils.serializeAggOp(groupingExpressions, aggregateExpressions, child.output)
val isPartial = aggregateExpressions.map(expr => expr.mode)
.exists(mode => mode == Partial || mode == PartialMerge)

val aggExprSer = Utils.serializeAggOp(groupingExprs, aggExprs, child.output)

timeOperator(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedPartialAggregateExec") {
childRDD => childRDD.map { block =>
val (enclave, eid) = Utils.initEnclave()
Block(enclave.NonObliviousAggregate(eid, aggExprSer, block.bytes, (mode == Partial)))
Block(enclave.NonObliviousAggregate(eid, aggExprSer, block.bytes, isPartial))
}
}
}
Expand Down
103 changes: 84 additions & 19 deletions src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,25 +132,90 @@ object OpaqueOperators extends Strategy {
if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) =>

val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression])

if (groupingExpressions.size == 0) {
// Global aggregation
val partialAggregate = EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, planLater(child))
val partialOutput = partialAggregate.output
val (projSchema, tag) = tagForGlobalAggregate(partialOutput)

EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
EncryptedProjectExec(partialOutput,
EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true,
EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil
} else {
// Grouping aggregation
EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial,
EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil
val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct)

functionsWithDistinct.size match {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length?

can we support this query:select count(distinct age) ,sum(distinct age) from table_a?

case 0 => // No distinct aggregate operations
if (groupingExpressions.size == 0) {
// Global aggregation
val partialAggregate = EncryptedAggregateExec(groupingExpressions,
aggregateExpressions.map(_.copy(mode = Partial)), planLater(child))
val partialOutput = partialAggregate.output
val (projSchema, tag) = tagForGlobalAggregate(partialOutput)

EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions.map(_.copy(mode = Final)),
EncryptedProjectExec(partialOutput,
EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true,
EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil
} else {
// Grouping aggregation
EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions.map(_.copy(mode = Final)),
EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions.map(_.copy(mode = Partial)),
EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil
}
case size if size == 1 => // One distinct aggregate operation
// Because we are also grouping on the columns used in the distinct expressions,
// we do not need separate cases for global and grouping aggregation.

// We need to extract named expressions from the children of the distinct aggregate functions
// in order to group by those columns.
val namedDistinctExpressions = functionsWithDistinct.head.aggregateFunction.children.flatMap{ e =>
e match {
case ne: NamedExpression =>
Seq(ne)
case _ =>
e.children.filter(child => child.isInstanceOf[NamedExpression])
.map(child => child.asInstanceOf[NamedExpression])
}
}
val combinedGroupingExpressions = groupingExpressions ++ namedDistinctExpressions

// 1. Create an Aggregate operator for partial aggregations.
val partialAggregate = {
val sorted = EncryptedSortExec(combinedGroupingExpressions.map(e => SortOrder(e, Ascending)), false,
planLater(child))
EncryptedAggregateExec(combinedGroupingExpressions, functionsWithoutDistinct.map(_.copy(mode = Partial)), sorted)
}

// 2. Create an Aggregate operator for partial merge aggregations.
val partialMergeAggregate = {
// Partition based on the final grouping expressions.
val partitionOrder = groupingExpressions.map(e => SortOrder(e, Ascending))
val partitioned = EncryptedRangePartitionExec(partitionOrder, partialAggregate)

// Local sort on the combined grouping expressions.
val sortOrder = combinedGroupingExpressions.map(e => SortOrder(e, Ascending))
val sorted = EncryptedSortExec(sortOrder, false, partitioned)

EncryptedAggregateExec(combinedGroupingExpressions,
functionsWithoutDistinct.map(_.copy(mode = PartialMerge)), sorted)
}

// 3. Create an Aggregate operator for partial aggregation of distinct aggregate expressions.
val partialDistinctAggregate = {
// Indistinct functions operate on aggregation buffers since partial aggregation was already called,
// but distinct functions operate on the original input to the aggregation.
EncryptedAggregateExec(groupingExpressions,
functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) ++
functionsWithDistinct.map(_.copy(mode = Partial)), partialMergeAggregate)
}

// 4. Create an Aggregate operator for the final aggregation.
val finalAggregate = {
val sorted = EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)),
true, partialDistinctAggregate)
EncryptedAggregateExec(groupingExpressions,
(functionsWithoutDistinct ++ functionsWithDistinct).map(_.copy(mode = Final)), sorted)
}

EncryptedProjectExec(resultExpressions, finalAggregate) :: Nil

case _ => { // More than one distinct operations
throw new UnsupportedOperationException("Aggregate operations with more than one distinct expressions are not yet supported.")
}
}

case p @ Union(Seq(left, right)) if isEncrypted(p) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,30 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
.collect.sortBy { case Row(category: String, _) => category }
}

testAgainstSpark("aggregate count distinct and indistinct") { securityLevel =>
val data = (0 until 64).map{ i =>
if (i % 6 == 0)
(abc(i), null.asInstanceOf[Int], i % 8)
else
(abc(i), i % 4, i % 8)
}.toSeq
val words = makeDF(data, securityLevel, "category", "id", "price")
words.groupBy("category").agg(countDistinct("id").as("num_unique_ids"),
count("price").as("num_prices")).collect.toSet
}

testAgainstSpark("aggregate count distinct") { securityLevel =>
val data = (0 until 64).map{ i =>
if (i % 6 == 0)
(abc(i), null.asInstanceOf[Int])
else
(abc(i), i % 8)
}.toSeq
val words = makeDF(data, securityLevel, "category", "price")
words.groupBy("category").agg(countDistinct("price").as("num_unique_prices"))
.collect.sortBy { case Row(category: String, _) => category }
}

testAgainstSpark("aggregate first") { securityLevel =>
val data = for (i <- 0 until 256) yield (i, abc(i), 1)
val words = makeDF(data, securityLevel, "id", "category", "price")
Expand Down Expand Up @@ -526,6 +550,30 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
.collect.sortBy { case Row(word: String, _) => word }
}

testAgainstSpark("aggregate sum distinct and indistinct") { securityLevel =>
val data = (0 until 64).map{ i =>
if (i % 6 == 0)
(abc(i), null.asInstanceOf[Int], i % 8)
else
(abc(i), i % 4, i % 8)
}.toSeq
val words = makeDF(data, securityLevel, "category", "id", "price")
words.groupBy("category").agg(sumDistinct("id").as("sum_unique_ids"),
sum("price").as("sum_prices")).collect.toSet
}

testAgainstSpark("aggregate sum distinct") { securityLevel =>
val data = (0 until 64).map{ i =>
if (i % 6 == 0)
(abc(i), null.asInstanceOf[Int])
else
(abc(i), i % 8)
}.toSeq
val words = makeDF(data, securityLevel, "category", "price")
words.groupBy("category").agg(sumDistinct("price").as("sum_unique_prices"))
.collect.sortBy { case Row(category: String, _) => category }
}

testAgainstSpark("aggregate on multiple columns") { securityLevel =>
val data = for (i <- 0 until 256) yield (abc(i), 1, 1.0f)
val words = makeDF(data, securityLevel, "str", "x", "y")
Expand Down Expand Up @@ -557,6 +605,12 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
words.agg(sum("count").as("totalCount")).collect
}

testAgainstSpark("global aggregate count distinct") { securityLevel =>
val data = for (i <- 0 until 256) yield (i, abc(i), i % 64)
val words = makeDF(data, securityLevel, "id", "word", "price")
words.agg(countDistinct("price").as("num_unique_prices")).collect
}

testAgainstSpark("global aggregate with 0 rows") { securityLevel =>
val data = for (i <- 0 until 256) yield (i, abc(i), 1)
val words = makeDF(data, securityLevel, "id", "word", "count")
Expand Down