Skip to content

Commit 27a9830

Browse files
authored
[mypyc] Simplify IR for tagged integer comparisons (#9607)
In a conditional context, such as in an if condition, simplify the IR for tagged integer comparisons. Also perform some additional optimizations if an operand is known to be a short integer. This slightly improves performance when compiling with no optimizations. The impact should be pretty negligible otherwise. This is a bit simple-minded, and some further optimizations are possible. For example, `3 < x < 6` could be made faster. This covers the most common cases, however. Closes mypyc/mypyc#758.
1 parent 3acbf3f commit 27a9830

10 files changed

+821
-1019
lines changed

mypyc/irbuild/builder.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from mypy.nodes import (
2020
MypyFile, SymbolNode, Statement, OpExpr, IntExpr, NameExpr, LDEF, Var, UnaryExpr,
2121
CallExpr, IndexExpr, Expression, MemberExpr, RefExpr, Lvalue, TupleExpr,
22-
TypeInfo, Decorator, OverloadedFuncDef, StarExpr, GDEF, ARG_POS, ARG_NAMED
22+
TypeInfo, Decorator, OverloadedFuncDef, StarExpr, ComparisonExpr, GDEF, ARG_POS, ARG_NAMED
2323
)
2424
from mypy.types import (
2525
Type, Instance, TupleType, UninhabitedType, get_proper_type
@@ -39,7 +39,7 @@
3939
from mypyc.ir.rtypes import (
4040
RType, RTuple, RInstance, int_rprimitive, dict_rprimitive,
4141
none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive,
42-
str_rprimitive,
42+
str_rprimitive, is_tagged
4343
)
4444
from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF
4545
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
@@ -813,11 +813,45 @@ def process_conditional(self, e: Expression, true: BasicBlock, false: BasicBlock
813813
self.process_conditional(e.right, true, false)
814814
elif isinstance(e, UnaryExpr) and e.op == 'not':
815815
self.process_conditional(e.expr, false, true)
816-
# Catch-all for arbitrary expressions.
817816
else:
817+
res = self.maybe_process_conditional_comparison(e, true, false)
818+
if res:
819+
return
820+
# Catch-all for arbitrary expressions.
818821
reg = self.accept(e)
819822
self.add_bool_branch(reg, true, false)
820823

824+
def maybe_process_conditional_comparison(self,
825+
e: Expression,
826+
true: BasicBlock,
827+
false: BasicBlock) -> bool:
828+
"""Transform simple tagged integer comparisons in a conditional context.
829+
830+
Return True if the operation is supported (and was transformed). Otherwise,
831+
do nothing and return False.
832+
833+
Args:
834+
e: Arbitrary expression
835+
true: Branch target if comparison is true
836+
false: Branch target if comparison is false
837+
"""
838+
if not isinstance(e, ComparisonExpr) or len(e.operands) != 2:
839+
return False
840+
ltype = self.node_type(e.operands[0])
841+
rtype = self.node_type(e.operands[1])
842+
if not is_tagged(ltype) or not is_tagged(rtype):
843+
return False
844+
op = e.operators[0]
845+
if op not in ('==', '!=', '<', '<=', '>', '>='):
846+
return False
847+
left = self.accept(e.operands[0])
848+
right = self.accept(e.operands[1])
849+
# "left op right" for two tagged integers
850+
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
851+
return True
852+
853+
# Basic helpers
854+
821855
def flatten_classes(self, arg: Union[RefExpr, TupleExpr]) -> Optional[List[ClassIR]]:
822856
"""Flatten classes in isinstance(obj, (A, (B, C))).
823857
@@ -841,8 +875,6 @@ def flatten_classes(self, arg: Union[RefExpr, TupleExpr]) -> Optional[List[Class
841875
return None
842876
return res
843877

844-
# Basic helpers
845-
846878
def enter(self, fn_info: Union[FuncInfo, str] = '') -> None:
847879
if isinstance(fn_info, str):
848880
fn_info = FuncInfo(name=fn_info)

mypyc/irbuild/ll_builder.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -589,17 +589,21 @@ def binary_op(self,
589589
assert target, 'Unsupported binary operation: %s' % op
590590
return target
591591

592-
def check_tagged_short_int(self, val: Value, line: int) -> Value:
593-
"""Check if a tagged integer is a short integer"""
592+
def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value:
593+
"""Check if a tagged integer is a short integer.
594+
595+
Return the result of the check (value of type 'bit').
596+
"""
594597
int_tag = self.add(LoadInt(1, line, rtype=c_pyssize_t_rprimitive))
595598
bitwise_and = self.binary_int_op(c_pyssize_t_rprimitive, val,
596599
int_tag, BinaryIntOp.AND, line)
597600
zero = self.add(LoadInt(0, line, rtype=c_pyssize_t_rprimitive))
598-
check = self.comparison_op(bitwise_and, zero, ComparisonOp.EQ, line)
601+
op = ComparisonOp.NEQ if negated else ComparisonOp.EQ
602+
check = self.comparison_op(bitwise_and, zero, op, line)
599603
return check
600604

601605
def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
602-
"""Compare two tagged integers using given op"""
606+
"""Compare two tagged integers using given operator (value context)."""
603607
# generate fast binary logic ops on short ints
604608
if is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type):
605609
return self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
@@ -610,13 +614,11 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
610614
if op in ("==", "!="):
611615
check = check_lhs
612616
else:
613-
# for non-equal logical ops(less than, greater than, etc.), need to check both side
617+
# for non-equality logical ops (less/greater than, etc.), need to check both sides
614618
check_rhs = self.check_tagged_short_int(rhs, line)
615619
check = self.binary_int_op(bit_rprimitive, check_lhs,
616620
check_rhs, BinaryIntOp.AND, line)
617-
branch = Branch(check, short_int_block, int_block, Branch.BOOL)
618-
branch.negated = False
619-
self.add(branch)
621+
self.add(Branch(check, short_int_block, int_block, Branch.BOOL))
620622
self.activate_block(short_int_block)
621623
eq = self.comparison_op(lhs, rhs, op_type, line)
622624
self.add(Assign(result, eq, line))
@@ -636,6 +638,60 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
636638
self.goto_and_activate(out)
637639
return result
638640

641+
def compare_tagged_condition(self,
642+
lhs: Value,
643+
rhs: Value,
644+
op: str,
645+
true: BasicBlock,
646+
false: BasicBlock,
647+
line: int) -> None:
648+
"""Compare two tagged integers using given operator (conditional context).
649+
650+
Assume lhs and and rhs are tagged integers.
651+
652+
Args:
653+
lhs: Left operand
654+
rhs: Right operand
655+
op: Operation, one of '==', '!=', '<', '<=', '>', '<='
656+
true: Branch target if comparison is true
657+
false: Branch target if comparison is false
658+
"""
659+
is_eq = op in ("==", "!=")
660+
if ((is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type))
661+
or (is_eq and (is_short_int_rprimitive(lhs.type) or
662+
is_short_int_rprimitive(rhs.type)))):
663+
# We can skip the tag check
664+
check = self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
665+
self.add(Branch(check, true, false, Branch.BOOL))
666+
return
667+
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
668+
int_block, short_int_block = BasicBlock(), BasicBlock()
669+
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
670+
if is_eq or is_short_int_rprimitive(rhs.type):
671+
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
672+
else:
673+
# For non-equality logical ops (less/greater than, etc.), need to check both sides
674+
rhs_block = BasicBlock()
675+
self.add(Branch(check_lhs, int_block, rhs_block, Branch.BOOL))
676+
self.activate_block(rhs_block)
677+
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
678+
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
679+
# Arbitrary integers (slow path)
680+
self.activate_block(int_block)
681+
if swap_op:
682+
args = [rhs, lhs]
683+
else:
684+
args = [lhs, rhs]
685+
call = self.call_c(c_func_desc, args, line)
686+
if negate_result:
687+
self.add(Branch(call, false, true, Branch.BOOL))
688+
else:
689+
self.add(Branch(call, true, false, Branch.BOOL))
690+
# Short integers (fast path)
691+
self.activate_block(short_int_block)
692+
eq = self.comparison_op(lhs, rhs, op_type, line)
693+
self.add(Branch(eq, true, false, Branch.BOOL))
694+
639695
def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
640696
"""Compare two strings"""
641697
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)

0 commit comments

Comments
 (0)