Skip to content

Commit 6e34280

Browse files
aalexandrovalamb
andauthored
Plan LATERAL subqueries (#11456)
* Planner: support `LATERAL` subqueries * Planner: use `DFSchema::merge` in `create_relation_subquery` In order to compute the `set_outer_from_schema` argument we currently use `DFSchema::join`. When we combine the current outer FROM schema with the current outer query schema columns from the latter should override columns from the first, so the correct way is to use `DFSchema::merge`. To witness the fix, note that the query in the fixed test case isn't planned as expected without the accompanying changes. * Update plans --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 37e54ee commit 6e34280

File tree

6 files changed

+300
-14
lines changed

6 files changed

+300
-14
lines changed

datafusion/sql/src/planner.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ pub struct PlannerContext {
135135
ctes: HashMap<String, Arc<LogicalPlan>>,
136136
/// The query schema of the outer query plan, used to resolve the columns in subquery
137137
outer_query_schema: Option<DFSchemaRef>,
138+
/// The joined schemas of all FROM clauses planned so far. When planning LATERAL
139+
/// FROM clauses, this should become a suffix of the `outer_query_schema`.
140+
outer_from_schema: Option<DFSchemaRef>,
138141
}
139142

140143
impl Default for PlannerContext {
@@ -150,6 +153,7 @@ impl PlannerContext {
150153
prepare_param_data_types: Arc::new(vec![]),
151154
ctes: HashMap::new(),
152155
outer_query_schema: None,
156+
outer_from_schema: None,
153157
}
154158
}
155159

@@ -177,6 +181,29 @@ impl PlannerContext {
177181
schema
178182
}
179183

184+
// return a clone of the outer FROM schema
185+
pub fn outer_from_schema(&self) -> Option<Arc<DFSchema>> {
186+
self.outer_from_schema.clone()
187+
}
188+
189+
/// sets the outer FROM schema, returning the existing one, if any
190+
pub fn set_outer_from_schema(
191+
&mut self,
192+
mut schema: Option<DFSchemaRef>,
193+
) -> Option<DFSchemaRef> {
194+
std::mem::swap(&mut self.outer_from_schema, &mut schema);
195+
schema
196+
}
197+
198+
/// extends the FROM schema, returning the existing one, if any
199+
pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> {
200+
self.outer_from_schema = match self.outer_from_schema.as_ref() {
201+
Some(from_schema) => Some(Arc::new(from_schema.join(schema)?)),
202+
None => Some(Arc::clone(schema)),
203+
};
204+
Ok(())
205+
}
206+
180207
/// Return the types of parameters (`$1`, `$2`, etc) if known
181208
pub fn prepare_param_data_types(&self) -> &[DataType] {
182209
&self.prepare_param_data_types

datafusion/sql/src/relation/join.rs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
1919
use datafusion_common::{not_impl_err, Column, Result};
2020
use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder};
21-
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins};
21+
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableFactor, TableWithJoins};
2222
use std::collections::HashSet;
2323

2424
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
@@ -27,10 +27,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
2727
t: TableWithJoins,
2828
planner_context: &mut PlannerContext,
2929
) -> Result<LogicalPlan> {
30-
let mut left = self.create_relation(t.relation, planner_context)?;
31-
for join in t.joins.into_iter() {
30+
let mut left = if is_lateral(&t.relation) {
31+
self.create_relation_subquery(t.relation, planner_context)?
32+
} else {
33+
self.create_relation(t.relation, planner_context)?
34+
};
35+
let old_outer_from_schema = planner_context.outer_from_schema();
36+
for join in t.joins {
37+
planner_context.extend_outer_from_schema(left.schema())?;
3238
left = self.parse_relation_join(left, join, planner_context)?;
3339
}
40+
planner_context.set_outer_from_schema(old_outer_from_schema);
3441
Ok(left)
3542
}
3643

@@ -40,7 +47,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
4047
join: Join,
4148
planner_context: &mut PlannerContext,
4249
) -> Result<LogicalPlan> {
43-
let right = self.create_relation(join.relation, planner_context)?;
50+
let right = if is_lateral_join(&join)? {
51+
self.create_relation_subquery(join.relation, planner_context)?
52+
} else {
53+
self.create_relation(join.relation, planner_context)?
54+
};
4455
match join.join_operator {
4556
JoinOperator::LeftOuter(constraint) => {
4657
self.parse_join(left, right, constraint, JoinType::Left, planner_context)
@@ -144,3 +155,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
144155
}
145156
}
146157
}
158+
159+
/// Return `true` iff the given [`TableFactor`] is lateral.
160+
pub(crate) fn is_lateral(factor: &TableFactor) -> bool {
161+
match factor {
162+
TableFactor::Derived { lateral, .. } => *lateral,
163+
TableFactor::Function { lateral, .. } => *lateral,
164+
_ => false,
165+
}
166+
}
167+
168+
/// Return `true` iff the given [`Join`] is lateral.
169+
pub(crate) fn is_lateral_join(join: &Join) -> Result<bool> {
170+
let is_lateral_syntax = is_lateral(&join.relation);
171+
let is_apply_syntax = match join.join_operator {
172+
JoinOperator::FullOuter(..)
173+
| JoinOperator::RightOuter(..)
174+
| JoinOperator::RightAnti(..)
175+
| JoinOperator::RightSemi(..)
176+
if is_lateral_syntax =>
177+
{
178+
return not_impl_err!(
179+
"LATERAL syntax is not supported for \
180+
FULL OUTER and RIGHT [OUTER | ANTI | SEMI] joins"
181+
);
182+
}
183+
JoinOperator::CrossApply | JoinOperator::OuterApply => true,
184+
_ => false,
185+
};
186+
Ok(is_lateral_syntax || is_apply_syntax)
187+
}

datafusion/sql/src/relation/mod.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::Arc;
19+
1820
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
1921

2022
use datafusion_common::tree_node::{Transformed, TreeNode};
2123
use datafusion_common::{not_impl_err, plan_err, DFSchema, Result, TableReference};
24+
use datafusion_expr::builder::subquery_alias;
2225
use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder};
26+
use datafusion_expr::{Subquery, SubqueryAlias};
2327
use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor};
2428

2529
mod join;
@@ -153,6 +157,53 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
153157
Ok(optimized_plan)
154158
}
155159
}
160+
161+
pub(crate) fn create_relation_subquery(
162+
&self,
163+
subquery: TableFactor,
164+
planner_context: &mut PlannerContext,
165+
) -> Result<LogicalPlan> {
166+
// At this point for a syntacitally valid query the outer_from_schema is
167+
// guaranteed to be set, so the `.unwrap()` call will never panic. This
168+
// is the case because we only call this method for lateral table
169+
// factors, and those can never be the first factor in a FROM list. This
170+
// means we arrived here through the `for` loop in `plan_from_tables` or
171+
// the `for` loop in `plan_table_with_joins`.
172+
let old_from_schema = planner_context
173+
.set_outer_from_schema(None)
174+
.unwrap_or_else(|| Arc::new(DFSchema::empty()));
175+
let new_query_schema = match planner_context.outer_query_schema() {
176+
Some(old_query_schema) => {
177+
let mut new_query_schema = old_from_schema.as_ref().clone();
178+
new_query_schema.merge(old_query_schema);
179+
Some(Arc::new(new_query_schema))
180+
}
181+
None => Some(Arc::clone(&old_from_schema)),
182+
};
183+
let old_query_schema = planner_context.set_outer_query_schema(new_query_schema);
184+
185+
let plan = self.create_relation(subquery, planner_context)?;
186+
let outer_ref_columns = plan.all_out_ref_exprs();
187+
188+
planner_context.set_outer_query_schema(old_query_schema);
189+
planner_context.set_outer_from_schema(Some(old_from_schema));
190+
191+
match plan {
192+
LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => {
193+
subquery_alias(
194+
LogicalPlan::Subquery(Subquery {
195+
subquery: input,
196+
outer_ref_columns,
197+
}),
198+
alias,
199+
)
200+
}
201+
plan => Ok(LogicalPlan::Subquery(Subquery {
202+
subquery: Arc::new(plan),
203+
outer_ref_columns,
204+
})),
205+
}
206+
}
156207
}
157208

158209
fn optimize_subquery_sort(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {

datafusion/sql/src/select.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -496,19 +496,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
496496
match from.len() {
497497
0 => Ok(LogicalPlanBuilder::empty(true).build()?),
498498
1 => {
499-
let from = from.remove(0);
500-
self.plan_table_with_joins(from, planner_context)
499+
let input = from.remove(0);
500+
self.plan_table_with_joins(input, planner_context)
501501
}
502502
_ => {
503-
let mut plans = from
504-
.into_iter()
505-
.map(|t| self.plan_table_with_joins(t, planner_context));
506-
507-
let mut left = LogicalPlanBuilder::from(plans.next().unwrap()?);
508-
509-
for right in plans {
510-
left = left.cross_join(right?)?;
503+
let mut from = from.into_iter();
504+
505+
let mut left = LogicalPlanBuilder::from({
506+
let input = from.next().unwrap();
507+
self.plan_table_with_joins(input, planner_context)?
508+
});
509+
let old_outer_from_schema = {
510+
let left_schema = Some(Arc::clone(left.schema()));
511+
planner_context.set_outer_from_schema(left_schema)
512+
};
513+
for input in from {
514+
// Join `input` with the current result (`left`).
515+
let right = self.plan_table_with_joins(input, planner_context)?;
516+
left = left.cross_join(right)?;
517+
// Update the outer FROM schema.
518+
let left_schema = Some(Arc::clone(left.schema()));
519+
planner_context.set_outer_from_schema(left_schema);
511520
}
521+
planner_context.set_outer_from_schema(old_outer_from_schema);
522+
512523
Ok(left.build()?)
513524
}
514525
}

datafusion/sql/tests/sql_integration.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3103,6 +3103,114 @@ fn join_on_complex_condition() {
31033103
quick_test(sql, expected);
31043104
}
31053105

3106+
#[test]
3107+
fn lateral_constant() {
3108+
let sql = "SELECT * FROM j1, LATERAL (SELECT 1) AS j2";
3109+
let expected = "Projection: *\
3110+
\n CrossJoin:\
3111+
\n TableScan: j1\
3112+
\n SubqueryAlias: j2\
3113+
\n Subquery:\
3114+
\n Projection: Int64(1)\
3115+
\n EmptyRelation";
3116+
quick_test(sql, expected);
3117+
}
3118+
3119+
#[test]
3120+
fn lateral_comma_join() {
3121+
let sql = "SELECT j1_string, j2_string FROM
3122+
j1, \
3123+
LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2";
3124+
let expected = "Projection: j1.j1_string, j2.j2_string\
3125+
\n CrossJoin:\
3126+
\n TableScan: j1\
3127+
\n SubqueryAlias: j2\
3128+
\n Subquery:\
3129+
\n Projection: *\
3130+
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
3131+
\n TableScan: j2";
3132+
quick_test(sql, expected);
3133+
}
3134+
3135+
#[test]
3136+
fn lateral_comma_join_referencing_join_rhs() {
3137+
let sql = "SELECT * FROM\
3138+
\n j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id),\
3139+
\n LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4;";
3140+
let expected = "Projection: *\
3141+
\n CrossJoin:\
3142+
\n Inner Join: Filter: j1.j1_id = j2.j2_id\
3143+
\n TableScan: j1\
3144+
\n Inner Join: Filter: j2.j2_id = j3.j3_id - Int64(2)\
3145+
\n TableScan: j2\
3146+
\n TableScan: j3\
3147+
\n SubqueryAlias: j4\
3148+
\n Subquery:\
3149+
\n Projection: *\
3150+
\n Filter: j3.j3_string = outer_ref(j2.j2_string)\
3151+
\n TableScan: j3";
3152+
quick_test(sql, expected);
3153+
}
3154+
3155+
#[test]
3156+
fn lateral_comma_join_with_shadowing() {
3157+
// The j1_id on line 3 references the (closest) j1 definition from line 2.
3158+
let sql = "\
3159+
SELECT * FROM j1, LATERAL (\
3160+
SELECT * FROM j1, LATERAL (\
3161+
SELECT * FROM j2 WHERE j1_id = j2_id\
3162+
) as j2\
3163+
) as j2;";
3164+
let expected = "Projection: *\
3165+
\n CrossJoin:\
3166+
\n TableScan: j1\
3167+
\n SubqueryAlias: j2\
3168+
\n Subquery:\
3169+
\n Projection: *\
3170+
\n CrossJoin:\
3171+
\n TableScan: j1\
3172+
\n SubqueryAlias: j2\
3173+
\n Subquery:\
3174+
\n Projection: *\
3175+
\n Filter: outer_ref(j1.j1_id) = j2.j2_id\
3176+
\n TableScan: j2";
3177+
quick_test(sql, expected);
3178+
}
3179+
3180+
#[test]
3181+
fn lateral_left_join() {
3182+
let sql = "SELECT j1_string, j2_string FROM \
3183+
j1 \
3184+
LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true);";
3185+
let expected = "Projection: j1.j1_string, j2.j2_string\
3186+
\n Left Join: Filter: Boolean(true)\
3187+
\n TableScan: j1\
3188+
\n SubqueryAlias: j2\
3189+
\n Subquery:\
3190+
\n Projection: *\
3191+
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
3192+
\n TableScan: j2";
3193+
quick_test(sql, expected);
3194+
}
3195+
3196+
#[test]
3197+
fn lateral_nested_left_join() {
3198+
let sql = "SELECT * FROM
3199+
j1, \
3200+
(j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))";
3201+
let expected = "Projection: *\
3202+
\n CrossJoin:\
3203+
\n TableScan: j1\
3204+
\n Left Join: Filter: Boolean(true)\
3205+
\n TableScan: j2\
3206+
\n SubqueryAlias: j3\
3207+
\n Subquery:\
3208+
\n Projection: *\
3209+
\n Filter: outer_ref(j1.j1_id) + outer_ref(j2.j2_id) = j3.j3_id\
3210+
\n TableScan: j3";
3211+
quick_test(sql, expected);
3212+
}
3213+
31063214
#[test]
31073215
fn hive_aggregate_with_filter() -> Result<()> {
31083216
let dialect = &HiveDialect {};

0 commit comments

Comments
 (0)