diff --git a/src/binder/expr.rs b/src/binder/expr.rs index b021c904..f35a4afa 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -110,22 +110,21 @@ impl<'a, T: Transaction> Binder<'a, T> { }), Expr::Subquery(subquery) => { let (sub_query, column) = self.bind_subquery(subquery)?; - self.context.sub_query(SubQueryType::SubQuery(sub_query)); - - if self.context.is_step(&QueryBindStep::Where) { - Ok(self.bind_temp_column(column)) + let (expr, sub_query) = if self.context.is_step(&QueryBindStep::Where) { + self.bind_temp_table(column, sub_query)? } else { - Ok(ScalarExpression::ColumnRef(column)) - } + (ScalarExpression::ColumnRef(column), sub_query) + }; + self.context.sub_query(SubQueryType::SubQuery(sub_query)); + Ok(expr) } Expr::InSubquery { expr, subquery, negated, } => { + let left_expr = Box::new(self.bind_expr(expr)?); let (sub_query, column) = self.bind_subquery(subquery)?; - self.context - .sub_query(SubQueryType::InSubQuery(*negated, sub_query)); if !self.context.is_step(&QueryBindStep::Where) { return Err(DatabaseError::UnsupportedStmt( @@ -133,11 +132,13 @@ impl<'a, T: Transaction> Binder<'a, T> { )); } - let alias_expr = self.bind_temp_column(column); + let (alias_expr, sub_query) = self.bind_temp_table(column, sub_query)?; + self.context + .sub_query(SubQueryType::InSubQuery(*negated, sub_query)); Ok(ScalarExpression::Binary { op: expression::BinaryOperator::Eq, - left_expr: Box::new(self.bind_expr(expr)?), + left_expr, right_expr: Box::new(alias_expr), ty: LogicalType::Boolean, }) @@ -203,16 +204,22 @@ impl<'a, T: Transaction> Binder<'a, T> { } } - fn bind_temp_column(&mut self, column: ColumnRef) -> ScalarExpression { + fn bind_temp_table( + &mut self, + column: ColumnRef, + sub_query: LogicalPlan, + ) -> Result<(ScalarExpression, LogicalPlan), DatabaseError> { let mut alias_column = ColumnCatalog::clone(&column); alias_column.set_table_name(self.context.temp_table()); - ScalarExpression::Alias { + let alias_expr = ScalarExpression::Alias { expr: Box::new(ScalarExpression::ColumnRef(column)), alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new( alias_column, )))), - } + }; + let alias_plan = self.bind_project(sub_query, vec![alias_expr.clone()])?; + Ok((alias_expr, alias_plan)) } fn bind_subquery( @@ -289,7 +296,7 @@ impl<'a, T: Transaction> Binder<'a, T> { } else { // handle col syntax let mut got_column = None; - for table_catalog in self.context.bind_table.values() { + for table_catalog in self.context.bind_table.values().rev() { if let Some(column_catalog) = table_catalog.get_column_by_name(&column_name) { got_column = Some(column_catalog); } diff --git a/src/binder/select.rs b/src/binder/select.rs index 7e2da9d7..0832fb74 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -76,7 +76,7 @@ impl<'a, T: Transaction> Binder<'a, T> { // Resolve scalar function call. // TODO support SRF(Set-Returning Function). - let mut select_list = self.normalize_select_item(&select.projection)?; + let mut select_list = self.normalize_select_item(&select.projection, &plan)?; if let Some(predicate) = &select.selection { plan = self.bind_where(plan, predicate)?; @@ -341,6 +341,7 @@ impl<'a, T: Transaction> Binder<'a, T> { fn normalize_select_item( &mut self, items: &[SelectItem], + plan: &LogicalPlan, ) -> Result, DatabaseError> { let mut select_items = vec![]; @@ -359,6 +360,9 @@ impl<'a, T: Transaction> Binder<'a, T> { }); } SelectItem::Wildcard(_) => { + if let Operator::Project(op) = &plan.operator { + return Ok(op.exprs.clone()); + } for (table_name, _) in self.context.bind_table.keys() { self.bind_table_column_refs(&mut select_items, table_name.clone())?; } @@ -510,7 +514,7 @@ impl<'a, T: Transaction> Binder<'a, T> { Ok(FilterOperator::build(having, children, true)) } - fn bind_project( + pub(crate) fn bind_project( &mut self, children: LogicalPlan, select_list: Vec, diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 7c9cb659..78672a4c 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -121,7 +121,13 @@ pub enum ScalarExpression { impl ScalarExpression { pub fn unpack_alias(self) -> ScalarExpression { - if let ScalarExpression::Alias { expr, .. } = self { + if let ScalarExpression::Alias { + alias: AliasType::Expr(expr), + .. + } = self + { + expr.unpack_alias() + } else if let ScalarExpression::Alias { expr, .. } = self { expr.unpack_alias() } else { self @@ -129,7 +135,13 @@ impl ScalarExpression { } pub fn unpack_alias_ref(&self) -> &ScalarExpression { - if let ScalarExpression::Alias { expr, .. } = self { + if let ScalarExpression::Alias { + alias: AliasType::Expr(expr), + .. + } = self + { + expr.unpack_alias_ref() + } else if let ScalarExpression::Alias { expr, .. } = self { expr.unpack_alias_ref() } else { self @@ -137,7 +149,7 @@ impl ScalarExpression { } pub fn try_reference(&mut self, output_exprs: &[ScalarExpression]) { - let fn_output_column = |expr: &ScalarExpression| expr.unpack_alias_ref().output_column(); + let fn_output_column = |expr: &ScalarExpression| expr.output_column(); let self_column = fn_output_column(self); if let Some((pos, _)) = output_exprs .iter() diff --git a/src/optimizer/rule/normalization/combine_operators.rs b/src/optimizer/rule/normalization/combine_operators.rs index 785b3106..2b59760b 100644 --- a/src/optimizer/rule/normalization/combine_operators.rs +++ b/src/optimizer/rule/normalization/combine_operators.rs @@ -61,8 +61,6 @@ impl NormalizationRule for CollapseProject { if let Operator::Project(child_op) = graph.operator(child_id) { if is_subset_exprs(&op.exprs, &child_op.exprs) { graph.remove_node(child_id, false); - } else { - graph.remove_node(node_id, false); } } }