@@ -1196,30 +1196,47 @@ object Utils extends Logging {
1196
1196
case avg @ Average (child) =>
1197
1197
val sum = avg.aggBufferAttributes(0 )
1198
1198
val count = avg.aggBufferAttributes(1 )
1199
+ val dataType = child.dataType
1200
+
1201
+ val sumInitValue = child.nullable match {
1202
+ case true => Literal .create(null , dataType)
1203
+ case false => Cast (Literal (0 ), dataType)
1204
+ }
1205
+ val sumExpr = child.nullable match {
1206
+ case true => If (IsNull (child), sum, If (IsNull (sum), Cast (child, dataType), Add (sum, Cast (child, dataType))))
1207
+ case false => Add (sum, Cast (child, dataType))
1208
+ }
1209
+ val countExpr = If (IsNull (child), count, Add (count, Literal (1L )))
1199
1210
1200
- // TODO: support aggregating null values
1201
1211
// TODO: support DecimalType to match Spark SQL behavior
1202
1212
tuix.AggregateExpr .createAggregateExpr(
1203
1213
builder,
1204
1214
tuix.AggregateExpr .createInitialValuesVector(
1205
1215
builder,
1206
1216
Array (
1207
- /* sum = */ flatbuffersSerializeExpression(builder, Literal ( 0.0 ) , input),
1217
+ /* sum = */ flatbuffersSerializeExpression(builder, sumInitValue , input),
1208
1218
/* count = */ flatbuffersSerializeExpression(builder, Literal (0L ), input))),
1209
1219
tuix.AggregateExpr .createUpdateExprsVector(
1210
1220
builder,
1211
1221
Array (
1212
1222
/* sum = */ flatbuffersSerializeExpression(
1213
- builder, Add (sum, Cast (child, DoubleType )) , concatSchema),
1223
+ builder, sumExpr , concatSchema),
1214
1224
/* count = */ flatbuffersSerializeExpression(
1215
- builder, Add (count, Literal ( 1L )) , concatSchema))),
1225
+ builder, countExpr , concatSchema))),
1216
1226
flatbuffersSerializeExpression(
1217
- builder, Divide (sum, Cast (count, DoubleType )), aggSchema))
1227
+ builder, Divide (Cast ( sum, DoubleType ) , Cast (count, DoubleType )), aggSchema))
1218
1228
1219
1229
case c @ Count (children) =>
1220
1230
val count = c.aggBufferAttributes(0 )
1231
+ // COUNT(*) should count NULL values
1232
+ // COUNT(expr) should return the number or rows for which the supplied expressions are non-NULL
1233
+
1234
+ val nullableChildren = children.filter(_.nullable)
1235
+ val countExpr = nullableChildren.isEmpty match {
1236
+ case true => Add (count, Literal (1L ))
1237
+ case false => If (nullableChildren.map(IsNull ).reduce(Or ), count, Add (count, Literal (1L )))
1238
+ }
1221
1239
1222
- // TODO: support skipping null values
1223
1240
tuix.AggregateExpr .createAggregateExpr(
1224
1241
builder,
1225
1242
tuix.AggregateExpr .createInitialValuesVector(
@@ -1230,7 +1247,7 @@ object Utils extends Logging {
1230
1247
builder,
1231
1248
Array (
1232
1249
/* count = */ flatbuffersSerializeExpression(
1233
- builder, Add (count, Literal ( 1L )) , concatSchema))),
1250
+ builder, countExpr , concatSchema))),
1234
1251
flatbuffersSerializeExpression(
1235
1252
builder, count, aggSchema))
1236
1253
@@ -1316,22 +1333,31 @@ object Utils extends Logging {
1316
1333
1317
1334
case s @ Sum (child) =>
1318
1335
val sum = s.aggBufferAttributes(0 )
1319
-
1320
1336
val sumDataType = s.dataType
1337
+ // If any value is not NULL, return a non-NULL value
1338
+ // If all values are NULL, return NULL
1339
+
1340
+ val initValue = child.nullable match {
1341
+ case true => Literal .create(null , sumDataType)
1342
+ case false => Cast (Literal (0 ), sumDataType)
1343
+ }
1344
+ val sumExpr = child.nullable match {
1345
+ case true => If (IsNull (child), sum, If (IsNull (sum), Cast (child, sumDataType), Add (sum, Cast (child, sumDataType))))
1346
+ case false => Add (sum, Cast (child, sumDataType))
1347
+ }
1321
1348
1322
- // TODO: support aggregating null values
1323
1349
tuix.AggregateExpr .createAggregateExpr(
1324
1350
builder,
1325
1351
tuix.AggregateExpr .createInitialValuesVector(
1326
1352
builder,
1327
1353
Array (
1328
1354
/* sum = */ flatbuffersSerializeExpression(
1329
- builder, Cast ( Literal ( 0 ), sumDataType) , input))),
1355
+ builder, initValue , input))),
1330
1356
tuix.AggregateExpr .createUpdateExprsVector(
1331
1357
builder,
1332
1358
Array (
1333
1359
/* sum = */ flatbuffersSerializeExpression(
1334
- builder, Add (sum, Cast (child, sumDataType)) , concatSchema))),
1360
+ builder, sumExpr , concatSchema))),
1335
1361
flatbuffersSerializeExpression(
1336
1362
builder, sum, aggSchema))
1337
1363
0 commit comments