Skip to content

Commit ee34089

Browse files
committed
Planner: support LATERAL subqueries
1 parent a7041fe commit ee34089

File tree

5 files changed

+230
-14
lines changed

5 files changed

+230
-14
lines changed

datafusion/sql/src/planner.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ pub struct PlannerContext {
109109
ctes: HashMap<String, Arc<LogicalPlan>>,
110110
/// The query schema of the outer query plan, used to resolve the columns in subquery
111111
outer_query_schema: Option<DFSchemaRef>,
112+
/// The joined schemas of all FROM clauses planned so far. When planning LATERAL
113+
/// FROM clauses, this should become a suffix of the `outer_query_schema`.
114+
outer_from_schema: Option<DFSchemaRef>,
112115
}
113116

114117
impl Default for PlannerContext {
@@ -124,6 +127,7 @@ impl PlannerContext {
124127
prepare_param_data_types: Arc::new(vec![]),
125128
ctes: HashMap::new(),
126129
outer_query_schema: None,
130+
outer_from_schema: None,
127131
}
128132
}
129133

@@ -151,6 +155,29 @@ impl PlannerContext {
151155
schema
152156
}
153157

158+
// return a clone of the outer FROM schema
159+
pub fn outer_from_schema(&self) -> Option<Arc<DFSchema>> {
160+
self.outer_from_schema.clone()
161+
}
162+
163+
/// sets the outer FROM schema, returning the existing one, if any
164+
pub fn set_outer_from_schema(
165+
&mut self,
166+
mut schema: Option<DFSchemaRef>,
167+
) -> Option<DFSchemaRef> {
168+
std::mem::swap(&mut self.outer_from_schema, &mut schema);
169+
schema
170+
}
171+
172+
/// extends the FROM schema, returning the existing one, if any
173+
pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> {
174+
self.outer_from_schema = match self.outer_from_schema.as_ref() {
175+
Some(from_schema) => Some(Arc::new(from_schema.join(schema)?)),
176+
None => Some(Arc::clone(schema)),
177+
};
178+
Ok(())
179+
}
180+
154181
/// Return the types of parameters (`$1`, `$2`, etc) if known
155182
pub fn prepare_param_data_types(&self) -> &[DataType] {
156183
&self.prepare_param_data_types

datafusion/sql/src/relation/join.rs

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
1919
use datafusion_common::{not_impl_err, Column, Result};
2020
use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder};
21-
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins};
21+
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableFactor, TableWithJoins};
2222
use std::collections::HashSet;
2323

2424
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
@@ -27,10 +27,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
2727
t: TableWithJoins,
2828
planner_context: &mut PlannerContext,
2929
) -> Result<LogicalPlan> {
30-
let mut left = self.create_relation(t.relation, planner_context)?;
31-
for join in t.joins.into_iter() {
30+
let mut left = if is_lateral(&t.relation) {
31+
self.create_relation_subquery(t.relation, planner_context)?
32+
} else {
33+
self.create_relation(t.relation, planner_context)?
34+
};
35+
let old_outer_from_schema = planner_context.outer_from_schema();
36+
for join in t.joins {
37+
planner_context.extend_outer_from_schema(left.schema())?;
3238
left = self.parse_relation_join(left, join, planner_context)?;
3339
}
40+
planner_context.set_outer_from_schema(old_outer_from_schema);
3441
Ok(left)
3542
}
3643

@@ -40,7 +47,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
4047
join: Join,
4148
planner_context: &mut PlannerContext,
4249
) -> Result<LogicalPlan> {
43-
let right = self.create_relation(join.relation, planner_context)?;
50+
let right = if is_lateral_join(&join)? {
51+
self.create_relation_subquery(join.relation, planner_context)?
52+
} else {
53+
self.create_relation(join.relation, planner_context)?
54+
};
4455
match join.join_operator {
4556
JoinOperator::LeftOuter(constraint) => {
4657
self.parse_join(left, right, constraint, JoinType::Left, planner_context)
@@ -144,3 +155,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
144155
}
145156
}
146157
}
158+
159+
/// Return `true` iff the given [`TableFactor`] is lateral.
160+
pub(crate) fn is_lateral(factor: &TableFactor) -> bool {
161+
match factor {
162+
TableFactor::Derived { lateral, .. } => *lateral,
163+
TableFactor::Function { lateral, .. } => *lateral,
164+
_ => false,
165+
}
166+
}
167+
168+
/// Return `true` iff the given [`Join`] is lateral.
169+
pub(crate) fn is_lateral_join(join: &Join) -> Result<bool> {
170+
let is_lateral_syntax = is_lateral(&join.relation);
171+
let is_apply_syntax = match join.join_operator {
172+
JoinOperator::FullOuter(..)
173+
| JoinOperator::RightOuter(..)
174+
| JoinOperator::RightAnti(..)
175+
| JoinOperator::RightSemi(..)
176+
if is_lateral_syntax =>
177+
{
178+
return not_impl_err!("NONE constraint is not supported");
179+
}
180+
JoinOperator::CrossApply | JoinOperator::OuterApply => true,
181+
_ => false,
182+
};
183+
Ok(is_lateral_syntax || is_apply_syntax)
184+
}

datafusion/sql/src/relation/mod.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::Arc;
19+
1820
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
1921
use datafusion_common::{not_impl_err, plan_err, DFSchema, Result, TableReference};
22+
use datafusion_expr::builder::subquery_alias;
2023
use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder};
24+
use datafusion_expr::{Subquery, SubqueryAlias};
2125
use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor};
2226

2327
mod join;
@@ -143,4 +147,45 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
143147
Ok(plan)
144148
}
145149
}
150+
151+
pub(crate) fn create_relation_subquery(
152+
&self,
153+
subquery: TableFactor,
154+
planner_context: &mut PlannerContext,
155+
) -> Result<LogicalPlan> {
156+
// At this point for a syntacitally valid query the outer_from_schema is
157+
// guaranteed to be set, so the `.unwrap()` call will never panic. This
158+
// is the case because we only call this method for lateral table
159+
// factors, and those can never be the first factor in a FROM list. This
160+
// means we arrived here through the `for` loop in `plan_from_tables` or
161+
// the `for` loop in `plan_table_with_joins`.
162+
let old_from_schema = planner_context.set_outer_from_schema(None).unwrap();
163+
let new_query_schema = match planner_context.outer_query_schema() {
164+
Some(lhs) => Some(Arc::new(lhs.join(&old_from_schema)?)),
165+
None => Some(Arc::clone(&old_from_schema)),
166+
};
167+
let old_query_schema = planner_context.set_outer_query_schema(new_query_schema);
168+
169+
let plan = self.create_relation(subquery, planner_context)?;
170+
let outer_ref_columns = plan.all_out_ref_exprs();
171+
172+
planner_context.set_outer_query_schema(old_query_schema);
173+
planner_context.set_outer_from_schema(Some(old_from_schema));
174+
175+
match plan {
176+
LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => {
177+
subquery_alias(
178+
LogicalPlan::Subquery(Subquery {
179+
subquery: input,
180+
outer_ref_columns,
181+
}),
182+
alias,
183+
)
184+
}
185+
plan => Ok(LogicalPlan::Subquery(Subquery {
186+
subquery: Arc::new(plan),
187+
outer_ref_columns,
188+
})),
189+
}
190+
}
146191
}

datafusion/sql/src/select.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -396,19 +396,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
396396
match from.len() {
397397
0 => Ok(LogicalPlanBuilder::empty(true).build()?),
398398
1 => {
399-
let from = from.remove(0);
400-
self.plan_table_with_joins(from, planner_context)
399+
let input = from.remove(0);
400+
self.plan_table_with_joins(input, planner_context)
401401
}
402402
_ => {
403-
let mut plans = from
404-
.into_iter()
405-
.map(|t| self.plan_table_with_joins(t, planner_context));
406-
407-
let mut left = LogicalPlanBuilder::from(plans.next().unwrap()?);
408-
409-
for right in plans {
410-
left = left.cross_join(right?)?;
403+
let mut from = from.into_iter();
404+
405+
let mut left = LogicalPlanBuilder::from({
406+
let input = from.next().unwrap();
407+
self.plan_table_with_joins(input, planner_context)?
408+
});
409+
let old_outer_from_schema = {
410+
let left_schema = Some(Arc::clone(left.schema()));
411+
planner_context.set_outer_from_schema(left_schema)
412+
};
413+
for input in from {
414+
// Join `input` with the current result (`left`).
415+
let right = self.plan_table_with_joins(input, planner_context)?;
416+
left = left.cross_join(right)?;
417+
// Update the outer FROM schema.
418+
let left_schema = Some(Arc::clone(left.schema()));
419+
planner_context.set_outer_from_schema(left_schema);
411420
}
421+
planner_context.set_outer_from_schema(old_outer_from_schema);
422+
412423
Ok(left.build()?)
413424
}
414425
}

datafusion/sql/tests/sql_integration.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3177,6 +3177,101 @@ fn join_on_complex_condition() {
31773177
quick_test(sql, expected);
31783178
}
31793179

3180+
#[test]
3181+
fn lateral_comma_join() {
3182+
let sql = "SELECT j1_string, j2_string FROM
3183+
j1, \
3184+
LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2";
3185+
let expected = "Projection: j1.j1_string, j2.j2_string\
3186+
\n CrossJoin:\
3187+
\n TableScan: j1\
3188+
\n SubqueryAlias: j2\
3189+
\n Subquery:\
3190+
\n Projection: j2.j2_id, j2.j2_string\
3191+
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
3192+
\n TableScan: j2";
3193+
quick_test(sql, expected);
3194+
}
3195+
3196+
#[test]
3197+
fn lateral_comma_join_referencing_join_rhs() {
3198+
let sql = "SELECT * FROM\
3199+
\n j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id),\
3200+
\n LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4;";
3201+
let expected = "Projection: j1.j1_id, j1.j1_string, j2.j2_id, j2.j2_string, j3.j3_id, j3.j3_string, j4.j3_id, j4.j3_string\
3202+
\n CrossJoin:\
3203+
\n Inner Join: Filter: j1.j1_id = j2.j2_id\
3204+
\n TableScan: j1\
3205+
\n Inner Join: Filter: j2.j2_id = j3.j3_id - Int64(2)\
3206+
\n TableScan: j2\
3207+
\n TableScan: j3\
3208+
\n SubqueryAlias: j4\
3209+
\n Subquery:\
3210+
\n Projection: j3.j3_id, j3.j3_string\
3211+
\n Filter: j3.j3_string = outer_ref(j2.j2_string)\
3212+
\n TableScan: j3";
3213+
quick_test(sql, expected);
3214+
}
3215+
3216+
#[test]
3217+
fn lateral_comma_join_with_shadowing() {
3218+
// The j1_id on line 3 references the (closest) j1 definition from line 2.
3219+
let sql = "-- Triple nested correlated queries queries\
3220+
\nSELECT * FROM j1, LATERAL ( -- line 1\
3221+
\n SELECT * FROM j1, LATERAL ( -- line 2\
3222+
\n SELECT * FROM j2 WHERE j1.j1_id = j2_id -- line 3\
3223+
\n ) as j2\
3224+
\n) as j2;";
3225+
let expected = "Projection: j1.j1_id, j1.j1_string, j2.j1_id, j2.j1_string, j2.j2_id, j2.j2_string\
3226+
\n CrossJoin:\
3227+
\n TableScan: j1\
3228+
\n SubqueryAlias: j2\
3229+
\n Subquery:\
3230+
\n Projection: j1.j1_id, j1.j1_string, j2.j2_id, j2.j2_string\
3231+
\n CrossJoin:\
3232+
\n TableScan: j1\
3233+
\n SubqueryAlias: j2\
3234+
\n Subquery:\
3235+
\n Projection: j2.j2_id, j2.j2_string\
3236+
\n Filter: outer_ref(j1.j1_id) = j2.j2_id\
3237+
\n TableScan: j2";
3238+
quick_test(sql, expected);
3239+
}
3240+
3241+
#[test]
3242+
fn lateral_left_join() {
3243+
let sql = "SELECT j1_string, j2_string FROM
3244+
j1 \
3245+
LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true);";
3246+
let expected = "Projection: j1.j1_string, j2.j2_string\
3247+
\n Left Join: Filter: Boolean(true)\
3248+
\n TableScan: j1\
3249+
\n SubqueryAlias: j2\
3250+
\n Subquery:\
3251+
\n Projection: j2.j2_id, j2.j2_string\
3252+
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
3253+
\n TableScan: j2";
3254+
quick_test(sql, expected);
3255+
}
3256+
3257+
#[test]
3258+
fn lateral_nested_left_join() {
3259+
let sql = "SELECT * FROM
3260+
j1, \
3261+
(j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))";
3262+
let expected = "Projection: j1.j1_id, j1.j1_string, j2.j2_id, j2.j2_string, j3.j3_id, j3.j3_string\
3263+
\n CrossJoin:\
3264+
\n TableScan: j1\
3265+
\n Left Join: Filter: Boolean(true)\
3266+
\n TableScan: j2\
3267+
\n SubqueryAlias: j3\
3268+
\n Subquery:\
3269+
\n Projection: j3.j3_id, j3.j3_string\
3270+
\n Filter: outer_ref(j1.j1_id) + outer_ref(j2.j2_id) = j3.j3_id\
3271+
\n TableScan: j3";
3272+
quick_test(sql, expected);
3273+
}
3274+
31803275
#[test]
31813276
fn hive_aggregate_with_filter() -> Result<()> {
31823277
let dialect = &HiveDialect {};

0 commit comments

Comments
 (0)