Skip to content

Commit a18c169

Browse files
clockflycloud-fan
authored andcommitted
[SPARK-16283][SQL] Implements percentile_approx aggregation function which supports partial aggregation.
## What changes were proposed in this pull request? This PR implements aggregation function `percentile_approx`. Function `percentile_approx` returns the approximate percentile(s) of a column at the given percentage(s). A percentile is a watermark value below which a given percentage of the column values fall. For example, the percentile of column `col` at percentage 50% is the median value of column `col`. ### Syntax: ``` # Returns percentile at a given percentage value. The approximation error can be reduced by increasing parameter accuracy, at the cost of memory. percentile_approx(col, percentage [, accuracy]) # Returns percentile value array at given percentage value array percentile_approx(col, array(percentage1 [, percentage2]...) [, accuracy]) ``` ### Features: 1. This function supports partial aggregation. 2. The memory consumption is bounded. The larger `accuracy` parameter we choose, we smaller error we get. The default accuracy value is 10000, to match with Hive default setting. Choose a smaller value for smaller memory footprint. 3. This function supports window function aggregation. ### Example usages: ``` ## Returns the 25th percentile value, with default accuracy SELECT percentile_approx(col, 0.25) FROM table ## Returns an array of percentile value (25th, 50th, 75th), with default accuracy SELECT percentile_approx(col, array(0.25, 0.5, 0.75)) FROM table ## Returns 25th percentile value, with custom accuracy value 100, larger accuracy parameter yields smaller approximation error SELECT percentile_approx(col, 0.25, 100) FROM table ## Returns the 25th, and 50th percentile values, with custom accuracy value 100 SELECT percentile_approx(col, array(0.25, 0.5), 100) FROM table ``` ### NOTE: 1. The `percentile_approx` implementation is different from Hive, so the result returned on same query maybe slightly different with Hive. This implementation uses `QuantileSummaries` as the underlying probabilistic data structure, and mainly follows paper `Space-efficient Online Computation of Quantile Summaries` by Greenwald, Michael and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)` 2. The current implementation of `QuantileSummaries` doesn't support automatic compression. This PR has a rule to do compression automatically at the caller side, but it may not be optimal. ## How was this patch tested? Unit test, and Sql query test. ## Acknowledgement 1. This PR's work in based on lw-lin's PR #14298, with improvements like supporting partial aggregation, fixing out of memory issue. Author: Sean Zhong <[email protected]> Closes #14868 from clockfly/appro_percentile_try_2.
1 parent 536fa91 commit a18c169

File tree

6 files changed

+893
-2
lines changed

6 files changed

+893
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ object FunctionRegistry {
250250
expression[Average]("mean"),
251251
expression[Min]("min"),
252252
expression[Skewness]("skewness"),
253+
expression[ApproximatePercentile]("percentile_approx"),
253254
expression[StddevSamp]("std"),
254255
expression[StddevSamp]("stddev"),
255256
expression[StddevPop]("stddev_pop"),
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import java.nio.ByteBuffer
21+
22+
import com.google.common.primitives.{Doubles, Ints, Longs}
23+
24+
import org.apache.spark.sql.AnalysisException
25+
import org.apache.spark.sql.catalyst.{InternalRow}
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
27+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
28+
import org.apache.spark.sql.catalyst.expressions._
29+
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest}
30+
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
31+
import org.apache.spark.sql.catalyst.util.QuantileSummaries
32+
import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
33+
import org.apache.spark.sql.types._
34+
35+
/**
36+
* The ApproximatePercentile function returns the approximate percentile(s) of a column at the given
37+
* percentage(s). A percentile is a watermark value below which a given percentage of the column
38+
* values fall. For example, the percentile of column `col` at percentage 50% is the median of
39+
* column `col`.
40+
*
41+
* This function supports partial aggregation.
42+
*
43+
* @param child child expression that can produce column value with `child.eval(inputRow)`
44+
* @param percentageExpression Expression that represents a single percentage value or
45+
* an array of percentage values. Each percentage value must be between
46+
* 0.0 and 1.0.
47+
* @param accuracyExpression Integer literal expression of approximation accuracy. Higher value
48+
* yields better accuracy, the default value is
49+
* DEFAULT_PERCENTILE_ACCURACY.
50+
*/
51+
@ExpressionDescription(
52+
usage =
53+
"""
54+
_FUNC_(col, percentage [, accuracy]) - Returns the approximate percentile value of numeric
55+
column `col` at the given percentage. The value of percentage must be between 0.0
56+
and 1.0. The `accuracy` parameter (default: 10000) is a positive integer literal which
57+
controls approximation accuracy at the cost of memory. Higher value of `accuracy` yields
58+
better accuracy, `1.0/accuracy` is the relative error of the approximation.
59+
60+
_FUNC_(col, array(percentage1 [, percentage2]...) [, accuracy]) - Returns the approximate
61+
percentile array of column `col` at the given percentage array. Each value of the
62+
percentage array must be between 0.0 and 1.0. The `accuracy` parameter (default: 10000) is
63+
a positive integer literal which controls approximation accuracy at the cost of memory.
64+
Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is the relative error of
65+
the approximation.
66+
""")
67+
case class ApproximatePercentile(
68+
child: Expression,
69+
percentageExpression: Expression,
70+
accuracyExpression: Expression,
71+
override val mutableAggBufferOffset: Int,
72+
override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] {
73+
74+
def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = {
75+
this(child, percentageExpression, accuracyExpression, 0, 0)
76+
}
77+
78+
def this(child: Expression, percentageExpression: Expression) = {
79+
this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
80+
}
81+
82+
// Mark as lazy so that accuracyExpression is not evaluated during tree transformation.
83+
private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]
84+
85+
override def inputTypes: Seq[AbstractDataType] = {
86+
Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType)
87+
}
88+
89+
// Mark as lazy so that percentageExpression is not evaluated during tree transformation.
90+
private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = {
91+
(percentageExpression.dataType, percentageExpression.eval()) match {
92+
// Rule ImplicitTypeCasts can cast other numeric types to double
93+
case (_, num: Double) => (false, Array(num))
94+
case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
95+
val numericArray = arrayData.toObjectArray(baseType)
96+
(true, numericArray.map { x =>
97+
baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])
98+
})
99+
case other =>
100+
throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage")
101+
}
102+
}
103+
104+
override def checkInputDataTypes(): TypeCheckResult = {
105+
val defaultCheck = super.checkInputDataTypes()
106+
if (defaultCheck.isFailure) {
107+
defaultCheck
108+
} else if (!percentageExpression.foldable || !accuracyExpression.foldable) {
109+
TypeCheckFailure(s"The accuracy or percentage provided must be a constant literal")
110+
} else if (accuracy <= 0) {
111+
TypeCheckFailure(
112+
s"The accuracy provided must be a positive integer literal (current value = $accuracy)")
113+
} else if (percentages.exists(percentage => percentage < 0.0D || percentage > 1.0D)) {
114+
TypeCheckFailure(
115+
s"All percentage values must be between 0.0 and 1.0 " +
116+
s"(current = ${percentages.mkString(", ")})")
117+
} else {
118+
TypeCheckSuccess
119+
}
120+
}
121+
122+
override def createAggregationBuffer(): PercentileDigest = {
123+
val relativeError = 1.0D / accuracy
124+
new PercentileDigest(relativeError)
125+
}
126+
127+
override def update(buffer: PercentileDigest, inputRow: InternalRow): Unit = {
128+
val value = child.eval(inputRow)
129+
// Ignore empty rows, for example: percentile_approx(null)
130+
if (value != null) {
131+
buffer.add(value.asInstanceOf[Double])
132+
}
133+
}
134+
135+
override def merge(buffer: PercentileDigest, other: PercentileDigest): Unit = {
136+
buffer.merge(other)
137+
}
138+
139+
override def eval(buffer: PercentileDigest): Any = {
140+
val result = buffer.getPercentiles(percentages)
141+
if (result.length == 0) {
142+
null
143+
} else if (returnPercentileArray) {
144+
new GenericArrayData(result)
145+
} else {
146+
result(0)
147+
}
148+
}
149+
150+
override def withNewMutableAggBufferOffset(newOffset: Int): ApproximatePercentile =
151+
copy(mutableAggBufferOffset = newOffset)
152+
153+
override def withNewInputAggBufferOffset(newOffset: Int): ApproximatePercentile =
154+
copy(inputAggBufferOffset = newOffset)
155+
156+
override def children: Seq[Expression] = Seq(child, percentageExpression, accuracyExpression)
157+
158+
// Returns null for empty inputs
159+
override def nullable: Boolean = true
160+
161+
override def dataType: DataType = {
162+
if (returnPercentileArray) ArrayType(DoubleType) else DoubleType
163+
}
164+
165+
override def prettyName: String = "percentile_approx"
166+
167+
override def serialize(obj: PercentileDigest): Array[Byte] = {
168+
ApproximatePercentile.serializer.serialize(obj)
169+
}
170+
171+
override def deserialize(bytes: Array[Byte]): PercentileDigest = {
172+
ApproximatePercentile.serializer.deserialize(bytes)
173+
}
174+
}
175+
176+
object ApproximatePercentile {
177+
178+
// Default accuracy of Percentile approximation. Larger value means better accuracy.
179+
// The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY
180+
val DEFAULT_PERCENTILE_ACCURACY: Int = 10000
181+
182+
/**
183+
* PercentileDigest is a probabilistic data structure used for approximating percentiles
184+
* with limited memory. PercentileDigest is backed by [[QuantileSummaries]].
185+
*
186+
* @param summaries underlying probabilistic data structure [[QuantileSummaries]].
187+
* @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the
188+
* underlying quantileSummaries is compressed.
189+
*/
190+
class PercentileDigest(
191+
private var summaries: QuantileSummaries,
192+
private var isCompressed: Boolean) {
193+
194+
// Trigger compression if the QuantileSummaries's buffer length exceeds
195+
// compressThresHoldBufferLength. The buffer length can be get by
196+
// quantileSummaries.sampled.length
197+
private[this] final val compressThresHoldBufferLength: Int = {
198+
// Max buffer length after compression.
199+
val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2
200+
// A safe upper bound for buffer length before compression
201+
maxBufferLengthAfterCompression * 2
202+
}
203+
204+
def this(relativeError: Double) = {
205+
this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true)
206+
}
207+
208+
/** Returns compressed object of [[QuantileSummaries]] */
209+
def quantileSummaries: QuantileSummaries = {
210+
if (!isCompressed) compress()
211+
summaries
212+
}
213+
214+
/** Insert an observation value into the PercentileDigest data structure. */
215+
def add(value: Double): Unit = {
216+
summaries = summaries.insert(value)
217+
// The result of QuantileSummaries.insert is un-compressed
218+
isCompressed = false
219+
220+
// Currently, QuantileSummaries ignores the construction parameter compressThresHold,
221+
// which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here
222+
// to make sure QuantileSummaries doesn't occupy infinite memory.
223+
// TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold
224+
if (summaries.sampled.length >= compressThresHoldBufferLength) compress()
225+
}
226+
227+
/** In-place merges in another PercentileDigest. */
228+
def merge(other: PercentileDigest): Unit = {
229+
if (!isCompressed) compress()
230+
summaries = summaries.merge(other.quantileSummaries)
231+
}
232+
233+
/**
234+
* Returns the approximate percentiles of all observation values at the given percentages.
235+
* A percentile is a watermark value below which a given percentage of observation values fall.
236+
* For example, the following code returns the 25th, median, and 75th percentiles of
237+
* all observation values:
238+
*
239+
* {{{
240+
* val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
241+
* }}}
242+
*/
243+
def getPercentiles(percentages: Array[Double]): Array[Double] = {
244+
if (!isCompressed) compress()
245+
if (summaries.count == 0 || percentages.length == 0) {
246+
Array.empty[Double]
247+
} else {
248+
val result = new Array[Double](percentages.length)
249+
var i = 0
250+
while (i < percentages.length) {
251+
result(i) = summaries.query(percentages(i))
252+
i += 1
253+
}
254+
result
255+
}
256+
}
257+
258+
private final def compress(): Unit = {
259+
summaries = summaries.compress()
260+
isCompressed = true
261+
}
262+
}
263+
264+
/**
265+
* Serializer for class [[PercentileDigest]]
266+
*
267+
* This class is thread safe.
268+
*/
269+
class PercentileDigestSerializer {
270+
271+
private final def length(summaries: QuantileSummaries): Int = {
272+
// summaries.compressThreshold, summary.relativeError, summary.count
273+
Ints.BYTES + Doubles.BYTES + Longs.BYTES +
274+
// length of summary.sampled
275+
Ints.BYTES +
276+
// summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)]
277+
summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES)
278+
}
279+
280+
final def serialize(obj: PercentileDigest): Array[Byte] = {
281+
val summary = obj.quantileSummaries
282+
val buffer = ByteBuffer.wrap(new Array(length(summary)))
283+
buffer.putInt(summary.compressThreshold)
284+
buffer.putDouble(summary.relativeError)
285+
buffer.putLong(summary.count)
286+
buffer.putInt(summary.sampled.length)
287+
288+
var i = 0
289+
while (i < summary.sampled.length) {
290+
val stat = summary.sampled(i)
291+
buffer.putDouble(stat.value)
292+
buffer.putInt(stat.g)
293+
buffer.putInt(stat.delta)
294+
i += 1
295+
}
296+
buffer.array()
297+
}
298+
299+
final def deserialize(bytes: Array[Byte]): PercentileDigest = {
300+
val buffer = ByteBuffer.wrap(bytes)
301+
val compressThreshold = buffer.getInt()
302+
val relativeError = buffer.getDouble()
303+
val count = buffer.getLong()
304+
val sampledLength = buffer.getInt()
305+
val sampled = new Array[Stats](sampledLength)
306+
307+
var i = 0
308+
while (i < sampledLength) {
309+
val value = buffer.getDouble()
310+
val g = buffer.getInt()
311+
val delta = buffer.getInt()
312+
sampled(i) = Stats(value, g, delta)
313+
i += 1
314+
}
315+
val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count)
316+
new PercentileDigest(summary, isCompressed = true)
317+
}
318+
}
319+
320+
val serializer: PercentileDigestSerializer = new PercentileDigestSerializer
321+
}

0 commit comments

Comments
 (0)