@@ -24,7 +24,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
2424 {
2525 var visitedExpression = base . VisitBinary ( binaryExpression ) ;
2626
27- return TryOptimizeConditionalEquality ( visitedExpression ) ?? visitedExpression ;
27+ return TryOptimizeQueryableNullCheck ( visitedExpression )
28+ ?? TryOptimizeConditionalEquality ( visitedExpression )
29+ ?? visitedExpression ;
2830 }
2931
3032 /// <summary>
@@ -75,6 +77,34 @@ protected override Expression VisitConditional(ConditionalExpression conditional
7577 return base . VisitConditional ( conditionalExpression ) ;
7678 }
7779
80+ private static Expression ? TryOptimizeQueryableNullCheck ( Expression expression )
81+ {
82+ // Optimize IQueryable/DbSet null checks
83+ // IQueryable != null => true
84+ // IQueryable == null => false
85+ if ( expression is BinaryExpression
86+ {
87+ NodeType : ExpressionType . Equal or ExpressionType . NotEqual
88+ } binaryExpression )
89+ {
90+ var isLeftNull = IsNullConstant ( binaryExpression . Left ) ;
91+ var isRightNull = IsNullConstant ( binaryExpression . Right ) ;
92+
93+ if ( isLeftNull != isRightNull )
94+ {
95+ var nonNullExpression = isLeftNull ? binaryExpression . Right : binaryExpression . Left ;
96+
97+ if ( IsQueryableType ( nonNullExpression . Type ) )
98+ {
99+ var result = binaryExpression . NodeType == ExpressionType . NotEqual ;
100+ return Expression . Constant ( result , typeof ( bool ) ) ;
101+ }
102+ }
103+ }
104+
105+ return null ;
106+ }
107+
78108 private static Expression ? TryOptimizeConditionalEquality ( Expression expression )
79109 {
80110 // Simplify (a ? b : null) == null => !a || b == null
@@ -161,4 +191,17 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
161191
162192 private static bool IsNullConstant ( Expression expression )
163193 => expression is ConstantExpression { Value : null } ;
194+
195+ private static bool IsQueryableType ( Type type )
196+ {
197+ if ( type . IsGenericType )
198+ {
199+ var genericTypeDefinition = type . GetGenericTypeDefinition ( ) ;
200+ return genericTypeDefinition == typeof ( IQueryable < > )
201+ || genericTypeDefinition == typeof ( IOrderedQueryable < > )
202+ || genericTypeDefinition == typeof ( DbSet < > ) ;
203+ }
204+
205+ return type == typeof ( IQueryable ) ;
206+ }
164207}
0 commit comments