Skip to content

Commit e36e183

Browse files
committed
fix: fix pivot extra columns on projection
1 parent ff62551 commit e36e183

File tree

10 files changed

+189
-38
lines changed

10 files changed

+189
-38
lines changed

src/query/ast/src/ast/query.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,8 @@ impl Display for TableReference {
11041104
pub struct TableAlias {
11051105
pub name: Identifier,
11061106
pub columns: Vec<Identifier>,
1107+
/// When true, keep the original database name on bound columns even after aliasing.
1108+
pub keep_database_name: bool,
11071109
}
11081110

11091111
impl Display for TableAlias {

src/query/ast/src/parser/query.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ pub fn table_alias(i: Input) -> IResult<TableAlias> {
662662
|(name, opt_columns)| TableAlias {
663663
name,
664664
columns: opt_columns.map(|(_, cols, _)| cols).unwrap_or_default(),
665+
keep_database_name: false,
665666
},
666667
)
667668
.parse(i)
@@ -673,6 +674,7 @@ pub fn table_alias_without_as(i: Input) -> IResult<TableAlias> {
673674
|(name, opt_columns)| TableAlias {
674675
name,
675676
columns: opt_columns.map(|(_, cols, _)| cols).unwrap_or_default(),
677+
keep_database_name: false,
676678
},
677679
)
678680
.parse(i)

src/query/ast/src/parser/statement.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5288,6 +5288,7 @@ pub fn table_reference_with_alias(i: Input) -> IResult<TableReference> {
52885288
alias: alias.map(|v| TableAlias {
52895289
name: v,
52905290
columns: vec![],
5291+
keep_database_name: false,
52915292
}),
52925293
temporal: None,
52935294
with_options: None,

src/query/sql/src/planner/binder/bind_context.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,9 @@ pub fn apply_alias_for_columns(
683683
name_resolution_ctx: &NameResolutionContext,
684684
) -> Result<()> {
685685
for column in columns.iter_mut() {
686-
column.database_name = None;
686+
if !alias.keep_database_name {
687+
column.database_name = None;
688+
}
687689
column.table_name = Some(normalize_identifier(&alias.name, name_resolution_ctx).name);
688690
}
689691

src/query/sql/src/planner/binder/bind_query/bind_select.rs

Lines changed: 108 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,19 @@ use databend_common_ast::ast::Expr::Array;
2323
use databend_common_ast::ast::FunctionCall;
2424
use databend_common_ast::ast::GroupBy;
2525
use databend_common_ast::ast::Identifier;
26+
use databend_common_ast::ast::Indirection;
2627
use databend_common_ast::ast::Join;
2728
use databend_common_ast::ast::JoinCondition;
2829
use databend_common_ast::ast::JoinOperator;
2930
use databend_common_ast::ast::Literal;
3031
use databend_common_ast::ast::OrderByExpr;
3132
use databend_common_ast::ast::Pivot;
3233
use databend_common_ast::ast::PivotValues;
34+
use databend_common_ast::ast::Query;
3335
use databend_common_ast::ast::SelectStmt;
3436
use databend_common_ast::ast::SelectTarget;
37+
use databend_common_ast::ast::SetExpr;
38+
use databend_common_ast::ast::TableAlias;
3539
use databend_common_ast::ast::TableReference;
3640
use databend_common_ast::ast::UnpivotName;
3741
use databend_common_ast::Span;
@@ -50,7 +54,6 @@ use crate::planner::binder::BindContext;
5054
use crate::planner::binder::Binder;
5155
use crate::planner::QueryExecutor;
5256
use crate::AsyncFunctionRewriter;
53-
use crate::ColumnBinding;
5457

5558
// A normalized IR for `SELECT` clause.
5659
#[derive(Debug, Default)]
@@ -84,6 +87,12 @@ impl Binder {
8487
.check_enterprise_enabled(self.ctx.get_license_key(), Feature::VirtualColumn)
8588
.is_ok();
8689

90+
let mut rewriter =
91+
SelectRewriter::new(self.name_resolution_ctx.unquoted_ident_case_sensitive)
92+
.with_subquery_executor(self.subquery_executor.clone());
93+
let new_stmt = rewriter.rewrite(stmt)?;
94+
let stmt = new_stmt.as_ref().unwrap_or(stmt);
95+
8796
let (mut s_expr, mut from_context) = if stmt.from.is_empty() {
8897
let select_list = &stmt.select_list;
8998
self.bind_dummy_table(bind_context, select_list)?
@@ -111,14 +120,6 @@ impl Binder {
111120
self.bind_table_reference(bind_context, &cross_joins)?
112121
};
113122

114-
let mut rewriter = SelectRewriter::new(
115-
from_context.all_column_bindings(),
116-
self.name_resolution_ctx.unquoted_ident_case_sensitive,
117-
)
118-
.with_subquery_executor(self.subquery_executor.clone());
119-
let new_stmt = rewriter.rewrite(stmt)?;
120-
let stmt = new_stmt.as_ref().unwrap_or(stmt);
121-
122123
// Try put window definitions into bind context.
123124
// This operation should be before `normalize_select_list` because window functions can be used in select list.
124125
self.analyze_window_definition(&mut from_context, &stmt.window_list)?;
@@ -275,19 +276,16 @@ impl Binder {
275276

276277
/// It is useful when implementing some SQL syntax sugar,
277278
///
278-
/// [`column_binding`] contains the column binding information of the SelectStmt.
279-
///
280279
/// to rewrite the SelectStmt, just add a new rewrite_* function and call it in the `rewrite` function.
281280
#[allow(dead_code)]
282-
struct SelectRewriter<'a> {
283-
column_binding: &'a [ColumnBinding],
281+
struct SelectRewriter {
284282
new_stmt: Option<SelectStmt>,
285283
is_unquoted_ident_case_sensitive: bool,
286284
subquery_executor: Option<Arc<dyn QueryExecutor>>,
287285
}
288286

289287
// helper functions to SelectRewriter
290-
impl SelectRewriter<'_> {
288+
impl SelectRewriter {
291289
fn parse_aggregate_function(expr: &Expr) -> Result<(&Identifier, &[Expr])> {
292290
match expr {
293291
Expr::FunctionCall {
@@ -372,10 +370,9 @@ impl SelectRewriter<'_> {
372370
}
373371
}
374372

375-
impl<'a> SelectRewriter<'a> {
376-
fn new(column_binding: &'a [ColumnBinding], is_unquoted_ident_case_sensitive: bool) -> Self {
373+
impl SelectRewriter {
374+
fn new(is_unquoted_ident_case_sensitive: bool) -> Self {
377375
SelectRewriter {
378-
column_binding,
379376
new_stmt: None,
380377
is_unquoted_ident_case_sensitive,
381378
subquery_executor: None,
@@ -430,14 +427,14 @@ impl<'a> SelectRewriter<'a> {
430427
.set_span(expr.span())),
431428
})
432429
.collect::<Result<Vec<_>>>()?;
433-
let new_group_by = stmt.group_by.clone().unwrap_or_else(|| GroupBy::All);
434-
435-
let mut new_select_list = stmt.select_list.clone();
436-
if let Some(star) = new_select_list.iter_mut().find(|target| target.is_star()) {
437-
let mut exclude_columns = aggregate_args_names;
438-
exclude_columns.push(pivot.value_column.clone());
439-
star.exclude(exclude_columns);
430+
let mut exclude_columns = aggregate_args_names.clone();
431+
exclude_columns.push(pivot.value_column.clone());
432+
let mut star_target = SelectTarget::StarColumns {
433+
qualified: vec![Indirection::Star(None)],
434+
column_filter: None,
440435
};
436+
star_target.exclude(exclude_columns);
437+
let mut inner_select_list = vec![star_target];
441438
let new_aggregate_name = Identifier {
442439
name: format!("{}_if", aggregate_name.name),
443440
..aggregate_name.clone()
@@ -465,7 +462,7 @@ impl<'a> SelectRewriter<'a> {
465462
&values,
466463
&new_aggregate_name,
467464
aggregate_args,
468-
&mut new_select_list,
465+
&mut inner_select_list,
469466
stmt,
470467
)?;
471468
}
@@ -485,7 +482,7 @@ impl<'a> SelectRewriter<'a> {
485482
&values,
486483
&new_aggregate_name,
487484
aggregate_args,
488-
&mut new_select_list,
485+
&mut inner_select_list,
489486
stmt,
490487
)?;
491488
} else {
@@ -526,7 +523,7 @@ impl<'a> SelectRewriter<'a> {
526523
&values,
527524
&new_aggregate_name,
528525
aggregate_args,
529-
&mut new_select_list,
526+
&mut inner_select_list,
530527
stmt,
531528
)?;
532529
} else {
@@ -537,16 +534,46 @@ impl<'a> SelectRewriter<'a> {
537534
}
538535
}
539536

540-
if let Some(ref mut new_stmt) = self.new_stmt {
541-
new_stmt.select_list = new_select_list;
542-
new_stmt.group_by = Some(new_group_by);
543-
} else {
544-
self.new_stmt = Some(SelectStmt {
545-
select_list: new_select_list,
546-
group_by: Some(new_group_by),
547-
..stmt.clone()
548-
});
549-
}
537+
let mut inner_from = stmt.from[0].clone();
538+
Self::strip_pivot(&mut inner_from);
539+
540+
let inner_stmt = SelectStmt {
541+
span: stmt.span,
542+
hints: None,
543+
distinct: false,
544+
top_n: None,
545+
select_list: inner_select_list,
546+
from: vec![inner_from],
547+
selection: None,
548+
group_by: Some(GroupBy::All),
549+
having: None,
550+
window_list: None,
551+
qualify: None,
552+
};
553+
554+
let inner_query = Query {
555+
span: stmt.span,
556+
with: None,
557+
body: SetExpr::Select(Box::new(inner_stmt)),
558+
order_by: vec![],
559+
limit: vec![],
560+
offset: None,
561+
ignore_result: false,
562+
};
563+
564+
let subquery_ref = TableReference::Subquery {
565+
span: Self::table_ref_span(&stmt.from[0]),
566+
lateral: false,
567+
subquery: Box::new(inner_query),
568+
alias: Some(Self::table_ref_alias(&stmt.from[0])),
569+
pivot: None,
570+
unpivot: None,
571+
};
572+
573+
let mut outer_stmt = stmt.clone();
574+
outer_stmt.from = vec![subquery_ref];
575+
576+
self.new_stmt = Some(outer_stmt);
550577
Ok(())
551578
}
552579

@@ -625,6 +652,50 @@ impl<'a> SelectRewriter<'a> {
625652
Ok(values)
626653
}
627654

655+
fn strip_pivot(table_ref: &mut TableReference) {
656+
match table_ref {
657+
TableReference::Table { pivot, .. } => {
658+
*pivot = None;
659+
}
660+
TableReference::Subquery { pivot, .. } => {
661+
*pivot = None;
662+
}
663+
_ => {}
664+
}
665+
}
666+
667+
fn table_ref_span(table_ref: &TableReference) -> Span {
668+
match table_ref {
669+
TableReference::Table { span, .. } => *span,
670+
TableReference::TableFunction { span, .. } => *span,
671+
TableReference::Subquery { span, .. } => *span,
672+
TableReference::Join { span, .. } => *span,
673+
TableReference::Location { span, .. } => *span,
674+
}
675+
}
676+
677+
fn table_ref_alias(table_ref: &TableReference) -> TableAlias {
678+
match table_ref {
679+
TableReference::Table { table, alias, .. } => {
680+
alias.clone().unwrap_or_else(|| TableAlias {
681+
name: table.clone(),
682+
columns: vec![],
683+
keep_database_name: true,
684+
})
685+
}
686+
TableReference::Subquery { alias, .. } => alias.clone().unwrap_or_else(|| TableAlias {
687+
name: Identifier::from_name(Self::table_ref_span(table_ref), "__pivot_subquery"),
688+
columns: vec![],
689+
keep_database_name: false,
690+
}),
691+
_ => TableAlias {
692+
name: Identifier::from_name(Self::table_ref_span(table_ref), "__pivot_subquery"),
693+
columns: vec![],
694+
keep_database_name: false,
695+
},
696+
}
697+
}
698+
628699
fn build_pivot_source_query(&self, stmt: &SelectStmt) -> Result<String> {
629700
// Build the source query for the pivot table without the pivot clause
630701
// This is used to get distinct values for ANY pivot

src/query/sql/src/planner/binder/bind_table_reference/bind_obfuscate.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ fn build_subquery(
267267
alias: Some(TableAlias {
268268
name: ident(model_table_name.to_string()),
269269
columns: vec![],
270+
keep_database_name: false,
270271
}),
271272
pivot: None,
272273
unpivot: None,

src/query/sql/src/planner/binder/copy_into_table.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ impl Binder {
110110
let alias = alias_name.as_ref().map(|name| TableAlias {
111111
name: name.clone(),
112112
columns: vec![],
113+
keep_database_name: false,
113114
});
114115
self.bind_copy_from_query_into_table(bind_context, plan, select_list, &alias)
115116
.await

src/query/sql/src/planner/semantic/distinct_to_groupby.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ impl DistinctToGroupBy {
132132
alias: Some(TableAlias {
133133
name: Identifier::from_name(None, sub_query_name),
134134
columns: vec![Identifier::from_name(None, "_1")],
135+
keep_database_name: false,
135136
}),
136137
pivot: None,
137138
unpivot: None,

src/tests/sqlsmith/src/sql_gen/query.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,7 @@ impl<R: Rng> SqlGenerator<'_, R> {
878878
let alias = TableAlias {
879879
name: table_name.clone(),
880880
columns,
881+
keep_database_name: false,
881882
};
882883
let table = Table::new(None, table_name, schema);
883884

0 commit comments

Comments
 (0)