Skip to content

Commit 8bf770d

Browse files
jdahlinJohan Dahlinr3m0tTH3CHARLie
authored
[mypyc] Speed up in operations for list/tuple (#9004)
When right hand side of a in/not in operation is a literal list/tuple, simplify it into simpler direct equality comparison expressions and use binary and/or to join them. Yields speedup of up to 46% in micro benchmarks. Co-authored-by: Johan Dahlin <[email protected]> Co-authored-by: Tomer Chachamu <[email protected]> Co-authored-by: Xuanda Yang <[email protected]>
1 parent 4fb5a21 commit 8bf770d

File tree

3 files changed

+234
-3
lines changed

3 files changed

+234
-3
lines changed

mypyc/irbuild/expression.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
and mypyc.irbuild.builder.
55
"""
66

7-
from typing import List, Optional, Union, Callable
7+
from typing import List, Optional, Union, Callable, cast
88

99
from mypy.nodes import (
1010
Expression, NameExpr, MemberExpr, SuperExpr, CallExpr, UnaryExpr, OpExpr, IndexExpr,
@@ -13,7 +13,7 @@
1313
SetComprehension, DictionaryComprehension, SliceExpr, GeneratorExpr, CastExpr, StarExpr,
1414
Var, RefExpr, MypyFile, TypeInfo, TypeApplication, LDEF, ARG_POS
1515
)
16-
from mypy.types import TupleType, get_proper_type
16+
from mypy.types import TupleType, get_proper_type, Instance
1717

1818
from mypyc.common import MAX_LITERAL_SHORT_INT
1919
from mypyc.ir.ops import (
@@ -406,8 +406,56 @@ def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Val
406406

407407

408408
def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
409-
# TODO: Don't produce an expression when used in conditional context
409+
# x in (...)/[...]
410+
# x not in (...)/[...]
411+
if (e.operators[0] in ['in', 'not in']
412+
and len(e.operators) == 1
413+
and isinstance(e.operands[1], (TupleExpr, ListExpr))):
414+
items = e.operands[1].items
415+
n_items = len(items)
416+
# x in y -> x == y[0] or ... or x == y[n]
417+
# x not in y -> x != y[0] and ... and x != y[n]
418+
# 16 is arbitrarily chosen to limit code size
419+
if 1 < n_items < 16:
420+
if e.operators[0] == 'in':
421+
bin_op = 'or'
422+
cmp_op = '=='
423+
else:
424+
bin_op = 'and'
425+
cmp_op = '!='
426+
lhs = e.operands[0]
427+
mypy_file = builder.graph['builtins'].tree
428+
assert mypy_file is not None
429+
bool_type = Instance(cast(TypeInfo, mypy_file.names['bool'].node), [])
430+
exprs = []
431+
for item in items:
432+
expr = ComparisonExpr([cmp_op], [lhs, item])
433+
builder.types[expr] = bool_type
434+
exprs.append(expr)
435+
436+
or_expr = exprs.pop(0) # type: Expression
437+
for expr in exprs:
438+
or_expr = OpExpr(bin_op, or_expr, expr)
439+
builder.types[or_expr] = bool_type
440+
return builder.accept(or_expr)
441+
# x in [y]/(y) -> x == y
442+
# x not in [y]/(y) -> x != y
443+
elif n_items == 1:
444+
if e.operators[0] == 'in':
445+
cmp_op = '=='
446+
else:
447+
cmp_op = '!='
448+
e.operators = [cmp_op]
449+
e.operands[1] = items[0]
450+
# x in []/() -> False
451+
# x not in []/() -> True
452+
elif n_items == 0:
453+
if e.operators[0] == 'in':
454+
return builder.false()
455+
else:
456+
return builder.true()
410457

458+
# TODO: Don't produce an expression when used in conditional context
411459
# All of the trickiness here is due to support for chained conditionals
412460
# (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
413461
# `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.

mypyc/test-data/irbuild-tuple.test

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,67 @@ L2:
181181
r2 = CPySequenceTuple_GetItem(nt, 2)
182182
r3 = unbox(int, r2)
183183
return r3
184+
185+
186+
[case testTupleOperatorIn]
187+
def f(i: int) -> bool:
188+
return i in [1, 2, 3]
189+
[out]
190+
def f(i):
191+
i :: int
192+
r0, r1, r2 :: bool
193+
r3 :: native_int
194+
r4, r5, r6, r7 :: bool
195+
r8 :: native_int
196+
r9, r10, r11, r12 :: bool
197+
r13 :: native_int
198+
r14, r15, r16 :: bool
199+
L0:
200+
r3 = i & 1
201+
r4 = r3 == 0
202+
if r4 goto L1 else goto L2 :: bool
203+
L1:
204+
r5 = i == 2
205+
r2 = r5
206+
goto L3
207+
L2:
208+
r6 = CPyTagged_IsEq_(i, 2)
209+
r2 = r6
210+
L3:
211+
if r2 goto L4 else goto L5 :: bool
212+
L4:
213+
r1 = r2
214+
goto L9
215+
L5:
216+
r8 = i & 1
217+
r9 = r8 == 0
218+
if r9 goto L6 else goto L7 :: bool
219+
L6:
220+
r10 = i == 4
221+
r7 = r10
222+
goto L8
223+
L7:
224+
r11 = CPyTagged_IsEq_(i, 4)
225+
r7 = r11
226+
L8:
227+
r1 = r7
228+
L9:
229+
if r1 goto L10 else goto L11 :: bool
230+
L10:
231+
r0 = r1
232+
goto L15
233+
L11:
234+
r13 = i & 1
235+
r14 = r13 == 0
236+
if r14 goto L12 else goto L13 :: bool
237+
L12:
238+
r15 = i == 6
239+
r12 = r15
240+
goto L14
241+
L13:
242+
r16 = CPyTagged_IsEq_(i, 6)
243+
r12 = r16
244+
L14:
245+
r0 = r12
246+
L15:
247+
return r0

mypyc/test-data/run-lists.test

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,122 @@ def test_slicing() -> None:
149149
assert s[1:long_int] == ["o", "o", "b", "a", "r"]
150150
assert s[long_int:] == []
151151
assert s[-long_int:-1] == ["f", "o", "o", "b", "a"]
152+
153+
[case testOperatorInExpression]
154+
155+
def tuple_in_int0(i: int) -> bool:
156+
return i in []
157+
158+
def tuple_in_int1(i: int) -> bool:
159+
return i in (1,)
160+
161+
def tuple_in_int3(i: int) -> bool:
162+
return i in (1, 2, 3)
163+
164+
def tuple_not_in_int0(i: int) -> bool:
165+
return i not in []
166+
167+
def tuple_not_in_int1(i: int) -> bool:
168+
return i not in (1,)
169+
170+
def tuple_not_in_int3(i: int) -> bool:
171+
return i not in (1, 2, 3)
172+
173+
def tuple_in_str(s: "str") -> bool:
174+
return s in ("foo", "bar", "baz")
175+
176+
def tuple_not_in_str(s: "str") -> bool:
177+
return s not in ("foo", "bar", "baz")
178+
179+
def list_in_int0(i: int) -> bool:
180+
return i in []
181+
182+
def list_in_int1(i: int) -> bool:
183+
return i in (1,)
184+
185+
def list_in_int3(i: int) -> bool:
186+
return i in (1, 2, 3)
187+
188+
def list_not_in_int0(i: int) -> bool:
189+
return i not in []
190+
191+
def list_not_in_int1(i: int) -> bool:
192+
return i not in (1,)
193+
194+
def list_not_in_int3(i: int) -> bool:
195+
return i not in (1, 2, 3)
196+
197+
def list_in_str(s: "str") -> bool:
198+
return s in ("foo", "bar", "baz")
199+
200+
def list_not_in_str(s: "str") -> bool:
201+
return s not in ("foo", "bar", "baz")
202+
203+
def list_in_mixed(i: object):
204+
return i in [[], (), "", 0, 0.0, False, 0j, {}, set(), type]
205+
206+
[file driver.py]
207+
208+
from native import *
209+
210+
assert not tuple_in_int0(0)
211+
assert not tuple_in_int1(0)
212+
assert tuple_in_int1(1)
213+
assert not tuple_in_int3(0)
214+
assert tuple_in_int3(1)
215+
assert tuple_in_int3(2)
216+
assert tuple_in_int3(3)
217+
assert not tuple_in_int3(4)
218+
219+
assert tuple_not_in_int0(0)
220+
assert tuple_not_in_int1(0)
221+
assert not tuple_not_in_int1(1)
222+
assert tuple_not_in_int3(0)
223+
assert not tuple_not_in_int3(1)
224+
assert not tuple_not_in_int3(2)
225+
assert not tuple_not_in_int3(3)
226+
assert tuple_not_in_int3(4)
227+
228+
assert tuple_in_str("foo")
229+
assert tuple_in_str("bar")
230+
assert tuple_in_str("baz")
231+
assert not tuple_in_str("apple")
232+
assert not tuple_in_str("pie")
233+
assert not tuple_in_str("\0")
234+
assert not tuple_in_str("")
235+
236+
assert not list_in_int0(0)
237+
assert not list_in_int1(0)
238+
assert list_in_int1(1)
239+
assert not list_in_int3(0)
240+
assert list_in_int3(1)
241+
assert list_in_int3(2)
242+
assert list_in_int3(3)
243+
assert not list_in_int3(4)
244+
245+
assert list_not_in_int0(0)
246+
assert list_not_in_int1(0)
247+
assert not list_not_in_int1(1)
248+
assert list_not_in_int3(0)
249+
assert not list_not_in_int3(1)
250+
assert not list_not_in_int3(2)
251+
assert not list_not_in_int3(3)
252+
assert list_not_in_int3(4)
253+
254+
assert list_in_str("foo")
255+
assert list_in_str("bar")
256+
assert list_in_str("baz")
257+
assert not list_in_str("apple")
258+
assert not list_in_str("pie")
259+
assert not list_in_str("\0")
260+
assert not list_in_str("")
261+
262+
assert list_in_mixed(0)
263+
assert list_in_mixed([])
264+
assert list_in_mixed({})
265+
assert list_in_mixed(())
266+
assert list_in_mixed(False)
267+
assert list_in_mixed(0.0)
268+
assert not list_in_mixed([1])
269+
assert not list_in_mixed(object)
270+
assert list_in_mixed(type)

0 commit comments

Comments
 (0)