Skip to content

[mypyc] Simplify IR for tagged integer comparisons #9607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mypy.nodes import (
MypyFile, SymbolNode, Statement, OpExpr, IntExpr, NameExpr, LDEF, Var, UnaryExpr,
CallExpr, IndexExpr, Expression, MemberExpr, RefExpr, Lvalue, TupleExpr,
TypeInfo, Decorator, OverloadedFuncDef, StarExpr, GDEF, ARG_POS, ARG_NAMED
TypeInfo, Decorator, OverloadedFuncDef, StarExpr, ComparisonExpr, GDEF, ARG_POS, ARG_NAMED
)
from mypy.types import (
Type, Instance, TupleType, UninhabitedType, get_proper_type
Expand All @@ -39,7 +39,7 @@
from mypyc.ir.rtypes import (
RType, RTuple, RInstance, int_rprimitive, dict_rprimitive,
none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive,
str_rprimitive,
str_rprimitive, is_tagged
)
from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
Expand Down Expand Up @@ -813,11 +813,45 @@ def process_conditional(self, e: Expression, true: BasicBlock, false: BasicBlock
self.process_conditional(e.right, true, false)
elif isinstance(e, UnaryExpr) and e.op == 'not':
self.process_conditional(e.expr, false, true)
# Catch-all for arbitrary expressions.
else:
res = self.maybe_process_conditional_comparison(e, true, false)
if res:
return
# Catch-all for arbitrary expressions.
reg = self.accept(e)
self.add_bool_branch(reg, true, false)

def maybe_process_conditional_comparison(self,
e: Expression,
true: BasicBlock,
false: BasicBlock) -> bool:
"""Transform simple tagged integer comparisons in a conditional context.

Return True if the operation is supported (and was transformed). Otherwise,
do nothing and return False.

Args:
e: Arbitrary expression
true: Branch target if comparison is true
false: Branch target if comparison is false
"""
if not isinstance(e, ComparisonExpr) or len(e.operands) != 2:
return False
ltype = self.node_type(e.operands[0])
rtype = self.node_type(e.operands[1])
if not is_tagged(ltype) or not is_tagged(rtype):
return False
op = e.operators[0]
if op not in ('==', '!=', '<', '<=', '>', '>='):
return False
left = self.accept(e.operands[0])
right = self.accept(e.operands[1])
# "left op right" for two tagged integers
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
return True

# Basic helpers

def flatten_classes(self, arg: Union[RefExpr, TupleExpr]) -> Optional[List[ClassIR]]:
"""Flatten classes in isinstance(obj, (A, (B, C))).

Expand All @@ -841,8 +875,6 @@ def flatten_classes(self, arg: Union[RefExpr, TupleExpr]) -> Optional[List[Class
return None
return res

# Basic helpers

def enter(self, fn_info: Union[FuncInfo, str] = '') -> None:
if isinstance(fn_info, str):
fn_info = FuncInfo(name=fn_info)
Expand Down
72 changes: 64 additions & 8 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,17 +589,21 @@ def binary_op(self,
assert target, 'Unsupported binary operation: %s' % op
return target

def check_tagged_short_int(self, val: Value, line: int) -> Value:
"""Check if a tagged integer is a short integer"""
def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value:
"""Check if a tagged integer is a short integer.

Return the result of the check (value of type 'bit').
"""
int_tag = self.add(LoadInt(1, line, rtype=c_pyssize_t_rprimitive))
bitwise_and = self.binary_int_op(c_pyssize_t_rprimitive, val,
int_tag, BinaryIntOp.AND, line)
zero = self.add(LoadInt(0, line, rtype=c_pyssize_t_rprimitive))
check = self.comparison_op(bitwise_and, zero, ComparisonOp.EQ, line)
op = ComparisonOp.NEQ if negated else ComparisonOp.EQ
check = self.comparison_op(bitwise_and, zero, op, line)
return check

def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two tagged integers using given op"""
"""Compare two tagged integers using given operator (value context)."""
# generate fast binary logic ops on short ints
if is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type):
return self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
Expand All @@ -610,13 +614,11 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
if op in ("==", "!="):
check = check_lhs
else:
# for non-equal logical ops(less than, greater than, etc.), need to check both side
# for non-equality logical ops (less/greater than, etc.), need to check both sides
check_rhs = self.check_tagged_short_int(rhs, line)
check = self.binary_int_op(bit_rprimitive, check_lhs,
check_rhs, BinaryIntOp.AND, line)
branch = Branch(check, short_int_block, int_block, Branch.BOOL)
branch.negated = False
self.add(branch)
self.add(Branch(check, short_int_block, int_block, Branch.BOOL))
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Assign(result, eq, line))
Expand All @@ -636,6 +638,60 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
self.goto_and_activate(out)
return result

def compare_tagged_condition(self,
lhs: Value,
rhs: Value,
op: str,
true: BasicBlock,
false: BasicBlock,
line: int) -> None:
"""Compare two tagged integers using given operator (conditional context).

Assume lhs and and rhs are tagged integers.

Args:
lhs: Left operand
rhs: Right operand
op: Operation, one of '==', '!=', '<', '<=', '>', '<='
true: Branch target if comparison is true
false: Branch target if comparison is false
"""
is_eq = op in ("==", "!=")
if ((is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type))
or (is_eq and (is_short_int_rprimitive(lhs.type) or
is_short_int_rprimitive(rhs.type)))):
# We can skip the tag check
check = self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
self.add(Branch(check, true, false, Branch.BOOL))
return
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
int_block, short_int_block = BasicBlock(), BasicBlock()
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
if is_eq or is_short_int_rprimitive(rhs.type):
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
else:
# For non-equality logical ops (less/greater than, etc.), need to check both sides
rhs_block = BasicBlock()
self.add(Branch(check_lhs, int_block, rhs_block, Branch.BOOL))
self.activate_block(rhs_block)
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
# Arbitrary integers (slow path)
self.activate_block(int_block)
if swap_op:
args = [rhs, lhs]
else:
args = [lhs, rhs]
call = self.call_c(c_func_desc, args, line)
if negate_result:
self.add(Branch(call, false, true, Branch.BOOL))
else:
self.add(Branch(call, true, false, Branch.BOOL))
# Short integers (fast path)
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Branch(eq, true, false, Branch.BOOL))

def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two strings"""
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)
Expand Down
Loading