Skip to content

Commit b9a9722

Browse files
Copilotcincuranet
andcommitted
Fix null check optimization for IQueryable/DbSet types
Co-authored-by: cincuranet <[email protected]>
1 parent 3a23a86 commit b9a9722

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

src/EFCore/Query/Internal/NullCheckRemovingExpressionVisitor.cs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
2424
{
2525
var visitedExpression = base.VisitBinary(binaryExpression);
2626

27-
return TryOptimizeConditionalEquality(visitedExpression) ?? visitedExpression;
27+
return TryOptimizeQueryableNullCheck(visitedExpression)
28+
?? TryOptimizeConditionalEquality(visitedExpression)
29+
?? visitedExpression;
2830
}
2931

3032
/// <summary>
@@ -75,6 +77,34 @@ protected override Expression VisitConditional(ConditionalExpression conditional
7577
return base.VisitConditional(conditionalExpression);
7678
}
7779

80+
private static Expression? TryOptimizeQueryableNullCheck(Expression expression)
81+
{
82+
// Optimize IQueryable/DbSet null checks
83+
// IQueryable != null => true
84+
// IQueryable == null => false
85+
if (expression is BinaryExpression
86+
{
87+
NodeType: ExpressionType.Equal or ExpressionType.NotEqual
88+
} binaryExpression)
89+
{
90+
var isLeftNull = IsNullConstant(binaryExpression.Left);
91+
var isRightNull = IsNullConstant(binaryExpression.Right);
92+
93+
if (isLeftNull != isRightNull)
94+
{
95+
var nonNullExpression = isLeftNull ? binaryExpression.Right : binaryExpression.Left;
96+
97+
if (IsQueryableType(nonNullExpression.Type))
98+
{
99+
var result = binaryExpression.NodeType == ExpressionType.NotEqual;
100+
return Expression.Constant(result, typeof(bool));
101+
}
102+
}
103+
}
104+
105+
return null;
106+
}
107+
78108
private static Expression? TryOptimizeConditionalEquality(Expression expression)
79109
{
80110
// Simplify (a ? b : null) == null => !a || b == null
@@ -161,4 +191,17 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
161191

162192
private static bool IsNullConstant(Expression expression)
163193
=> expression is ConstantExpression { Value: null };
194+
195+
private static bool IsQueryableType(Type type)
196+
{
197+
if (type.IsGenericType)
198+
{
199+
var genericTypeDefinition = type.GetGenericTypeDefinition();
200+
return genericTypeDefinition == typeof(IQueryable<>)
201+
|| genericTypeDefinition == typeof(IOrderedQueryable<>)
202+
|| genericTypeDefinition == typeof(DbSet<>);
203+
}
204+
205+
return type == typeof(IQueryable);
206+
}
164207
}

test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,18 @@ public virtual Task Where_Queryable_AsEnumerable_Contains_negated(bool async)
14181418
elementSorter: e => e.CustomerID,
14191419
elementAsserter: (e, a) => AssertCollection(e.Subquery, a.Subquery));
14201420

1421+
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
1422+
public virtual Task Where_Queryable_null_check_with_Contains(bool async)
1423+
{
1424+
return AssertQuery(
1425+
async,
1426+
ss =>
1427+
{
1428+
var ids = ss.Set<Customer>().Select(c => c.CustomerID);
1429+
return ss.Set<Customer>().Where(c => ids != null && ids.Contains(c.CustomerID));
1430+
});
1431+
}
1432+
14211433
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
14221434
public virtual Task Where_Queryable_ToList_Count_member(bool async)
14231435
=> AssertQuery(

0 commit comments

Comments
 (0)