|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +//! Analyzed rule to replace TableScan references |
| 19 | +//! such as DataFrames and Views and inlines the LogicalPlan. |
| 20 | +
|
| 21 | +use std::cmp::Ordering; |
| 22 | +use std::collections::HashMap; |
| 23 | +use std::sync::Arc; |
| 24 | + |
| 25 | +use crate::analyzer::AnalyzerRule; |
| 26 | + |
| 27 | +use arrow::datatypes::DataType; |
| 28 | +use datafusion_common::config::ConfigOptions; |
| 29 | +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; |
| 30 | +use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result, ScalarValue}; |
| 31 | +use datafusion_expr::expr::{AggregateFunction, Alias}; |
| 32 | +use datafusion_expr::logical_plan::LogicalPlan; |
| 33 | +use datafusion_expr::utils::grouping_set_to_exprlist; |
| 34 | +use datafusion_expr::{ |
| 35 | + bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, |
| 36 | + Expr, Projection, |
| 37 | +}; |
| 38 | +use itertools::Itertools; |
| 39 | + |
| 40 | +/// Replaces grouping aggregation function with value derived from internal grouping id |
| 41 | +#[derive(Default, Debug)] |
| 42 | +pub struct ResolveGroupingFunction; |
| 43 | + |
| 44 | +impl ResolveGroupingFunction { |
| 45 | + pub fn new() -> Self { |
| 46 | + Self {} |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +impl AnalyzerRule for ResolveGroupingFunction { |
| 51 | + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> { |
| 52 | + plan.transform_up(analyze_internal).data() |
| 53 | + } |
| 54 | + |
| 55 | + fn name(&self) -> &str { |
| 56 | + "resolve_grouping_function" |
| 57 | + } |
| 58 | +} |
| 59 | + |
| 60 | +fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, usize>> { |
| 61 | + Ok(grouping_set_to_exprlist(group_expr)? |
| 62 | + .into_iter() |
| 63 | + .rev() |
| 64 | + .enumerate() |
| 65 | + .map(|(idx, v)| (v, idx)) |
| 66 | + .collect::<HashMap<_, _>>()) |
| 67 | +} |
| 68 | + |
| 69 | +fn replace_grouping_exprs( |
| 70 | + input: Arc<LogicalPlan>, |
| 71 | + schema: DFSchemaRef, |
| 72 | + group_expr: Vec<Expr>, |
| 73 | + aggr_expr: Vec<Expr>, |
| 74 | +) -> Result<LogicalPlan> { |
| 75 | + // Create HashMap from Expr to index in the grouping_id bitmap |
| 76 | + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); |
| 77 | + let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?; |
| 78 | + let columns = schema.columns(); |
| 79 | + let mut new_agg_expr = Vec::new(); |
| 80 | + let mut projection_exprs = Vec::new(); |
| 81 | + let grouping_id_len = if is_grouping_set { 1 } else { 0 }; |
| 82 | + let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len; |
| 83 | + projection_exprs.extend( |
| 84 | + columns |
| 85 | + .iter() |
| 86 | + .take(group_expr_len) |
| 87 | + .map(|column| Expr::Column(column.clone())), |
| 88 | + ); |
| 89 | + for (expr, column) in aggr_expr |
| 90 | + .into_iter() |
| 91 | + .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) |
| 92 | + { |
| 93 | + match expr { |
| 94 | + Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { |
| 95 | + let grouping_expr = grouping_function_on_id( |
| 96 | + function, |
| 97 | + &group_expr_to_bitmap_index, |
| 98 | + is_grouping_set, |
| 99 | + )?; |
| 100 | + projection_exprs.push(Expr::Alias(Alias::new( |
| 101 | + grouping_expr, |
| 102 | + column.relation, |
| 103 | + column.name, |
| 104 | + ))); |
| 105 | + } |
| 106 | + _ => { |
| 107 | + projection_exprs.push(Expr::Column(column)); |
| 108 | + new_agg_expr.push(expr); |
| 109 | + } |
| 110 | + } |
| 111 | + } |
| 112 | + // Recreate aggregate without grouping functions |
| 113 | + let new_aggregate = |
| 114 | + LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); |
| 115 | + // Create projection with grouping functions calculations |
| 116 | + let projection = LogicalPlan::Projection(Projection::try_new( |
| 117 | + projection_exprs, |
| 118 | + new_aggregate.into(), |
| 119 | + )?); |
| 120 | + Ok(projection) |
| 121 | +} |
| 122 | + |
| 123 | +fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> { |
| 124 | + // rewrite any subqueries in the plan first |
| 125 | + let transformed_plan = |
| 126 | + plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?; |
| 127 | + |
| 128 | + let transformed_plan = transformed_plan.transform_data(|plan| match plan { |
| 129 | + LogicalPlan::Aggregate(Aggregate { |
| 130 | + input, |
| 131 | + group_expr, |
| 132 | + aggr_expr, |
| 133 | + schema, |
| 134 | + .. |
| 135 | + }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( |
| 136 | + replace_grouping_exprs(input, schema, group_expr, aggr_expr)?, |
| 137 | + )), |
| 138 | + _ => Ok(Transformed::no(plan)), |
| 139 | + })?; |
| 140 | + |
| 141 | + Ok(transformed_plan) |
| 142 | +} |
| 143 | + |
| 144 | +fn is_grouping_function(expr: &Expr) -> bool { |
| 145 | + // TODO: Do something better than name here should grouping be a built |
| 146 | + // in expression? |
| 147 | + matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") |
| 148 | +} |
| 149 | + |
| 150 | +fn contains_grouping_function(exprs: &[Expr]) -> bool { |
| 151 | + exprs.iter().any(is_grouping_function) |
| 152 | +} |
| 153 | + |
| 154 | +fn validate_args( |
| 155 | + function: &AggregateFunction, |
| 156 | + group_by_expr: &HashMap<&Expr, usize>, |
| 157 | +) -> Result<()> { |
| 158 | + let expr_not_in_group_by = function |
| 159 | + .args |
| 160 | + .iter() |
| 161 | + .find(|expr| !group_by_expr.contains_key(expr)); |
| 162 | + if let Some(expr) = expr_not_in_group_by { |
| 163 | + Err(DataFusionError::Plan(format!( |
| 164 | + "Argument {} to grouping function is not in grouping columns {}", |
| 165 | + expr, |
| 166 | + group_by_expr.keys().map(|e| e.to_string()).join(", ") |
| 167 | + ))) |
| 168 | + } else { |
| 169 | + Ok(()) |
| 170 | + } |
| 171 | +} |
| 172 | + |
| 173 | +fn grouping_function_on_id( |
| 174 | + function: &AggregateFunction, |
| 175 | + group_by_expr: &HashMap<&Expr, usize>, |
| 176 | + is_grouping_set: bool, |
| 177 | +) -> Result<Expr> { |
| 178 | + validate_args(function, group_by_expr)?; |
| 179 | + let args = &function.args; |
| 180 | + |
| 181 | + // Postgres allows grouping function for group by without grouping sets, the result is then |
| 182 | + // always 0 |
| 183 | + if !is_grouping_set { |
| 184 | + return Ok(Expr::Literal(ScalarValue::from(0u32))); |
| 185 | + } |
| 186 | + |
| 187 | + let group_by_expr_count = group_by_expr.len(); |
| 188 | + let literal = |value: usize| { |
| 189 | + if group_by_expr_count < 8 { |
| 190 | + Expr::Literal(ScalarValue::from(value as u8)) |
| 191 | + } else if group_by_expr_count < 16 { |
| 192 | + Expr::Literal(ScalarValue::from(value as u16)) |
| 193 | + } else if group_by_expr_count < 32 { |
| 194 | + Expr::Literal(ScalarValue::from(value as u32)) |
| 195 | + } else { |
| 196 | + Expr::Literal(ScalarValue::from(value as u64)) |
| 197 | + } |
| 198 | + }; |
| 199 | + |
| 200 | + let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); |
| 201 | + // The grouping call is exactly our internal grouping id |
| 202 | + if args.len() == group_by_expr_count |
| 203 | + && args |
| 204 | + .iter() |
| 205 | + .rev() |
| 206 | + .enumerate() |
| 207 | + .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) |
| 208 | + { |
| 209 | + return Ok(cast(grouping_id_column, DataType::UInt32)); |
| 210 | + } |
| 211 | + |
| 212 | + args.iter() |
| 213 | + .rev() |
| 214 | + .enumerate() |
| 215 | + .map(|(arg_idx, expr)| { |
| 216 | + group_by_expr.get(expr).map(|group_by_idx| { |
| 217 | + let group_by_bit = |
| 218 | + bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); |
| 219 | + match group_by_idx.cmp(&arg_idx) { |
| 220 | + Ordering::Less => { |
| 221 | + bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) |
| 222 | + } |
| 223 | + Ordering::Greater => { |
| 224 | + bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) |
| 225 | + } |
| 226 | + Ordering::Equal => group_by_bit, |
| 227 | + } |
| 228 | + }) |
| 229 | + }) |
| 230 | + .collect::<Option<Vec<_>>>() |
| 231 | + .and_then(|bit_exprs| { |
| 232 | + bit_exprs |
| 233 | + .into_iter() |
| 234 | + .reduce(bitwise_or) |
| 235 | + .map(|expr| cast(expr, DataType::UInt32)) |
| 236 | + }) |
| 237 | + .ok_or_else(|| { |
| 238 | + DataFusionError::Internal( |
| 239 | + "Grouping sets should contains at least one element".to_string(), |
| 240 | + ) |
| 241 | + }) |
| 242 | +} |
0 commit comments