Skip to content

Commit 41b10ca

Browse files
authored
Push down null filters for more join types (#12348)
* Push down null filters for more join types * fix tests * Fix docs
1 parent 9d819e1 commit 41b10ca

File tree

3 files changed

+75
-41
lines changed

3 files changed

+75
-41
lines changed

datafusion/optimizer/src/filter_null_join_keys.rs

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,16 @@
1818
//! [`FilterNullJoinKeys`] adds filters to join inputs when input isn't nullable
1919
2020
use crate::optimizer::ApplyOrder;
21+
use crate::push_down_filter::on_lr_is_preserved;
2122
use crate::{OptimizerConfig, OptimizerRule};
2223
use datafusion_common::tree_node::Transformed;
2324
use datafusion_common::Result;
2425
use datafusion_expr::utils::conjunction;
25-
use datafusion_expr::{
26-
logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, LogicalPlan,
27-
};
26+
use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan};
2827
use std::sync::Arc;
2928

30-
/// The FilterNullJoinKeys rule will identify inner joins with equi-join conditions
31-
/// where the join key is nullable on one side and non-nullable on the other side
32-
/// and then insert an `IsNotNull` filter on the nullable side since null values
29+
/// The FilterNullJoinKeys rule will identify joins with equi-join conditions
30+
/// where the join key is nullable and then insert an `IsNotNull` filter on the nullable side since null values
3331
/// can never match.
3432
#[derive(Default)]
3533
pub struct FilterNullJoinKeys {}
@@ -51,21 +49,23 @@ impl OptimizerRule for FilterNullJoinKeys {
5149
if !config.options().optimizer.filter_null_join_keys {
5250
return Ok(Transformed::no(plan));
5351
}
54-
5552
match plan {
56-
LogicalPlan::Join(mut join) if join.join_type == JoinType::Inner => {
53+
LogicalPlan::Join(mut join) if !join.on.is_empty() => {
54+
let (left_preserved, right_preserved) =
55+
on_lr_is_preserved(join.join_type);
56+
5757
let left_schema = join.left.schema();
5858
let right_schema = join.right.schema();
5959

6060
let mut left_filters = vec![];
6161
let mut right_filters = vec![];
6262

6363
for (l, r) in &join.on {
64-
if l.nullable(left_schema)? {
64+
if left_preserved && l.nullable(left_schema)? {
6565
left_filters.push(l.clone());
6666
}
6767

68-
if r.nullable(right_schema)? {
68+
if right_preserved && r.nullable(right_schema)? {
6969
right_filters.push(r.clone());
7070
}
7171
}
@@ -109,7 +109,7 @@ mod tests {
109109
use arrow::datatypes::{DataType, Field, Schema};
110110
use datafusion_common::Column;
111111
use datafusion_expr::logical_plan::table_scan;
112-
use datafusion_expr::{col, lit, LogicalPlanBuilder};
112+
use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder};
113113

114114
fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
115115
assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected)
@@ -118,18 +118,41 @@ mod tests {
118118
#[test]
119119
fn left_nullable() -> Result<()> {
120120
let (t1, t2) = test_tables()?;
121-
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?;
121+
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
122122
let expected = "Inner Join: t1.optional_id = t2.id\
123123
\n Filter: t1.optional_id IS NOT NULL\
124124
\n TableScan: t1\
125125
\n TableScan: t2";
126126
assert_optimized_plan_equal(plan, expected)
127127
}
128128

129+
#[test]
130+
fn left_nullable_left_join() -> Result<()> {
131+
let (t1, t2) = test_tables()?;
132+
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?;
133+
let expected = "Left Join: t1.optional_id = t2.id\
134+
\n TableScan: t1\
135+
\n TableScan: t2";
136+
assert_optimized_plan_equal(plan, expected)
137+
}
138+
139+
#[test]
140+
fn left_nullable_left_join_reordered() -> Result<()> {
141+
let (t_left, t_right) = test_tables()?;
142+
// Note: order of tables is reversed
143+
let plan =
144+
build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?;
145+
let expected = "Left Join: t2.id = t1.optional_id\
146+
\n TableScan: t2\
147+
\n Filter: t1.optional_id IS NOT NULL\
148+
\n TableScan: t1";
149+
assert_optimized_plan_equal(plan, expected)
150+
}
151+
129152
#[test]
130153
fn left_nullable_on_condition_reversed() -> Result<()> {
131154
let (t1, t2) = test_tables()?;
132-
let plan = build_plan(t1, t2, "t2.id", "t1.optional_id")?;
155+
let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?;
133156
let expected = "Inner Join: t1.optional_id = t2.id\
134157
\n Filter: t1.optional_id IS NOT NULL\
135158
\n TableScan: t1\
@@ -140,7 +163,7 @@ mod tests {
140163
#[test]
141164
fn nested_join_multiple_filter_expr() -> Result<()> {
142165
let (t1, t2) = test_tables()?;
143-
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?;
166+
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
144167
let schema = Schema::new(vec![
145168
Field::new("id", DataType::UInt32, false),
146169
Field::new("t1_id", DataType::UInt32, true),
@@ -244,11 +267,12 @@ mod tests {
244267
right_table: LogicalPlan,
245268
left_key: &str,
246269
right_key: &str,
270+
join_type: JoinType,
247271
) -> Result<LogicalPlan> {
248272
LogicalPlanBuilder::from(left_table)
249273
.join(
250274
right_table,
251-
JoinType::Inner,
275+
join_type,
252276
(
253277
vec![Column::from_qualified_name(left_key)],
254278
vec![Column::from_qualified_name(right_key)],

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,18 +157,18 @@ pub struct PushDownFilter {}
157157
/// the right is not, because there may be rows in the output that don't
158158
/// directly map to a row in the right input (due to nulls filling where there
159159
/// is no match on the right).
160-
fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
160+
pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
161161
match join_type {
162-
JoinType::Inner => Ok((true, true)),
163-
JoinType::Left => Ok((true, false)),
164-
JoinType::Right => Ok((false, true)),
165-
JoinType::Full => Ok((false, false)),
162+
JoinType::Inner => (true, true),
163+
JoinType::Left => (true, false),
164+
JoinType::Right => (false, true),
165+
JoinType::Full => (false, false),
166166
// No columns from the right side of the join can be referenced in output
167167
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
168-
JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
168+
JoinType::LeftSemi | JoinType::LeftAnti => (true, false),
169169
// No columns from the left side of the join can be referenced in output
170170
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
171-
JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
171+
JoinType::RightSemi | JoinType::RightAnti => (false, true),
172172
}
173173
}
174174

@@ -181,15 +181,15 @@ fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
181181
/// A tuple of booleans - (left_preserved, right_preserved).
182182
///
183183
/// See [`lr_is_preserved`] for a definition of "preserved".
184-
fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
184+
pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
185185
match join_type {
186-
JoinType::Inner => Ok((true, true)),
187-
JoinType::Left => Ok((false, true)),
188-
JoinType::Right => Ok((true, false)),
189-
JoinType::Full => Ok((false, false)),
190-
JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
191-
JoinType::LeftAnti => Ok((false, true)),
192-
JoinType::RightAnti => Ok((true, false)),
186+
JoinType::Inner => (true, true),
187+
JoinType::Left => (false, true),
188+
JoinType::Right => (true, false),
189+
JoinType::Full => (false, false),
190+
JoinType::LeftSemi | JoinType::RightSemi => (true, true),
191+
JoinType::LeftAnti => (false, true),
192+
JoinType::RightAnti => (true, false),
193193
}
194194
}
195195

@@ -420,7 +420,7 @@ fn push_down_all_join(
420420
) -> Result<Transformed<LogicalPlan>> {
421421
let is_inner_join = join.join_type == JoinType::Inner;
422422
// Get pushable predicates from current optimizer state
423-
let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?;
423+
let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
424424

425425
// The predicates can be divided to three categories:
426426
// 1) can push through join to its children(left or right)
@@ -457,7 +457,7 @@ fn push_down_all_join(
457457
}
458458

459459
let mut on_filter_join_conditions = vec![];
460-
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?;
460+
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
461461

462462
if !on_filter.is_empty() {
463463
for on in on_filter {

datafusion/optimizer/tests/optimizer_integration.rs

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,12 @@ fn semi_join_with_join_filter() -> Result<()> {
124124
let plan = test_sql(sql)?;
125125
let expected = "Projection: test.col_utf8\
126126
\n LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32\
127-
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
127+
\n Filter: test.col_int32 IS NOT NULL\
128+
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
128129
\n SubqueryAlias: __correlated_sq_1\
129130
\n SubqueryAlias: t2\
130-
\n TableScan: test projection=[col_int32, col_uint32]";
131+
\n Filter: test.col_int32 IS NOT NULL\
132+
\n TableScan: test projection=[col_int32, col_uint32]";
131133
assert_eq!(expected, format!("{plan}"));
132134
Ok(())
133135
}
@@ -144,7 +146,8 @@ fn anti_join_with_join_filter() -> Result<()> {
144146
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
145147
\n SubqueryAlias: __correlated_sq_1\
146148
\n SubqueryAlias: t2\
147-
\n TableScan: test projection=[col_int32, col_uint32]";
149+
\n Filter: test.col_int32 IS NOT NULL\
150+
\n TableScan: test projection=[col_int32, col_uint32]";
148151
assert_eq!(expected, format!("{plan}"));
149152
Ok(())
150153
}
@@ -155,11 +158,13 @@ fn where_exists_distinct() -> Result<()> {
155158
SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)";
156159
let plan = test_sql(sql)?;
157160
let expected = "LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32\
158-
\n TableScan: test projection=[col_int32]\
161+
\n Filter: test.col_int32 IS NOT NULL\
162+
\n TableScan: test projection=[col_int32]\
159163
\n SubqueryAlias: __correlated_sq_1\
160164
\n Aggregate: groupBy=[[t2.col_int32]], aggr=[[]]\
161165
\n SubqueryAlias: t2\
162-
\n TableScan: test projection=[col_int32]";
166+
\n Filter: test.col_int32 IS NOT NULL\
167+
\n TableScan: test projection=[col_int32]";
163168
assert_eq!(expected, format!("{plan}"));
164169
Ok(())
165170
}
@@ -175,9 +180,12 @@ fn intersect() -> Result<()> {
175180
\n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\
176181
\n LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\
177182
\n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\
183+
\n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\
184+
\n TableScan: test projection=[col_int32, col_utf8]\
185+
\n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\
178186
\n TableScan: test projection=[col_int32, col_utf8]\
179-
\n TableScan: test projection=[col_int32, col_utf8]\
180-
\n TableScan: test projection=[col_int32, col_utf8]";
187+
\n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\
188+
\n TableScan: test projection=[col_int32, col_utf8]";
181189
assert_eq!(expected, format!("{plan}"));
182190
Ok(())
183191
}
@@ -273,9 +281,11 @@ fn test_same_name_but_not_ambiguous() {
273281
let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\
274282
\n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\
275283
\n SubqueryAlias: t1\
276-
\n TableScan: test projection=[col_int32]\
284+
\n Filter: test.col_int32 IS NOT NULL\
285+
\n TableScan: test projection=[col_int32]\
277286
\n SubqueryAlias: t2\
278-
\n TableScan: test projection=[col_int32]";
287+
\n Filter: test.col_int32 IS NOT NULL\
288+
\n TableScan: test projection=[col_int32]";
279289
assert_eq!(expected, format!("{plan}"));
280290
}
281291

0 commit comments

Comments
 (0)