Skip to content

Commit 160b224

Browse files
authored
Merge pull request #212 from servicetitan/upstream/precise-bool-expression-converter
BooleanExpressionConverter: use precise optimization
2 parents c170312 + cf86df2 commit 160b224

File tree

13 files changed

+179
-25
lines changed

13 files changed

+179
-25
lines changed

Extensions/Xtensive.Orm.BulkOperations/Internals/BaseSqlVisitor.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,12 +516,16 @@ public virtual void Visit(SqlWhile node)
516516
VisitInternal(node.Statement);
517517
}
518518

519-
520519
public virtual void Visit(SqlFragment node)
521520
{
522521
VisitInternal(node.Expression);
523522
}
524523

524+
public virtual void Visit(SqlMetadata node)
525+
{
526+
VisitInternal(node.Expression);
527+
}
528+
525529
#region Non-public methods
526530

527531
private void VisitInternal(ISqlNode node)

Orm/Xtensive.Orm.Tests/Linq/DistinctTest.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,5 +399,26 @@ public void DistinctSkipTakeTest()
399399
Assert.IsTrue(expected.SequenceEqual(result));
400400
Assert.Greater(result.ToList().Count, 0);
401401
}
402+
403+
[Test]
404+
public void DistinctByBoolExpression()
405+
{
406+
var result = Session.Query.All<Invoice>().Select(c => c.Status == (InvoiceStatus) 1)
407+
.Distinct()
408+
.ToArray();
409+
410+
CollectionAssert.AreEquivalent(new[] {false, true}, result);
411+
}
412+
413+
[Test]
414+
public void DistinctByBoolExpressionComplex()
415+
{
416+
var result = Session.Query.All<Invoice>()
417+
.Select(c => c.Status == (InvoiceStatus) 1 || c.Status == (InvoiceStatus) 2)
418+
.Distinct()
419+
.ToArray();
420+
421+
CollectionAssert.AreEquivalent(new[] {false, true}, result);
422+
}
402423
}
403424
}

Orm/Xtensive.Orm.Tests/Linq/GroupByTest.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,37 @@ public void GroupWithJoinTest()
10211021
QueryDumper.Dump(query);
10221022
}
10231023

1024+
[Test]
1025+
public void GroupByBoolExpression()
1026+
{
1027+
var query = Session.Query.All<Invoice>();
1028+
var falseResult = query.Count(c => c.Status != (InvoiceStatus) 1);
1029+
var trueResult = query.Count(c => c.Status == (InvoiceStatus) 1);
1030+
1031+
var result = query.GroupBy(c => c.Status == (InvoiceStatus) 1)
1032+
.Select(c => new {Value = c.Key, Count = c.Count()})
1033+
.ToArray();
1034+
1035+
Assert.AreEqual(falseResult, result.Single(i => !i.Value).Count);
1036+
Assert.AreEqual(trueResult, result.Single(i => i.Value).Count);
1037+
}
1038+
1039+
[Test]
1040+
public void GroupByBoolExpressionComplex()
1041+
{
1042+
var query = Session.Query.All<Invoice>();
1043+
var falseResult = query.Count(c => !(c.Status == (InvoiceStatus) 1 || c.Status == (InvoiceStatus) 2));
1044+
var trueResult = query.Count(c => c.Status == (InvoiceStatus) 1 || c.Status == (InvoiceStatus) 2);
1045+
1046+
var result = query
1047+
.GroupBy(c => c.Status == (InvoiceStatus) 1 || c.Status == (InvoiceStatus) 2)
1048+
.Select(c => new {Value = c.Key, Count = c.Count()})
1049+
.ToArray();
1050+
1051+
Assert.AreEqual(falseResult, result.Single(i => !i.Value).Count);
1052+
Assert.AreEqual(trueResult, result.Single(i => i.Value).Count);
1053+
}
1054+
10241055
private void DumpGrouping<TKey, TValue>(IQueryable<IGrouping<TKey, TValue>> result)
10251056
{
10261057
DumpGrouping(result, false);

Orm/Xtensive.Orm.Tests/Linq/OrderByTest.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,5 +316,24 @@ public void ThenByTest()
316316
Assert.That(result, Is.Not.Empty);
317317
Assert.IsTrue(expected.SequenceEqual(result));
318318
}
319+
320+
[Test]
321+
public void OrderByBoolExpression()
322+
{
323+
var result = Session.Query.All<Invoice>().OrderBy(c => c.Status == (InvoiceStatus) 1)
324+
.Select(c => c.Status)
325+
.ToArray();
326+
Assert.AreEqual(result.Last(), (InvoiceStatus) 1);
327+
}
328+
329+
[Test]
330+
public void OrderByBoolExpressionComplex()
331+
{
332+
var result = Session.Query.All<Invoice>()
333+
.OrderBy(c => c.Status == (InvoiceStatus) 1 || c.Status == (InvoiceStatus) 2)
334+
.Select(c => c.Status)
335+
.ToArray();
336+
Assert.Contains(result.Last(), new[] { (InvoiceStatus) 1, (InvoiceStatus) 2 });
337+
}
319338
}
320-
}
339+
}

Orm/Xtensive.Orm.Tests/Linq/WhereTest.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,5 +1326,29 @@ public void ApplyTest()
13261326
.Where(customer => customer.Invoices.Any(i => i.Commission > 0.30m));
13271327
Assert.IsTrue(expected.SequenceEqual(actual));
13281328
}
1329+
1330+
[Test]
1331+
public void WhereBoolEquals()
1332+
{
1333+
var expected = Session.Query.All<Invoice>().Count(c => c.Status != (InvoiceStatus) 1);
1334+
// ReSharper disable once ReplaceWithSingleCallToCount
1335+
var actual = Session.Query.All<Invoice>().Where(c => (c.Status == (InvoiceStatus) 1) == false).Count();
1336+
1337+
Assert.AreEqual(expected, actual);
1338+
}
1339+
1340+
[Test]
1341+
public void WhereBoolEqualsComplex()
1342+
{
1343+
var expected = Session.Query.All<Invoice>()
1344+
.Count(c => !(c.Status == (InvoiceStatus) 1 || c.Status == (InvoiceStatus) 2));
1345+
1346+
// ReSharper disable once ReplaceWithSingleCallToCount
1347+
var actual = Session.Query.All<Invoice>()
1348+
.Where(c => (c.Status == (InvoiceStatus) 1 || c.Status == (InvoiceStatus) 2) == false)
1349+
.Count();
1350+
1351+
Assert.AreEqual(expected, actual);
1352+
}
13291353
}
13301354
}

Orm/Xtensive.Orm/Orm/Providers/Expressions/BooleanExpressionConverter.cs

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// Created by: Denis Krjuchkov
55
// Created: 2009.08.17
66

7+
using System.Collections.Generic;
78
using System.Linq;
89
using Xtensive.Reflection;
910
using Xtensive.Sql;
@@ -13,46 +14,36 @@ namespace Xtensive.Orm.Providers
1314
{
1415
internal sealed class BooleanExpressionConverter
1516
{
17+
private static readonly object IntToBooleanTag = new();
18+
private static readonly object BooleanToIntTag = new();
19+
1620
private readonly SqlValueType booleanType;
1721

1822
public SqlExpression IntToBoolean(SqlExpression expression)
1923
{
2024
// optimization: omitting IntToBoolean(BooleanToInt(x)) sequences
21-
if (expression.NodeType==SqlNodeType.Cast) {
22-
var operand = ((SqlCast) expression).Operand;
23-
if (operand.NodeType==SqlNodeType.Case) {
24-
var _case = (SqlCase) operand;
25-
if (_case.Count == 1) {
26-
var firstCaseItem = _case.First();
27-
var whenTrue = firstCaseItem.Value as SqlLiteral<int>;
28-
var whenFalse = _case.Else as SqlLiteral<int>;
29-
if (!ReferenceEquals(whenTrue, null)
30-
&& !ReferenceEquals(whenFalse, null)
31-
&& whenTrue.Value==1
32-
&& whenFalse.Value==0)
33-
return firstCaseItem.Key;
34-
}
35-
}
25+
if (expression.NodeType == SqlNodeType.Metadata &&
26+
expression is SqlMetadata metadata &&
27+
metadata.Value == BooleanToIntTag) {
28+
return ((SqlCase) ((SqlCast) metadata.Expression).Operand).First().Key;
3629
}
3730

38-
return SqlDml.Equals(expression, 1);
31+
return SqlDml.Metadata(SqlDml.Equals(expression, 1), IntToBooleanTag);
3932
}
4033

4134
public SqlExpression BooleanToInt(SqlExpression expression)
4235
{
4336
// optimization: omitting BooleanToInt(IntToBoolean(x)) sequences
44-
if (expression.NodeType==SqlNodeType.Equals) {
45-
var binary = (SqlBinary) expression;
46-
var left = binary.Left;
47-
var right = binary.Right as SqlLiteral<int>;
48-
if (!ReferenceEquals(right, null) && right.Value==1)
49-
return left;
37+
if (expression.NodeType == SqlNodeType.Metadata &&
38+
expression is SqlMetadata metadata &&
39+
metadata.Value == IntToBooleanTag) {
40+
return ((SqlBinary) metadata.Expression).Left;
5041
}
5142

5243
var result = SqlDml.Case();
5344
result.Add(expression, 1);
5445
result.Else = 0;
55-
return SqlDml.Cast(result, booleanType);
46+
return SqlDml.Metadata(SqlDml.Cast(result, booleanType), BooleanToIntTag);
5647
}
5748

5849
public BooleanExpressionConverter(StorageDriver driver)

Orm/Xtensive.Orm/Orm/Providers/SqlSelectProcessor.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,11 @@ public void Visit(SqlUnary node)
493493
Visit(node.Operand);
494494
}
495495

496+
public void Visit(SqlMetadata node)
497+
{
498+
Visit(node.Expression);
499+
}
500+
496501
public void Visit(SqlUpdate node)
497502
{
498503
if (node.From!=null)

Orm/Xtensive.Orm/Sql/Compiler/SqlCompiler.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,11 @@ public virtual void Visit(SqlUnary node)
13801380
}
13811381
}
13821382

1383+
public virtual void Visit(SqlMetadata node)
1384+
{
1385+
node.Expression.AcceptVisitor(this);
1386+
}
1387+
13831388
public virtual void Visit(SqlUpdate node)
13841389
{
13851390
VisitUpdateDefault(node);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (C) 2022 Xtensive LLC.
2+
// This code is distributed under MIT license terms.
3+
// See the License.txt file in the project root for more information.
4+
5+
using Xtensive.Core;
6+
7+
namespace Xtensive.Sql.Dml
8+
{
9+
/// <summary>
10+
/// Arbitrary metadata that could be attached to SQL expression tree.
11+
/// </summary>
12+
public class SqlMetadata : SqlExpression
13+
{
14+
public SqlExpression Expression { get; private set; }
15+
16+
public object Value { get; private set; }
17+
18+
public override void ReplaceWith(SqlExpression expression)
19+
{
20+
ArgumentValidator.EnsureArgumentNotNull(expression, nameof(expression));
21+
ArgumentValidator.EnsureArgumentIs<SqlMetadata>(expression, nameof(expression));
22+
var source = (SqlMetadata) expression;
23+
NodeType = source.NodeType;
24+
Expression = source.Expression;
25+
Value = source.Value;
26+
}
27+
28+
internal override object Clone(SqlNodeCloneContext context) =>
29+
context.NodeMapping.TryGetValue(this, out var clone)
30+
? clone
31+
: context.NodeMapping[this] = new SqlMetadata((SqlExpression) Expression.Clone(context), Value);
32+
33+
public override void AcceptVisitor(ISqlVisitor visitor) => visitor.Visit(this);
34+
35+
// Constructors
36+
37+
internal SqlMetadata(SqlExpression expression, object value) : base(SqlNodeType.Metadata)
38+
{
39+
Expression = expression;
40+
Value = value;
41+
}
42+
}
43+
}

Orm/Xtensive.Orm/Sql/Interfaces/ISqlVisitor.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ public interface ISqlVisitor
9898
void Visit(SqlSelect node);
9999
void Visit(SqlSubQuery node);
100100
void Visit(SqlUnary node);
101+
void Visit(SqlMetadata node);
101102
void Visit(SqlUpdate node);
102103
void Visit(SqlUserColumn node);
103104
void Visit(SqlUserFunctionCall node);

0 commit comments

Comments
 (0)