Skip to content

Commit f54781e

Browse files
committed
Planner: support LATERAL subqueries
1 parent a7041fe commit f54781e

File tree

5 files changed

+176
-14
lines changed

5 files changed

+176
-14
lines changed

datafusion/sql/src/planner.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ pub struct PlannerContext {
109109
ctes: HashMap<String, Arc<LogicalPlan>>,
110110
/// The query schema of the outer query plan, used to resolve the columns in subquery
111111
outer_query_schema: Option<DFSchemaRef>,
112+
/// The joined schemas of all FROM clauses planned so far. When planning LATERAL
113+
/// FROM clauses, this should become a suffix of the `outer_query_schema`.
114+
outer_from_schema: Option<DFSchemaRef>,
112115
}
113116

114117
impl Default for PlannerContext {
@@ -124,6 +127,7 @@ impl PlannerContext {
124127
prepare_param_data_types: Arc::new(vec![]),
125128
ctes: HashMap::new(),
126129
outer_query_schema: None,
130+
outer_from_schema: None,
127131
}
128132
}
129133

@@ -151,6 +155,25 @@ impl PlannerContext {
151155
schema
152156
}
153157

158+
/// sets the FROM schema, returning the existing one, if
159+
/// any
160+
pub fn set_outer_from_schema(
161+
&mut self,
162+
mut schema: Option<DFSchemaRef>,
163+
) -> Option<DFSchemaRef> {
164+
std::mem::swap(&mut self.outer_from_schema, &mut schema);
165+
schema
166+
}
167+
168+
/// extends the FROM schema, returning the existing one, if any
169+
pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> {
170+
self.outer_from_schema = match self.outer_from_schema.as_ref() {
171+
Some(from_schema) => Some(Arc::new(from_schema.join(schema)?)),
172+
None => Some(Arc::clone(schema)),
173+
};
174+
Ok(())
175+
}
176+
154177
/// Return the types of parameters (`$1`, `$2`, etc) if known
155178
pub fn prepare_param_data_types(&self) -> &[DataType] {
156179
&self.prepare_param_data_types

datafusion/sql/src/relation/join.rs

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
1919
use datafusion_common::{not_impl_err, Column, Result};
2020
use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder};
21-
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins};
21+
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableFactor, TableWithJoins};
2222
use std::collections::HashSet;
2323

2424
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
@@ -27,10 +27,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
2727
t: TableWithJoins,
2828
planner_context: &mut PlannerContext,
2929
) -> Result<LogicalPlan> {
30-
let mut left = self.create_relation(t.relation, planner_context)?;
31-
for join in t.joins.into_iter() {
30+
let origin_planner_context = planner_context.clone();
31+
let mut left = if is_lateral(&t.relation) {
32+
self.create_relation_subquery(t.relation, planner_context)?
33+
} else {
34+
self.create_relation(t.relation, planner_context)?
35+
};
36+
for join in t.joins {
37+
*planner_context = origin_planner_context.clone();
38+
planner_context.extend_outer_from_schema(left.schema())?;
3239
left = self.parse_relation_join(left, join, planner_context)?;
3340
}
41+
*planner_context = origin_planner_context;
3442
Ok(left)
3543
}
3644

@@ -40,7 +48,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
4048
join: Join,
4149
planner_context: &mut PlannerContext,
4250
) -> Result<LogicalPlan> {
43-
let right = self.create_relation(join.relation, planner_context)?;
51+
let right = if is_lateral_join(&join)? {
52+
self.create_relation_subquery(join.relation, planner_context)?
53+
} else {
54+
self.create_relation(join.relation, planner_context)?
55+
};
4456
match join.join_operator {
4557
JoinOperator::LeftOuter(constraint) => {
4658
self.parse_join(left, right, constraint, JoinType::Left, planner_context)
@@ -144,3 +156,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
144156
}
145157
}
146158
}
159+
160+
/// Return `true` iff the given [`TableFactor`] is lateral.
161+
pub(crate) fn is_lateral(factor: &TableFactor) -> bool {
162+
match factor {
163+
TableFactor::Derived { lateral, .. } => *lateral,
164+
TableFactor::Function { lateral, .. } => *lateral,
165+
_ => false,
166+
}
167+
}
168+
169+
/// Return `true` iff the given [`Join`] is lateral.
170+
pub(crate) fn is_lateral_join(join: &Join) -> Result<bool> {
171+
let is_lateral_syntax = is_lateral(&join.relation);
172+
let is_apply_syntax = match join.join_operator {
173+
JoinOperator::FullOuter(..)
174+
| JoinOperator::RightOuter(..)
175+
| JoinOperator::RightAnti(..)
176+
| JoinOperator::RightSemi(..)
177+
if is_lateral_syntax =>
178+
{
179+
return not_impl_err!("NONE constraint is not supported");
180+
}
181+
JoinOperator::CrossApply | JoinOperator::OuterApply => true,
182+
_ => false,
183+
};
184+
Ok(is_lateral_syntax || is_apply_syntax)
185+
}

datafusion/sql/src/relation/mod.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::Arc;
19+
1820
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
1921
use datafusion_common::{not_impl_err, plan_err, DFSchema, Result, TableReference};
22+
use datafusion_expr::builder::subquery_alias;
2023
use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder};
24+
use datafusion_expr::{Subquery, SubqueryAlias};
2125
use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor};
2226

2327
mod join;
@@ -143,4 +147,39 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
143147
Ok(plan)
144148
}
145149
}
150+
151+
pub(crate) fn create_relation_subquery(
152+
&self,
153+
subquery: TableFactor,
154+
planner_context: &mut PlannerContext,
155+
) -> Result<LogicalPlan> {
156+
let old_from_schema = planner_context.set_outer_from_schema(None).unwrap(); // set in plan_table_with_joins
157+
let new_outer_schema = match planner_context.outer_query_schema() {
158+
Some(lhs) => Some(Arc::new(lhs.join(&old_from_schema)?)),
159+
None => Some(Arc::clone(&old_from_schema)),
160+
};
161+
let old_outer_sch = planner_context.set_outer_query_schema(new_outer_schema);
162+
163+
let plan = self.create_relation(subquery, planner_context)?;
164+
let outer_ref_columns = plan.all_out_ref_exprs();
165+
166+
planner_context.set_outer_query_schema(old_outer_sch);
167+
planner_context.set_outer_from_schema(Some(old_from_schema));
168+
169+
match plan {
170+
LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => {
171+
subquery_alias(
172+
LogicalPlan::Subquery(Subquery {
173+
subquery: input,
174+
outer_ref_columns,
175+
}),
176+
alias,
177+
)
178+
}
179+
plan => Ok(LogicalPlan::Subquery(Subquery {
180+
subquery: Arc::new(plan),
181+
outer_ref_columns,
182+
})),
183+
}
184+
}
146185
}

datafusion/sql/src/select.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -396,19 +396,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
396396
match from.len() {
397397
0 => Ok(LogicalPlanBuilder::empty(true).build()?),
398398
1 => {
399-
let from = from.remove(0);
400-
self.plan_table_with_joins(from, planner_context)
399+
let input = from.remove(0);
400+
self.plan_table_with_joins(input, planner_context)
401401
}
402402
_ => {
403-
let mut plans = from
404-
.into_iter()
405-
.map(|t| self.plan_table_with_joins(t, planner_context));
406-
407-
let mut left = LogicalPlanBuilder::from(plans.next().unwrap()?);
408-
409-
for right in plans {
410-
left = left.cross_join(right?)?;
403+
let mut from = from.into_iter();
404+
405+
let mut left = LogicalPlanBuilder::from({
406+
let input = from.next().unwrap();
407+
self.plan_table_with_joins(input, planner_context)?
408+
});
409+
let old_outer_from_schema = {
410+
let left_schema = Some(Arc::clone(left.schema()));
411+
planner_context.set_outer_from_schema(left_schema)
412+
};
413+
for input in from {
414+
// Join `input` with the current result (`left`).
415+
let right = self.plan_table_with_joins(input, planner_context)?;
416+
left = left.cross_join(right)?;
417+
// Update the outer FROM schema.
418+
let left_schema = Some(Arc::clone(left.schema()));
419+
planner_context.set_outer_from_schema(left_schema);
411420
}
421+
planner_context.set_outer_from_schema(old_outer_from_schema);
422+
412423
Ok(left.build()?)
413424
}
414425
}

datafusion/sql/tests/sql_integration.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3177,6 +3177,56 @@ fn join_on_complex_condition() {
31773177
quick_test(sql, expected);
31783178
}
31793179

3180+
#[test]
3181+
fn lateral_comma_join() {
3182+
let sql = "SELECT j1_string, j2_string FROM
3183+
j1, \
3184+
LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2";
3185+
let expected = "Projection: j1.j1_string, j2.j2_string\
3186+
\n CrossJoin:\
3187+
\n TableScan: j1\
3188+
\n SubqueryAlias: j2\
3189+
\n Subquery:\
3190+
\n Projection: j2.j2_id, j2.j2_string\
3191+
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
3192+
\n TableScan: j2";
3193+
quick_test(sql, expected);
3194+
}
3195+
3196+
#[test]
3197+
fn lateral_left_join() {
3198+
let sql = "SELECT j1_string, j2_string FROM
3199+
j1 \
3200+
LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true);";
3201+
let expected = "Projection: j1.j1_string, j2.j2_string\
3202+
\n Left Join: Filter: Boolean(true)\
3203+
\n TableScan: j1\
3204+
\n SubqueryAlias: j2\
3205+
\n Subquery:\
3206+
\n Projection: j2.j2_id, j2.j2_string\
3207+
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
3208+
\n TableScan: j2";
3209+
quick_test(sql, expected);
3210+
}
3211+
3212+
#[test]
3213+
fn lateral_nested_left_join() {
3214+
let sql = "SELECT * FROM
3215+
j1, \
3216+
(j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))";
3217+
let expected = "Projection: j1.j1_id, j1.j1_string, j2.j2_id, j2.j2_string, j3.j3_id, j3.j3_string\
3218+
\n CrossJoin:\
3219+
\n TableScan: j1\
3220+
\n Left Join: Filter: Boolean(true)\
3221+
\n TableScan: j2\
3222+
\n SubqueryAlias: j3\
3223+
\n Subquery:\
3224+
\n Projection: j3.j3_id, j3.j3_string\
3225+
\n Filter: outer_ref(j1.j1_id) + outer_ref(j2.j2_id) = j3.j3_id\
3226+
\n TableScan: j3";
3227+
quick_test(sql, expected);
3228+
}
3229+
31803230
#[test]
31813231
fn hive_aggregate_with_filter() -> Result<()> {
31823232
let dialect = &HiveDialect {};

0 commit comments

Comments
 (0)