Skip to content

Commit 29e3312

Browse files
authored
Fix NULL handling for aggregation (#130)
* Modify COUNT and SUM to correctly handle NULL values * Change average to support NULL values * Fix
1 parent 366e92c commit 29e3312

File tree

2 files changed

+85
-17
lines changed

2 files changed

+85
-17
lines changed

src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,30 +1196,47 @@ object Utils extends Logging {
11961196
case avg @ Average(child) =>
11971197
val sum = avg.aggBufferAttributes(0)
11981198
val count = avg.aggBufferAttributes(1)
1199+
val dataType = child.dataType
1200+
1201+
val sumInitValue = child.nullable match {
1202+
case true => Literal.create(null, dataType)
1203+
case false => Cast(Literal(0), dataType)
1204+
}
1205+
val sumExpr = child.nullable match {
1206+
case true => If(IsNull(child), sum, If(IsNull(sum), Cast(child, dataType), Add(sum, Cast(child, dataType))))
1207+
case false => Add(sum, Cast(child, dataType))
1208+
}
1209+
val countExpr = If(IsNull(child), count, Add(count, Literal(1L)))
11991210

1200-
// TODO: support aggregating null values
12011211
// TODO: support DecimalType to match Spark SQL behavior
12021212
tuix.AggregateExpr.createAggregateExpr(
12031213
builder,
12041214
tuix.AggregateExpr.createInitialValuesVector(
12051215
builder,
12061216
Array(
1207-
/* sum = */ flatbuffersSerializeExpression(builder, Literal(0.0), input),
1217+
/* sum = */ flatbuffersSerializeExpression(builder, sumInitValue, input),
12081218
/* count = */ flatbuffersSerializeExpression(builder, Literal(0L), input))),
12091219
tuix.AggregateExpr.createUpdateExprsVector(
12101220
builder,
12111221
Array(
12121222
/* sum = */ flatbuffersSerializeExpression(
1213-
builder, Add(sum, Cast(child, DoubleType)), concatSchema),
1223+
builder, sumExpr, concatSchema),
12141224
/* count = */ flatbuffersSerializeExpression(
1215-
builder, Add(count, Literal(1L)), concatSchema))),
1225+
builder, countExpr, concatSchema))),
12161226
flatbuffersSerializeExpression(
1217-
builder, Divide(sum, Cast(count, DoubleType)), aggSchema))
1227+
builder, Divide(Cast(sum, DoubleType), Cast(count, DoubleType)), aggSchema))
12181228

12191229
case c @ Count(children) =>
12201230
val count = c.aggBufferAttributes(0)
1231+
// COUNT(*) should count NULL values
1232+
// COUNT(expr) should return the number or rows for which the supplied expressions are non-NULL
1233+
1234+
val nullableChildren = children.filter(_.nullable)
1235+
val countExpr = nullableChildren.isEmpty match {
1236+
case true => Add(count, Literal(1L))
1237+
case false => If(nullableChildren.map(IsNull).reduce(Or), count, Add(count, Literal(1L)))
1238+
}
12211239

1222-
// TODO: support skipping null values
12231240
tuix.AggregateExpr.createAggregateExpr(
12241241
builder,
12251242
tuix.AggregateExpr.createInitialValuesVector(
@@ -1230,7 +1247,7 @@ object Utils extends Logging {
12301247
builder,
12311248
Array(
12321249
/* count = */ flatbuffersSerializeExpression(
1233-
builder, Add(count, Literal(1L)), concatSchema))),
1250+
builder, countExpr, concatSchema))),
12341251
flatbuffersSerializeExpression(
12351252
builder, count, aggSchema))
12361253

@@ -1316,22 +1333,31 @@ object Utils extends Logging {
13161333

13171334
case s @ Sum(child) =>
13181335
val sum = s.aggBufferAttributes(0)
1319-
13201336
val sumDataType = s.dataType
1337+
// If any value is not NULL, return a non-NULL value
1338+
// If all values are NULL, return NULL
1339+
1340+
val initValue = child.nullable match {
1341+
case true => Literal.create(null, sumDataType)
1342+
case false => Cast(Literal(0), sumDataType)
1343+
}
1344+
val sumExpr = child.nullable match {
1345+
case true => If(IsNull(child), sum, If(IsNull(sum), Cast(child, sumDataType), Add(sum, Cast(child, sumDataType))))
1346+
case false => Add(sum, Cast(child, sumDataType))
1347+
}
13211348

1322-
// TODO: support aggregating null values
13231349
tuix.AggregateExpr.createAggregateExpr(
13241350
builder,
13251351
tuix.AggregateExpr.createInitialValuesVector(
13261352
builder,
13271353
Array(
13281354
/* sum = */ flatbuffersSerializeExpression(
1329-
builder, Cast(Literal(0), sumDataType), input))),
1355+
builder, initValue, input))),
13301356
tuix.AggregateExpr.createUpdateExprsVector(
13311357
builder,
13321358
Array(
13331359
/* sum = */ flatbuffersSerializeExpression(
1334-
builder, Add(sum, Cast(child, sumDataType)), concatSchema))),
1360+
builder, sumExpr, concatSchema))),
13351361
flatbuffersSerializeExpression(
13361362
builder, sum, aggSchema))
13371363

src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,30 @@ trait OpaqueOperatorTests extends FunSuite with BeforeAndAfterAll { self =>
122122
}
123123
}
124124

125+
/** Modified from https://stackoverflow.com/questions/33193958/change-nullable-property-of-column-in-spark-dataframe
126+
* and https://stackoverflow.com/questions/32585670/what-is-the-best-way-to-define-custom-methods-on-a-dataframe
127+
* Set nullable property of column.
128+
* @param cn is the column name to change
129+
* @param nullable is the flag to set, such that the column is either nullable or not
130+
*/
131+
object ExtraDFOperations {
132+
implicit class AlternateDF(df : DataFrame) {
133+
def setNullableStateOfColumn(cn: String, nullable: Boolean) : DataFrame = {
134+
// get schema
135+
val schema = df.schema
136+
// modify [[StructField] with name `cn`
137+
val newSchema = StructType(schema.map {
138+
case StructField( c, t, _, m) if c.equals(cn) => StructField( c, t, nullable = nullable, m)
139+
case y: StructField => y
140+
})
141+
// apply new schema
142+
df.sqlContext.createDataFrame( df.rdd, newSchema )
143+
}
144+
}
145+
}
146+
147+
import ExtraDFOperations._
148+
125149
testAgainstSpark("Interval SQL") { securityLevel =>
126150
val data = Seq(Tuple2(1, new java.sql.Date(new java.util.Date().getTime())))
127151
val df = makeDF(data, securityLevel, "index", "time")
@@ -375,17 +399,28 @@ trait OpaqueOperatorTests extends FunSuite with BeforeAndAfterAll { self =>
375399
}
376400

377401
testAgainstSpark("aggregate average") { securityLevel =>
378-
val data = for (i <- 0 until 256) yield (i, abc(i), i.toDouble)
402+
val data = (0 until 256).map{ i =>
403+
if (i % 3 == 0 || (i + 1) % 6 == 0)
404+
(i, abc(i), None)
405+
else
406+
(i, abc(i), Some(i.toDouble))
407+
}.toSeq
379408
val words = makeDF(data, securityLevel, "id", "category", "price")
409+
words.setNullableStateOfColumn("price", true)
380410

381-
words.groupBy("category").agg(avg("price").as("avgPrice"))
382-
.collect.sortBy { case Row(category: String, _) => category }
411+
val result = words.groupBy("category").agg(avg("price").as("avgPrice"))
412+
result.collect.sortBy { case Row(category: String, _) => category }
383413
}
384414

385415
testAgainstSpark("aggregate count") { securityLevel =>
386-
val data = for (i <- 0 until 256) yield (i, abc(i), 1)
416+
val data = (0 until 256).map{ i =>
417+
if (i % 3 == 0 || (i + 1) % 6 == 0)
418+
(i, abc(i), None)
419+
else
420+
(i, abc(i), Some(i))
421+
}.toSeq
387422
val words = makeDF(data, securityLevel, "id", "category", "price")
388-
423+
words.setNullableStateOfColumn("price", true)
389424
words.groupBy("category").agg(count("category").as("itemsInCategory"))
390425
.collect.sortBy { case Row(category: String, _) => category }
391426
}
@@ -423,8 +458,15 @@ trait OpaqueOperatorTests extends FunSuite with BeforeAndAfterAll { self =>
423458
}
424459

425460
testAgainstSpark("aggregate sum") { securityLevel =>
426-
val data = for (i <- 0 until 256) yield (i, abc(i), 1)
461+
val data = (0 until 256).map{ i =>
462+
if (i % 3 == 0 || i % 4 == 0)
463+
(i, abc(i), None)
464+
else
465+
(i, abc(i), Some(i.toDouble))
466+
}.toSeq
467+
427468
val words = makeDF(data, securityLevel, "id", "word", "count")
469+
words.setNullableStateOfColumn("count", true)
428470

429471
words.groupBy("word").agg(sum("count").as("totalCount"))
430472
.collect.sortBy { case Row(word: String, _) => word }

0 commit comments

Comments
 (0)