Skip to content

Commit 5614ffa

Browse files
authored
[mypyc] Generate faster code for bool comparisons and arithmetic (#14489)
Generate specialized, efficient IR for various operations on bools. These are covered: * Bool comparisons * Mixed bool/integer comparisons * Bool arithmetic (binary and unary) * Mixed bool/integer arithmetic and bitwise ops Mixed operations where the left operand is a `bool` and the right operand is a native int still have some unnecessary conversions between native int and `int`. This would be a bit trickier to fix and is seems rare, so it doesn't seem urgent to fix this. Fixes mypyc/mypyc#968.
1 parent 27f51fc commit 5614ffa

File tree

6 files changed

+533
-26
lines changed

6 files changed

+533
-26
lines changed

mypyc/analysis/ircheck.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
217217
source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}"
218218
)
219219

220+
def check_compatibility(self, op: Op, t: RType, s: RType) -> None:
221+
if not can_coerce_to(t, s) or not can_coerce_to(s, t):
222+
self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible")
223+
220224
def visit_goto(self, op: Goto) -> None:
221225
self.check_control_op_targets(op)
222226

@@ -375,7 +379,7 @@ def visit_int_op(self, op: IntOp) -> None:
375379
pass
376380

377381
def visit_comparison_op(self, op: ComparisonOp) -> None:
378-
pass
382+
self.check_compatibility(op, op.lhs.type, op.rhs.type)
379383

380384
def visit_load_mem(self, op: LoadMem) -> None:
381385
pass

mypyc/irbuild/ll_builder.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@
199199
">>=",
200200
}
201201

202+
# Binary operations on bools that are specialized and don't just promote operands to int
203+
BOOL_BINARY_OPS: Final = {"&", "&=", "|", "|=", "^", "^=", "==", "!=", "<", "<=", ">", ">="}
204+
202205

203206
class LowLevelIRBuilder:
204207
def __init__(self, current_module: str, mapper: Mapper, options: CompilerOptions) -> None:
@@ -326,13 +329,13 @@ def coerce(
326329
):
327330
# Equivalent types
328331
return src
329-
elif (
330-
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
331-
) and is_int_rprimitive(target_type):
332+
elif (is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)) and is_tagged(
333+
target_type
334+
):
332335
shifted = self.int_op(
333336
bool_rprimitive, src, Integer(1, bool_rprimitive), IntOp.LEFT_SHIFT
334337
)
335-
return self.add(Extend(shifted, int_rprimitive, signed=False))
338+
return self.add(Extend(shifted, target_type, signed=False))
336339
elif (
337340
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
338341
) and is_fixed_width_rtype(target_type):
@@ -1245,48 +1248,45 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
12451248
return self.compare_bytes(lreg, rreg, op, line)
12461249
if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping:
12471250
return self.compare_tagged(lreg, rreg, op, line)
1248-
if (
1249-
is_bool_rprimitive(ltype)
1250-
and is_bool_rprimitive(rtype)
1251-
and op in ("&", "&=", "|", "|=", "^", "^=")
1252-
):
1253-
return self.bool_bitwise_op(lreg, rreg, op[0], line)
1251+
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
1252+
if op in ComparisonOp.signed_ops:
1253+
return self.bool_comparison_op(lreg, rreg, op, line)
1254+
else:
1255+
return self.bool_bitwise_op(lreg, rreg, op[0], line)
12541256
if isinstance(rtype, RInstance) and op in ("in", "not in"):
12551257
return self.translate_instance_contains(rreg, lreg, op, line)
12561258
if is_fixed_width_rtype(ltype):
12571259
if op in FIXED_WIDTH_INT_BINARY_OPS:
12581260
if op.endswith("="):
12591261
op = op[:-1]
1262+
if op != "//":
1263+
op_id = int_op_to_id[op]
1264+
else:
1265+
op_id = IntOp.DIV
1266+
if is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
1267+
rreg = self.coerce(rreg, ltype, line)
1268+
rtype = ltype
12601269
if is_fixed_width_rtype(rtype) or is_tagged(rtype):
1261-
if op != "//":
1262-
op_id = int_op_to_id[op]
1263-
else:
1264-
op_id = IntOp.DIV
12651270
return self.fixed_width_int_op(ltype, lreg, rreg, op_id, line)
12661271
if isinstance(rreg, Integer):
12671272
# TODO: Check what kind of Integer
1268-
if op != "//":
1269-
op_id = int_op_to_id[op]
1270-
else:
1271-
op_id = IntOp.DIV
12721273
return self.fixed_width_int_op(
12731274
ltype, lreg, Integer(rreg.value >> 1, ltype), op_id, line
12741275
)
12751276
elif op in ComparisonOp.signed_ops:
12761277
if is_int_rprimitive(rtype):
12771278
rreg = self.coerce_int_to_fixed_width(rreg, ltype, line)
1279+
elif is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
1280+
rreg = self.coerce(rreg, ltype, line)
12781281
op_id = ComparisonOp.signed_ops[op]
12791282
if is_fixed_width_rtype(rreg.type):
12801283
return self.comparison_op(lreg, rreg, op_id, line)
12811284
if isinstance(rreg, Integer):
12821285
return self.comparison_op(lreg, Integer(rreg.value >> 1, ltype), op_id, line)
12831286
elif is_fixed_width_rtype(rtype):
1284-
if (
1285-
isinstance(lreg, Integer) or is_tagged(ltype)
1286-
) and op in FIXED_WIDTH_INT_BINARY_OPS:
1287+
if op in FIXED_WIDTH_INT_BINARY_OPS:
12871288
if op.endswith("="):
12881289
op = op[:-1]
1289-
# TODO: Support comparison ops (similar to above)
12901290
if op != "//":
12911291
op_id = int_op_to_id[op]
12921292
else:
@@ -1296,15 +1296,38 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
12961296
return self.fixed_width_int_op(
12971297
rtype, Integer(lreg.value >> 1, rtype), rreg, op_id, line
12981298
)
1299-
else:
1299+
if is_tagged(ltype):
1300+
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
1301+
if is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
1302+
lreg = self.coerce(lreg, rtype, line)
13001303
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
13011304
elif op in ComparisonOp.signed_ops:
13021305
if is_int_rprimitive(ltype):
13031306
lreg = self.coerce_int_to_fixed_width(lreg, rtype, line)
1307+
elif is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
1308+
lreg = self.coerce(lreg, rtype, line)
13041309
op_id = ComparisonOp.signed_ops[op]
13051310
if isinstance(lreg, Integer):
13061311
return self.comparison_op(Integer(lreg.value >> 1, rtype), rreg, op_id, line)
1312+
if is_fixed_width_rtype(lreg.type):
1313+
return self.comparison_op(lreg, rreg, op_id, line)
1314+
1315+
# Mixed int comparisons
1316+
if op in ("==", "!="):
1317+
op_id = ComparisonOp.signed_ops[op]
1318+
if is_tagged(ltype) and is_subtype(rtype, ltype):
1319+
rreg = self.coerce(rreg, int_rprimitive, line)
1320+
return self.comparison_op(lreg, rreg, op_id, line)
1321+
if is_tagged(rtype) and is_subtype(ltype, rtype):
1322+
lreg = self.coerce(lreg, int_rprimitive, line)
13071323
return self.comparison_op(lreg, rreg, op_id, line)
1324+
elif op in op in int_comparison_op_mapping:
1325+
if is_tagged(ltype) and is_subtype(rtype, ltype):
1326+
rreg = self.coerce(rreg, short_int_rprimitive, line)
1327+
return self.compare_tagged(lreg, rreg, op, line)
1328+
if is_tagged(rtype) and is_subtype(ltype, rtype):
1329+
lreg = self.coerce(lreg, short_int_rprimitive, line)
1330+
return self.compare_tagged(lreg, rreg, op, line)
13081331

13091332
call_c_ops_candidates = binary_ops.get(op, [])
13101333
target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line)
@@ -1509,14 +1532,21 @@ def bool_bitwise_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value
15091532
assert False, op
15101533
return self.add(IntOp(bool_rprimitive, lreg, rreg, code, line))
15111534

1535+
def bool_comparison_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
1536+
op_id = ComparisonOp.signed_ops[op]
1537+
return self.comparison_op(lreg, rreg, op_id, line)
1538+
15121539
def unary_not(self, value: Value, line: int) -> Value:
15131540
mask = Integer(1, value.type, line)
15141541
return self.int_op(value.type, value, mask, IntOp.XOR, line)
15151542

15161543
def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
15171544
typ = value.type
1518-
if (is_bool_rprimitive(typ) or is_bit_rprimitive(typ)) and expr_op == "not":
1519-
return self.unary_not(value, line)
1545+
if is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
1546+
if expr_op == "not":
1547+
return self.unary_not(value, line)
1548+
if expr_op == "+":
1549+
return value
15201550
if is_fixed_width_rtype(typ):
15211551
if expr_op == "-":
15221552
# Translate to '0 - x'
@@ -1532,6 +1562,8 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
15321562
if is_short_int_rprimitive(typ):
15331563
num >>= 1
15341564
return Integer(-num, typ, value.line)
1565+
if is_tagged(typ) and expr_op == "+":
1566+
return value
15351567
if isinstance(typ, RInstance):
15361568
if expr_op == "-":
15371569
method = "__neg__"

0 commit comments

Comments
 (0)